diff --git a/src/remove_ai_watermarks/instantid_restore.py b/src/remove_ai_watermarks/instantid_restore.py index 921b928..b2dfec3 100644 --- a/src/remove_ai_watermarks/instantid_restore.py +++ b/src/remove_ai_watermarks/instantid_restore.py @@ -233,7 +233,20 @@ def _get_pipeline() -> Any: pipe.to(device) # IP-Adapter weights that wire the ArcFace embedding into cross-attention. ip_adapter_path = hf_hub_download(repo_id=_INSTANTID_REPO, filename=_INSTANTID_IP_ADAPTER) - pipe.load_ip_adapter_instantid(ip_adapter_path) + # IP-Adapter scale (the weight on the ArcFace cross-attention branch) is + # set at load time, not at call time. 0.8 mirrors the upstream demo. + pipe.load_ip_adapter_instantid(ip_adapter_path, scale=0.8) + # Diffusers 0.38 vs InstantID upstream compat patch: InstantID's __call__ + # calls ``self.check_inputs(...)`` POSITIONALLY (signature from ~v0.29), + # but diffusers 0.38 added two new params (``ip_adapter_image``, + # ``ip_adapter_image_embeds``) BEFORE ``controlnet_conditioning_scale`` in + # the parent's signature. That shifts every argument by two, so + # ``control_guidance_end`` (which InstantID converts to ``[1.0]`` for the + # single-controlnet case before this point) lands in the slot the parent + # validates as ``controlnet_conditioning_scale`` and trips + # ``TypeError("must be type float")``. Our inputs are programmatic and + # already validated by our own callers, so neutralising the check is safe. + pipe.check_inputs = lambda *_a, **_k: None _pipeline = pipe return _pipeline @@ -298,7 +311,6 @@ def restore_faces_instantid( cleaned_bgr: NDArray[Any], num_inference_steps: int = 30, guidance_scale: float = 5.0, - ip_adapter_scale: float = 0.8, controlnet_conditioning_scale: float = 0.8, seed: int | None = None, detect_faces_fn: Any | None = None, @@ -387,7 +399,6 @@ def restore_faces_instantid( image_embeds=face_emb, image=landmark_img, controlnet_conditioning_scale=controlnet_conditioning_scale, - ip_adapter_scale=ip_adapter_scale, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator,