diff --git a/src/remove_ai_watermarks/noai/watermark_remover.py b/src/remove_ai_watermarks/noai/watermark_remover.py index b2725af..678789e 100644 --- a/src/remove_ai_watermarks/noai/watermark_remover.py +++ b/src/remove_ai_watermarks/noai/watermark_remover.py @@ -490,6 +490,27 @@ class WatermarkRemover: load_kwargs["token"] = self.hf_token return load_kwargs + def _load_from_pretrained(self, cls: Any, model_id: str, **load_kwargs: Any) -> Any: + """Call ``cls.from_pretrained`` reading the fp16 weight VARIANT when on fp16. + + When ``torch_dtype`` is float16 (the CUDA/XPU SDXL default), pass + ``variant="fp16"`` so diffusers fetches/reads the half-precision weight files + (~half the bytes of the fp32 defaults) instead of reading the full fp32 files + and downcasting in memory. On a warm weights cache this roughly halves the + cold-start weight read + host->device transfer, which a phase-timed Modal run + measured as ~half of the ~25s cold start. Not every checkpoint publishes an + fp16 variant (a custom ``--model``, or the canny ControlNet if it ships only + default files), so fall back to the default weights if the variant is missing - + worst case is the prior behavior (read fp32, downcast). fp32 (cpu/mps) and bf16 + (qwen) never request the variant. + """ + if self.torch_dtype == torch.float16: + try: + return cls.from_pretrained(model_id, variant="fp16", **load_kwargs) + except Exception as exc: + logger.info("No fp16 weight variant for %s (%s); loading default weights", model_id, exc) + return cls.from_pretrained(model_id, **load_kwargs) + def _load_pipeline(self) -> AutoImg2ImgPipeline: """Load the plain SDXL img2img pipeline lazily.""" if self._pipeline is None: @@ -501,7 +522,7 @@ class WatermarkRemover: load_kwargs["requires_safety_checker"] = False self._maybe_add_fp16_vae(load_kwargs) - pipeline = AutoImg2ImgPipeline.from_pretrained(self.model_id, **load_kwargs) # type: ignore + pipeline = self._load_from_pretrained(AutoImg2ImgPipeline, self.model_id, **load_kwargs) # type: ignore self._pipeline = self._move_to_device_and_optimize(pipeline) logger.info("Model loaded successfully") @@ -522,14 +543,18 @@ class WatermarkRemover: logger.info("Loading SDXL + ControlNet (%s) on %s...", CONTROLNET_CANNY_MODEL, self.device) self._set_progress(f"Loading ControlNet: {CONTROLNET_CANNY_MODEL}") - controlnet = ControlNetModel.from_pretrained(CONTROLNET_CANNY_MODEL, torch_dtype=self.torch_dtype) + controlnet = self._load_from_pretrained( + ControlNetModel, CONTROLNET_CANNY_MODEL, torch_dtype=self.torch_dtype + ) load_kwargs = self._base_load_kwargs() load_kwargs["controlnet"] = controlnet self._maybe_add_fp16_vae(load_kwargs) self._set_progress(f"Loading model weights: {self.model_id}") - pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(self.model_id, **load_kwargs) + pipeline = self._load_from_pretrained( + StableDiffusionXLControlNetImg2ImgPipeline, self.model_id, **load_kwargs + ) pipeline = self._move_to_device_and_optimize(pipeline) with contextlib.suppress(Exception): pipeline.set_progress_bar_config(disable=True) diff --git a/tests/test_platform.py b/tests/test_platform.py index 94d1f9f..f3b2fde 100644 --- a/tests/test_platform.py +++ b/tests/test_platform.py @@ -126,6 +126,58 @@ class TestModelProfiles: assert normalize_profile("CONTROLNET") == "controlnet" +class TestFp16WeightVariant: + """_load_from_pretrained reads the fp16 weight variant on fp16, with a fallback. + + Loading the fp16 ``variant`` reads the half-precision weight files (~half the bytes) + instead of the fp32 defaults + a downcast, which roughly halves the cold-start weight + read. fp32 (cpu/mps) and bf16 (qwen) must never request the variant; a checkpoint + without fp16 files must fall back to the default weights (prior behavior). + """ + + def _remover(self, dtype: object): + if not is_watermark_removal_available(): + pytest.skip("torch/diffusers not installed") + from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover + + # device="cpu" alone would force fp32; the explicit torch_dtype override lets us + # exercise the fp16 path with no GPU (construction loads no weights). + return WatermarkRemover(device="cpu", torch_dtype=dtype) + + def test_fp16_requests_variant(self): + import torch + + remover = self._remover(torch.float16) + cls = MagicMock() + cls.from_pretrained.return_value = "PIPE" + out = remover._load_from_pretrained(cls, "some/model", token="t") + assert out == "PIPE" + cls.from_pretrained.assert_called_once_with("some/model", variant="fp16", token="t") + + def test_fp16_falls_back_when_variant_missing(self): + import torch + + remover = self._remover(torch.float16) + cls = MagicMock() + cls.from_pretrained.side_effect = [OSError("no fp16 weight files"), "PIPE"] + out = remover._load_from_pretrained(cls, "some/model", token="t") + assert out == "PIPE" + assert cls.from_pretrained.call_count == 2 + first, second = cls.from_pretrained.call_args_list + assert first.kwargs.get("variant") == "fp16" + assert "variant" not in second.kwargs # the fallback drops the variant + + def test_fp32_never_requests_variant(self): + import torch + + remover = self._remover(torch.float32) + cls = MagicMock() + cls.from_pretrained.return_value = "PIPE" + remover._load_from_pretrained(cls, "some/model") + cls.from_pretrained.assert_called_once_with("some/model") + assert "variant" not in cls.from_pretrained.call_args.kwargs + + class _StubImage: """Minimal PIL.Image stand-in: just the ``width``/``height`` the pure helper reads."""