mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-05-27 06:32:24 +02:00
fix(cli): preserve alpha channel in visible-watermark pipeline
`cv2.imread(..., IMREAD_COLOR)` was silently stripping the alpha channel on RGBA inputs, and `cv2.imwrite` then wrote opaque 3-channel PNGs — so images with transparent backgrounds came back with an opaque-black (or white) background and the sparkle area baked in as a solid blob. Read the source with `IMREAD_UNCHANGED`, keep the alpha plane out of the detection/inpaint path (those still operate on BGR), and rejoin alpha at save time. The detected watermark bbox is also zeroed in the alpha plane so the sparkle region becomes transparent rather than an opaque artifact. Applies to `visible`, `all`, and `batch` modes. RGB-only inputs and JPEG outputs are unaffected.
This commit is contained in:
+112
-45
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
@@ -21,6 +22,11 @@ from rich.table import Table
|
||||
|
||||
from remove_ai_watermarks import __version__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import DetectionResult
|
||||
|
||||
console = Console()
|
||||
|
||||
SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
@@ -56,6 +62,74 @@ def _validate_image(path: Path) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
_ALPHA_FORMATS = {".png", ".webp"}
|
||||
|
||||
|
||||
def _watermark_region(det: DetectionResult, width: int, height: int) -> tuple[int, int, int, int]:
|
||||
"""Pick a watermark bbox: detector's region if confident, else the default config slot."""
|
||||
if det.confidence > 0.15:
|
||||
return det.region
|
||||
from remove_ai_watermarks.gemini_engine import get_watermark_config
|
||||
|
||||
config = get_watermark_config(width, height)
|
||||
px, py = config.get_position(width, height)
|
||||
return (px, py, config.logo_size, config.logo_size)
|
||||
|
||||
|
||||
def _read_bgr_and_alpha(path: Path) -> tuple[np.ndarray | None, np.ndarray | None]:
|
||||
"""Read an image preserving its alpha channel separately.
|
||||
|
||||
Returns ``(bgr, alpha)`` where ``alpha`` is a single-channel ndarray when the
|
||||
source has transparency, else ``None``. Greyscale inputs are promoted to BGR.
|
||||
Returns ``(None, None)`` if the image cannot be decoded.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
image = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
|
||||
if image is None:
|
||||
return None, None
|
||||
if image.ndim == 2:
|
||||
return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR), None
|
||||
if image.shape[2] == 4:
|
||||
return image[:, :, :3].copy(), image[:, :, 3].copy()
|
||||
return image, None
|
||||
|
||||
|
||||
def _write_bgr_with_alpha(
|
||||
path: Path,
|
||||
bgr: np.ndarray,
|
||||
alpha: np.ndarray | None,
|
||||
clear_region: tuple[int, int, int, int] | None = None,
|
||||
pad: int = 6,
|
||||
) -> None:
|
||||
"""Write BGR (with optional alpha) to ``path``.
|
||||
|
||||
When ``alpha`` is provided and the output extension supports it, writes a
|
||||
4-channel image. If ``clear_region`` is given as ``(x, y, w, h)``, alpha is
|
||||
forced to 0 inside that bbox (expanded by ``pad`` px) so the watermark area
|
||||
becomes fully transparent in the saved file.
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
if alpha is None or path.suffix.lower() not in _ALPHA_FORMATS:
|
||||
cv2.imwrite(str(path), bgr)
|
||||
return
|
||||
|
||||
alpha_out = alpha
|
||||
if clear_region is not None:
|
||||
alpha_out = alpha.copy()
|
||||
x, y, w, h = clear_region
|
||||
height, width = alpha.shape[:2]
|
||||
x0, y0 = max(0, x - pad), max(0, y - pad)
|
||||
x1, y1 = min(width, x + w + pad), min(height, y + h + pad)
|
||||
if x1 > x0 and y1 > y0:
|
||||
alpha_out[y0:y1, x0:x1] = 0
|
||||
|
||||
bgra = np.dstack([bgr, alpha_out])
|
||||
cv2.imwrite(str(path), bgra)
|
||||
|
||||
|
||||
# ── Main group ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -110,8 +184,6 @@ def cmd_visible(
|
||||
|
||||
Uses reverse alpha blending — fast, deterministic, offline.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
||||
|
||||
_banner()
|
||||
@@ -122,8 +194,8 @@ def cmd_visible(
|
||||
|
||||
engine = GeminiEngine()
|
||||
|
||||
# Load image
|
||||
image = cv2.imread(str(source), cv2.IMREAD_COLOR)
|
||||
# Load image (preserving any alpha channel separately)
|
||||
image, alpha = _read_bgr_and_alpha(source)
|
||||
if image is None:
|
||||
console.print(f"[red]Error:[/] Failed to read image: {source}")
|
||||
raise SystemExit(1)
|
||||
@@ -151,18 +223,12 @@ def cmd_visible(
|
||||
|
||||
# Removal
|
||||
t0 = time.monotonic()
|
||||
region: tuple[int, int, int, int] | None = None
|
||||
with console.status("[cyan]Removing watermark…[/]"):
|
||||
result = engine.remove_watermark(image)
|
||||
|
||||
if inpaint:
|
||||
if det.confidence > 0.15:
|
||||
region = det.region
|
||||
else:
|
||||
from remove_ai_watermarks.gemini_engine import get_watermark_config
|
||||
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
region = (pos[0], pos[1], config.logo_size, config.logo_size)
|
||||
region = _watermark_region(det, w, h)
|
||||
result = engine.inpaint_residual(
|
||||
result,
|
||||
region,
|
||||
@@ -172,9 +238,9 @@ def cmd_visible(
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
|
||||
# Save
|
||||
# Save (preserves transparency by clearing alpha in the watermark region)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(output), result)
|
||||
_write_bgr_with_alpha(output, result, alpha, clear_region=region)
|
||||
|
||||
# Strip metadata
|
||||
if strip_metadata:
|
||||
@@ -435,9 +501,7 @@ def cmd_all(
|
||||
|
||||
If invisible watermark deps are not installed, skips step 2 with a warning.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine, get_watermark_config
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
||||
|
||||
_banner()
|
||||
source = _validate_image(source)
|
||||
@@ -461,7 +525,7 @@ def cmd_all(
|
||||
# ── Step 1: Visible watermark ────────────────────────────────
|
||||
console.print("\n [bold cyan]① Visible watermark removal[/]")
|
||||
engine = GeminiEngine()
|
||||
image = cv2.imread(str(source), cv2.IMREAD_COLOR)
|
||||
image, alpha = _read_bgr_and_alpha(source)
|
||||
if image is None:
|
||||
console.print(f"[red]Error:[/] Failed to read image: {source}")
|
||||
raise SystemExit(1)
|
||||
@@ -469,25 +533,21 @@ def cmd_all(
|
||||
h, w = image.shape[:2]
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}x{h})")
|
||||
|
||||
region: tuple[int, int, int, int] | None = None
|
||||
with console.status("[cyan]Removing visible watermark…[/]"):
|
||||
det = engine.detect_watermark(image)
|
||||
if det.detected:
|
||||
result = engine.remove_watermark(image)
|
||||
if inpaint:
|
||||
if det.confidence > 0.15:
|
||||
region = det.region
|
||||
else:
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
region = (pos[0], pos[1], config.logo_size, config.logo_size)
|
||||
region = _watermark_region(det, w, h)
|
||||
result = engine.inpaint_residual(result, region, method=inpaint_method)
|
||||
console.print(" [green]✓[/] Visible watermark removed")
|
||||
else:
|
||||
result = image.copy()
|
||||
console.print(" [dim]Skipped (no visible watermark detected)[/]")
|
||||
|
||||
# Save to temp file for invisible engine input
|
||||
cv2.imwrite(str(tmp_path), result)
|
||||
# Save to temp file for invisible engine input (preserve alpha if present)
|
||||
_write_bgr_with_alpha(tmp_path, result, alpha, clear_region=region)
|
||||
|
||||
# ── Step 2: Invisible watermark ──────────────────────────────
|
||||
console.print("\n [bold cyan]② Invisible watermark removal[/]")
|
||||
@@ -536,10 +596,15 @@ def cmd_all(
|
||||
console.print(f" [yellow]⚠[/] Metadata strip failed: {e}")
|
||||
|
||||
# ── Write final result ────────────────────────────────────────
|
||||
import shutil
|
||||
|
||||
# The invisible step (and downstream cv2.IMREAD_COLOR paths) drops alpha,
|
||||
# so re-attach the original alpha (with the watermark region cleared)
|
||||
# when writing the final output for transparent formats.
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(tmp_path), str(output))
|
||||
final_bgr, _ = _read_bgr_and_alpha(tmp_path)
|
||||
if final_bgr is None:
|
||||
console.print(f"[red]Error:[/] Failed to read intermediate file: {tmp_path}")
|
||||
raise SystemExit(1)
|
||||
_write_bgr_with_alpha(output, final_bgr, alpha, clear_region=region)
|
||||
|
||||
finally:
|
||||
# Clean up temp file if it still exists
|
||||
@@ -577,13 +642,11 @@ def _process_batch_image(
|
||||
Raises:
|
||||
ValueError: If the image cannot be opened.
|
||||
"""
|
||||
if mode in ("visible", "all"):
|
||||
import cv2
|
||||
saved_alpha: np.ndarray | None = None
|
||||
saved_region: tuple[int, int, int, int] | None = None
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import (
|
||||
GeminiEngine,
|
||||
get_watermark_config,
|
||||
)
|
||||
if mode in ("visible", "all"):
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
||||
|
||||
if "_vis_engine" not in ctx.obj:
|
||||
ctx.obj["_vis_engine"] = GeminiEngine()
|
||||
@@ -591,27 +654,24 @@ def _process_batch_image(
|
||||
read_path = img_path
|
||||
if mode == "all" and out_path.exists():
|
||||
read_path = out_path
|
||||
image = cv2.imread(str(read_path), cv2.IMREAD_COLOR)
|
||||
image, alpha = _read_bgr_and_alpha(read_path)
|
||||
if image is None:
|
||||
raise ValueError("Failed to read image")
|
||||
|
||||
region: tuple[int, int, int, int] | None = None
|
||||
det = engine.detect_watermark(image)
|
||||
if det.detected:
|
||||
result = engine.remove_watermark(image)
|
||||
if inpaint:
|
||||
if det.confidence > 0.15:
|
||||
region = det.region
|
||||
else:
|
||||
h, w = image.shape[:2]
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
region = (pos[0], pos[1], config.logo_size, config.logo_size)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
region = _watermark_region(det, w, h)
|
||||
result = engine.inpaint_residual(result, region)
|
||||
else:
|
||||
result = image.copy()
|
||||
|
||||
cv2.imwrite(str(out_path), result)
|
||||
_write_bgr_with_alpha(out_path, result, alpha, clear_region=region)
|
||||
saved_alpha = alpha
|
||||
saved_region = region
|
||||
|
||||
if mode in ("invisible", "all"):
|
||||
from remove_ai_watermarks.invisible_engine import (
|
||||
@@ -642,6 +702,13 @@ def _process_batch_image(
|
||||
|
||||
remove_ai_metadata(img_path if mode == "metadata" else out_path, out_path)
|
||||
|
||||
# In "all" mode, the invisible step (color-only OpenCV paths) drops alpha,
|
||||
# so re-attach the cached alpha when the input had transparency.
|
||||
if mode == "all" and saved_alpha is not None:
|
||||
final_bgr, _ = _read_bgr_and_alpha(out_path)
|
||||
if final_bgr is not None:
|
||||
_write_bgr_with_alpha(out_path, final_bgr, saved_alpha, clear_region=saved_region)
|
||||
|
||||
|
||||
@main.command("batch")
|
||||
@click.argument("directory", type=click.Path(exists=True, file_okay=False, path_type=Path))
|
||||
|
||||
@@ -72,6 +72,27 @@ def _mock_invisible_engine():
|
||||
return mock_cls, mock_engine
|
||||
|
||||
|
||||
def _mock_invisible_engine_drops_alpha():
|
||||
"""Mock InvisibleEngine that mimics the real engine's BGR-only output path.
|
||||
|
||||
The real diffusion-based engine reads with cv2.IMREAD_COLOR and writes a
|
||||
3-channel result. This mock simulates that so we can regression-test alpha
|
||||
preservation across the ``all`` pipeline.
|
||||
"""
|
||||
|
||||
def _mock_remove_watermark(image_path, output_path=None, **kwargs):
|
||||
out = output_path or image_path.with_stem(image_path.stem + "_clean")
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
bgr = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
|
||||
cv2.imwrite(str(out), bgr)
|
||||
return out
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.remove_watermark.side_effect = _mock_remove_watermark
|
||||
mock_cls = MagicMock(return_value=mock_engine)
|
||||
return mock_cls, mock_engine
|
||||
|
||||
|
||||
class TestMainGroup:
|
||||
"""Tests for the top-level CLI group."""
|
||||
|
||||
@@ -144,6 +165,73 @@ class TestVisibleCommand:
|
||||
result = runner.invoke(main, ["visible", "/nonexistent/file.png"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_visible_preserves_rgba_transparency(self, runner, tmp_path):
|
||||
"""Visible removal on an RGBA PNG must keep the alpha channel,
|
||||
not silently flatten the image onto an opaque background.
|
||||
"""
|
||||
rgba = np.zeros((200, 200, 4), dtype=np.uint8)
|
||||
rgba[:, :, :3] = 200 # light grey foreground
|
||||
rgba[50:150, 50:150, 3] = 255 # opaque square in the middle, rest transparent
|
||||
src = tmp_path / "rgba_in.png"
|
||||
cv2.imwrite(str(src), rgba)
|
||||
|
||||
output = tmp_path / "rgba_out.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["visible", str(src), "-o", str(output), "--no-detect"],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert output.exists()
|
||||
|
||||
out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED)
|
||||
assert out.ndim == 3, f"output is not 3D: shape={out.shape}"
|
||||
assert out.shape[2] == 4, f"output is not RGBA: shape={out.shape}"
|
||||
# The transparent corners must remain transparent.
|
||||
assert out[0, 0, 3] == 0
|
||||
assert out[199, 199, 3] == 0
|
||||
# The opaque centre remains opaque (the watermark region default is bottom-right,
|
||||
# which doesn't overlap the centre square at 200x200).
|
||||
assert out[100, 100, 3] == 255
|
||||
|
||||
def test_visible_clears_alpha_in_watermark_region(self, runner, tmp_path):
|
||||
"""When inpainting an RGBA image, the watermark region must be cleared
|
||||
in the alpha channel so the sparkle area becomes transparent, not opaque-black.
|
||||
"""
|
||||
rgba = np.full((200, 200, 4), 255, dtype=np.uint8) # fully opaque white
|
||||
src = tmp_path / "rgba_full.png"
|
||||
cv2.imwrite(str(src), rgba)
|
||||
|
||||
output = tmp_path / "rgba_cleared.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["visible", str(src), "-o", str(output), "--no-detect"],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED)
|
||||
assert out.shape[2] == 4
|
||||
# Default sparkle position is in the bottom-right; alpha there must be 0.
|
||||
from remove_ai_watermarks.gemini_engine import get_watermark_config
|
||||
|
||||
cfg = get_watermark_config(200, 200)
|
||||
px, py = cfg.get_position(200, 200)
|
||||
size = cfg.logo_size
|
||||
assert out[py + size // 2, px + size // 2, 3] == 0, "alpha in the watermark region was not cleared"
|
||||
|
||||
def test_visible_rgb_input_stays_rgb(self, runner, sample_png, tmp_path):
|
||||
"""Regression: a plain RGB PNG must NOT gain a spurious alpha channel."""
|
||||
output = tmp_path / "rgb_out.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["visible", str(sample_png), "-o", str(output), "--no-detect"],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED)
|
||||
assert out.ndim == 3, f"output is not 3D: shape={out.shape}"
|
||||
assert out.shape[2] == 3, f"RGB input produced non-RGB output: shape={out.shape}"
|
||||
|
||||
|
||||
class TestInvisibleCommand:
|
||||
"""Tests for the 'invisible' subcommand."""
|
||||
@@ -210,6 +298,33 @@ class TestAllCommand:
|
||||
result = runner.invoke(main, ["all", "/nonexistent/file.png"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_all_preserves_rgba_across_invisible_step(self, runner, tmp_path):
|
||||
"""Regression: ``all`` must keep transparency even when the invisible
|
||||
step writes a 3-channel result (as the real diffusion engine does).
|
||||
"""
|
||||
rgba = np.zeros((200, 200, 4), dtype=np.uint8)
|
||||
rgba[:, :, :3] = 200
|
||||
rgba[50:150, 50:150, 3] = 255 # opaque square; corners transparent
|
||||
src = tmp_path / "rgba_in.png"
|
||||
cv2.imwrite(str(src), rgba)
|
||||
|
||||
output = tmp_path / "rgba_out.png"
|
||||
mock_cls, _engine = _mock_invisible_engine_drops_alpha()
|
||||
with (
|
||||
patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls),
|
||||
patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.is_available", return_value=True),
|
||||
):
|
||||
result = runner.invoke(main, ["all", str(src), "-o", str(output)])
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
out = cv2.imread(str(output), cv2.IMREAD_UNCHANGED)
|
||||
assert out.ndim == 3, f"output not 3D: shape={out.shape}"
|
||||
assert out.shape[2] == 4, f"output is not RGBA: shape={out.shape}"
|
||||
assert out[0, 0, 3] == 0
|
||||
assert out[100, 100, 3] == 255
|
||||
|
||||
|
||||
class TestMetadataCommand:
|
||||
"""Tests for the 'metadata' subcommand."""
|
||||
@@ -350,6 +465,37 @@ class TestBatchCommand:
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "3 processed" in result.output
|
||||
|
||||
def test_batch_all_mode_preserves_rgba(self, runner, tmp_path):
|
||||
"""Regression: batch ``all`` must keep transparency across the
|
||||
alpha-dropping invisible step (mirrors test_all_preserves_rgba_...).
|
||||
"""
|
||||
input_dir = tmp_path / "input"
|
||||
input_dir.mkdir()
|
||||
rgba = np.zeros((200, 200, 4), dtype=np.uint8)
|
||||
rgba[:, :, :3] = 200
|
||||
rgba[50:150, 50:150, 3] = 255
|
||||
cv2.imwrite(str(input_dir / "rgba.png"), rgba)
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
mock_cls, _engine = _mock_invisible_engine_drops_alpha()
|
||||
with (
|
||||
patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls),
|
||||
patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.is_available", return_value=True),
|
||||
):
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["batch", str(input_dir), "-o", str(output_dir), "--mode", "all"],
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
|
||||
out = cv2.imread(str(output_dir / "rgba.png"), cv2.IMREAD_UNCHANGED)
|
||||
assert out.ndim == 3, f"output not 3D: shape={out.shape}"
|
||||
assert out.shape[2] == 4, f"output is not RGBA: shape={out.shape}"
|
||||
assert out[0, 0, 3] == 0
|
||||
assert out[100, 100, 3] == 255
|
||||
|
||||
def test_batch_default_output_dir(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
result = runner.invoke(
|
||||
|
||||
Reference in New Issue
Block a user