From 031c38dc7f367f825c690d826c47ca46da0a8889 Mon Sep 17 00:00:00 2001 From: Victor Kuznetsov Date: Mon, 8 Jun 2026 16:29:00 -0700 Subject: [PATCH] fix(photomaker): place id_encoder on the right device + dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Modal cert sweep #5 made it through component load (V1 id_encoder + lora_weights) and died at inference with the classic "Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) should be the same" — id_encoder lived on CPU/fp32 while the rest of the pipeline ran on CUDA/fp16. Two fixes: 1. Call `pipe.to(device)` BEFORE `load_photomaker_adapter` so the loader picks the right device/dtype from `self.device` / `self.unet.dtype` when it builds the encoder. 2. Belt: after load, explicitly `pipe.id_encoder.to(device, dtype)` because some torch/diffusers combos leave custom attributes on the old device even when `pipe.to` ran first. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/remove_ai_watermarks/photomaker_restore.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/remove_ai_watermarks/photomaker_restore.py b/src/remove_ai_watermarks/photomaker_restore.py index 6b6a466..19495d0 100644 --- a/src/remove_ai_watermarks/photomaker_restore.py +++ b/src/remove_ai_watermarks/photomaker_restore.py @@ -151,6 +151,14 @@ def _get_pipeline() -> Any: ) adapter_path = hf_hub_download(repo_id=_PHOTOMAKER_REPO, filename=_PHOTOMAKER_FILE) pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(_SDXL_MODEL_ID, torch_dtype=dtype) + # Move SDXL submodules to the device BEFORE loading the PhotoMaker adapter: + # ``load_photomaker_adapter`` reads ``self.device`` / ``self.unet.dtype`` to + # place the new ID encoder. If we ``.to(device)`` after, the SDXL submodules + # move but the id_encoder stays where it was (custom attribute, not in the + # auto-managed module tree), and inference errors with + # "Input type (torch.cuda.HalfTensor) and weight type (torch.HalfTensor) + # should be the same" (caught empirically 2026-06-04). + pipe.to(device) # ``pm_version="v1"`` is REQUIRED: the upstream loader defaults to v2 and would # build the V2 encoder (PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken), then # error on load_state_dict because the v1 weights have a different shape. @@ -163,8 +171,12 @@ def _get_pipeline() -> Any: trigger_word="img", pm_version="v1", ) - pipe.to(device) pipe.fuse_lora() + # Belt: also explicitly cast the loaded id_encoder, because some + # diffusers/torch combinations leave the encoder buffers untouched even + # though ``pipe.to(device)`` ran first. + if hasattr(pipe, "id_encoder") and pipe.id_encoder is not None: + pipe.id_encoder = pipe.id_encoder.to(device=device, dtype=dtype) _pipeline = pipe return _pipeline