mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-01 00:31:35 +02:00
Test the untested pure logic: MPS fallback, tiling, isobmff/c2pa edges
Coverage audit (pytest --cov) found real, non-model logic at 0%/low cover. Add unit tests that need no model download: - img2img_runner.py 0% -> 100%: the MPS->CPU fallback orchestration, mocked via injected load_pipeline/reload_on_cpu callables. Guards the production behavior hit this session (native-res SDXL OOMs on MPS, must retry on CPU; non-MPS errors must propagate; "mps"-worded error on a cpu device must not reload). - ctrlregen/tiling.py 0% -> 40%: the pure tile math (tile_positions, make_blend_weight, resize_center_crop) that decides how large images are split and blended. (run_tiled stays model-bound, untested.) - isobmff.py 93% -> 100%: size==0 (box-to-EOF) and truncated 64-bit largesize parsing branches for AVIF/HEIF/JXL C2PA stripping. - c2pa.py: non-PNG-signed .png reads as clean (has_c2pa_metadata / extract_c2pa_chunk) instead of mis-parsing. 309 tests pass (+23). Document in CLAUDE.md that these pure helpers are unit-tested without downloads so future sessions don't skip them as "ML". No src/ change, no release. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -20,7 +20,7 @@ You are a **principal Python engineer** maintaining a CLI tool and library for r
|
||||
## Configuration
|
||||
|
||||
- GPU/ML modules (invisible_engine, ctrlregen, watermark_remover) are optional — guard imports with `is_available()` checks
|
||||
- Tests for ML modules are limited to availability checks (require multi-GB downloads)
|
||||
- Tests for the *model-running* paths are limited to availability checks (multi-GB downloads). But the **pure helpers inside ML-adjacent modules are unit-tested without any download** and must stay that way: `_target_size` (native-vs-downscale, `test_invisible_engine.py`), the MPS->CPU fallback control flow via mocked pipelines (`test_img2img_runner.py`, 100% cover), and the tiling math `tile_positions`/`make_blend_weight`/`resize_center_crop` (`test_tiling.py`; `pytest.importorskip("torch")` since `tiling.py` imports torch at module top). Don't skip these as "ML, needs a model" — only `run_tiled`/`remove_watermark`/the diffusion bodies do.
|
||||
|
||||
## Key modules
|
||||
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Unit tests for the MPS->CPU fallback orchestration (no GPU/model required).
|
||||
|
||||
``img2img_runner`` has no torch import at module top -- the pipeline is
|
||||
injected as a plain callable -- so the fallback control flow is fully
|
||||
mockable. This guards the exact behavior hit in production on Apple Silicon:
|
||||
a native-resolution SDXL run that OOMs on MPS must transparently retry on CPU,
|
||||
while any non-MPS error must propagate unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from remove_ai_watermarks.noai import img2img_runner
|
||||
from remove_ai_watermarks.noai.img2img_runner import run_img2img, run_img2img_with_mps_fallback
|
||||
|
||||
_MPS_OOM = "MPS backend out of memory (MPS allocated: 17.21 GiB, max allowed: 20.13 GiB)"
|
||||
|
||||
|
||||
def _result(image: object) -> Mock:
|
||||
"""A stand-in for a diffusers pipeline output object (has .images)."""
|
||||
out = Mock()
|
||||
out.images = [image]
|
||||
return out
|
||||
|
||||
|
||||
class TestMpsFallback:
|
||||
def test_mps_error_reloads_on_cpu_and_retries(self, monkeypatch: pytest.MonkeyPatch):
|
||||
sentinel = object()
|
||||
inner = Mock(side_effect=[RuntimeError(_MPS_OOM), sentinel])
|
||||
monkeypatch.setattr(img2img_runner, "run_img2img", inner)
|
||||
load_pipeline = Mock(return_value="gpu_pipe")
|
||||
reload_on_cpu = Mock(return_value="cpu_pipe")
|
||||
|
||||
img, device = run_img2img_with_mps_fallback(
|
||||
load_pipeline, object(), 0.05, 50, 7.5, "gen", "mps", lambda _m: None, reload_on_cpu=reload_on_cpu
|
||||
)
|
||||
|
||||
assert (img, device) == (sentinel, "cpu")
|
||||
reload_on_cpu.assert_called_once()
|
||||
assert inner.call_count == 2
|
||||
# Retry must use the reloaded CPU pipeline, device "cpu", and drop the
|
||||
# MPS generator (generator=None) so CPU runs deterministically.
|
||||
retry_args = inner.call_args_list[1].args
|
||||
assert retry_args[0] == "cpu_pipe"
|
||||
assert retry_args[5] is None # generator
|
||||
assert retry_args[6] == "cpu" # device
|
||||
|
||||
def test_happy_path_returns_original_device_without_reload(self, monkeypatch: pytest.MonkeyPatch):
|
||||
sentinel = object()
|
||||
monkeypatch.setattr(img2img_runner, "run_img2img", Mock(return_value=sentinel))
|
||||
reload_on_cpu = Mock()
|
||||
|
||||
img, device = run_img2img_with_mps_fallback(
|
||||
Mock(return_value="gpu_pipe"),
|
||||
object(),
|
||||
0.05,
|
||||
50,
|
||||
7.5,
|
||||
"gen",
|
||||
"mps",
|
||||
lambda _m: None,
|
||||
reload_on_cpu=reload_on_cpu,
|
||||
)
|
||||
|
||||
assert (img, device) == (sentinel, "mps")
|
||||
reload_on_cpu.assert_not_called()
|
||||
|
||||
def test_non_mps_runtime_error_propagates(self, monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(img2img_runner, "run_img2img", Mock(side_effect=RuntimeError("CUDA out of memory")))
|
||||
reload_on_cpu = Mock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="CUDA"):
|
||||
run_img2img_with_mps_fallback(
|
||||
Mock(return_value="gpu_pipe"),
|
||||
object(),
|
||||
0.05,
|
||||
50,
|
||||
7.5,
|
||||
"gen",
|
||||
"mps",
|
||||
lambda _m: None,
|
||||
reload_on_cpu=reload_on_cpu,
|
||||
)
|
||||
reload_on_cpu.assert_not_called()
|
||||
|
||||
def test_mps_error_on_non_mps_device_propagates(self, monkeypatch: pytest.MonkeyPatch):
|
||||
# An "mps"-worded error while running on cpu must NOT trigger the reload.
|
||||
monkeypatch.setattr(img2img_runner, "run_img2img", Mock(side_effect=RuntimeError(_MPS_OOM)))
|
||||
reload_on_cpu = Mock()
|
||||
|
||||
with pytest.raises(RuntimeError, match="MPS backend"):
|
||||
run_img2img_with_mps_fallback(
|
||||
Mock(return_value="cpu_pipe"),
|
||||
object(),
|
||||
0.05,
|
||||
50,
|
||||
7.5,
|
||||
None,
|
||||
"cpu",
|
||||
lambda _m: None,
|
||||
reload_on_cpu=reload_on_cpu,
|
||||
)
|
||||
reload_on_cpu.assert_not_called()
|
||||
|
||||
|
||||
class TestRunImg2Img:
|
||||
def test_returns_first_image_from_pipeline_result(self):
|
||||
sentinel = object()
|
||||
pipeline = Mock(return_value=_result(sentinel))
|
||||
|
||||
out = run_img2img(pipeline, object(), 0.05, 50, 7.5, None, "cpu", lambda _m: None)
|
||||
|
||||
assert out is sentinel
|
||||
|
||||
def test_typeerror_on_callback_retries_without_callback(self):
|
||||
# Older diffusers reject the progress callback kwarg with TypeError;
|
||||
# run_img2img must retry once without it rather than fail.
|
||||
sentinel = object()
|
||||
pipeline = Mock(side_effect=[TypeError("unexpected keyword 'callback'"), _result(sentinel)])
|
||||
|
||||
out = run_img2img(pipeline, object(), 0.05, 50, 7.5, None, "cpu", lambda _m: None)
|
||||
|
||||
assert out is sentinel
|
||||
assert pipeline.call_count == 2
|
||||
# First attempt passes the progress callback; the retry omits it.
|
||||
assert "callback" in pipeline.call_args_list[0].kwargs
|
||||
assert "callback" not in pipeline.call_args_list[1].kwargs
|
||||
@@ -294,3 +294,31 @@ class TestISOBMFF:
|
||||
cleaned, stripped = strip_c2pa_boxes(FTYP + b"\x00\x00\x00\x04XXXX")
|
||||
assert stripped == 0
|
||||
assert cleaned.startswith(FTYP)
|
||||
|
||||
def test_size_zero_box_runs_to_eof(self):
|
||||
# size32==0 means the box extends to EOF; a non-C2PA box round-trips.
|
||||
box = struct.pack(">I", 0) + b"free" + b"\x00\x00\x00\x00"
|
||||
cleaned, stripped = strip_c2pa_boxes(FTYP + box)
|
||||
assert stripped == 0
|
||||
assert cleaned == FTYP + box
|
||||
|
||||
def test_truncated_largesize_terminates_safely(self):
|
||||
# size32==1 promises a 64-bit largesize, but the box ends after 8 bytes;
|
||||
# iteration must stop rather than read the missing largesize past EOF.
|
||||
cleaned, stripped = strip_c2pa_boxes(FTYP + b"\x00\x00\x00\x01uuid")
|
||||
assert stripped == 0
|
||||
assert cleaned == FTYP
|
||||
|
||||
|
||||
class TestC2PAInvalidSignature:
|
||||
"""A .png file that is not actually PNG-signed must read as clean, not crash."""
|
||||
|
||||
def test_has_c2pa_false_for_non_png_bytes(self, tmp_path: Path):
|
||||
fake = tmp_path / "fake.png"
|
||||
fake.write_bytes(b"\xff\xd8\xff\xe0 not a png at all, just garbage bytes")
|
||||
assert has_c2pa_metadata(fake) is False
|
||||
|
||||
def test_extract_chunk_none_for_non_png_bytes(self, tmp_path: Path):
|
||||
fake = tmp_path / "fake.png"
|
||||
fake.write_bytes(b"\xff\xd8\xff\xe0 not a png at all, just garbage bytes")
|
||||
assert extract_c2pa_chunk(fake) is None
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Unit tests for the pure tiling helpers (no GPU/model required).
|
||||
|
||||
``tiling.py`` imports torch at module top, so skip cleanly when torch is
|
||||
absent. The helpers themselves are pure numpy/PIL/math -- they decide how a
|
||||
large image is split into overlapping tiles and blended back, so a regression
|
||||
here would seam or crop the CtrlRegen output wrongly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("torch")
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from remove_ai_watermarks.noai.ctrlregen.tiling import (
|
||||
make_blend_weight,
|
||||
resize_center_crop,
|
||||
tile_positions,
|
||||
)
|
||||
|
||||
|
||||
class TestTilePositions:
|
||||
def test_image_smaller_than_tile_single_position(self):
|
||||
assert tile_positions(500, 512, 64) == [0]
|
||||
|
||||
def test_image_equal_to_tile_single_position(self):
|
||||
assert tile_positions(512, 512, 64) == [0]
|
||||
|
||||
def test_first_is_zero_last_is_total_minus_tile(self):
|
||||
# The tiles must fully cover the span: first starts at 0, last ends at
|
||||
# the far edge (start == total - tile), or the image's edge is missed.
|
||||
pos = tile_positions(2000, 512, 64)
|
||||
assert pos[0] == 0
|
||||
assert pos[-1] == 2000 - 512
|
||||
|
||||
def test_overlap_positions_are_monotonic_and_exact(self):
|
||||
assert tile_positions(1000, 512, 64) == [0, 244, 488]
|
||||
|
||||
def test_zero_overlap_tiles_are_contiguous(self):
|
||||
# 1024 wide, 512 tile, no overlap -> two tiles butting at 512.
|
||||
assert tile_positions(1024, 512, 0) == [0, 512]
|
||||
|
||||
|
||||
class TestMakeBlendWeight:
|
||||
def test_zero_overlap_is_all_ones(self):
|
||||
w = make_blend_weight(8, 8, 0)
|
||||
assert w.shape == (8, 8)
|
||||
assert w.dtype == np.float64
|
||||
assert np.all(w == 1.0)
|
||||
|
||||
def test_overlap_ramps_corners_to_zero_center_to_one(self):
|
||||
w = make_blend_weight(16, 16, 4)
|
||||
assert w[0, 0] == 0.0 # cosine ramp starts at 0
|
||||
assert w[8, 8] == 1.0 # center is unweighted
|
||||
assert w.max() == 1.0
|
||||
assert w.min() == 0.0
|
||||
|
||||
def test_weight_is_point_symmetric(self):
|
||||
# Symmetric ramps on both edges -> mask equals its 180-degree rotation,
|
||||
# so opposite tile seams blend identically.
|
||||
w = make_blend_weight(16, 16, 4)
|
||||
assert np.allclose(w, w[::-1, ::-1])
|
||||
|
||||
|
||||
class TestResizeCenterCrop:
|
||||
@pytest.mark.parametrize(("width", "height"), [(400, 800), (800, 400), (300, 300), (1000, 1001)])
|
||||
def test_output_is_always_square_of_requested_size(self, width: int, height: int):
|
||||
out = resize_center_crop(Image.new("RGB", (width, height)), 256)
|
||||
assert out.size == (256, 256)
|
||||
|
||||
def test_default_size_is_512(self):
|
||||
out = resize_center_crop(Image.new("RGB", (640, 480)))
|
||||
assert out.size == (512, 512)
|
||||
Reference in New Issue
Block a user