mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-07-04 23:47:49 +02:00
perf(invisible): read the fp16 weight variant to halve the cold-start weight load
InvisibleEngine loads SDXL/ControlNet in fp16 on CUDA/XPU but called from_pretrained without variant="fp16", so it read the full fp32 weight files (~7 GB) and downcast in memory. _load_from_pretrained now passes variant="fp16" when torch_dtype is float16, reading the half-precision files (~3.5 GB) instead - roughly halving the cold-start weight read + host->device transfer (a phase-timed Modal run measured weight load as ~half of the ~25s cold start). Falls back to the default weights when a checkpoint ships no fp16 variant (a custom --model), so the worst case is the prior behavior. fp32 (cpu/mps) and bf16 (qwen) never request the variant. Tests: TestFp16WeightVariant (variant requested on fp16, fallback on missing, never on fp32). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -490,6 +490,27 @@ class WatermarkRemover:
|
||||
load_kwargs["token"] = self.hf_token
|
||||
return load_kwargs
|
||||
|
||||
def _load_from_pretrained(self, cls: Any, model_id: str, **load_kwargs: Any) -> Any:
|
||||
"""Call ``cls.from_pretrained`` reading the fp16 weight VARIANT when on fp16.
|
||||
|
||||
When ``torch_dtype`` is float16 (the CUDA/XPU SDXL default), pass
|
||||
``variant="fp16"`` so diffusers fetches/reads the half-precision weight files
|
||||
(~half the bytes of the fp32 defaults) instead of reading the full fp32 files
|
||||
and downcasting in memory. On a warm weights cache this roughly halves the
|
||||
cold-start weight read + host->device transfer, which a phase-timed Modal run
|
||||
measured as ~half of the ~25s cold start. Not every checkpoint publishes an
|
||||
fp16 variant (a custom ``--model``, or the canny ControlNet if it ships only
|
||||
default files), so fall back to the default weights if the variant is missing -
|
||||
worst case is the prior behavior (read fp32, downcast). fp32 (cpu/mps) and bf16
|
||||
(qwen) never request the variant.
|
||||
"""
|
||||
if self.torch_dtype == torch.float16:
|
||||
try:
|
||||
return cls.from_pretrained(model_id, variant="fp16", **load_kwargs)
|
||||
except Exception as exc:
|
||||
logger.info("No fp16 weight variant for %s (%s); loading default weights", model_id, exc)
|
||||
return cls.from_pretrained(model_id, **load_kwargs)
|
||||
|
||||
def _load_pipeline(self) -> AutoImg2ImgPipeline:
|
||||
"""Load the plain SDXL img2img pipeline lazily."""
|
||||
if self._pipeline is None:
|
||||
@@ -501,7 +522,7 @@ class WatermarkRemover:
|
||||
load_kwargs["requires_safety_checker"] = False
|
||||
self._maybe_add_fp16_vae(load_kwargs)
|
||||
|
||||
pipeline = AutoImg2ImgPipeline.from_pretrained(self.model_id, **load_kwargs) # type: ignore
|
||||
pipeline = self._load_from_pretrained(AutoImg2ImgPipeline, self.model_id, **load_kwargs) # type: ignore
|
||||
self._pipeline = self._move_to_device_and_optimize(pipeline)
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
@@ -522,14 +543,18 @@ class WatermarkRemover:
|
||||
|
||||
logger.info("Loading SDXL + ControlNet (%s) on %s...", CONTROLNET_CANNY_MODEL, self.device)
|
||||
self._set_progress(f"Loading ControlNet: {CONTROLNET_CANNY_MODEL}")
|
||||
controlnet = ControlNetModel.from_pretrained(CONTROLNET_CANNY_MODEL, torch_dtype=self.torch_dtype)
|
||||
controlnet = self._load_from_pretrained(
|
||||
ControlNetModel, CONTROLNET_CANNY_MODEL, torch_dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
load_kwargs = self._base_load_kwargs()
|
||||
load_kwargs["controlnet"] = controlnet
|
||||
self._maybe_add_fp16_vae(load_kwargs)
|
||||
|
||||
self._set_progress(f"Loading model weights: {self.model_id}")
|
||||
pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(self.model_id, **load_kwargs)
|
||||
pipeline = self._load_from_pretrained(
|
||||
StableDiffusionXLControlNetImg2ImgPipeline, self.model_id, **load_kwargs
|
||||
)
|
||||
pipeline = self._move_to_device_and_optimize(pipeline)
|
||||
with contextlib.suppress(Exception):
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -126,6 +126,58 @@ class TestModelProfiles:
|
||||
assert normalize_profile("CONTROLNET") == "controlnet"
|
||||
|
||||
|
||||
class TestFp16WeightVariant:
|
||||
"""_load_from_pretrained reads the fp16 weight variant on fp16, with a fallback.
|
||||
|
||||
Loading the fp16 ``variant`` reads the half-precision weight files (~half the bytes)
|
||||
instead of the fp32 defaults + a downcast, which roughly halves the cold-start weight
|
||||
read. fp32 (cpu/mps) and bf16 (qwen) must never request the variant; a checkpoint
|
||||
without fp16 files must fall back to the default weights (prior behavior).
|
||||
"""
|
||||
|
||||
def _remover(self, dtype: object):
|
||||
if not is_watermark_removal_available():
|
||||
pytest.skip("torch/diffusers not installed")
|
||||
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover
|
||||
|
||||
# device="cpu" alone would force fp32; the explicit torch_dtype override lets us
|
||||
# exercise the fp16 path with no GPU (construction loads no weights).
|
||||
return WatermarkRemover(device="cpu", torch_dtype=dtype)
|
||||
|
||||
def test_fp16_requests_variant(self):
|
||||
import torch
|
||||
|
||||
remover = self._remover(torch.float16)
|
||||
cls = MagicMock()
|
||||
cls.from_pretrained.return_value = "PIPE"
|
||||
out = remover._load_from_pretrained(cls, "some/model", token="t")
|
||||
assert out == "PIPE"
|
||||
cls.from_pretrained.assert_called_once_with("some/model", variant="fp16", token="t")
|
||||
|
||||
def test_fp16_falls_back_when_variant_missing(self):
|
||||
import torch
|
||||
|
||||
remover = self._remover(torch.float16)
|
||||
cls = MagicMock()
|
||||
cls.from_pretrained.side_effect = [OSError("no fp16 weight files"), "PIPE"]
|
||||
out = remover._load_from_pretrained(cls, "some/model", token="t")
|
||||
assert out == "PIPE"
|
||||
assert cls.from_pretrained.call_count == 2
|
||||
first, second = cls.from_pretrained.call_args_list
|
||||
assert first.kwargs.get("variant") == "fp16"
|
||||
assert "variant" not in second.kwargs # the fallback drops the variant
|
||||
|
||||
def test_fp32_never_requests_variant(self):
|
||||
import torch
|
||||
|
||||
remover = self._remover(torch.float32)
|
||||
cls = MagicMock()
|
||||
cls.from_pretrained.return_value = "PIPE"
|
||||
remover._load_from_pretrained(cls, "some/model")
|
||||
cls.from_pretrained.assert_called_once_with("some/model")
|
||||
assert "variant" not in cls.from_pretrained.call_args.kwargs
|
||||
|
||||
|
||||
class _StubImage:
|
||||
"""Minimal PIL.Image stand-in: just the ``width``/``height`` the pure helper reads."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user