fix(photomaker): place id_encoder on the right device + dtype

Modal cert sweep #5 made it through component load (V1 id_encoder + lora_weights)
and died at inference with the classic
"Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be
the same" — id_encoder lived on CPU/fp32 while the rest of the pipeline ran on
CUDA/fp16. Two fixes:

1. Call `pipe.to(device)` BEFORE `load_photomaker_adapter` so the loader picks the
   right device/dtype from `self.device` / `self.unet.dtype` when it builds the
   encoder.
2. Belt: after load, explicitly `pipe.id_encoder.to(device, dtype)` because some
   torch/diffusers combos leave custom attributes on the old device even when
   `pipe.to` ran first.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Victor Kuznetsov
2026-06-08 16:29:00 -07:00
parent 9435e12ce6
commit 031c38dc7f
+13 -1
View File
@@ -151,6 +151,14 @@ def _get_pipeline() -> Any:
)
adapter_path = hf_hub_download(repo_id=_PHOTOMAKER_REPO, filename=_PHOTOMAKER_FILE)
pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(_SDXL_MODEL_ID, torch_dtype=dtype)
# Move SDXL submodules to the device BEFORE loading the PhotoMaker adapter:
# ``load_photomaker_adapter`` reads ``self.device`` / ``self.unet.dtype`` to
# place the new ID encoder. If we ``.to(device)`` after, the SDXL submodules
# move but the id_encoder stays where it was (custom attribute, not in the
# auto-managed module tree), and inference errors with
# "Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor)
# should be the same" (caught empirically 2026-06-04).
pipe.to(device)
# ``pm_version="v1"`` is REQUIRED: the upstream loader defaults to v2 and would
# build the V2 encoder (PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken), then
# error on load_state_dict because the v1 weights have a different shape.
@@ -163,8 +171,12 @@ def _get_pipeline() -> Any:
trigger_word="img",
pm_version="v1",
)
pipe.to(device)
pipe.fuse_lora()
# Belt: also explicitly cast the loaded id_encoder, because some
# diffusers/torch combinations leave the encoder buffers untouched even
# though ``pipe.to(device)`` ran first.
if hasattr(pipe, "id_encoder") and pipe.id_encoder is not None:
pipe.id_encoder = pipe.id_encoder.to(device=device, dtype=dtype)
_pipeline = pipe
return _pipeline