mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-10 12:53:56 +02:00
fix(invisible): retry in fp32 on a degenerate fp16 output (#41)
The fp16-fix VAE swap (#29) is gated to the default SDXL checkpoint, so a custom model_id, a stale pre-fix install, or a fal/custom loader can still decode to an all-black/NaN frame in fp16 (reporter: gpt-image 1448x1086, the `image_processor.py invalid value encountered in cast` warning). Add a model-agnostic backstop in remove_watermark: after generation, if the run was fp16 and the output is degenerate (_is_degenerate_image: near-zero mean and variance), rebuild the pipeline in fp32 on the same device and re-run once. fp32 is the verified-clean path, so a black image is never returned regardless of model_id or version. Mirrors the MPS->CPU fallback's self-mutation pattern; batch inherits it. Verified e2e on MPS by forcing fp16 with the swap disabled (first pass black, guard fired, retry clean). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -79,6 +79,28 @@ def _needs_fp16_vae_fix(model_id: str, default_model_id: str, is_fp16: bool) ->
|
||||
return is_fp16 and model_id == default_model_id
|
||||
|
||||
|
||||
# An fp16 VAE/UNet overflow decodes to NaN, which diffusers' postprocess casts to 0
|
||||
# -> a uniform all-black frame (issues #29, #41). The VAE swap above prevents it for
|
||||
# the default checkpoint, but a custom model_id, a stale install, or a fal/custom
|
||||
# loader can still bypass it. Detecting a degenerate output and retrying in fp32 (the
|
||||
# path verified clean) is the model-agnostic safety net: never hand back a black image.
|
||||
# One threshold serves both guards: a NaN->0 collapse drives mean and variance to ~0.
|
||||
_DEGENERATE_THRESHOLD = 1.0
|
||||
|
||||
|
||||
def _is_degenerate_image(image: Image.Image) -> bool:
|
||||
"""True if a generated image collapsed to an all-black/NaN frame (#29/#41).
|
||||
|
||||
A NaN fp16 decode casts to 0, so the output is a uniform near-zero image: an
|
||||
extremely low mean AND near-zero variance. The variance guard keeps a
|
||||
legitimately dark-but-textured photo (low mean, real detail) from being flagged.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
arr = np.asarray(image.convert("RGB"), dtype=np.float32)
|
||||
return float(arr.mean()) < _DEGENERATE_THRESHOLD and float(arr.std()) < _DEGENERATE_THRESHOLD
|
||||
|
||||
|
||||
_CUDA_FIX_ENV_KEY = "NOAI_CUDA_FIXED"
|
||||
|
||||
|
||||
@@ -513,22 +535,25 @@ class WatermarkRemover:
|
||||
|
||||
_total_start = time.monotonic()
|
||||
|
||||
if self.model_profile == "controlnet":
|
||||
cleaned_image = self._run_controlnet(
|
||||
init_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
cleaned_image = self._run_img2img(
|
||||
init_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
)
|
||||
def _generate() -> Image.Image:
|
||||
if self.model_profile == "controlnet":
|
||||
return self._run_controlnet(init_image, strength, num_inference_steps, guidance_scale, generator)
|
||||
return self._run_img2img(init_image, strength, num_inference_steps, guidance_scale, generator)
|
||||
|
||||
cleaned_image = _generate()
|
||||
|
||||
# Safety net for the fp16 all-black/NaN decode (#29/#41): if an fp16 run
|
||||
# produced a degenerate (uniform black) frame -- the VAE swap did not engage
|
||||
# for this model/version -- retry once in fp32 on the same device (verified
|
||||
# clean) so the user never gets a black image. Skipped when an MPS->CPU
|
||||
# fallback already moved us to fp32.
|
||||
if self.torch_dtype == torch.float16 and _is_degenerate_image(cleaned_image):
|
||||
logger.warning("fp16 output was degenerate (all-black/NaN, #29/#41); retrying in fp32 on %s.", self.device)
|
||||
self._set_progress("Output was black (fp16 overflow); retrying in fp32...")
|
||||
self.torch_dtype = torch.float32
|
||||
self._pipeline = None
|
||||
self._controlnet_pipeline = None
|
||||
cleaned_image = _generate()
|
||||
|
||||
self._set_progress(f"Regeneration complete · Output: {w}x{h}px {cleaned_image.mode}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user