mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-05 02:28:00 +02:00
Fix #29 black output: use fp16-fixed SDXL VAE on fp16 GPUs
The stock SDXL VAE overflows to NaN in fp16, so the plain img2img path decodes to an all-black image on a CUDA/XPU fp16 backend. This is the raiw.cc black result HitaoLin reported (a 1086x1448 input came back uniformly black). cpu/mps run fp32 and never hit it, and the differential / region-hires pipeline already upcasts the VAE itself, so only the plain path on a fp16 GPU was exposed. `_load_pipeline` now loads `madebyollin/sdxl-vae-fp16-fix` for the default SDXL checkpoint when running fp16, gated by the pure helper `_needs_fp16_vae_fix`. A custom non-SDXL model keeps its own VAE. The decision logic is unit-tested without a download (TestFp16VaeFix). The black->clean recovery itself needs a CUDA GPU and was not verifiable on this MPS machine; it must be confirmed on the backend. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -204,3 +204,28 @@ class TestPlatformPaths:
|
||||
# If we get here without error, asset loading works
|
||||
assert engine._alpha_small.shape == (48, 48)
|
||||
assert engine._alpha_large.shape == (96, 96)
|
||||
|
||||
|
||||
class TestFp16VaeFix:
|
||||
"""The plain SDXL img2img pipeline must swap in the fp16-fixed VAE on fp16
|
||||
GPUs to avoid the NaN/all-black decode (issue #29). Pure decision logic, no
|
||||
torch or model download needed."""
|
||||
|
||||
DEFAULT = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
def test_default_sdxl_on_fp16_needs_fix(self):
|
||||
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
|
||||
|
||||
assert _needs_fp16_vae_fix(self.DEFAULT, self.DEFAULT, is_fp16=True) is True
|
||||
|
||||
def test_fp32_does_not_need_fix(self):
|
||||
"""cpu/mps run fp32, where the stock SDXL VAE is fine."""
|
||||
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
|
||||
|
||||
assert _needs_fp16_vae_fix(self.DEFAULT, self.DEFAULT, is_fp16=False) is False
|
||||
|
||||
def test_non_default_model_keeps_own_vae(self):
|
||||
"""A custom (non-SDXL) checkpoint must not get the SDXL-specific VAE."""
|
||||
from remove_ai_watermarks.noai.watermark_remover import _needs_fp16_vae_fix
|
||||
|
||||
assert _needs_fp16_vae_fix("runwayml/stable-diffusion-v1-5", self.DEFAULT, is_fp16=True) is False
|
||||
|
||||
Reference in New Issue
Block a user