diff --git a/src/remove_ai_watermarks/cli.py b/src/remove_ai_watermarks/cli.py index 2eef982..1f70878 100644 --- a/src/remove_ai_watermarks/cli.py +++ b/src/remove_ai_watermarks/cli.py @@ -12,6 +12,7 @@ import json import logging import time from pathlib import Path +from typing import TYPE_CHECKING import click from rich.console import Console @@ -21,6 +22,11 @@ from rich.table import Table from remove_ai_watermarks import __version__ +if TYPE_CHECKING: + import numpy as np + + from remove_ai_watermarks.gemini_engine import DetectionResult + console = Console() SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp"} @@ -56,6 +62,74 @@ def _validate_image(path: Path) -> Path: return path +_ALPHA_FORMATS = {".png", ".webp"} + + +def _watermark_region(det: DetectionResult, width: int, height: int) -> tuple[int, int, int, int]: + """Pick a watermark bbox: detector's region if confident, else the default config slot.""" + if det.confidence > 0.15: + return det.region + from remove_ai_watermarks.gemini_engine import get_watermark_config + + config = get_watermark_config(width, height) + px, py = config.get_position(width, height) + return (px, py, config.logo_size, config.logo_size) + + +def _read_bgr_and_alpha(path: Path) -> tuple[np.ndarray | None, np.ndarray | None]: + """Read an image preserving its alpha channel separately. + + Returns ``(bgr, alpha)`` where ``alpha`` is a single-channel ndarray when the + source has transparency, else ``None``. Greyscale inputs are promoted to BGR. + Returns ``(None, None)`` if the image cannot be decoded. + """ + import cv2 + + image = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + if image is None: + return None, None + if image.ndim == 2: + return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR), None + if image.shape[2] == 4: + return image[:, :, :3].copy(), image[:, :, 3].copy() + return image, None + + +def _write_bgr_with_alpha( + path: Path, + bgr: np.ndarray, + alpha: np.ndarray | None, + clear_region: tuple[int, int, int, int] | None = None, + pad: int = 6, +) -> None: + """Write BGR (with optional alpha) to ``path``. + + When ``alpha`` is provided and the output extension supports it, writes a + 4-channel image. If ``clear_region`` is given as ``(x, y, w, h)``, alpha is + forced to 0 inside that bbox (expanded by ``pad`` px) so the watermark area + becomes fully transparent in the saved file. + """ + import cv2 + import numpy as np + + if alpha is None or path.suffix.lower() not in _ALPHA_FORMATS: + cv2.imwrite(str(path), bgr) + return + + alpha_out = alpha + if clear_region is not None: + alpha_out = alpha.copy() + x, y, w, h = clear_region + height, width = alpha.shape[:2] + x0, y0 = max(0, x - pad), max(0, y - pad) + x1, y1 = min(width, x + w + pad), min(height, y + h + pad) + if x1 > x0 and y1 > y0: + alpha_out[y0:y1, x0:x1] = 0 + + bgra = np.dstack([bgr, alpha_out]) + cv2.imwrite(str(path), bgra) + + # ── Main group ─────────────────────────────────────────────────────── @@ -110,8 +184,6 @@ def cmd_visible( Uses reverse alpha blending — fast, deterministic, offline. """ - import cv2 - from remove_ai_watermarks.gemini_engine import GeminiEngine _banner() @@ -122,8 +194,8 @@ def cmd_visible( engine = GeminiEngine() - # Load image - image = cv2.imread(str(source), cv2.IMREAD_COLOR) + # Load image (preserving any alpha channel separately) + image, alpha = _read_bgr_and_alpha(source) if image is None: console.print(f"[red]Error:[/] Failed to read image: {source}") raise SystemExit(1) @@ -151,18 +223,12 @@ def cmd_visible( # Removal t0 = time.monotonic() + region: tuple[int, int, int, int] | None = None with console.status("[cyan]Removing watermark…[/]"): result = engine.remove_watermark(image) if inpaint: - if det.confidence > 0.15: - region = det.region - else: - from remove_ai_watermarks.gemini_engine import get_watermark_config - - config = get_watermark_config(w, h) - pos = config.get_position(w, h) - region = (pos[0], pos[1], config.logo_size, config.logo_size) + region = _watermark_region(det, w, h) result = engine.inpaint_residual( result, region, @@ -172,9 +238,9 @@ def cmd_visible( elapsed = time.monotonic() - t0 - # Save + # Save (preserves transparency by clearing alpha in the watermark region) output.parent.mkdir(parents=True, exist_ok=True) - cv2.imwrite(str(output), result) + _write_bgr_with_alpha(output, result, alpha, clear_region=region) # Strip metadata if strip_metadata: @@ -435,9 +501,7 @@ def cmd_all( If invisible watermark deps are not installed, skips step 2 with a warning. """ - import cv2 - - from remove_ai_watermarks.gemini_engine import GeminiEngine, get_watermark_config + from remove_ai_watermarks.gemini_engine import GeminiEngine _banner() source = _validate_image(source) @@ -461,7 +525,7 @@ def cmd_all( # ── Step 1: Visible watermark ──────────────────────────────── console.print("\n [bold cyan]① Visible watermark removal[/]") engine = GeminiEngine() - image = cv2.imread(str(source), cv2.IMREAD_COLOR) + image, alpha = _read_bgr_and_alpha(source) if image is None: console.print(f"[red]Error:[/] Failed to read image: {source}") raise SystemExit(1) @@ -469,25 +533,21 @@ def cmd_all( h, w = image.shape[:2] console.print(f" [dim]Input:[/] {source.name} ({w}x{h})") + region: tuple[int, int, int, int] | None = None with console.status("[cyan]Removing visible watermark…[/]"): det = engine.detect_watermark(image) if det.detected: result = engine.remove_watermark(image) if inpaint: - if det.confidence > 0.15: - region = det.region - else: - config = get_watermark_config(w, h) - pos = config.get_position(w, h) - region = (pos[0], pos[1], config.logo_size, config.logo_size) + region = _watermark_region(det, w, h) result = engine.inpaint_residual(result, region, method=inpaint_method) console.print(" [green]✓[/] Visible watermark removed") else: result = image.copy() console.print(" [dim]Skipped (no visible watermark detected)[/]") - # Save to temp file for invisible engine input - cv2.imwrite(str(tmp_path), result) + # Save to temp file for invisible engine input (preserve alpha if present) + _write_bgr_with_alpha(tmp_path, result, alpha, clear_region=region) # ── Step 2: Invisible watermark ────────────────────────────── console.print("\n [bold cyan]② Invisible watermark removal[/]") @@ -536,10 +596,15 @@ def cmd_all( console.print(f" [yellow]⚠[/] Metadata strip failed: {e}") # ── Write final result ──────────────────────────────────────── - import shutil - + # The invisible step (and downstream cv2.IMREAD_COLOR paths) drops alpha, + # so re-attach the original alpha (with the watermark region cleared) + # when writing the final output for transparent formats. output.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(tmp_path), str(output)) + final_bgr, _ = _read_bgr_and_alpha(tmp_path) + if final_bgr is None: + console.print(f"[red]Error:[/] Failed to read intermediate file: {tmp_path}") + raise SystemExit(1) + _write_bgr_with_alpha(output, final_bgr, alpha, clear_region=region) finally: # Clean up temp file if it still exists @@ -577,13 +642,11 @@ def _process_batch_image( Raises: ValueError: If the image cannot be opened. """ - if mode in ("visible", "all"): - import cv2 + saved_alpha: np.ndarray | None = None + saved_region: tuple[int, int, int, int] | None = None - from remove_ai_watermarks.gemini_engine import ( - GeminiEngine, - get_watermark_config, - ) + if mode in ("visible", "all"): + from remove_ai_watermarks.gemini_engine import GeminiEngine if "_vis_engine" not in ctx.obj: ctx.obj["_vis_engine"] = GeminiEngine() @@ -591,27 +654,24 @@ def _process_batch_image( read_path = img_path if mode == "all" and out_path.exists(): read_path = out_path - image = cv2.imread(str(read_path), cv2.IMREAD_COLOR) + image, alpha = _read_bgr_and_alpha(read_path) if image is None: raise ValueError("Failed to read image") + region: tuple[int, int, int, int] | None = None det = engine.detect_watermark(image) if det.detected: result = engine.remove_watermark(image) if inpaint: - if det.confidence > 0.15: - region = det.region - else: - h, w = image.shape[:2] - config = get_watermark_config(w, h) - pos = config.get_position(w, h) - region = (pos[0], pos[1], config.logo_size, config.logo_size) - + h, w = image.shape[:2] + region = _watermark_region(det, w, h) result = engine.inpaint_residual(result, region) else: result = image.copy() - cv2.imwrite(str(out_path), result) + _write_bgr_with_alpha(out_path, result, alpha, clear_region=region) + saved_alpha = alpha + saved_region = region if mode in ("invisible", "all"): from remove_ai_watermarks.invisible_engine import ( @@ -642,6 +702,13 @@ def _process_batch_image( remove_ai_metadata(img_path if mode == "metadata" else out_path, out_path) + # In "all" mode, the invisible step (color-only OpenCV paths) drops alpha, + # so re-attach the cached alpha when the input had transparency. + if mode == "all" and saved_alpha is not None: + final_bgr, _ = _read_bgr_and_alpha(out_path) + if final_bgr is not None: + _write_bgr_with_alpha(out_path, final_bgr, saved_alpha, clear_region=saved_region) + @main.command("batch") @click.argument("directory", type=click.Path(exists=True, file_okay=False, path_type=Path)) diff --git a/tests/test_cli.py b/tests/test_cli.py index c8afaed..ecb1431 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -72,6 +72,27 @@ def _mock_invisible_engine(): return mock_cls, mock_engine +def _mock_invisible_engine_drops_alpha(): + """Mock InvisibleEngine that mimics the real engine's BGR-only output path. + + The real diffusion-based engine reads with cv2.IMREAD_COLOR and writes a + 3-channel result. This mock simulates that so we can regression-test alpha + preservation across the ``all`` pipeline. + """ + + def _mock_remove_watermark(image_path, output_path=None, **kwargs): + out = output_path or image_path.with_stem(image_path.stem + "_clean") + out.parent.mkdir(parents=True, exist_ok=True) + bgr = cv2.imread(str(image_path), cv2.IMREAD_COLOR) + cv2.imwrite(str(out), bgr) + return out + + mock_engine = MagicMock() + mock_engine.remove_watermark.side_effect = _mock_remove_watermark + mock_cls = MagicMock(return_value=mock_engine) + return mock_cls, mock_engine + + class TestMainGroup: """Tests for the top-level CLI group.""" @@ -144,6 +165,73 @@ class TestVisibleCommand: result = runner.invoke(main, ["visible", "/nonexistent/file.png"]) assert result.exit_code != 0 + def test_visible_preserves_rgba_transparency(self, runner, tmp_path): + """Visible removal on an RGBA PNG must keep the alpha channel, + not silently flatten the image onto an opaque background. + """ + rgba = np.zeros((200, 200, 4), dtype=np.uint8) + rgba[:, :, :3] = 200 # light grey foreground + rgba[50:150, 50:150, 3] = 255 # opaque square in the middle, rest transparent + src = tmp_path / "rgba_in.png" + cv2.imwrite(str(src), rgba) + + output = tmp_path / "rgba_out.png" + result = runner.invoke( + main, + ["visible", str(src), "-o", str(output), "--no-detect"], + ) + + assert result.exit_code == 0, result.output + assert output.exists() + + out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED) + assert out.ndim == 3, f"output is not 3D: shape={out.shape}" + assert out.shape[2] == 4, f"output is not RGBA: shape={out.shape}" + # The transparent corners must remain transparent. + assert out[0, 0, 3] == 0 + assert out[199, 199, 3] == 0 + # The opaque centre remains opaque (the watermark region default is bottom-right, + # which doesn't overlap the centre square at 200x200). + assert out[100, 100, 3] == 255 + + def test_visible_clears_alpha_in_watermark_region(self, runner, tmp_path): + """When inpainting an RGBA image, the watermark region must be cleared + in the alpha channel so the sparkle area becomes transparent, not opaque-black. + """ + rgba = np.full((200, 200, 4), 255, dtype=np.uint8) # fully opaque white + src = tmp_path / "rgba_full.png" + cv2.imwrite(str(src), rgba) + + output = tmp_path / "rgba_cleared.png" + result = runner.invoke( + main, + ["visible", str(src), "-o", str(output), "--no-detect"], + ) + + assert result.exit_code == 0, result.output + out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED) + assert out.shape[2] == 4 + # Default sparkle position is in the bottom-right; alpha there must be 0. + from remove_ai_watermarks.gemini_engine import get_watermark_config + + cfg = get_watermark_config(200, 200) + px, py = cfg.get_position(200, 200) + size = cfg.logo_size + assert out[py + size // 2, px + size // 2, 3] == 0, "alpha in the watermark region was not cleared" + + def test_visible_rgb_input_stays_rgb(self, runner, sample_png, tmp_path): + """Regression: a plain RGB PNG must NOT gain a spurious alpha channel.""" + output = tmp_path / "rgb_out.png" + result = runner.invoke( + main, + ["visible", str(sample_png), "-o", str(output), "--no-detect"], + ) + + assert result.exit_code == 0, result.output + out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED) + assert out.ndim == 3, f"output is not 3D: shape={out.shape}" + assert out.shape[2] == 3, f"RGB input produced non-RGB output: shape={out.shape}" + class TestInvisibleCommand: """Tests for the 'invisible' subcommand.""" @@ -210,6 +298,33 @@ class TestAllCommand: result = runner.invoke(main, ["all", "/nonexistent/file.png"]) assert result.exit_code != 0 + def test_all_preserves_rgba_across_invisible_step(self, runner, tmp_path): + """Regression: ``all`` must keep transparency even when the invisible + step writes a 3-channel result (as the real diffusion engine does). + """ + rgba = np.zeros((200, 200, 4), dtype=np.uint8) + rgba[:, :, :3] = 200 + rgba[50:150, 50:150, 3] = 255 # opaque square; corners transparent + src = tmp_path / "rgba_in.png" + cv2.imwrite(str(src), rgba) + + output = tmp_path / "rgba_out.png" + mock_cls, _engine = _mock_invisible_engine_drops_alpha() + with ( + patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), + patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls), + patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True), + patch("remove_ai_watermarks.invisible_engine.is_available", return_value=True), + ): + result = runner.invoke(main, ["all", str(src), "-o", str(output)]) + + assert result.exit_code == 0, result.output + out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED) + assert out.ndim == 3, f"output not 3D: shape={out.shape}" + assert out.shape[2] == 4, f"output is not RGBA: shape={out.shape}" + assert out[0, 0, 3] == 0 + assert out[100, 100, 3] == 255 + class TestMetadataCommand: """Tests for the 'metadata' subcommand.""" @@ -350,6 +465,37 @@ class TestBatchCommand: assert result.exit_code == 0, result.output assert "3 processed" in result.output + def test_batch_all_mode_preserves_rgba(self, runner, tmp_path): + """Regression: batch ``all`` must keep transparency across the + alpha-dropping invisible step (mirrors test_all_preserves_rgba_...). + """ + input_dir = tmp_path / "input" + input_dir.mkdir() + rgba = np.zeros((200, 200, 4), dtype=np.uint8) + rgba[:, :, :3] = 200 + rgba[50:150, 50:150, 3] = 255 + cv2.imwrite(str(input_dir / "rgba.png"), rgba) + + output_dir = tmp_path / "output" + mock_cls, _engine = _mock_invisible_engine_drops_alpha() + with ( + patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), + patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls), + patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True), + patch("remove_ai_watermarks.invisible_engine.is_available", return_value=True), + ): + result = runner.invoke( + main, + ["batch", str(input_dir), "-o", str(output_dir), "--mode", "all"], + ) + assert result.exit_code == 0, result.output + + out = cv2.imread(str(output_dir / "rgba.png"), cv2.IMREAD_UNCHANGED) + assert out.ndim == 3, f"output not 3D: shape={out.shape}" + assert out.shape[2] == 4, f"output is not RGBA: shape={out.shape}" + assert out[0, 0, 3] == 0 + assert out[100, 100, 3] == 255 + def test_batch_default_output_dir(self, runner, tmp_path): input_dir = _make_batch_dir(tmp_path) result = runner.invoke(