mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-10 12:53:56 +02:00
fix(photomaker): place id_encoder on the right device + dtype
Modal cert sweep #5 made it through component load (V1 id_encoder + lora_weights) and died at inference with the classic "Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same" — id_encoder lived on CPU/fp32 while the rest of the pipeline ran on CUDA/fp16. Two fixes: 1. Call `pipe.to(device)` BEFORE `load_photomaker_adapter` so the loader picks the right device/dtype from `self.device` / `self.unet.dtype` when it builds the encoder. 2. Belt: after load, explicitly `pipe.id_encoder.to(device, dtype)` because some torch/diffusers combos leave custom attributes on the old device even when `pipe.to` ran first. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -151,6 +151,14 @@ def _get_pipeline() -> Any:
|
||||
)
|
||||
adapter_path = hf_hub_download(repo_id=_PHOTOMAKER_REPO, filename=_PHOTOMAKER_FILE)
|
||||
pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(_SDXL_MODEL_ID, torch_dtype=dtype)
|
||||
# Move SDXL submodules to the device BEFORE loading the PhotoMaker adapter:
|
||||
# ``load_photomaker_adapter`` reads ``self.device`` / ``self.unet.dtype`` to
|
||||
# place the new ID encoder. If we ``.to(device)`` after, the SDXL submodules
|
||||
# move but the id_encoder stays where it was (custom attribute, not in the
|
||||
# auto-managed module tree), and inference errors with
|
||||
# "Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor)
|
||||
# should be the same" (caught empirically 2026-06-04).
|
||||
pipe.to(device)
|
||||
# ``pm_version="v1"`` is REQUIRED: the upstream loader defaults to v2 and would
|
||||
# build the V2 encoder (PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken), then
|
||||
# error on load_state_dict because the v1 weights have a different shape.
|
||||
@@ -163,8 +171,12 @@ def _get_pipeline() -> Any:
|
||||
trigger_word="img",
|
||||
pm_version="v1",
|
||||
)
|
||||
pipe.to(device)
|
||||
pipe.fuse_lora()
|
||||
# Belt: also explicitly cast the loaded id_encoder, because some
|
||||
# diffusers/torch combinations leave the encoder buffers untouched even
|
||||
# though ``pipe.to(device)`` ran first.
|
||||
if hasattr(pipe, "id_encoder") and pipe.id_encoder is not None:
|
||||
pipe.id_encoder = pipe.id_encoder.to(device=device, dtype=dtype)
|
||||
_pipeline = pipe
|
||||
return _pipeline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user