fix(invisible): retry in fp32 on a degenerate fp16 output (#41)

The fp16-fix VAE swap (#29) is gated to the default SDXL checkpoint, so a
custom model_id, a stale pre-fix install, or a fal/custom loader can still
decode to an all-black/NaN frame in fp16 (reporter: gpt-image 1448x1086,
the `image_processor.py invalid value encountered in cast` warning).

Add a model-agnostic backstop in remove_watermark: after generation, if the
run was fp16 and the output is degenerate (_is_degenerate_image: near-zero
mean and variance), rebuild the pipeline in fp32 on the same device and
re-run once. fp32 is the verified-clean path, so a black image is never
returned regardless of model_id or version. Mirrors the MPS->CPU fallback's
self-mutation pattern; batch inherits it. Verified e2e on MPS by forcing
fp16 with the swap disabled (first pass black, guard fired, retry clean).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Victor Kuznetsov
2026-06-04 17:43:27 -07:00
parent ec549b5c55
commit 6f4aa4c7b1
3 changed files with 70 additions and 17 deletions
@@ -79,6 +79,28 @@ def _needs_fp16_vae_fix(model_id: str, default_model_id: str, is_fp16: bool) ->
return is_fp16 and model_id == default_model_id
# An fp16 VAE/UNet overflow decodes to NaN, which diffusers' postprocess casts to 0
# -> a uniform all-black frame (issues #29, #41). The VAE swap above prevents it for
# the default checkpoint, but a custom model_id, a stale install, or a fal/custom
# loader can still bypass it. Detecting a degenerate output and retrying in fp32 (the
# path verified clean) is the model-agnostic safety net: never hand back a black image.
# One threshold serves both guards: a NaN->0 collapse drives mean and variance to ~0.
_DEGENERATE_THRESHOLD = 1.0
def _is_degenerate_image(image: Image.Image) -> bool:
"""True if a generated image collapsed to an all-black/NaN frame (#29/#41).
A NaN fp16 decode casts to 0, so the output is a uniform near-zero image: an
extremely low mean AND near-zero variance. The variance guard keeps a
legitimately dark-but-textured photo (low mean, real detail) from being flagged.
"""
import numpy as np
arr = np.asarray(image.convert("RGB"), dtype=np.float32)
return float(arr.mean()) < _DEGENERATE_THRESHOLD and float(arr.std()) < _DEGENERATE_THRESHOLD
_CUDA_FIX_ENV_KEY = "NOAI_CUDA_FIXED"
@@ -513,22 +535,25 @@ class WatermarkRemover:
_total_start = time.monotonic()
if self.model_profile == "controlnet":
cleaned_image = self._run_controlnet(
init_image,
strength,
num_inference_steps,
guidance_scale,
generator,
)
else:
cleaned_image = self._run_img2img(
init_image,
strength,
num_inference_steps,
guidance_scale,
generator,
)
def _generate() -> Image.Image:
if self.model_profile == "controlnet":
return self._run_controlnet(init_image, strength, num_inference_steps, guidance_scale, generator)
return self._run_img2img(init_image, strength, num_inference_steps, guidance_scale, generator)
cleaned_image = _generate()
# Safety net for the fp16 all-black/NaN decode (#29/#41): if an fp16 run
# produced a degenerate (uniform black) frame -- the VAE swap did not engage
# for this model/version -- retry once in fp32 on the same device (verified
# clean) so the user never gets a black image. Skipped when an MPS->CPU
# fallback already moved us to fp32.
if self.torch_dtype == torch.float16 and _is_degenerate_image(cleaned_image):
logger.warning("fp16 output was degenerate (all-black/NaN, #29/#41); retrying in fp32 on %s.", self.device)
self._set_progress("Output was black (fp16 overflow); retrying in fp32...")
self.torch_dtype = torch.float32
self._pipeline = None
self._controlnet_pipeline = None
cleaned_image = _generate()
self._set_progress(f"Regeneration complete · Output: {w}x{h}px {cleaned_image.mode}")