diff --git a/src/remove_ai_watermarks/photomaker_restore.py b/src/remove_ai_watermarks/photomaker_restore.py index 6b6a466..19495d0 100644 --- a/src/remove_ai_watermarks/photomaker_restore.py +++ b/src/remove_ai_watermarks/photomaker_restore.py @@ -151,6 +151,14 @@ def _get_pipeline() -> Any: ) adapter_path = hf_hub_download(repo_id=_PHOTOMAKER_REPO, filename=_PHOTOMAKER_FILE) pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(_SDXL_MODEL_ID, torch_dtype=dtype) + # Move SDXL submodules to the device BEFORE loading the PhotoMaker adapter: + # ``load_photomaker_adapter`` reads ``self.device`` / ``self.unet.dtype`` to + # place the new ID encoder. If we ``.to(device)`` after, the SDXL submodules + # move but the id_encoder stays where it was (custom attribute, not in the + # auto-managed module tree), and inference errors with + # "Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) + # should be the same" (caught empirically 2026-06-04). + pipe.to(device) # ``pm_version="v1"`` is REQUIRED: the upstream loader defaults to v2 and would # build the V2 encoder (PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken), then # error on load_state_dict because the v1 weights have a different shape. @@ -163,8 +171,12 @@ def _get_pipeline() -> Any: trigger_word="img", pm_version="v1", ) - pipe.to(device) pipe.fuse_lora() + # Belt: also explicitly cast the loaded id_encoder, because some + # diffusers/torch combinations leave the encoder buffers untouched even + # though ``pipe.to(device)`` ran first. + if hasattr(pipe, "id_encoder") and pipe.id_encoder is not None: + pipe.id_encoder = pipe.id_encoder.to(device=device, dtype=dtype) _pipeline = pipe return _pipeline