mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-05 02:28:00 +02:00
29da3c52b6
* Raise default SynthID-removal strength 0.05 -> 0.10 (current Google SynthID) The old default (0.04/0.05) no longer removes the CURRENT Google SynthID (Nano Banana / Gemini 3): verified 2026-05-30 via the Gemini 'Verify with SynthID' oracle on a real image -- 0.05 still detected, 0.10 not detected (OpenAI's was already cleared at 0.05). Add DEFAULT_STRENGTH=0.10 in watermark_profiles, route the engine + CLI defaults to it. At 0.10 small text deforms more, which is why text protection (_run_region_hires) runs by default. CLAUDE.md SynthID note corrected. CAVEAT: n=1 Google + n=1 OpenAI; broad corpus oracle validation pending (task tracked). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * Drop unused LOW/MEDIUM/HIGH strength profiles; CLI --strength defaults to DEFAULT_STRENGTH The fixed strength presets (and get_recommended_strength) were dead -- nothing in the pipeline used them, only tests. One knob now: DEFAULT_STRENGTH (0.10), overridable per-call via the CLI --strength flag, which now defaults to that constant (single source of truth). Removed the WatermarkRemover.LOW/MEDIUM/HIGH class attrs and the get_recommended_strength tests. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
207 lines
7.9 KiB
Python
207 lines
7.9 KiB
Python
"""Tests for cross-platform and cross-device compatibility.
|
|
|
|
Verifies that device detection, MPS fallback, and platform-specific
|
|
code paths work correctly on CPU, MPS (macOS), and CUDA (Linux/Windows).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from remove_ai_watermarks.noai.progress import is_mps_error
|
|
from remove_ai_watermarks.noai.utils import get_image_format, is_supported_format
|
|
from remove_ai_watermarks.noai.watermark_profiles import (
|
|
detect_model_profile,
|
|
get_model_id_for_profile,
|
|
)
|
|
from remove_ai_watermarks.noai.watermark_remover import get_device, is_watermark_removal_available
|
|
|
|
# ── Device detection ────────────────────────────────────────────────
|
|
|
|
|
|
class TestDeviceDetection:
|
|
"""Tests for get_device() across platforms."""
|
|
|
|
def test_returns_valid_device(self):
|
|
device = get_device()
|
|
assert device in ("cpu", "mps", "cuda", "xpu")
|
|
|
|
def test_cpu_fallback_when_no_gpu(self):
|
|
"""On CI / machines without GPU, should fall back to cpu or mps."""
|
|
device = get_device()
|
|
# Just verify it doesn't crash and returns a valid string
|
|
assert isinstance(device, str)
|
|
|
|
@patch("remove_ai_watermarks.noai.watermark_remover._HAS_TORCH", False)
|
|
def test_no_torch_returns_cpu(self):
|
|
assert get_device() == "cpu"
|
|
|
|
def test_xpu_selected_when_available(self):
|
|
"""An XPU-enabled torch (no CUDA) routes to the Intel GPU backend.
|
|
|
|
The whole torch module is mocked so the smoke-test ops succeed without
|
|
any real device; cuda must read False so the cuda branch is skipped.
|
|
"""
|
|
fake_torch = MagicMock()
|
|
fake_torch.cuda.is_available.return_value = False
|
|
fake_torch.xpu.is_available.return_value = True
|
|
with patch("remove_ai_watermarks.noai.watermark_remover.torch", fake_torch):
|
|
assert get_device() == "xpu"
|
|
fake_torch.tensor.assert_called_with([1.0], device="xpu")
|
|
|
|
def test_init_accepts_xpu_and_selects_fp16(self):
|
|
"""WatermarkRemover accepts device='xpu' and picks fp16 (not fp32)."""
|
|
if not is_watermark_removal_available():
|
|
pytest.skip("torch/diffusers not installed")
|
|
import torch
|
|
|
|
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover
|
|
|
|
remover = WatermarkRemover(device="xpu")
|
|
assert remover.device == "xpu"
|
|
assert remover.torch_dtype == torch.float16
|
|
|
|
def test_seed_generator_falls_back_to_cpu_when_device_rng_unsupported(self):
|
|
"""A device with no RNG backend (e.g. some torch-xpu builds) falls back
|
|
to a CPU generator instead of raising when --seed is used."""
|
|
from remove_ai_watermarks.noai import watermark_remover as wr
|
|
|
|
def fake_generator(device="cpu"):
|
|
if device == "xpu":
|
|
raise RuntimeError("Device type xpu is not supported for torch.Generator()")
|
|
gen = MagicMock()
|
|
gen.manual_seed.return_value = f"gen:{device}"
|
|
return gen
|
|
|
|
fake_torch = MagicMock()
|
|
fake_torch.Generator.side_effect = fake_generator
|
|
with patch.object(wr, "torch", fake_torch):
|
|
assert wr._make_seed_generator("xpu", 123) == "gen:cpu"
|
|
assert wr._make_seed_generator("cuda", 123) == "gen:cuda"
|
|
|
|
|
|
class TestMpsErrorDetection:
|
|
"""Tests for MPS error detection helper."""
|
|
|
|
def test_detects_mps_error(self):
|
|
err = RuntimeError("MPS backend out of memory")
|
|
assert is_mps_error(err) is True
|
|
|
|
def test_non_mps_error(self):
|
|
err = RuntimeError("CUDA out of memory")
|
|
assert is_mps_error(err) is False
|
|
|
|
def test_generic_error(self):
|
|
err = RuntimeError("something went wrong")
|
|
assert is_mps_error(err) is False
|
|
|
|
|
|
# ── Model profiles ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestModelProfiles:
|
|
"""Tests for watermark_profiles.py."""
|
|
|
|
def test_default_profile(self):
|
|
assert get_model_id_for_profile("default") == "stabilityai/stable-diffusion-xl-base-1.0"
|
|
|
|
def test_ctrlregen_profile(self):
|
|
assert get_model_id_for_profile("ctrlregen") == "yepengliu/ctrlregen"
|
|
|
|
def test_unknown_profile_raises(self):
|
|
with pytest.raises(ValueError, match="Unknown model profile"):
|
|
get_model_id_for_profile("nonexistent")
|
|
|
|
def test_detect_default(self):
|
|
assert detect_model_profile("stabilityai/stable-diffusion-xl-base-1.0") == "default"
|
|
|
|
def test_detect_ctrlregen(self):
|
|
assert detect_model_profile("yepengliu/ctrlregen") == "ctrlregen"
|
|
|
|
|
|
# ── Format utilities ────────────────────────────────────────────────
|
|
|
|
|
|
class TestFormatUtils:
|
|
"""Tests for utils.py format helpers."""
|
|
|
|
def test_supported_png(self, tmp_path):
|
|
assert is_supported_format(tmp_path / "test.png")
|
|
|
|
def test_supported_jpg(self, tmp_path):
|
|
assert is_supported_format(tmp_path / "test.jpg")
|
|
|
|
def test_supported_jpeg(self, tmp_path):
|
|
assert is_supported_format(tmp_path / "test.jpeg")
|
|
|
|
def test_supported_webp(self, tmp_path):
|
|
assert is_supported_format(tmp_path / "test.webp")
|
|
|
|
def test_unsupported_bmp(self, tmp_path):
|
|
assert not is_supported_format(tmp_path / "test.bmp")
|
|
|
|
def test_unsupported_gif(self, tmp_path):
|
|
assert not is_supported_format(tmp_path / "test.gif")
|
|
|
|
def test_get_format_png(self, tmp_path):
|
|
assert get_image_format(tmp_path / "x.png") == "PNG"
|
|
|
|
def test_get_format_jpg(self, tmp_path):
|
|
assert get_image_format(tmp_path / "x.jpg") == "JPEG"
|
|
|
|
def test_get_format_jpeg(self, tmp_path):
|
|
assert get_image_format(tmp_path / "x.jpeg") == "JPEG"
|
|
|
|
def test_get_format_webp_defaults_png(self, tmp_path):
|
|
# .webp falls through to PNG in current implementation
|
|
assert get_image_format(tmp_path / "x.webp") == "PNG"
|
|
|
|
|
|
# ── Availability checks ────────────────────────────────────────────
|
|
|
|
|
|
class TestAvailability:
|
|
"""Tests for dependency availability checks."""
|
|
|
|
def test_watermark_removal_available(self):
|
|
# Reflects the actual environment: True iff torch + diffusers (the gpu
|
|
# extra) are importable. The core+dev CI env has no diffusers, so this
|
|
# must not assume the full stack is present.
|
|
import importlib.util
|
|
|
|
expected = all(importlib.util.find_spec(m) is not None for m in ("torch", "diffusers"))
|
|
assert is_watermark_removal_available() is expected
|
|
|
|
def test_invisible_is_available(self):
|
|
import importlib.util
|
|
|
|
from remove_ai_watermarks.invisible_engine import is_available
|
|
|
|
expected = all(importlib.util.find_spec(m) is not None for m in ("torch", "diffusers"))
|
|
assert is_available() is expected
|
|
|
|
|
|
# ── Platform-specific path handling ─────────────────────────────────
|
|
|
|
|
|
class TestPlatformPaths:
|
|
"""Verify path handling works on current platform."""
|
|
|
|
def test_pathlib_works_for_assets(self):
|
|
from pathlib import Path
|
|
|
|
asset_dir = Path(__file__).parent.parent / "src" / "remove_ai_watermarks" / "assets"
|
|
assert (asset_dir / "gemini_bg_48.png").exists()
|
|
assert (asset_dir / "gemini_bg_96.png").exists()
|
|
|
|
def test_asset_loading_works(self):
|
|
"""Verify embedded assets load correctly (critical for packaging)."""
|
|
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
|
|
|
engine = GeminiEngine()
|
|
# If we get here without error, asset loading works
|
|
assert engine._alpha_small.shape == (48, 48)
|
|
assert engine._alpha_large.shape == (96, 96)
|