fix(invisible): retry in fp32 on a degenerate fp16 output (#41)

The fp16-fix VAE swap (#29) is gated to the default SDXL checkpoint, so a
custom model_id, a stale pre-fix install, or a fal/custom loader can still
decode to an all-black/NaN frame in fp16 (reporter: gpt-image 1448x1086,
the `image_processor.py invalid value encountered in cast` warning).

Add a model-agnostic backstop in remove_watermark: after generation, if the
run was fp16 and the output is degenerate (_is_degenerate_image: near-zero
mean and variance), rebuild the pipeline in fp32 on the same device and
re-run once. fp32 is the verified-clean path, so a black image is never
returned regardless of model_id or version. Mirrors the MPS->CPU fallback's
self-mutation pattern; batch inherits it. Verified e2e on MPS by forcing
fp16 with the swap disabled (first pass black, guard fired, retry clean).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Victor Kuznetsov
2026-06-04 17:43:27 -07:00
parent ec549b5c55
commit 6f4aa4c7b1
3 changed files with 70 additions and 17 deletions
+28
View File
@@ -9,7 +9,9 @@ from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from PIL import Image
from remove_ai_watermarks.noai.progress import is_mps_error
from remove_ai_watermarks.noai.utils import get_image_format, is_supported_format
@@ -298,3 +300,29 @@ class TestFp16VaeFix:
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
assert _needs_fp16_vae_fix("runwayml/stable-diffusion-v1-5", self.DEFAULT, is_fp16=True) is False
class TestDegenerateOutputGuard:
"""The fp16 black-output safety net (#29/#41): detect an all-black/NaN frame so
``remove_watermark`` can retry in fp32. Pure image statistics, no model needed."""
def test_all_black_is_degenerate(self):
from remove_ai_watermarks.noai.watermark_remover import _is_degenerate_image
black = Image.fromarray(np.zeros((64, 64, 3), np.uint8))
assert _is_degenerate_image(black) is True
def test_normal_image_is_not_degenerate(self):
from remove_ai_watermarks.noai.watermark_remover import _is_degenerate_image
rng = np.random.default_rng(0)
normal = Image.fromarray(rng.integers(0, 256, (64, 64, 3), dtype=np.uint8))
assert _is_degenerate_image(normal) is False
def test_dark_but_textured_image_is_not_degenerate(self):
"""A legitimately dark photo with real detail must NOT be flagged (variance guard)."""
from remove_ai_watermarks.noai.watermark_remover import _is_degenerate_image
rng = np.random.default_rng(1)
dark = Image.fromarray(rng.integers(0, 40, (64, 64, 3), dtype=np.uint8))
assert _is_degenerate_image(dark) is False