mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-05 02:28:00 +02:00
Fix #29 black output: use fp16-fixed SDXL VAE on fp16 GPUs
The stock SDXL VAE overflows to NaN in fp16, so the plain img2img path decodes to an all-black image on a CUDA/XPU fp16 backend. This is the raiw.cc black result HitaoLin reported (a 1086x1448 input came back uniformly black). cpu/mps run fp32 and never hit it, and the differential / region-hires pipeline already upcasts the VAE itself, so only the plain path on a fp16 GPU was exposed. `_load_pipeline` now loads `madebyollin/sdxl-vae-fp16-fix` for the default SDXL checkpoint when running fp16, gated by the pure helper `_needs_fp16_vae_fix`. A custom non-SDXL model keeps its own VAE. The decision logic is unit-tested without a download (TestFp16VaeFix). The black->clean recovery itself needs a CUDA GPU and was not verifiable on this MPS machine; it must be confirmed on the backend. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -74,6 +74,7 @@ Who embeds what, and whether it is locally detectable (so we know which gaps are
|
||||
## Known limitations
|
||||
|
||||
- `invisible` pipeline processes at **native resolution by default** (`max_resolution=0`), matching the hosted raiw.cc backend (fal fast-sdxl, no pre-downscale). The old forced downscale-to-1024 -> upscale-back round-trip was the main quality loss (issue #10) and is gone; at strength ~0.05 SDXL img2img does not need the ~1024 downscale. `--max-resolution N` re-introduces an opt-in long-side cap purely to bound GPU/MPS memory on very large inputs (it reintroduces the lossy round-trip). For huge images that OOM at native, tile-based diffusion is still the proper long-term fix. **Concrete MPS data point (verified 2026-05-25 on a 1254x1254 gpt-image SDXL run, fp32, 20 GB MPS ceiling):** native res OOMs at the *UNet* step (peak ~17 GiB), not only the VAE decode, and the auto-fallback in `img2img_runner` reloads on CPU and finishes (slow, ~13 min) -- the output is still weight-identical and defeats SynthID, so "looks hung/crashed" on Mac is usually this CPU fallback, not a pipeline error. Adding `enable_vae_tiling()` alone does NOT prevent it (the peak is the UNet, not the VAE). The fast Mac workarounds are fp16 on MPS (roughly halves memory) or `--max-resolution` to cap the long side; neither is wired as the default. The native-vs-downscale decision lives in the pure helper `invisible_engine._target_size(w, h, max_resolution)` (returns `None` for native, a clamped target tuple otherwise) so it is unit-tested (`tests/test_invisible_engine.py::TestTargetSize`, the #10/#15 regression guard) without loading the model -- keep that logic in the helper, don't re-inline it.
|
||||
- **fp16 VAE black-output fix (issue #29, 2026-05-30):** on a **CUDA/XPU fp16** backend the stock SDXL VAE overflows to NaN and the *plain* img2img path decodes to an **all-black** image (reproduced on the raiw.cc result: a 1086x1448 input -> a uniformly black 4.6 KB PNG, mean 0). `watermark_remover._load_pipeline` now swaps in the fp16-fixed SDXL VAE (`madebyollin/sdxl-vae-fp16-fix` = `_SDXL_FP16_VAE_ID`) when `_needs_fp16_vae_fix(model_id, DEFAULT_MODEL_ID, is_fp16)` is true -- only the default SDXL checkpoint on fp16. **cpu/mps run fp32** (the stock VAE is fine there, which is why the bug never reproduces on Mac), and the **differential / region-hires** pipeline already upcasts the VAE itself (see the `text_protector` bullet). A custom non-SDXL `model_id` keeps its own VAE (the fp16-fix VAE is SDXL-architecture-specific). The decision is a pure helper, unit-tested without a download (`tests/test_platform.py::TestFp16VaeFix`); the actual black->clean recovery needs a CUDA GPU and was NOT verifiable on this MPS machine -- confirm on the backend / an NVIDIA box.
|
||||
- Pyright first run is slow (2-3 min) due to ML deps (torch/diffusers/transformers stubs); full-project `uv run pyright` can stall for many minutes — scope it to changed files.
|
||||
- `ultralytics` monkey-patches `PIL.Image.open` and tries to autoload `pi_heif`. When `pi_heif` is missing, opening files raises `ModuleNotFoundError`, not `UnidentifiedImageError`. Code that opens user-supplied or unknown-format files should `except Exception`, not just `OSError`/`UnidentifiedImageError`.
|
||||
- **rich `console.print` parses `[word]` as a style tag and silently drops unknown ones.** A literal bracketed token in a print string disappears: `pip install 'remove-ai-watermarks[gpu]'` rendered as `...remove-ai-watermarks'` (the `[gpu]` extra eaten), which sent users a broken install command (surfaced via #19). Escape the literal bracket as `\[gpu]` (in a normal Python string that is `"\\[gpu]"`) in any rich string carrying user-facing brackets. Regression-guarded by `tests/test_cli.py::TestGpuHintMarkup`.
|
||||
|
||||
@@ -61,6 +61,24 @@ def is_watermark_removal_available() -> bool:
|
||||
return _HAS_TORCH and _HAS_DIFFUSERS
|
||||
|
||||
|
||||
# Drop-in fp16-safe replacement for the SDXL VAE. The stock SDXL VAE overflows
|
||||
# to NaN in fp16 and decodes to an all-black image (issue #29: the raiw.cc black
|
||||
# result on a CUDA fp16 backend). This community VAE is numerically rescaled to
|
||||
# stay in fp16 range. SDXL-architecture only.
|
||||
_SDXL_FP16_VAE_ID = "madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
|
||||
def _needs_fp16_vae_fix(model_id: str, default_model_id: str, is_fp16: bool) -> bool:
|
||||
"""Whether the plain img2img pipeline must swap in the fp16-fixed SDXL VAE.
|
||||
|
||||
Gated to the default SDXL checkpoint running in fp16: cpu/mps run fp32 (the
|
||||
stock VAE is fine there) and the differential pipeline upcasts the VAE on its
|
||||
own, so only this path on a fp16 GPU (CUDA/XPU) hits the NaN/black decode.
|
||||
A custom non-SDXL ``model_id`` keeps its own VAE (the fix is SDXL-specific).
|
||||
"""
|
||||
return is_fp16 and model_id == default_model_id
|
||||
|
||||
|
||||
_CUDA_FIX_ENV_KEY = "NOAI_CUDA_FIXED"
|
||||
|
||||
|
||||
@@ -370,6 +388,14 @@ class WatermarkRemover:
|
||||
if self.hf_token:
|
||||
load_kwargs["token"] = self.hf_token
|
||||
|
||||
# Avoid the SDXL fp16 NaN/all-black decode (issue #29) by loading the
|
||||
# fp16-fixed VAE for the default SDXL checkpoint on a fp16 GPU.
|
||||
if _needs_fp16_vae_fix(self.model_id, self.DEFAULT_MODEL_ID, self.torch_dtype == torch.float16):
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
self._set_progress("Loading fp16-fixed SDXL VAE (avoids black output)...")
|
||||
load_kwargs["vae"] = AutoencoderKL.from_pretrained(_SDXL_FP16_VAE_ID, torch_dtype=torch.float16)
|
||||
|
||||
self._pipeline = AutoImg2ImgPipeline.from_pretrained( # type: ignore
|
||||
self.model_id,
|
||||
**load_kwargs,
|
||||
|
||||
@@ -204,3 +204,28 @@ class TestPlatformPaths:
|
||||
# If we get here without error, asset loading works
|
||||
assert engine._alpha_small.shape == (48, 48)
|
||||
assert engine._alpha_large.shape == (96, 96)
|
||||
|
||||
|
||||
class TestFp16VaeFix:
|
||||
"""The plain SDXL img2img pipeline must swap in the fp16-fixed VAE on fp16
|
||||
GPUs to avoid the NaN/all-black decode (issue #29). Pure decision logic, no
|
||||
torch or model download needed."""
|
||||
|
||||
DEFAULT = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
def test_default_sdxl_on_fp16_needs_fix(self):
|
||||
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
|
||||
|
||||
assert _needs_fp16_vae_fix(self.DEFAULT, self.DEFAULT, is_fp16=True) is True
|
||||
|
||||
def test_fp32_does_not_need_fix(self):
|
||||
"""cpu/mps run fp32, where the stock SDXL VAE is fine."""
|
||||
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
|
||||
|
||||
assert _needs_fp16_vae_fix(self.DEFAULT, self.DEFAULT, is_fp16=False) is False
|
||||
|
||||
def test_non_default_model_keeps_own_vae(self):
|
||||
"""A custom (non-SDXL) checkpoint must not get the SDXL-specific VAE."""
|
||||
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
|
||||
|
||||
assert _needs_fp16_vae_fix("runwayml/stable-diffusion-v1-5", self.DEFAULT, is_fp16=True) is False
|
||||
|
||||
Reference in New Issue
Block a user