mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-05 10:38:00 +02:00
Fix #29 black output: use fp16-fixed SDXL VAE on fp16 GPUs
The stock SDXL VAE overflows to NaN in fp16, so the plain img2img path decodes to an all-black image on a CUDA/XPU fp16 backend. This is the raiw.cc black result HitaoLin reported (a 1086x1448 input came back uniformly black). cpu/mps run fp32 and never hit it, and the differential / region-hires pipeline already upcasts the VAE itself, so only the plain path on a fp16 GPU was exposed. `_load_pipeline` now loads `madebyollin/sdxl-vae-fp16-fix` for the default SDXL checkpoint when running fp16, gated by the pure helper `_needs_fp16_vae_fix`. A custom non-SDXL model keeps its own VAE. The decision logic is unit-tested without a download (TestFp16VaeFix). The black->clean recovery itself needs a CUDA GPU and was not verifiable on this MPS machine; it must be confirmed on the backend. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -61,6 +61,24 @@ def is_watermark_removal_available() -> bool:
|
||||
return _HAS_TORCH and _HAS_DIFFUSERS
|
||||
|
||||
|
||||
# Drop-in fp16-safe replacement for the SDXL VAE. The stock SDXL VAE overflows
|
||||
# to NaN in fp16 and decodes to an all-black image (issue #29: the raiw.cc black
|
||||
# result on a CUDA fp16 backend). This community VAE is numerically rescaled to
|
||||
# stay in fp16 range. SDXL-architecture only.
|
||||
_SDXL_FP16_VAE_ID = "madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
|
||||
def _needs_fp16_vae_fix(model_id: str, default_model_id: str, is_fp16: bool) -> bool:
|
||||
"""Whether the plain img2img pipeline must swap in the fp16-fixed SDXL VAE.
|
||||
|
||||
Gated to the default SDXL checkpoint running in fp16: cpu/mps run fp32 (the
|
||||
stock VAE is fine there) and the differential pipeline upcasts the VAE on its
|
||||
own, so only this path on a fp16 GPU (CUDA/XPU) hits the NaN/black decode.
|
||||
A custom non-SDXL ``model_id`` keeps its own VAE (the fix is SDXL-specific).
|
||||
"""
|
||||
return is_fp16 and model_id == default_model_id
|
||||
|
||||
|
||||
_CUDA_FIX_ENV_KEY = "NOAI_CUDA_FIXED"
|
||||
|
||||
|
||||
@@ -370,6 +388,14 @@ class WatermarkRemover:
|
||||
if self.hf_token:
|
||||
load_kwargs["token"] = self.hf_token
|
||||
|
||||
# Avoid the SDXL fp16 NaN/all-black decode (issue #29) by loading the
|
||||
# fp16-fixed VAE for the default SDXL checkpoint on a fp16 GPU.
|
||||
if _needs_fp16_vae_fix(self.model_id, self.DEFAULT_MODEL_ID, self.torch_dtype == torch.float16):
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
self._set_progress("Loading fp16-fixed SDXL VAE (avoids black output)...")
|
||||
load_kwargs["vae"] = AutoencoderKL.from_pretrained(_SDXL_FP16_VAE_ID, torch_dtype=torch.float16)
|
||||
|
||||
self._pipeline = AutoImg2ImgPipeline.from_pretrained( # type: ignore
|
||||
self.model_id,
|
||||
**load_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user