Add project files, tests, and documentation for GitHub release

- CLI with visible, invisible, all, metadata, and batch commands
- Gemini watermark removal via reverse alpha blending
- Invisible watermark removal via diffusion regeneration (SynthID, TreeRing)
- AI metadata stripping (EXIF, PNG text, C2PA)
- Face protection (YOLO/Haar) and analog humanizer
- 137 tests covering all CLI modes and core engines
- Ruff and Pyright clean
This commit is contained in:
test-user
2026-03-25 11:15:05 -07:00
parent 3055f1ae46
commit e5d8970add
51 changed files with 8859 additions and 0 deletions
+65
View File
@@ -0,0 +1,65 @@
---
trigger: always_on
description: Project rules for remove-ai-watermarks
---
# Project rules
## Role
You are a **principal Python engineer** working on `remove-ai-watermarks` — a CLI tool for removing visible and invisible AI watermarks from images.
## Python environment
- This project uses `uv` for Python package management.
- Use `uv pip install` instead of `pip install`.
- Use `uv run` to run Python scripts (e.g., `uv run python scripts/example.py`).
- Project uses **src-layout**: source code is in `src/remove_ai_watermarks/`.
## Project structure
- `src/remove_ai_watermarks/` — main package (CLI, engines, metadata)
- `src/remove_ai_watermarks/noai/` — vendored noai-watermark code (invisible watermark removal)
- `src/remove_ai_watermarks/noai/ctrlregen/` — vendored CtrlRegen pipeline
- `src/remove_ai_watermarks/assets/` — embedded watermark alpha maps (PNG)
- `tests/` — pytest test suite
- `data/samples/` — sample images for manual testing
## Testing
- Run tests: `uv run pytest`
- Run with coverage: `uv run pytest --cov=remove_ai_watermarks --cov-report=term-missing`
- Tests create temporary files via `tmp_path` — they do NOT depend on `data/samples/`.
- ML pipeline modules (invisible_engine, ctrlregen, watermark_remover) require GPU and multi-GB model downloads — unit tests for these are limited to availability checks and constants.
## Git
- Do not create commits or push unless explicitly asked by the user.
## Code quality
- Always run `./maintain.sh` before committing to ensure code quality (ruff, pyright).
- **CRITICAL**: Before using any method on a client or class, ALWAYS verify that the method exists by reading the class file first.
- **NEVER** assume or invent methods. If the method doesn't exist, either use an existing alternative method or explicitly create the new method first.
## Documentation
- Update documentation (README.md) when you change functionality or add new features.
- **DO NOT** create artifact documentation (walkthrough.md, verification.md) for bug fixes or small corrections.
## Language
- Use only English for all code, comments, docstrings, documentation, commit messages, and project artifacts.
- Communicate with users in Russian, but all technical content must be in English.
## API integrations
- Do not assume or invent API request/response structures.
- Always verify API payloads against official documentation before implementing.
- Use Context7 MCP to retrieve up-to-date library documentation when working with external APIs or packages.
## Work completion
- You have **NO time limits**. Always complete the full task in one go.
- Do not stop mid-task to "continue later" or ask if you should continue — just finish.
- Do not split work into multiple commits unless the task is genuinely large.
+3
View File
@@ -0,0 +1,3 @@
# HuggingFace token (required for invisible watermark removal)
# Get yours at: https://huggingface.co/settings/tokens
HF_TOKEN=
+7
View File
@@ -1,2 +1,9 @@
# Auto detect text files and perform LF normalization
* text=auto
# Binary files
*.png binary
*.jpg binary
*.jpeg binary
*.webp binary
*.pt binary
+29
View File
@@ -0,0 +1,29 @@
# Dependencies
.venv/
__pycache__/
*.egg-info/
dist/
build/
# Environment secrets
.env
# OS files
.DS_Store
Thumbs.db
# IDE
.idea/
.vscode/
*.swp
*.swo
# Test results
data/results/
# Reference materials
_refs/
# Downloaded model weights
yolov8n.pt
.coverage
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 wiltodelta
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+147
View File
@@ -0,0 +1,147 @@
# Remove-AI-Watermarks
Unified tool for removing **visible** and **invisible** AI watermarks from images.
## Features
- **Visible watermark removal** — Gemini sparkle logo via reverse alpha blending (fast, offline, deterministic)
- **Invisible watermark removal** — SynthID, StableSignature, TreeRing via diffusion-based regeneration
- **AI metadata stripping** — EXIF, PNG text chunks, C2PA provenance manifests
- **Analog Humanizer** — film grain and chromatic aberration injection to bypass AI image classifiers
- **Smart Face Protection** — automatic extraction and blending of human faces to prevent AI distortion
- **High-Res Upscaler** — prevents resolution degradation during invisible watermark removal
- **Batch processing** — process entire directories
- **Detection** — three-stage NCC watermark detection with confidence scoring
## Examples
| Before (Watermarked) | After (Cleaned) |
| --- | --- |
| ![Before](demo_banana_before.png) | ![After](demo_banana_after.png) |
## Installation
### Recommended (macOS)
Install as an isolated CLI tool — no need to manage virtual environments:
```bash
# Using pipx (brew install pipx)
pipx install git+https://github.com/wiltodelta/remove-ai-watermarks.git
# Or using uv (brew install uv)
uv tool install git+https://github.com/wiltodelta/remove-ai-watermarks.git
```
### Install from repository (macOS)
**Prerequisites:** Python 3.10+ and `pip` (or [`uv`](https://docs.astral.sh/uv/)).
```bash
# 1. Clone the repository
git clone https://github.com/wiltodelta/remove-ai-watermarks.git
cd remove-ai-watermarks
# 2. Install the package in editable mode
pip install -e .
# Or, if you use uv:
uv pip install -e .
```
After installation the `remove-ai-watermarks` command is available system-wide.
#### Invisible watermark removal (optional)
Invisible removal uses diffusion models and requires a **HuggingFace token** and a decent GPU (CUDA) or Apple Silicon (MPS).
```bash
# 1. Create a free token at https://huggingface.co/settings/tokens
# 2. Copy the example env file and paste your token
cp .env.example .env
# Edit .env and set HF_TOKEN=hf_your_token_here
# 3. On first run, the model (~2 GB) will be downloaded automatically.
# On macOS with Apple Silicon, MPS acceleration is used by default.
# On macOS without GPU, add --device cpu (inference will be slow).
```
#### Developer setup
```bash
# Install with dev dependencies (pytest, ruff, pyright)
pip install -e ".[dev]"
# Or with uv:
uv pip install -e ".[dev]"
# Run tests
pytest
# Run linters
./maintain.sh
```
## Usage
### CLI
```bash
# Remove visible Gemini watermark
remove-ai-watermarks visible image.png -o clean.png
# Remove invisible watermarks (SynthID etc.) with optimal quality retention
remove-ai-watermarks invisible image.png -o clean.png --humanize 4.0
# Strip AI metadata
remove-ai-watermarks metadata image.png --check
remove-ai-watermarks metadata image.png --remove
# Batch processing
remove-ai-watermarks batch ./images/ --mode visible
# Full pipeline: visible + invisible + metadata
remove-ai-watermarks all image.png -o clean.png
```
### Python API
```python
from remove_ai_watermarks.gemini_engine import GeminiEngine
import cv2
engine = GeminiEngine()
image = cv2.imread("watermarked.png")
# Detect
result = engine.detect_watermark(image)
print(f"Detected: {result.detected} (confidence: {result.confidence:.1%})")
# Remove
clean = engine.remove_watermark(image)
cv2.imwrite("clean.png", clean)
```
### Metadata stripping
```python
from remove_ai_watermarks.metadata import has_ai_metadata, remove_ai_metadata
from pathlib import Path
if has_ai_metadata(Path("image.png")):
remove_ai_metadata(Path("image.png"), Path("clean.png"))
```
## Requirements
- Python ≥ 3.10
- **Visible removal / metadata**: CPU only, no GPU required
- **Invisible removal**: GPU recommended (CUDA or MPS), works on CPU (slow)
## Credits
- [noai-watermark](https://github.com/mertizci/noai-watermark) by mertizci — invisible watermark removal
- [GeminiWatermarkTool](https://github.com/allenk/GeminiWatermarkTool) by Allen Kuo — visible watermark removal algorithm
## License
MIT
Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 MiB

Executable
+9
View File
@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -euo pipefail
uv sync --all-extras
uv-outdated
uv run uv-secure --ignore-unfixed
uv run ruff check --fix
uv run ruff format
uv run pyright
+73
View File
@@ -0,0 +1,73 @@
[project]
name = "remove-ai-watermarks"
version = "0.1.0"
description = "Unified tool for removing visible and invisible AI watermarks from images"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "MIT"}
dependencies = [
"pillow>=10.0.0",
"piexif>=1.1.3",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"click>=8.0.0",
"rich>=13.0.0",
"torch>=2.0.0",
"diffusers>=0.25.0",
"transformers>=4.35.0",
"accelerate>=0.25.0",
"controlnet-aux>=0.0.9",
"color-matcher",
"safetensors",
"python-dotenv>=1.0.0",
"ultralytics>=8.0.0",
"requests>=2.33.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-cov>=4.1.0",
"ruff>=0.4.0",
]
all = ["remove-ai-watermarks[dev]"]
[project.scripts]
remove-ai-watermarks = "remove_ai_watermarks.cli:main"
[project.urls]
Repository = "https://github.com/wiltodelta/remove-ai-watermarks"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/remove_ai_watermarks"]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["src"]
addopts = "-v --tb=short"
[tool.ruff]
target-version = "py310"
line-length = 120
[tool.pyright]
exclude = ["tests"]
reportOptionalMemberAccess = "warning"
reportPrivateImportUsage = false
reportInvalidTypeForm = false
reportOptionalCall = false
reportMissingImports = "warning"
reportCallIssue = "warning"
reportAttributeAccessIssue = "warning"
reportArgumentType = "warning"
reportPossiblyUnboundVariable = "warning"
reportAssignmentType = "warning"
reportOperatorIssue = "warning"
[tool.ruff.lint]
select = ["E", "F", "W", "I", "N", "UP"]
ignore = ["E501"]
+3
View File
@@ -0,0 +1,3 @@
"""Remove-AI-Watermarks: Unified tool for removing visible and invisible AI watermarks."""
__version__ = "0.1.0"
@@ -0,0 +1 @@
"""Embedded assets for visible watermark removal."""
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

+598
View File
@@ -0,0 +1,598 @@
"""Unified CLI for remove-ai-watermarks.
Provides commands for:
- Visible watermark removal (Gemini sparkle) — works offline, fast
- Invisible watermark removal (SynthID etc.) — requires GPU/diffusion models
- AI metadata stripping — lightweight, no ML deps needed
"""
from __future__ import annotations
import logging
import time
from pathlib import Path
import click
from rich.console import Console
from rich.panel import Panel
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.table import Table
from remove_ai_watermarks import __version__
console = Console()
SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp"}
def _setup_logging(verbose: bool) -> None:
level = logging.DEBUG if verbose else logging.WARNING
logging.basicConfig(
level=level,
format="%(name)s | %(message)s",
handlers=[logging.StreamHandler()],
)
def _banner() -> None:
console.print(
Panel(
f"[bold cyan]Remove-AI-Watermarks[/] [dim]v{__version__}[/]\n[dim]Visible & invisible watermark removal[/]",
border_style="cyan",
padding=(0, 2),
)
)
def _validate_image(path: Path) -> Path:
if not path.exists():
console.print(f"[red]Error:[/] File not found: {path}")
raise SystemExit(1)
if path.suffix.lower() not in SUPPORTED_FORMATS:
console.print(
f"[yellow]Warning:[/] {path.suffix} may not be supported (expected: {', '.join(SUPPORTED_FORMATS)})"
)
return path
# ── Main group ───────────────────────────────────────────────────────
@click.group(invoke_without_command=True)
@click.version_option(__version__, prog_name="remove-ai-watermarks")
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging.")
@click.pass_context
def main(ctx: click.Context, verbose: bool) -> None:
"""Remove visible and invisible AI watermarks from images."""
from dotenv import load_dotenv
load_dotenv() # Load .env (e.g. HF_TOKEN)
ctx.ensure_object(dict)
ctx.obj["verbose"] = verbose
_setup_logging(verbose)
if ctx.invoked_subcommand is None:
_banner()
click.echo(ctx.get_help())
# ── Visible (Gemini) watermark removal ───────────────────────────────
@main.command("visible")
@click.argument("source", type=click.Path(exists=True, path_type=Path))
@click.option(
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
)
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting cleanup after removal.")
@click.option(
"--inpaint-method", type=click.Choice(["ns", "telea", "gaussian"]), default="ns", help="Inpainting method."
)
@click.option("--inpaint-strength", type=float, default=0.85, help="Inpainting blend strength (0.01.0).")
@click.option("--detect/--no-detect", default=True, help="Detect watermark before removal.")
@click.option("--detect-threshold", type=float, default=0.25, help="Detection confidence threshold.")
@click.option("--strip-metadata/--keep-metadata", default=True, help="Strip AI metadata from output.")
@click.pass_context
def cmd_visible(
ctx: click.Context,
source: Path,
output: Path | None,
inpaint: bool,
inpaint_method: str,
inpaint_strength: float,
detect: bool,
detect_threshold: float,
strip_metadata: bool,
) -> None:
"""Remove visible Gemini watermark (sparkle logo) from an image.
Uses reverse alpha blending — fast, deterministic, offline.
"""
import cv2
from remove_ai_watermarks.gemini_engine import GeminiEngine
_banner()
source = _validate_image(source)
if output is None:
output = source.with_stem(source.stem + "_clean")
engine = GeminiEngine()
# Load image
image = cv2.imread(str(source), cv2.IMREAD_COLOR)
if image is None:
console.print(f"[red]Error:[/] Failed to read image: {source}")
raise SystemExit(1)
h, w = image.shape[:2]
console.print(f" [dim]Input:[/] {source.name} ({w}×{h})")
# Detection (we always detect softly, to find dynamic region for inpainting)
with console.status("[cyan]Detecting watermark…[/]"):
det = engine.detect_watermark(image)
if detect:
if det.detected:
console.print(
f" [green]✓[/] Watermark detected "
f"[dim](confidence: {det.confidence:.1%}, "
f"spatial: {det.spatial_score:.3f}, "
f"gradient: {det.gradient_score:.3f})[/]"
)
else:
console.print(f" [yellow]⚠[/] Watermark not detected [dim](confidence: {det.confidence:.1%})[/]")
if det.confidence < detect_threshold:
console.print(" [dim]Skipping. Use --no-detect to force removal.[/]")
raise SystemExit(0)
# Removal
t0 = time.monotonic()
with console.status("[cyan]Removing watermark…[/]"):
result = engine.remove_watermark(image)
if inpaint:
if det.confidence > 0.15:
region = det.region
else:
from remove_ai_watermarks.gemini_engine import get_watermark_config
config = get_watermark_config(w, h)
pos = config.get_position(w, h)
region = (pos[0], pos[1], config.logo_size, config.logo_size)
result = engine.inpaint_residual(
result,
region,
strength=inpaint_strength,
method=inpaint_method,
)
elapsed = time.monotonic() - t0
# Save
output.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(output), result)
# Strip metadata
if strip_metadata:
try:
from remove_ai_watermarks.metadata import remove_ai_metadata
remove_ai_metadata(output, output)
except Exception:
pass
size_kb = output.stat().st_size / 1024
console.print(f" [green]✓[/] Saved: {output} [dim]({size_kb:.0f} KB, {elapsed:.2f}s)[/]")
# ── Invisible watermark removal ─────────────────────────────────────
@main.command("invisible")
@click.argument("source", type=click.Path(exists=True, path_type=Path))
@click.option(
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
)
@click.option("--strength", type=float, default=0.02, help="Denoising strength (0.01.0). Default: 0.02.")
@click.option("--steps", type=int, default=100, help="Number of denoising steps. Default: 100.")
@click.option("--pipeline", type=click.Choice(["default", "ctrlregen"]), default="default", help="Pipeline profile.")
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
@click.option("--humanize", type=float, default=0.0, help="Humanization strength (0.01.0) for invisible removal.")
@click.pass_context
def cmd_invisible(
ctx: click.Context,
source: Path,
output: Path | None,
strength: float,
steps: int,
pipeline: str,
device: str,
seed: int | None,
hf_token: str | None,
humanize: float,
) -> None:
"""Remove invisible AI watermarks (SynthID, StableSignature, TreeRing).
Uses diffusion-based regeneration. Requires GPU for reasonable speed.
"""
from remove_ai_watermarks.invisible_engine import InvisibleEngine
source = _validate_image(source)
if output is None:
output = source.with_stem(source.stem + "_clean")
device_str = None if device == "auto" else device
def progress_cb(msg: str) -> None:
console.print(f" [dim]{msg}[/]")
engine = InvisibleEngine(
device=device_str,
pipeline=pipeline,
hf_token=hf_token,
progress_callback=progress_cb,
)
console.print(f" [dim]Input:[/] {source.name}")
console.print(f" [dim]Pipeline:[/] {pipeline}")
console.print(f" [dim]Strength:[/] {strength} Steps: {steps}")
t0 = time.monotonic()
result_path = engine.remove_watermark(
image_path=source,
output_path=output,
strength=strength,
num_inference_steps=steps,
guidance_scale=None,
seed=seed,
humanize=humanize,
)
elapsed = time.monotonic() - t0
size_kb = result_path.stat().st_size / 1024
console.print(f"\n [green]✓[/] Saved: {result_path} [dim]({size_kb:.0f} KB, {elapsed:.1f}s)[/]")
# ── Metadata operations ─────────────────────────────────────────────
@main.command("metadata")
@click.argument("source", type=click.Path(exists=True, path_type=Path))
@click.option("--check", is_flag=True, help="Check for AI metadata (don't modify).")
@click.option("--remove", is_flag=True, help="Remove AI metadata.")
@click.option(
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: overwrite source)."
)
@click.option("--keep-standard/--remove-all", default=True, help="Keep standard metadata (Author, Title, etc.).")
@click.pass_context
def cmd_metadata(
ctx: click.Context,
source: Path,
check: bool,
remove: bool,
output: Path | None,
keep_standard: bool,
) -> None:
"""Check or remove AI-generation metadata from images.
Strips EXIF AI tags, PNG text chunks, and C2PA provenance manifests.
"""
from remove_ai_watermarks.metadata import get_ai_metadata, has_ai_metadata, remove_ai_metadata
_banner()
source = _validate_image(source)
if check or (not remove):
has_ai = has_ai_metadata(source)
if has_ai:
console.print(f" [yellow]⚠[/] AI metadata detected in {source.name}:")
meta = get_ai_metadata(source)
table = Table(show_header=True, header_style="bold")
table.add_column("Key", style="cyan")
table.add_column("Value")
for k, v in meta.items():
table.add_row(k, str(v)[:80])
console.print(table)
else:
console.print(f" [green]✓[/] No AI metadata found in {source.name}")
if not remove:
return
# Remove
out = remove_ai_metadata(source, output, keep_standard=keep_standard)
console.print(f" [green]✓[/] AI metadata stripped → {out}")
# ── Combined "all" mode ──────────────────────────────────────────────
@main.command("all")
@click.argument("source", type=click.Path(exists=True, path_type=Path))
@click.option(
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
)
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting cleanup after visible removal.")
@click.option(
"--inpaint-method", type=click.Choice(["ns", "telea", "gaussian"]), default="ns", help="Inpainting method."
)
@click.option("--strength", type=float, default=0.04, help="Invisible watermark denoising strength (0.01.0).")
@click.option("--steps", type=int, default=50, help="Number of denoising steps for invisible removal.")
@click.option(
"--pipeline",
type=click.Choice(["default", "ctrlregen"]),
default="default",
help="Pipeline profile for invisible removal.",
)
@click.option("--model", type=str, default=None, help="HuggingFace model ID for invisible removal.")
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
@click.option("--humanize", type=float, default=0.0, help="Humanization strength (0.01.0) for invisible removal.")
@click.pass_context
def cmd_all(
ctx: click.Context,
source: Path,
output: Path | None,
inpaint: bool,
inpaint_method: str,
strength: float,
steps: int,
pipeline: str,
model: str | None,
device: str,
seed: int | None,
hf_token: str | None,
humanize: float,
) -> None:
"""Remove ALL watermarks: visible + invisible + metadata.
Runs the full pipeline in order:
1. Visible watermark removal (Gemini sparkle, reverse alpha blending)
2. Invisible watermark removal (SynthID etc., diffusion regeneration)
3. AI metadata stripping (EXIF, PNG text, C2PA)
If invisible watermark deps are not installed, skips step 2 with a warning.
"""
import cv2
from remove_ai_watermarks.gemini_engine import GeminiEngine, get_watermark_config
_banner()
source = _validate_image(source)
if output is None:
output = source.with_stem(source.stem + "_clean")
t0 = time.monotonic()
# ── Step 1: Visible watermark ────────────────────────────────
console.print("\n [bold cyan]① Visible watermark removal[/]")
engine = GeminiEngine()
image = cv2.imread(str(source), cv2.IMREAD_COLOR)
if image is None:
console.print(f"[red]Error:[/] Failed to read image: {source}")
raise SystemExit(1)
h, w = image.shape[:2]
console.print(f" [dim]Input:[/] {source.name} ({w}×{h})")
with console.status("[cyan]Removing visible watermark…[/]"):
det = engine.detect_watermark(image)
if det.detected:
result = engine.remove_watermark(image)
if inpaint:
if det.confidence > 0.15:
region = det.region
else:
config = get_watermark_config(w, h)
pos = config.get_position(w, h)
region = (pos[0], pos[1], config.logo_size, config.logo_size)
result = engine.inpaint_residual(result, region, method=inpaint_method)
console.print(" [green]✓[/] Visible watermark removed")
else:
result = image.copy()
console.print(" [dim]Skipped (no visible watermark detected)[/]")
# Save intermediate
output.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(str(output), result)
# ── Step 2: Invisible watermark ──────────────────────────────
console.print("\n [bold cyan]② Invisible watermark removal[/]")
from remove_ai_watermarks.invisible_engine import InvisibleEngine
device_str = None if device == "auto" else device
def progress_cb(msg: str) -> None:
console.print(f" [dim]{msg}[/]")
inv_engine = InvisibleEngine(
model_id=model,
device=device_str,
pipeline=pipeline,
hf_token=hf_token,
progress_callback=progress_cb,
)
console.print(f" [dim]Strength:[/] {strength} Steps: {steps}")
inv_engine.remove_watermark(
image_path=output,
output_path=output,
strength=strength,
num_inference_steps=steps,
seed=seed,
humanize=humanize,
)
console.print(" [green]✓[/] Invisible watermark removed")
# ── Step 3: Metadata ─────────────────────────────────────────
console.print("\n [bold cyan]③ AI metadata stripping[/]")
try:
from remove_ai_watermarks.metadata import remove_ai_metadata
remove_ai_metadata(output, output)
console.print(" [green]✓[/] AI metadata stripped")
except Exception as e:
console.print(f" [yellow]⚠[/] Metadata strip failed: {e}")
# ── Done ─────────────────────────────────────────────────────
elapsed = time.monotonic() - t0
size_kb = output.stat().st_size / 1024
console.print(f"\n [bold green]✓ Done:[/] {output} [dim]({size_kb:.0f} KB, {elapsed:.1f}s total)[/]")
# ── Batch command ────────────────────────────────────────────────────
@main.command("batch")
@click.argument("directory", type=click.Path(exists=True, file_okay=False, path_type=Path))
@click.option(
"-o",
"--output-dir",
type=click.Path(path_type=Path),
default=None,
help="Output directory (default: <dir>_clean/).",
)
@click.option(
"--mode", type=click.Choice(["visible", "invisible", "metadata", "all"]), default="visible", help="Processing mode."
)
@click.option("--strength", type=float, default=None, help="Denoising strength (invisible mode).")
@click.option("--steps", type=int, default=50, help="Number of denoising steps (invisible mode).")
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting (visible mode).")
@click.option("--humanize", type=float, default=0.0, help="Humanization strength (0.01.0) for invisible removal.")
@click.option("--pipeline", type=click.Choice(["default", "ctrlregen"]), default="default", help="Pipeline profile.")
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
@click.pass_context
def cmd_batch(
ctx: click.Context,
directory: Path,
mode: str,
output_dir: Path | None,
strength: float | None,
steps: int,
pipeline: str,
device: str,
seed: int | None,
hf_token: str | None,
inpaint: bool,
humanize: float,
) -> None:
"""Process all images in a directory."""
_banner()
if output_dir is None:
output_dir = directory.parent / (directory.name + "_clean")
output_dir.mkdir(parents=True, exist_ok=True)
# Assuming supported_formats is defined elsewhere
supported_formats = [".jpg", ".jpeg", ".png", ".webp"]
images = sorted(p for p in directory.iterdir() if p.suffix.lower() in supported_formats)
if not images:
console.print(f"[yellow]No supported images found in {directory}[/]")
return
console.print(f" Found [bold]{len(images)}[/] images in {directory}")
console.print(f" Output → {output_dir}")
console.print(f" Mode: [cyan]{mode}[/]")
processed = 0
errors = 0
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console,
) as progress:
task = progress.add_task("Processing…", total=len(images))
for img_path in images:
out_path = output_dir / img_path.name
progress.update(task, description=f"[cyan]{img_path.name}[/]")
try:
if mode in ("visible", "all"):
import cv2
from remove_ai_watermarks.gemini_engine import (
GeminiEngine,
get_watermark_config,
)
if "_vis_engine" not in ctx.obj:
ctx.obj["_vis_engine"] = GeminiEngine()
engine = ctx.obj["_vis_engine"]
read_path = img_path
if mode == "all" and out_path.exists():
read_path = out_path
image = cv2.imread(str(read_path), cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Failed to read image")
det = engine.detect_watermark(image)
if det.detected:
result = engine.remove_watermark(image)
if inpaint:
if det.confidence > 0.15:
region = det.region
else:
h, w = image.shape[:2]
config = get_watermark_config(w, h)
pos = config.get_position(w, h)
region = (pos[0], pos[1], config.logo_size, config.logo_size)
result = engine.inpaint_residual(result, region)
else:
result = image.copy()
cv2.imwrite(str(out_path), result)
if mode in ("invisible", "all"):
from remove_ai_watermarks.invisible_engine import (
is_available as invisible_available,
)
if invisible_available():
from remove_ai_watermarks.invisible_engine import InvisibleEngine
engine_inv = InvisibleEngine()
engine_inv.remove_watermark(
img_path if mode == "invisible" else out_path,
out_path,
strength=strength,
num_inference_steps=steps,
seed=seed,
humanize=humanize,
)
if mode in ("metadata", "all"):
from remove_ai_watermarks.metadata import remove_ai_metadata
remove_ai_metadata(img_path if mode == "metadata" else out_path, out_path)
processed += 1
except Exception as e:
errors += 1
if ctx.obj.get("verbose"):
console.print(f" [red]✗[/] {img_path.name}: {e}")
progress.advance(task)
console.print(f"\n [green]✓[/] {processed} processed" + (f" [red]✗[/] {errors} errors" if errors else ""))
if __name__ == "__main__":
main()
+128
View File
@@ -0,0 +1,128 @@
import logging
from pathlib import Path
import cv2
import numpy as np
try:
from ultralytics import YOLO
HAS_YOLO = True
except ImportError:
HAS_YOLO = False
logger = logging.getLogger("face_protector")
class FaceProtector:
"""
Detects faces in an image and provides methods to seamlessly paste them back
onto the an upscaled/processed image to preserve facial details that may have
been destroyed by latent diffusion or other algorithms.
"""
def __init__(self, use_yolo: bool = True, model_name: str = "yolov8n.pt"):
self.use_yolo = use_yolo and HAS_YOLO
self.detector = None
self.haar_cascade = None
if self.use_yolo:
logger.info(f"Loading YOLO model '{model_name}' for face protection...")
self.detector = YOLO(model_name)
else:
if use_yolo and not HAS_YOLO:
logger.warning(
"ultralytics YOLO is not installed. Falling back to OpenCV Haar Cascades. Install ultralytics with `pip install ultralytics` for better face detection."
)
logger.info("Loading OpenCV Haar Cascade for face protection...")
cascade_path = Path(cv2.__file__).parent / "data" / "haarcascade_frontalface_default.xml"
if not cascade_path.exists():
cascade_path = "haarcascade_frontalface_default.xml"
self.haar_cascade = cv2.CascadeClassifier(str(cascade_path))
def detect_face_bboxes(self, image: np.ndarray) -> list[tuple[int, int, int, int]]:
"""
Detect faces and return bounding boxes as (x1, y1, x2, y2).
"""
if self.use_yolo and self.detector is not None:
# For standard YOLOv8n, 'person' is class 0. We'll use person bounding boxes
# as a proxy for faces/people to protect them. If using a specific face model, adjust classes.
results = self.detector(image, verbose=False, classes=[0])
bboxes = []
for r in results:
boxes = r.boxes
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
bboxes.append((int(x1), int(y1), int(x2), int(y2)))
return bboxes
else:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
faces = self.haar_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
bboxes = []
for x, y, w, h in faces:
# Add a 20% margin around the haar cascade face box
margin_x = int(w * 0.2)
margin_y = int(h * 0.2)
x1 = max(0, x - margin_x)
y1 = max(0, y - int(margin_y * 1.5)) # more margin on top for hair
x2 = min(image.shape[1], x + w + margin_x)
y2 = min(image.shape[0], y + h + margin_y)
bboxes.append((x1, y1, x2, y2))
return bboxes
def extract_faces(self, image: np.ndarray) -> list[tuple[tuple[int, int, int, int], np.ndarray]]:
"""
Extract faces from the image.
Returns a list of (bbox, face_crop) tuples.
"""
bboxes = self.detect_face_bboxes(image)
faces = []
for bbox in bboxes:
x1, y1, x2, y2 = bbox
faces.append((bbox, image[y1:y2, x1:x2].copy()))
return faces
def restore_faces(
self, processed_image: np.ndarray, original_faces: list[tuple[tuple[int, int, int, int], np.ndarray]]
) -> np.ndarray:
"""
Paste original faces back onto the processed image using seamless cloning
or soft blending so the edges don't show.
"""
if not original_faces:
return processed_image
result = processed_image.copy()
for (x1, y1, x2, y2), face_crop in original_faces:
h, w = face_crop.shape[:2]
# If the processed image was resized, we'd need to resize face_crop, but
# pipeline ensures the output from InvisibleEngine is the same size or we resize it back before this.
if result.shape[:2] != processed_image.shape[:2]:
continue # Safety bypass
try:
# Create a soft alpha mask for the face crop to smoothly blend it
mask = np.zeros((h, w), dtype=np.float32)
# Inner ellipse is pure white
cv2.ellipse(mask, (w // 2, h // 2), (int(w * 0.4), int(h * 0.4)), 0, 0, 360, 1.0, -1)
# Blur the mask heavily for soft edges
blur_size = max(w, h) // 4
if blur_size % 2 == 0:
blur_size += 1
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
mask = cv2.merge([mask, mask, mask])
# Blend
target_roi = result[y1:y2, x1:x2].astype(np.float32)
src_roi = face_crop.astype(np.float32)
blended = src_roi * mask + target_roi * (1.0 - mask)
result[y1:y2, x1:x2] = blended.astype(np.uint8)
except Exception as e:
logger.warning(f"Failed to restore face at {x1},{y1} to {x2},{y2}: {e}")
return result
+549
View File
@@ -0,0 +1,549 @@
"""Gemini visible watermark removal engine.
Port of the GeminiWatermarkTool reverse-alpha-blending algorithm from C++ to Python.
Original author: Allen Kuo (allenk) — https://github.com/allenk/GeminiWatermarkTool
The Gemini AI watermark is applied using alpha blending:
watermarked = α × logo + (1 - α) × original
We reverse this to recover the original:
original = (watermarked - α × logo) / (1 - α)
The alpha maps are derived from background captures of the Gemini watermark
on pure-black backgrounds (48×48 for small images, 96×96 for large images).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Literal
import cv2
import numpy as np
from numpy.typing import NDArray
logger = logging.getLogger(__name__)
class WatermarkSize(Enum):
"""Watermark size mode based on image dimensions."""
SMALL = "small" # 48×48, for images ≤ 1024×1024
LARGE = "large" # 96×96, for images > 1024×1024
@dataclass
class DetectionResult:
"""Result of watermark detection."""
detected: bool = False
confidence: float = 0.0
region: tuple[int, int, int, int] = (0, 0, 0, 0) # x, y, w, h
size: WatermarkSize = WatermarkSize.SMALL
# stage scores
spatial_score: float = 0.0
gradient_score: float = 0.0
variance_score: float = 0.0
@dataclass
class WatermarkPosition:
"""Watermark position configuration."""
margin_right: int
margin_bottom: int
logo_size: int
def get_position(self, image_width: int, image_height: int) -> tuple[int, int]:
"""Get top-left position for a given image size."""
x = image_width - self.margin_right - self.logo_size
y = image_height - self.margin_bottom - self.logo_size
return (x, y)
def get_watermark_config(width: int, height: int) -> WatermarkPosition:
"""Get the appropriate watermark configuration based on image size.
Rules discovered from Gemini:
- W > 1024 AND H > 1024: 96×96 logo at (W-64-96, H-64-96)
- Otherwise: 48×48 logo at (W-32-48, H-32-48)
"""
if width > 1024 and height > 1024:
return WatermarkPosition(margin_right=64, margin_bottom=64, logo_size=96)
return WatermarkPosition(margin_right=32, margin_bottom=32, logo_size=48)
def get_watermark_size(width: int, height: int) -> WatermarkSize:
"""Determine watermark size mode from image dimensions."""
if width > 1024 and height > 1024:
return WatermarkSize.LARGE
return WatermarkSize.SMALL
def _calculate_alpha_map(bg_capture: NDArray) -> NDArray:
"""Calculate alpha map from a background capture.
The alpha map represents how much the watermark affects each pixel.
alpha = max(R, G, B) / 255.0
"""
if len(bg_capture.shape) == 2:
gray = bg_capture.astype(np.float32)
elif bg_capture.shape[2] >= 3:
# Use max of channels for brightness
gray = np.max(bg_capture[:, :, :3], axis=2).astype(np.float32)
else:
gray = bg_capture[:, :, 0].astype(np.float32)
return gray / 255.0
def _load_embedded_asset(name: str) -> NDArray:
"""Load an embedded PNG asset and decode it with OpenCV."""
asset_path = Path(__file__).parent / "assets" / name
if not asset_path.exists():
raise FileNotFoundError(f"Embedded asset not found: {asset_path}")
data = asset_path.read_bytes()
buf = np.frombuffer(data, dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
if img is None:
raise RuntimeError(f"Failed to decode embedded asset: {name}")
return img
class GeminiEngine:
"""Engine for removing visible Gemini watermarks via reverse alpha blending.
This is a Python port of the GeminiWatermarkTool C++ engine.
"""
def __init__(self, logo_value: float = 255.0) -> None:
"""Initialize the engine with embedded alpha maps.
Args:
logo_value: The logo brightness value (default 255.0 = white).
"""
self.logo_value = logo_value
# Load embedded background captures
bg_small = _load_embedded_asset("gemini_bg_48.png")
bg_large = _load_embedded_asset("gemini_bg_96.png")
# Ensure correct sizes
if bg_small.shape[:2] != (48, 48):
bg_small = cv2.resize(bg_small, (48, 48), interpolation=cv2.INTER_AREA)
if bg_large.shape[:2] != (96, 96):
bg_large = cv2.resize(bg_large, (96, 96), interpolation=cv2.INTER_AREA)
# Calculate alpha maps
self._alpha_small = _calculate_alpha_map(bg_small)
self._alpha_large = _calculate_alpha_map(bg_large)
logger.debug(
"Alpha maps loaded: small=%s, large=%s",
self._alpha_small.shape,
self._alpha_large.shape,
)
def get_alpha_map(self, size: WatermarkSize) -> NDArray:
"""Get the base alpha map for a specific standard size."""
if size == WatermarkSize.SMALL:
return self._alpha_small
return self._alpha_large
def get_interpolated_alpha(self, size_px: int) -> NDArray:
"""Create an interpolated alpha map dynamically scaled from the high-res 96x96 base."""
source = self._alpha_large
if size_px == source.shape[1]:
return source.copy()
interp = cv2.INTER_LINEAR if size_px > source.shape[1] else cv2.INTER_AREA
return cv2.resize(source, (size_px, size_px), interpolation=interp)
# ── Detection ────────────────────────────────────────────────────
def detect_watermark(
self,
image: NDArray,
force_size: WatermarkSize | None = None,
) -> DetectionResult:
"""Detect Gemini watermark using multi-scale Snap Engine logic (ported from C++ vendor algorithm)."""
result = DetectionResult()
if image is None or image.size == 0:
return result
h, w = image.shape[:2]
base_size = force_size or get_watermark_size(w, h)
result.size = base_size
# Use large alpha template (96x96) as the high-quality source for downscaling
source_alpha = self._alpha_large
# Dynamically search bottom-right corner (search up to 256x256 region)
search_size = int(min(min(w, h), 256))
sx1 = max(0, w - search_size)
sy1 = max(0, h - search_size)
search_region = image[sy1:h, sx1:w]
if len(search_region.shape) == 3 and search_region.shape[2] >= 3:
gray_sr = cv2.cvtColor(search_region, cv2.COLOR_BGR2GRAY)
else:
gray_sr = search_region.copy()
gray_sr_f = gray_sr.astype(np.float32) / 255.0
# Phase 1 & 2: Multi-scale spatial NCC search
best_scale = 0
best_score = -1.0
best_raw_ncc = -1.0
best_loc = (0, 0)
# Search scales from 16 to 120 (covering aggressively downscaled or slightly upscaled logos)
for scale in range(16, 120, 2):
if scale > search_region.shape[0] or scale > search_region.shape[1]:
continue
tmpl = cv2.resize(source_alpha, (scale, scale), interpolation=cv2.INTER_AREA)
match_res = cv2.matchTemplate(gray_sr_f, tmpl, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(match_res)
# Size-adjusted score to overcome NCC bias toward tiny patches (mimics C++ weight)
weight = min(1.0, (scale / 96.0) ** 0.5)
adj_val = max_val * weight
if adj_val > best_score:
best_score = adj_val
best_scale = scale
best_loc = max_loc
best_raw_ncc = max_val
# Exact dynamic location & size
pos_x = sx1 + best_loc[0]
pos_y = sy1 + best_loc[1]
result.region = (pos_x, pos_y, best_scale, best_scale)
result.spatial_score = float(best_raw_ncc)
# Generate exact alpha map for matched size
alpha_region = self.get_interpolated_alpha(best_scale)
# Extract exactly the matched region for Gradient & Variance analysis
x1 = pos_x
y1 = pos_y
x2 = min(w, x1 + best_scale)
y2 = min(h, y1 + best_scale)
region = image[y1:y2, x1:x2]
if len(region.shape) == 3 and region.shape[2] >= 3:
gray_region = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY)
else:
gray_region = region.copy()
gray_f = gray_region.astype(np.float32) / 255.0
# Adjust alpha_region if clipped by image boundary (rare, but possible)
ay1, ax1 = 0, 0
alpha_region = alpha_region[ay1 : ay1 + (y2 - y1), ax1 : ax1 + (x2 - x1)]
if result.spatial_score < 0.25:
result.confidence = float(max(0.0, result.spatial_score * 0.5))
return result
# ── Stage 2: Gradient NCC ────────────────────────────────────
img_gx = cv2.Sobel(gray_f, cv2.CV_32F, 1, 0, ksize=3)
img_gy = cv2.Sobel(gray_f, cv2.CV_32F, 0, 1, ksize=3)
img_gmag = cv2.magnitude(img_gx, img_gy)
alpha_gx = cv2.Sobel(alpha_region, cv2.CV_32F, 1, 0, ksize=3)
alpha_gy = cv2.Sobel(alpha_region, cv2.CV_32F, 0, 1, ksize=3)
alpha_gmag = cv2.magnitude(alpha_gx, alpha_gy)
grad_match = cv2.matchTemplate(img_gmag, alpha_gmag, cv2.TM_CCOEFF_NORMED)
_, grad_score, _, _ = cv2.minMaxLoc(grad_match)
result.gradient_score = float(grad_score)
# ── Stage 3: Variance Analysis ───────────────────────────────
var_score = 0.0
ref_h = min(y1, best_scale)
if ref_h > 8:
ref_region = image[y1 - ref_h : y1, x1:x2]
if len(ref_region.shape) == 3:
gray_ref = cv2.cvtColor(ref_region, cv2.COLOR_BGR2GRAY)
else:
gray_ref = ref_region
_, s_wm = cv2.meanStdDev(gray_region)
_, s_ref = cv2.meanStdDev(gray_ref)
if s_ref[0][0] > 5.0:
var_score = max(0.0, min(1.0, 1.0 - (s_wm[0][0] / s_ref[0][0])))
result.variance_score = float(var_score)
# ── Fusion ───────────────────────────────────────────────────
confidence = result.spatial_score * 0.50 + result.gradient_score * 0.30 + var_score * 0.20
result.confidence = float(max(0.0, min(1.0, confidence)))
result.detected = result.confidence >= 0.35
logger.debug(
"Detection: spatial=%.3f, grad=%.3f, var=%.3f → conf=%.3f (%s)",
result.spatial_score,
result.gradient_score,
var_score,
result.confidence,
"DETECTED" if result.detected else "not detected",
)
return result
# ── Removal ──────────────────────────────────────────────────────
def remove_watermark(
self,
image: NDArray,
force_size: WatermarkSize | None = None,
) -> NDArray:
"""Remove Gemini visible watermark from an image using reverse alpha blending.
Args:
image: BGR image as numpy array (will NOT be modified in-place).
force_size: Force a specific watermark size (auto-detect if None).
Returns:
Cleaned BGR image as numpy array.
"""
result = image.copy()
# Handle alpha channel
if result.shape[2] == 4:
result = cv2.cvtColor(result, cv2.COLOR_BGRA2BGR)
elif result.shape[2] == 1:
result = cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)
h, w = result.shape[:2]
size = force_size or get_watermark_size(w, h)
# Detect dynamic position & size
detection = self.detect_watermark(image, force_size=size)
# If the confidence is really high (>0.15), use the dynamically detected position and size.
if detection.confidence > 0.15:
pos = (detection.region[0], detection.region[1])
alpha_map = self.get_interpolated_alpha(detection.region[2])
logger.debug(
"Using dynamic watermark position (%d, %d) at size %dx%d [conf=%.3f]",
pos[0],
pos[1],
detection.region[2],
detection.region[3],
detection.confidence,
)
else:
config = get_watermark_config(w, h)
pos = config.get_position(w, h)
alpha_map = self.get_alpha_map(size)
logger.debug(
"Dynamic search failed. Using fallback default position (%d, %d) with %dx%d alpha map.",
pos[0],
pos[1],
alpha_map.shape[1],
alpha_map.shape[0],
)
self._reverse_alpha_blend(result, alpha_map, pos)
return result
def remove_watermark_custom(
self,
image: NDArray,
region: tuple[int, int, int, int],
) -> NDArray:
"""Remove watermark from a custom region with interpolated alpha map.
Args:
image: BGR image (will NOT be modified in-place).
region: (x, y, width, height) of the watermark region.
Returns:
Cleaned BGR image.
"""
result = image.copy()
x, y, rw, rh = region
# Check standard sizes
if rw == 48 and rh == 48:
self._reverse_alpha_blend(result, self._alpha_small, (x, y))
return result
if rw == 96 and rh == 96:
self._reverse_alpha_blend(result, self._alpha_large, (x, y))
return result
# Interpolate alpha map for custom size
interp = cv2.INTER_LINEAR if rw > 96 else cv2.INTER_AREA
alpha = cv2.resize(self._alpha_large, (rw, rh), interpolation=interp)
self._reverse_alpha_blend(result, alpha, (x, y))
return result
def _reverse_alpha_blend(
self,
image: NDArray,
alpha_map: NDArray,
position: tuple[int, int],
) -> None:
"""Apply reverse alpha blending in-place.
Formula: original = (watermarked - α × logo) / (1 - α)
"""
x, y = position
ah, aw = alpha_map.shape[:2]
ih, iw = image.shape[:2]
# Clip to bounds
x1 = max(0, x)
y1 = max(0, y)
x2 = min(iw, x + aw)
y2 = min(ih, y + ah)
if x1 >= x2 or y1 >= y2:
return
# Get ROIs
ax1, ay1 = x1 - x, y1 - y
alpha_roi = alpha_map[ay1 : ay1 + (y2 - y1), ax1 : ax1 + (x2 - x1)]
image_roi = image[y1:y2, x1:x2].astype(np.float32)
alpha_threshold = 0.002
max_alpha = 0.99
# Vectorized reverse alpha blending
alpha = alpha_roi.copy()
mask = alpha >= alpha_threshold
alpha = np.clip(alpha, 0.0, max_alpha)
one_minus_alpha = 1.0 - alpha
# Expand alpha for 3-channel broadcast
alpha_3d = alpha[:, :, np.newaxis]
one_minus_3d = one_minus_alpha[:, :, np.newaxis]
mask_3d = mask[:, :, np.newaxis]
# original = (watermarked - alpha * logo) / (1 - alpha)
restored = (image_roi - alpha_3d * self.logo_value) / one_minus_3d
restored = np.clip(restored, 0.0, 255.0)
# Apply only where alpha is significant
image_roi = np.where(mask_3d, restored, image_roi)
image[y1:y2, x1:x2] = image_roi.astype(np.uint8)
# ── Inpainting cleanup ───────────────────────────────────────────
def inpaint_residual(
self,
image: NDArray,
region: tuple[int, int, int, int],
strength: float = 0.85,
method: Literal["gaussian", "telea", "ns"] = "ns",
inpaint_radius: int = 10,
padding: int = 32,
) -> NDArray:
"""Apply inpaint cleanup on residual artifacts after reverse alpha blend.
Uses a sparse mask derived from alpha map gradient to repair only
the sparkle-edge pixels where interpolation broke the math.
Args:
image: BGR image (will NOT be modified in-place).
region: (x, y, w, h) of the watermark region.
strength: Blend strength (0.0 = keep original, 1.0 = fully inpainted).
method: Inpaint method ("gaussian", "telea", or "ns").
inpaint_radius: Radius for cv2.inpaint.
padding: Context padding around region in pixels.
Returns:
Cleaned BGR image.
"""
result = image.copy()
x, y, rw, rh = region
if rw < 4 or rh < 4:
return result
strength = max(0.0, min(1.0, strength))
if strength < 0.001:
return result
# Padded region
px1 = max(0, x - padding)
py1 = max(0, y - padding)
px2 = min(image.shape[1], x + rw + padding)
py2 = min(image.shape[0], y + rh + padding)
if (px2 - px1) < 8 or (py2 - py1) < 8:
return result
# Inner rect relative to padded
ix1 = x - px1
iy1 = y - py1
# Get alpha map (interpolated if needed)
source_alpha = self._alpha_large
interp = cv2.INTER_LINEAR if rw > source_alpha.shape[1] else cv2.INTER_AREA
alpha_resized = cv2.resize(source_alpha, (rw, rh), interpolation=interp)
# Compute gradient mask from alpha
grad_x = cv2.Sobel(alpha_resized, cv2.CV_32F, 1, 0, ksize=3)
grad_y = cv2.Sobel(alpha_resized, cv2.CV_32F, 0, 1, ksize=3)
grad_mag = cv2.magnitude(grad_x, grad_y)
grad_min, grad_max = grad_mag.min(), grad_mag.max()
if grad_max <= grad_min:
return result
# Normalize and apply gamma correction
grad_norm = (grad_mag - grad_min) / (grad_max - grad_min)
grad_weight = np.sqrt(grad_norm)
# Dilate the mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
grad_weight = cv2.dilate(grad_weight, kernel)
if method == "gaussian":
# Soft blend with Gaussian blur
padded_roi = result[py1:py2, px1:px2].copy()
blurred = cv2.GaussianBlur(padded_roi, (0, 0), sigmaX=2.0)
# Create weight mask on padded area (only inner region has weights)
weight_full = np.zeros((py2 - py1, px2 - px1), dtype=np.float32)
weight_full[iy1 : iy1 + rh, ix1 : ix1 + rw] = grad_weight * strength
weight_3d = weight_full[:, :, np.newaxis]
blended = padded_roi.astype(np.float32) * (1 - weight_3d) + blurred.astype(np.float32) * weight_3d
result[py1:py2, px1:px2] = blended.astype(np.uint8)
else:
# OpenCV inpainting (TELEA or NS)
inpaint_flag = cv2.INPAINT_TELEA if method == "telea" else cv2.INPAINT_NS
# Create binary mask from gradient weight
binary_mask = (grad_weight * 255).astype(np.uint8)
_, binary_mask = cv2.threshold(binary_mask, 30, 255, cv2.THRESH_BINARY)
# Expand mask to padded region
mask_full = np.zeros((py2 - py1, px2 - px1), dtype=np.uint8)
mask_full[iy1 : iy1 + rh, ix1 : ix1 + rw] = binary_mask
padded_roi = result[py1:py2, px1:px2].copy()
inpainted = cv2.inpaint(padded_roi, mask_full, inpaint_radius, inpaint_flag)
# Blend with strength
weight_full = np.zeros((py2 - py1, px2 - px1), dtype=np.float32)
weight_full[iy1 : iy1 + rh, ix1 : ix1 + rw] = grad_weight * strength
weight_3d = weight_full[:, :, np.newaxis]
blended = padded_roi.astype(np.float32) * (1 - weight_3d) + inpainted.astype(np.float32) * weight_3d
result[py1:py2, px1:px2] = blended.astype(np.uint8)
return result
+45
View File
@@ -0,0 +1,45 @@
import cv2
import numpy as np
from numpy.typing import NDArray
def apply_analog_humanizer(image: NDArray, grain_intensity: float = 4.0, chromatic_shift: int = 1) -> NDArray:
"""
Apply Analog Humanizer (film grain and chromatic aberration) to an image.
This simulates analog film imperfections to defeat digital AI perfection classifiers.
Ported from NeuralBleach.
Args:
image: BGR image as numpy array (uint8).
grain_intensity: Standard deviation of the Gaussian noise (film grain).
chromatic_shift: Number of pixels to shift the red/blue color channels.
Returns:
Humanized BGR image.
"""
# Ensure image is BGR
if len(image.shape) != 3 or image.shape[2] != 3:
return image.copy()
# Split channels (OpenCV uses BGR)
# B = 0, G = 1, R = 2
b, g, r = cv2.split(image)
# 1. Chromatic Aberration
# Shift R channel left, B channel right
if chromatic_shift > 0:
r = np.roll(r, -chromatic_shift, axis=1)
b = np.roll(b, chromatic_shift, axis=1)
merged = cv2.merge((b, g, r))
# 2. Film Grain (Gaussian Noise)
if grain_intensity > 0:
img_f = merged.astype(np.float32)
noise = np.random.normal(0, grain_intensity, img_f.shape).astype(np.float32)
humanized = np.clip(img_f + noise, 0, 255).astype(np.uint8)
else:
humanized = merged
return humanized
@@ -0,0 +1,253 @@
"""Invisible watermark removal engine.
Wraps the vendored noai-watermark code for removing invisible AI watermarks
(SynthID, StableSignature, TreeRing) via diffusion-based regeneration.
This module requires the 'invisible' extra dependencies:
uv pip install 'remove-ai-watermarks[invisible]'
"""
from __future__ import annotations
import logging
import os
import warnings
from collections.abc import Callable
from pathlib import Path
# Suppress verbose deprecation warnings from diffusers/transformers/huggingface_hub
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")
warnings.filterwarnings("ignore", category=UserWarning, module="diffusers")
warnings.filterwarnings("ignore", module="transformers")
# Suppress HuggingFace internal logging
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["DIFFUSERS_VERBOSITY"] = "error"
logger = logging.getLogger(__name__)
def is_available() -> bool:
"""Check if invisible watermark removal dependencies are installed."""
try:
import diffusers # noqa: F401
import torch # noqa: F401
return True
except ImportError:
return False
class InvisibleEngine:
"""Remove invisible AI watermarks using diffusion model regeneration.
Based on noai-watermark by mertizci:
https://github.com/mertizci/noai-watermark
The approach encodes the image into latent space, injects controlled noise
to break watermark patterns, and reconstructs via reverse diffusion.
"""
DEFAULT_MODEL_ID = "Lykon/dreamshaper-8"
CTRLREGEN_MODEL_ID = "yepengliu/ctrlregen"
def __init__(
self,
model_id: str | None = None,
device: str | None = None,
pipeline: str = "default",
hf_token: str | None = None,
progress_callback: Callable[[str], None] | None = None,
) -> None:
"""Initialize the invisible watermark removal engine.
Args:
model_id: HuggingFace model ID. None = use default for pipeline.
device: Device for inference (auto/cpu/mps/cuda). None = auto.
pipeline: Pipeline profile ("default" or "ctrlregen").
hf_token: HuggingFace API token.
progress_callback: Optional callback for progress messages.
"""
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover
effective_model = model_id
if pipeline == "ctrlregen" and model_id is None:
effective_model = self.CTRLREGEN_MODEL_ID
elif model_id is None:
effective_model = self.DEFAULT_MODEL_ID
self._remover = WatermarkRemover(
model_id=effective_model,
device=device,
progress_callback=progress_callback,
hf_token=hf_token,
)
self._progress_callback = progress_callback
def preload(self) -> None:
"""Eagerly load the pipeline so download progress is visible."""
self._remover.preload()
def remove_watermark(
self,
image_path: Path,
output_path: Path | None = None,
strength: float | None = None,
num_inference_steps: int = 100,
guidance_scale: float | None = None,
seed: int | None = None,
humanize: float = 0.0,
protect_faces: bool = True,
) -> Path:
"""Remove invisible watermark from an image.
Args:
image_path: Path to the watermarked image.
output_path: Output path (None = overwrite source).
strength: Denoising strength (0.01.0). Default 0.04.
steps: Number of denoising steps.
guidance_scale: Classifier-free guidance scale.
seed: Random seed for reproducibility.
humanize: Intensity of Analog Humanizer film grain (0 = off).
protect_faces: Boolean to extract and restore faces intact.
Returns:
Path to the cleaned image.
"""
import tempfile
from PIL import Image, ImageOps
max_dimension = 768
image = Image.open(image_path)
image = ImageOps.exif_transpose(image)
orig_size = image.size # (width, height)
_tmp_path = None
if max(image.width, image.height) > max_dimension:
ratio = max_dimension / max(image.width, image.height)
new_size = (int(image.width * ratio), int(image.height * ratio))
if self._progress_callback:
self._progress_callback(
f"Auto-downscaling {image.width}x{image.height} "
f"to {new_size[0]}x{new_size[1]} to prevent Memory Error..."
)
image = image.resize(new_size, Image.Resampling.LANCZOS)
# Save to a temp file instead of overwriting the original
_tmp_fd, _tmp_str = tempfile.mkstemp(suffix=image_path.suffix)
_tmp_path = Path(_tmp_str)
image.save(_tmp_path)
import os as _os
_os.close(_tmp_fd)
image_path = _tmp_path
else:
# We must save the transposed image back to a tmp file if it was rotated
# otherwise WatermarkRemover will reload it without EXIF rotation!
_tmp_fd, _tmp_str = tempfile.mkstemp(suffix=image_path.suffix)
_tmp_path = Path(_tmp_str)
image.save(_tmp_path)
import os as _os
_os.close(_tmp_fd)
image_path = _tmp_path
try:
# Optional: Face protection (Phase 1 - Extraction)
original_faces = []
if protect_faces:
try:
import cv2
from remove_ai_watermarks.face_protector import FaceProtector
if self._progress_callback:
self._progress_callback("Detecting and extracting faces (protect-faces)...")
# Convert PIL to CV2 BGR
import numpy as np
cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
protector = FaceProtector(use_yolo=True)
original_faces = protector.extract_faces(cv_img)
if self._progress_callback:
self._progress_callback(f"Extracted {len(original_faces)} face(s) for protection.")
except Exception as e:
logger.error(f"Failed to extract faces: {e}")
out_path = self._remover.remove_watermark(
image_path=image_path,
output_path=output_path,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
)
# Optional: Face restoration & Humanizer (Phase 2 - Post-processing)
if protect_faces or humanize > 0.0:
import cv2
import numpy as np
out_cv = cv2.imread(str(out_path), cv2.IMREAD_COLOR)
if protect_faces and original_faces:
if self._progress_callback:
self._progress_callback("Restoring protected faces with soft blending...")
from remove_ai_watermarks.face_protector import FaceProtector
protector = FaceProtector(use_yolo=True)
out_cv = protector.restore_faces(out_cv, original_faces)
if humanize > 0.0:
if self._progress_callback:
self._progress_callback(f"Applying Analog Humanizer (grain: {humanize})...")
from remove_ai_watermarks.humanizer import apply_analog_humanizer
out_cv = apply_analog_humanizer(out_cv, grain_intensity=humanize, chromatic_shift=1)
# Restore original resolution
if (out_cv.shape[1], out_cv.shape[0]) != orig_size:
if self._progress_callback:
self._progress_callback(
f"Upscaling result back to original resolution {orig_size[0]}x{orig_size[1]}..."
)
# Using INTER_LANCZOS4 for high-quality upscaling back to original
out_cv = cv2.resize(out_cv, orig_size, interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite(str(out_path), out_cv)
else:
# Even if no protect_faces or humanize, we must restore original size if needed
import cv2
out_cv = cv2.imread(str(out_path), cv2.IMREAD_COLOR)
if out_cv is not None and (out_cv.shape[1], out_cv.shape[0]) != orig_size:
if self._progress_callback:
self._progress_callback(
f"Upscaling result back to original resolution {orig_size[0]}x{orig_size[1]}..."
)
out_cv = cv2.resize(out_cv, orig_size, interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite(str(out_path), out_cv)
return out_path
finally:
if _tmp_path is not None and _tmp_path.exists():
_tmp_path.unlink()
def remove_watermark_batch(
self,
input_dir: Path,
output_dir: Path,
strength: float = 0.04,
steps: int = 50,
) -> list[Path]:
"""Remove invisible watermarks from all images in a directory."""
return self._remover.remove_watermark_batch(
input_dir=input_dir,
output_dir=output_dir,
strength=strength,
num_inference_steps=steps,
)
+214
View File
@@ -0,0 +1,214 @@
"""AI metadata detection and removal.
Wraps the noai-watermark metadata handling for stripping AI-generation
metadata (EXIF, PNG text chunks, C2PA provenance) from images.
For metadata-only operations, the heavy ML dependencies are NOT required.
"""
from __future__ import annotations
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
# ── Known AI metadata keys ──────────────────────────────────────────
AI_METADATA_KEYS: frozenset[str] = frozenset(
k.lower()
for k in [
"parameters",
"prompt",
"negative_prompt",
"workflow",
"comfyui",
"sd-metadata",
"invokeai_metadata",
"generation_data",
"ai_metadata",
"dream",
"sd:prompt",
"sd:negative_prompt",
"sd:seed",
"sd:steps",
"sd:sampler",
"sd:cfg_scale",
"sd:model_hash",
"c2pa",
"c2pa_chunk",
"Software",
]
)
AI_KEYWORDS: tuple[str, ...] = (
"stable_diffusion",
"comfyui",
"automatic1111",
"invokeai",
"midjourney",
"dall-e",
"dalle",
"imagen",
"synthid",
"google_ai",
"openai",
"c2pa",
)
STANDARD_METADATA_KEYS: frozenset[str] = frozenset(
[
"Author",
"Title",
"Description",
"Copyright",
"Creation Time",
"Software",
"Comment",
"Disclaimer",
"Source",
"Warning",
]
)
def _is_ai_key(key: str) -> bool:
"""Check if a metadata key is AI-related."""
key_lower = key.lower()
if key_lower in AI_METADATA_KEYS:
return True
return any(kw in key_lower for kw in AI_KEYWORDS)
def has_ai_metadata(image_path: Path) -> bool:
"""Check if an image contains AI-generation metadata.
Args:
image_path: Path to the image.
Returns:
True if AI metadata is detected.
"""
from PIL import Image
with Image.open(image_path) as img:
for key in img.info:
if _is_ai_key(key):
return True
# Check C2PA
try:
from c2pa import has_c2pa_metadata
if has_c2pa_metadata(image_path):
return True
except ImportError:
# Try simple binary scan
data = image_path.read_bytes()
if b"c2pa" in data.lower() or b"C2PA" in data:
return True
return False
def get_ai_metadata(image_path: Path) -> dict[str, str]:
"""Extract AI-related metadata from an image.
Args:
image_path: Path to the image.
Returns:
Dictionary of AI metadata key-value pairs.
"""
from PIL import Image
result: dict[str, str] = {}
with Image.open(image_path) as img:
for key, value in img.info.items():
if _is_ai_key(key):
if isinstance(value, bytes):
result[key] = f"<binary {len(value)} bytes>"
elif isinstance(value, str) and len(value) > 200:
result[key] = value[:200] + ""
else:
result[key] = str(value)
return result
def remove_ai_metadata(
source_path: Path,
output_path: Path | None = None,
keep_standard: bool = True,
) -> Path:
"""Remove AI-generation metadata from an image.
Strips EXIF AI tags, PNG text chunks, and C2PA provenance manifests
while optionally preserving standard metadata (Author, Title, etc.).
Args:
source_path: Path to the source image.
output_path: Output path (None = overwrite source).
keep_standard: If True, preserve standard metadata fields.
Returns:
Path to the cleaned image.
"""
import piexif
from PIL import Image
from PIL.PngImagePlugin import PngInfo
if output_path is None:
output_path = source_path
# Read image and filter metadata
with Image.open(source_path) as img:
img = img.copy()
fmt = output_path.suffix.lower()
save_kwargs: dict = {}
if fmt in (".jpg", ".jpeg"):
save_kwargs["format"] = "JPEG"
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
else:
save_kwargs["format"] = "PNG"
# Collect non-AI metadata
kept_meta: dict[str, str] = {}
exif_data = None
for key, value in img.info.items():
if _is_ai_key(key):
continue
if key == "exif":
try:
exif_data = piexif.load(value)
except Exception:
pass
continue
if key in ("dpi", "gamma"):
save_kwargs[key] = value
continue
if keep_standard and key in STANDARD_METADATA_KEYS:
kept_meta[key] = str(value) if not isinstance(value, str) else value
# Apply cleaned metadata
if save_kwargs["format"] == "PNG" and kept_meta:
pnginfo = PngInfo()
for k, v in kept_meta.items():
pnginfo.add_text(k, v)
save_kwargs["pnginfo"] = pnginfo
if exif_data and save_kwargs["format"] == "JPEG":
try:
save_kwargs["exif"] = piexif.dump(exif_data)
except Exception:
pass
output_path.parent.mkdir(parents=True, exist_ok=True)
img.save(output_path, **save_kwargs)
logger.info("Stripped AI metadata → %s", output_path)
return output_path
@@ -0,0 +1,9 @@
"""Vendored noai-watermark code for invisible watermark removal.
Original: https://github.com/mertizci/noai-watermark (MIT License)
"""
from remove_ai_watermarks.noai.cleaner import remove_ai_metadata
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover, remove_watermark
__all__ = ["WatermarkRemover", "remove_watermark", "remove_ai_metadata"]
+297
View File
@@ -0,0 +1,297 @@
"""C2PA (Coalition for Content Provenance and Authenticity) metadata handling.
C2PA metadata is embedded in PNG files as a JUMBF container chunk
(``caBX``). This module can detect, extract, and re-inject those
chunks. Supported issuers:
- Google Imagen
- Adobe Firefly
- Microsoft Designer
- OpenAI (ChatGPT, GPT-4o, Sora, DALL-E)
- Truepic (signing authority)
The parser uses byte-level scanning — it does not validate JUMBF/CBOR
structure but reliably identifies known signatures, issuers, tools,
and actions.
"""
from __future__ import annotations
import re
import struct
from pathlib import Path
from typing import Any
from remove_ai_watermarks.noai.constants import (
C2PA_ACTIONS,
C2PA_AI_TOOLS,
C2PA_CHUNK_TYPE,
C2PA_ISSUERS,
C2PA_SIGNATURES,
PNG_SIGNATURE,
)
def has_c2pa_metadata(image_path: Path) -> bool:
"""
Check if an image contains C2PA metadata.
Args:
image_path: Path to the image file.
Returns:
True if C2PA metadata is detected, False otherwise.
"""
image_path = Path(image_path)
if image_path.suffix.lower() != ".png":
return False
try:
with open(image_path, "rb") as f:
signature = f.read(8)
if signature != PNG_SIGNATURE:
return False
while True:
chunk_header = f.read(8)
if len(chunk_header) < 8:
break
length = struct.unpack(">I", chunk_header[:4])[0]
chunk_type = chunk_header[4:8]
if chunk_type == C2PA_CHUNK_TYPE:
chunk_data = f.read(length)
# Check for any C2PA signature
for sig in C2PA_SIGNATURES:
if sig in chunk_data:
return True
# Also check if chunk_data itself contains C2PA-like patterns
if b"jumb" in chunk_data.lower() or b"c2pa" in chunk_data.lower():
return True
f.read(4)
else:
f.read(length + 4)
if chunk_type == b"IEND":
break
except Exception:
pass
return False
def extract_c2pa_info(image_path: Path) -> dict[str, Any]:
"""
Extract basic C2PA metadata information from an image.
Args:
image_path: Path to the image file.
Returns:
Dictionary containing C2PA metadata info.
"""
c2pa_info: dict[str, Any] = {}
if not has_c2pa_metadata(image_path):
return c2pa_info
c2pa_info["has_c2pa"] = True
c2pa_info["type"] = "C2PA (Coalition for Content Provenance and Authenticity)"
try:
with open(image_path, "rb") as f:
signature = f.read(8)
if signature != PNG_SIGNATURE:
return c2pa_info
while True:
chunk_header = f.read(8)
if len(chunk_header) < 8:
break
length = struct.unpack(">I", chunk_header[:4])[0]
chunk_type = chunk_header[4:8]
if chunk_type == C2PA_CHUNK_TYPE:
chunk_data = f.read(length)
_parse_c2pa_chunk(chunk_data, c2pa_info)
f.read(4)
else:
f.read(length + 4)
if chunk_type == b"IEND":
break
except Exception:
pass
return c2pa_info
def _parse_c2pa_chunk(chunk_data: bytes, c2pa_info: dict[str, Any]) -> None:
"""Parse C2PA chunk data and populate info dictionary."""
# Find issuers
issuers = []
for sig, name in C2PA_ISSUERS.items():
if sig in chunk_data:
issuers.append(name)
if issuers:
c2pa_info["issuer"] = ", ".join(set(issuers))
# Find AI tools
ai_tools = []
for sig, name in C2PA_AI_TOOLS.items():
if sig in chunk_data:
ai_tools.append(name)
if ai_tools:
c2pa_info["ai_tool"] = ", ".join(set(ai_tools))
# Extract software agent (multiple patterns)
patterns = [
rb"softwareAgent.*?dname([^\x00]+?)(?:q|l|m|n)",
rb"software_agent[^\x00]*?([A-Za-z0-9_\-\.]+)",
rb"Software[^\x00]*?([A-Za-z0-9_\-\. ]+)",
]
for pattern in patterns:
match = re.search(pattern, chunk_data, re.DOTALL | re.IGNORECASE)
if match:
agent = match.group(1).decode("utf-8", errors="ignore").strip()
if agent and len(agent) < 100:
c2pa_info["software_agent"] = agent
break
# Extract claim generator (multiple patterns)
claim_patterns = [
rb"claim_generator[^\x00]*?([A-Za-z0-9_\-\.\/\:]+)",
rb"claimGenerator[^\x00]*?([A-Za-z0-9_\-\.\/\:]+)",
rb"dname([^\x00]{3,50})(?:q|l|m|n|i)",
]
for pattern in claim_patterns:
match = re.search(pattern, chunk_data, re.DOTALL | re.IGNORECASE)
if match:
gen_name = match.group(1).decode("utf-8", errors="ignore").strip()
# Filter out common false positives
if gen_name and len(gen_name) < 100 and not gen_name.startswith(("\\x", "\\\\x")):
c2pa_info["claim_generator"] = gen_name
break
# Find actions
actions = []
for sig, name in C2PA_ACTIONS.items():
if sig in chunk_data:
actions.append(name)
if actions:
c2pa_info["actions"] = ", ".join(actions)
# Find timestamps
timestamp_matches = re.findall(rb"(\d{14}Z)", chunk_data)
if timestamp_matches:
c2pa_info["timestamp"] = timestamp_matches[0].decode("utf-8")
if len(timestamp_matches) > 1:
c2pa_info["timestamps"] = [t.decode("utf-8") for t in timestamp_matches[:3]]
# Find digital source type
if b"trainedAlgorithmicMedia" in chunk_data:
c2pa_info["source_type"] = "trainedAlgorithmicMedia (AI-generated)"
elif b"algorithmicMedia" in chunk_data:
c2pa_info["source_type"] = "algorithmicMedia"
elif b"compositeWithTrainedAlgorithmicMedia" in chunk_data:
c2pa_info["source_type"] = "compositeWithTrainedAlgorithmicMedia (AI-enhanced)"
def extract_c2pa_chunk(image_path: Path) -> bytes | None:
"""
Extract the raw C2PA JUMBF chunk from a PNG file.
Args:
image_path: Path to the source PNG file.
Returns:
Raw bytes of the C2PA chunk or None.
"""
if image_path.suffix.lower() != ".png":
return None
try:
with open(image_path, "rb") as f:
signature = f.read(8)
if signature != PNG_SIGNATURE:
return None
while True:
chunk_header = f.read(8)
if len(chunk_header) < 8:
break
length = struct.unpack(">I", chunk_header[:4])[0]
chunk_type = chunk_header[4:8]
if chunk_type == C2PA_CHUNK_TYPE:
chunk_data = f.read(length)
crc = f.read(4)
# Check for any C2PA signature
for sig in C2PA_SIGNATURES:
if sig in chunk_data:
return chunk_header + chunk_data + crc
# Also check lowercase variants
if b"jumb" in chunk_data.lower() or b"c2pa" in chunk_data.lower():
return chunk_header + chunk_data + crc
else:
f.read(length + 4)
if chunk_type == b"IEND":
break
except Exception:
pass
return None
def inject_c2pa_chunk(target_path: Path, output_path: Path, c2pa_chunk: bytes) -> None:
"""
Inject a C2PA JUMBF chunk into a PNG file.
Args:
target_path: Path to the target PNG file.
output_path: Path where the output file will be saved.
c2pa_chunk: Raw bytes of the C2PA chunk to inject.
Raises:
ValueError: If not PNG files.
"""
if target_path.suffix.lower() != ".png" or output_path.suffix.lower() != ".png":
raise ValueError("C2PA chunk injection is only supported for PNG files")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(target_path, "rb") as f_in:
with open(output_path, "wb") as f_out:
f_out.write(f_in.read(8))
c2pa_injected = False
while True:
chunk_header = f_in.read(8)
if len(chunk_header) < 8:
break
length = struct.unpack(">I", chunk_header[:4])[0]
chunk_type = chunk_header[4:8]
chunk_data = f_in.read(length)
crc = f_in.read(4)
if chunk_type == b"IDAT" and not c2pa_injected:
f_out.write(c2pa_chunk)
c2pa_injected = True
if chunk_type == C2PA_CHUNK_TYPE:
continue
f_out.write(chunk_header)
f_out.write(chunk_data)
f_out.write(crc)
if chunk_type == b"IEND":
break
+181
View File
@@ -0,0 +1,181 @@
"""AI metadata cleaning and removal.
Provides functions to identify and strip AI-generation metadata from
PNG and JPEG images while optionally preserving standard fields like
Author, Title, and Copyright.
The removal pipeline:
1. Opens the image and iterates over all metadata keys.
2. Classifies each key as AI-related (using ``constants.AI_METADATA_KEYS``
and ``constants.AI_KEYWORDS``) or standard.
3. Rebuilds the image with only the desired metadata retained.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import piexif
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from remove_ai_watermarks.noai.constants import AI_KEYWORDS, AI_METADATA_KEYS, PNG_METADATA_KEYS
from remove_ai_watermarks.noai.utils import get_image_format
# Pre-compute a lowercase set for O(1) key lookups.
_AI_KEYS_LOWER: frozenset[str] = frozenset(k.lower() for k in AI_METADATA_KEYS)
def remove_ai_metadata(
source_path: Path,
output_path: Path | None = None,
keep_standard: bool = True,
) -> Path:
"""
Remove all AI-generated metadata from an image.
Removes:
- AI parameters (Stable Diffusion, ComfyUI, etc.)
- C2PA metadata (Google Imagen, OpenAI, etc.)
- Any metadata with AI-related keywords
Args:
source_path: Path to the source image file.
output_path: Optional output path. If not provided, modifies source in place.
keep_standard: If True, keeps standard metadata (Author, Title, etc.).
Returns:
Path to the output file with AI metadata removed.
"""
if output_path is None:
output_path = source_path
cleaned_metadata = _extract_non_ai_metadata(source_path, keep_standard)
with Image.open(source_path) as img:
img = img.copy()
output_format = get_image_format(output_path)
save_kwargs: dict[str, Any] = {"format": output_format}
# Handle EXIF data (keep it, just don't include AI-related fields)
if "exif" in cleaned_metadata:
save_kwargs["exif"] = cleaned_metadata["exif"]
if output_format == "PNG":
save_kwargs = _prepare_clean_png_kwargs(save_kwargs, cleaned_metadata)
elif output_format == "JPEG":
save_kwargs = _prepare_clean_jpeg_kwargs(save_kwargs, cleaned_metadata)
output_path.parent.mkdir(parents=True, exist_ok=True)
if output_format == "JPEG" and img.mode in ("RGBA", "P"):
img = img.convert("RGB")
img.save(output_path, **save_kwargs)
return output_path
def _extract_non_ai_metadata(source_path: Path, keep_standard: bool) -> dict[str, Any]:
"""Extract metadata excluding AI-related fields."""
cleaned_metadata: dict[str, Any] = {}
with Image.open(source_path) as img:
# Handle EXIF data
if "exif" in img.info:
try:
exif_dict = piexif.load(img.info["exif"])
cleaned_metadata["exif"] = exif_dict
except Exception:
pass
# Extract non-AI metadata
for key, value in img.info.items():
if _is_ai_metadata_key(key):
continue
if keep_standard and key in PNG_METADATA_KEYS:
cleaned_metadata[key] = value
elif not keep_standard:
# Remove standard metadata while still preserving non-standard fields.
if key not in ["exif", "dpi", "gamma"] and key not in PNG_METADATA_KEYS:
cleaned_metadata[key] = value
# Keep DPI and gamma
if "dpi" in img.info:
cleaned_metadata["dpi"] = img.info["dpi"]
if "gamma" in img.info:
cleaned_metadata["gamma"] = img.info["gamma"]
return cleaned_metadata
def _is_ai_metadata_key(key: str) -> bool:
"""Return True if *key* is an AI-generation metadata field.
Detection uses two layers:
1. Exact match against the canonical ``AI_METADATA_KEYS`` list.
2. Substring match against ``AI_KEYWORDS`` (covers partial hits
like ``"stable_diffusion_model"``).
"""
key_lower = key.lower()
if key_lower in _AI_KEYS_LOWER:
return True
return any(kw in key_lower for kw in AI_KEYWORDS)
def _prepare_clean_png_kwargs(save_kwargs: dict[str, Any], metadata: dict[str, Any]) -> dict[str, Any]:
"""Prepare save kwargs for clean PNG."""
pnginfo = {}
exclude_keys = ["exif", "exif_raw", "dpi", "gamma"]
for key, value in metadata.items():
if key not in exclude_keys:
pnginfo[key] = value
if pnginfo:
pnginfo_obj = PngInfo()
for key, value in pnginfo.items():
if isinstance(value, str):
pnginfo_obj.add_text(key, value)
elif isinstance(value, bytes):
pnginfo_obj.add_text(key, value.decode("utf-8", errors="replace"))
save_kwargs["pnginfo"] = pnginfo_obj
if "dpi" in metadata:
save_kwargs["dpi"] = metadata["dpi"]
return save_kwargs
def _prepare_clean_jpeg_kwargs(save_kwargs: dict[str, Any], metadata: dict[str, Any]) -> dict[str, Any]:
"""Prepare save kwargs for clean JPEG."""
exif_dict = metadata.get("exif", {"0th": {}, "Exif": {}, "1st": {}, "GPS": {}, "Interop": {}})
try:
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception:
pass
if "dpi" in metadata:
save_kwargs["dpi"] = metadata["dpi"]
return save_kwargs
def has_ai_content(image_path: Path) -> bool:
"""
Check if an image has any AI-generated content or metadata.
Args:
image_path: Path to the image file.
Returns:
True if the image contains AI metadata.
"""
from remove_ai_watermarks.noai.extractor import has_ai_metadata
return has_ai_metadata(image_path)
+112
View File
@@ -0,0 +1,112 @@
"""Shared constants for AI metadata detection, C2PA parsing, and format support.
All modules reference these constants rather than hard-coding values,
so adding a new AI tool or metadata key requires updating only this file.
"""
# Supported image formats
SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp"}
# AI-generated image metadata keys (Stable Diffusion, ComfyUI, Midjourney, etc.)
AI_METADATA_KEYS = [
"parameters", # Stable Diffusion WebUI (AUTOMATIC1111, Vladmandic)
"postprocessing", # SD WebUI post-processing info
"extras", # SD WebUI extras
"workflow", # ComfyUI workflow JSON
"prompt", # Some AI tools
"Dream", # DreamStudio
"SD:mode", # Stability AI
"StableDiffusionVersion", # SD version info
"generation_time", # Generation time info
"Model", # Model name
"Model hash", # Model hash
"Seed", # Seed value
]
# Standard PNG metadata keys
PNG_METADATA_KEYS = [
"Author",
"Title",
"Description",
"Copyright",
"Creation Time",
"Software",
"Disclaimer",
"Warning",
"Source",
"Comment",
]
# AI-related keywords for detection
AI_KEYWORDS = [
"prompt",
"negative_prompt",
"sampler",
"cfg_scale",
"lora",
"diffusion",
"comfy",
"midjourney",
"dall-e",
"dalle",
"imagen",
"firefly",
"c2pa",
"chatgpt",
"gpt-4",
"sora",
"openai",
"truepic",
"stable_diffusion",
"invokeai",
]
# C2PA (Coalition for Content Provenance and Authenticity) constants
# Used by Google Imagen, Adobe Firefly, Microsoft Designer, OpenAI, etc.
C2PA_CHUNK_TYPE = b"caBX" # JUMBF container chunk type for C2PA
C2PA_SIGNATURES = [
b"c2pa",
b"C2PA",
b"jumb",
b"jumd",
b"JUMBF",
b"jumbf",
b"cbor",
b"contentcreds",
b"digid",
b"assertions",
b"manifest",
]
# C2PA known issuers
C2PA_ISSUERS = {
b"Google": "Google LLC",
b"Adobe": "Adobe",
b"Microsoft": "Microsoft",
b"OpenAI": "OpenAI",
b"Truepic": "Truepic",
}
# C2PA known AI tools
C2PA_AI_TOOLS = {
b"GPT-4o": "GPT-4o",
b"ChatGPT": "ChatGPT",
b"Sora": "Sora",
b"DALL-E": "DALL-E",
b"DALL": "DALL-E",
b"Imagen": "Imagen",
b"Firefly": "Firefly",
}
# C2PA action types
C2PA_ACTIONS = {
b"c2pa.created": "created",
b"c2pa.converted": "converted",
b"c2pa.edited": "edited",
b"c2pa.filtered": "filtered",
b"c2pa.cropped": "cropped",
b"c2pa.resized": "resized",
}
# PNG signature
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
@@ -0,0 +1,18 @@
"""CtrlRegen watermark removal via controllable regeneration.
Implements the pipeline from "Image Watermarks Are Removable Using
Controllable Regeneration from Clean Noise" (ICLR 2025) by Liu et al.
This sub-package uses a ControlNet for spatial guidance (canny edges)
and a DINOv2-based IP Adapter for semantic guidance to regenerate
watermarked images from partially noised latents.
Attribution:
Based on https://github.com/yepengliu/CtrlRegen (Apache-2.0).
"""
from __future__ import annotations
from remove_ai_watermarks.noai.ctrlregen.engine import CtrlRegenEngine, is_ctrlregen_available
__all__ = ["CtrlRegenEngine", "is_ctrlregen_available"]
@@ -0,0 +1,40 @@
"""Color matching post-processing for CtrlRegen output.
After diffusion-based regeneration, the output image may have slight
color shifts. This module uses histogram-based color transfer to
align the regenerated image's color distribution back to the original.
Attribution:
Adapted from https://github.com/yepengliu/CtrlRegen (Apache-2.0).
"""
from __future__ import annotations
import numpy as np
from color_matcher import ColorMatcher
from color_matcher.normalizer import Normalizer
from PIL import Image
def color_match(reference: Image.Image, source: Image.Image) -> Image.Image:
"""Transfer the color distribution of *reference* onto *source*.
Uses a two-pass histogram matching approach (``hm-mkl-hm``) that
preserves fine-grained color relationships while correcting global
shifts introduced by the regeneration pipeline.
Args:
reference: The original (watermarked) image whose colors should
be preserved.
source: The regenerated image whose colors will be adjusted.
Returns:
A new PIL Image with the structure of *source* but the color
palette of *reference*.
"""
cm = ColorMatcher()
ref_np = Normalizer(np.asarray(reference)).type_norm()
src_np = Normalizer(np.asarray(source)).type_norm()
result = cm.transfer(src=src_np, ref=ref_np, method="hm-mkl-hm")
result = Normalizer(result).uint8_norm()
return Image.fromarray(result)
@@ -0,0 +1,365 @@
"""CtrlRegen engine — orchestrates the full watermark removal pipeline.
Loads the base SD 1.5 model with a ControlNet (spatial control from
canny edges) and a DINOv2-based IP Adapter (semantic control), then
runs controllable regeneration with optional color matching.
Attribution:
Based on https://github.com/yepengliu/CtrlRegen (Apache-2.0).
"""
from __future__ import annotations
import logging
import os
import sys
import time
from collections.abc import Callable
from typing import Any
import torch
from PIL import Image
from remove_ai_watermarks.noai.progress import make_pipeline_progress
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Availability checks — these imports are optional.
# ---------------------------------------------------------------------------
_HAS_CONTROLNET_AUX = False
_HAS_COLOR_MATCHER = False
_HAS_DIFFUSERS = False
try:
from ctrlregen.pipeline import CustomCtrlRegenPipeline
from diffusers import AutoencoderKL, ControlNetModel, UniPCMultistepScheduler
_HAS_DIFFUSERS = True
except ImportError:
AutoencoderKL = None # type: ignore[assignment,misc]
ControlNetModel = None # type: ignore[assignment,misc]
UniPCMultistepScheduler = None # type: ignore[assignment,misc]
CustomCtrlRegenPipeline = None # type: ignore[assignment,misc]
try:
from controlnet_aux import CannyDetector
_HAS_CONTROLNET_AUX = True
except ImportError:
CannyDetector = None # type: ignore[assignment,misc]
try:
from ctrlregen.color import color_match
_HAS_COLOR_MATCHER = True
except ImportError:
color_match = None # type: ignore[assignment]
CTRLREGEN_HF_REPO = "yepengliu/ctrlregen"
SPATIAL_SUBFOLDER = "spatialnet_ckp/spatial_control_ckp_14000"
SEMANTIC_SUBFOLDER = "semanticnet_ckp/models"
SEMANTIC_WEIGHT_NAME = "semantic_control_ckp_435000.bin"
DEFAULT_BASE_MODEL = "SG161222/Realistic_Vision_V4.0_noVAE"
CUSTOM_VAE_ID = "stabilityai/sd-vae-ft-mse"
PROCESS_SIZE = 512
DEFAULT_GUIDANCE_SCALE = 2.0
QUALITY_PROMPT = "best quality, high quality"
NEGATIVE_PROMPT = "monochrome, lowres, bad anatomy, worst quality, low quality"
CANNY_LOW_THRESHOLD = 100
CANNY_HIGH_THRESHOLD = 150
TILE_SIZE = 512
TILE_OVERLAP = 192
def is_ctrlregen_available() -> bool:
"""Return True when all CtrlRegen-specific dependencies are installed."""
return _HAS_DIFFUSERS and _HAS_CONTROLNET_AUX and _HAS_COLOR_MATCHER
class CtrlRegenEngine:
"""End-to-end CtrlRegen watermark removal engine.
Handles model loading, canny edge extraction, controlled denoising,
and color-matched post-processing in a single ``run()`` call.
"""
def __init__(
self,
base_model_id: str | None = None,
device: str = "cpu",
torch_dtype: torch.dtype | None = None,
hf_token: str | None = None,
progress_callback: Callable[[str], None] | None = None,
) -> None:
if not is_ctrlregen_available():
missing: list[str] = []
if not _HAS_DIFFUSERS:
missing.extend(["diffusers", "transformers", "accelerate"])
if not _HAS_CONTROLNET_AUX:
missing.append("controlnet-aux")
if not _HAS_COLOR_MATCHER:
missing.append("color-matcher")
logger.info("Auto-installing missing dependencies: %s", missing)
import subprocess
try:
subprocess.check_call(
[sys.executable, "-m", "pip", "install", *missing],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except (subprocess.CalledProcessError, FileNotFoundError):
raise ImportError(
"Failed to auto-install missing dependencies: "
+ ", ".join(missing)
+ ". Try manually: pip install --force-reinstall noai-watermark"
)
self.base_model_id = base_model_id or DEFAULT_BASE_MODEL
self.device = device
self.torch_dtype = torch_dtype or (torch.float32 if device in ("cpu", "mps") else torch.float16)
self.hf_token: str | None = hf_token or os.environ.get("HF_TOKEN")
self._progress_callback = progress_callback
self._pipeline: CustomCtrlRegenPipeline | None = None # type: ignore[assignment]
self._canny_detector: CannyDetector | None = None # type: ignore[assignment]
def _set_progress(self, message: str) -> None:
if self._progress_callback is None:
return
try:
self._progress_callback(message)
except Exception:
pass
# ------------------------------------------------------------------
# Loading
# ------------------------------------------------------------------
def load(self) -> None:
"""Download and assemble the full CtrlRegen pipeline."""
if self._pipeline is not None:
return
token_kwargs: dict[str, Any] = {}
if self.hf_token:
token_kwargs["token"] = self.hf_token
self._set_progress(f"Loading CtrlRegen spatial ControlNet from {CTRLREGEN_HF_REPO}...")
logger.info("Loading ControlNet from %s/%s", CTRLREGEN_HF_REPO, SPATIAL_SUBFOLDER)
controlnet = [
ControlNetModel.from_pretrained(
CTRLREGEN_HF_REPO,
subfolder=SPATIAL_SUBFOLDER,
torch_dtype=self.torch_dtype,
**token_kwargs,
)
]
self._set_progress(f"Loading SD base model ({self.base_model_id}) for CtrlRegen pipeline...")
logger.info("Loading base pipeline from %s", self.base_model_id)
pipe = CustomCtrlRegenPipeline.from_pretrained(
self.base_model_id,
controlnet=controlnet,
torch_dtype=self.torch_dtype,
safety_checker=None,
requires_safety_checker=False,
**token_kwargs,
)
self._set_progress(f"Loading CtrlRegen semantic IP-Adapter + DINOv2 from {CTRLREGEN_HF_REPO}...")
logger.info("Loading IP-Adapter from %s/%s", CTRLREGEN_HF_REPO, SEMANTIC_SUBFOLDER)
pipe.load_ctrlregen_ip_adapter(
CTRLREGEN_HF_REPO,
subfolder=SEMANTIC_SUBFOLDER,
weight_name=SEMANTIC_WEIGHT_NAME,
**token_kwargs,
)
from transformers import AutoImageProcessor, AutoModel
pipe.image_encoder = AutoModel.from_pretrained("facebook/dinov2-giant").to(self.device, dtype=self.torch_dtype)
pipe.feature_extractor = AutoImageProcessor.from_pretrained("facebook/dinov2-giant")
self._set_progress(f"Loading custom VAE ({CUSTOM_VAE_ID})...")
logger.info("Loading VAE from %s", CUSTOM_VAE_ID)
pipe.vae = AutoencoderKL.from_pretrained(
CUSTOM_VAE_ID,
torch_dtype=self.torch_dtype,
**token_kwargs,
)
self._set_progress("Configuring UniPC scheduler...")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.set_ip_adapter_scale(1.0)
self._set_progress(f"Moving CtrlRegen pipeline to {self.device}...")
pipe = pipe.to(self.device)
if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
self._pipeline = pipe
self._canny_detector = CannyDetector()
self._set_progress("CtrlRegen pipeline ready.")
logger.info("CtrlRegen pipeline loaded on %s", self.device)
# ------------------------------------------------------------------
# Inference — public entry point
# ------------------------------------------------------------------
def run(
self,
image: Image.Image,
strength: float = 0.5,
num_inference_steps: int = 50,
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
seed: int | None = None,
) -> Image.Image:
"""Run CtrlRegen watermark removal on a single image.
Images that fit within ``TILE_SIZE`` (512) are processed as a
single pass. Larger images are split into overlapping tiles.
"""
self.load()
assert self._pipeline is not None
assert self._canny_detector is not None
orig_w, orig_h = image.size
orig_image = image
t0 = time.monotonic()
needs_tiling = orig_w > TILE_SIZE or orig_h > TILE_SIZE
if needs_tiling:
from ctrlregen.tiling import resize_center_crop, run_tiled
aligned_w = orig_w // 8 * 8
aligned_h = orig_h // 8 * 8
if aligned_w != orig_w or aligned_h != orig_h:
image = image.resize((aligned_w, aligned_h), Image.LANCZOS)
regen_image = run_tiled(
pipeline=self._pipeline,
canny_detector=self._canny_detector,
image=image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
tile_size=TILE_SIZE,
tile_overlap=TILE_OVERLAP,
quality_prompt=QUALITY_PROMPT,
negative_prompt=NEGATIVE_PROMPT,
canny_low=CANNY_LOW_THRESHOLD,
canny_high=CANNY_HIGH_THRESHOLD,
device=self.device,
set_progress=self._set_progress,
ip_adapter_image=orig_image,
)
else:
from ctrlregen.tiling import resize_center_crop
proc_image = resize_center_crop(image, PROCESS_SIZE)
self._set_progress(f"Preprocessed {orig_w}x{orig_h}px → {proc_image.size[0]}x{proc_image.size[1]}px")
regen_image = self._run_single(
proc_image,
strength,
num_inference_steps,
guidance_scale,
seed,
)
if regen_image.size != (orig_w, orig_h):
self._set_progress(f"Resizing {regen_image.size[0]}x{regen_image.size[1]}px → {orig_w}x{orig_h}px...")
regen_image = regen_image.resize((orig_w, orig_h), Image.LANCZOS)
self._set_progress(f"Applying color matching at {orig_w}x{orig_h}px...")
output = color_match(reference=orig_image, source=regen_image)
self._set_progress(f"✓ CtrlRegen done · {orig_w}x{orig_h}px · {time.monotonic() - t0:.0f}s total")
return output
# ------------------------------------------------------------------
# Single-image path (image <= 512x512)
# ------------------------------------------------------------------
def _run_single(
self,
image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
seed: int | None,
) -> Image.Image:
"""Process a single 512x512 image through the CtrlRegen pipeline."""
w, h = image.size
effective_steps = max(1, int(num_inference_steps * strength))
self._set_progress(
f"Extracting canny edges ({w}x{h}px, thresholds {CANNY_LOW_THRESHOLD}/{CANNY_HIGH_THRESHOLD})..."
)
control_image = self._canny_detector(
image,
low_threshold=CANNY_LOW_THRESHOLD,
high_threshold=CANNY_HIGH_THRESHOLD,
)
generator = torch.manual_seed(seed if seed is not None else 0)
self._set_progress(
f"Config: strength={strength}, steps={num_inference_steps} "
f"(~{effective_steps} effective), guidance={guidance_scale}"
)
step_cb, first_step, pipeline_done, start_updater = make_pipeline_progress(
effective_steps,
self.device,
self._set_progress,
label="CtrlRegen denoising",
)
start_updater()
try:
result = self._pipeline(
prompt=QUALITY_PROMPT,
negative_prompt=NEGATIVE_PROMPT,
image=[image],
control_image=[control_image],
controlnet_conditioning_scale=1.0,
ip_adapter_image=[image],
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback=step_cb,
callback_steps=1,
)
except TypeError:
first_step.set()
result = self._pipeline(
prompt=QUALITY_PROMPT,
negative_prompt=NEGATIVE_PROMPT,
image=[image],
control_image=[control_image],
controlnet_conditioning_scale=1.0,
ip_adapter_image=[image],
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
)
finally:
first_step.set()
pipeline_done.set()
return result.images[0]
@@ -0,0 +1,148 @@
"""Custom IP-Adapter mixin using DINOv2 as the image encoder.
The standard diffusers ``IPAdapterMixin`` uses a CLIP image encoder.
CtrlRegen replaces it with ``facebook/dinov2-giant`` for richer
semantic features. This mixin provides ``load_ctrlregen_ip_adapter``
which handles the custom weight format and encoder swap.
Attribution:
Adapted from https://github.com/yepengliu/CtrlRegen (Apache-2.0).
"""
from __future__ import annotations
import logging
import torch
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from diffusers.utils import (
_get_model_file,
is_accelerate_available,
is_torch_version,
)
from diffusers.utils import (
logging as diffusers_logging,
)
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open
from transformers import AutoImageProcessor, AutoModel
logger = logging.getLogger(__name__)
_diffusers_logger = diffusers_logging.get_logger(__name__)
DINOV2_MODEL_ID = "facebook/dinov2-giant"
class CustomIPAdapterMixin:
"""Mixin that adds ``load_ctrlregen_ip_adapter`` to a diffusers pipeline."""
@validate_hf_hub_args
def load_ctrlregen_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
subfolder: str | list[str],
weight_name: str | list[str],
image_encoder_folder: str | None = "image_encoder",
**kwargs,
) -> None:
"""Load CtrlRegen IP-Adapter weights and DINOv2 image encoder.
Parameters mirror ``IPAdapterMixin.load_ip_adapter`` but the
image encoder is always ``facebook/dinov2-giant`` regardless of
the ``image_encoder_folder`` value in the checkpoint.
"""
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
_diffusers_logger.warning(
"Cannot initialize model with low cpu memory usage because "
"`accelerate` was not found. Defaulting to "
"`low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError("Low memory initialization requires torch >= 1.9.0.")
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts: list[dict] = []
for path_or_dict, wn, sf in zip(pretrained_model_name_or_path_or_dict, weight_name, subfolder):
if not isinstance(path_or_dict, dict):
model_file = _get_model_file(
path_or_dict,
weights_name=wn,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=sf,
user_agent=user_agent,
)
if wn.endswith(".safetensors"):
state_dict: dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# Always use DINOv2-giant as the image encoder.
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
logger.info("Loading DINOv2-giant image encoder for CtrlRegen")
enc_dtype = getattr(self, "dtype", torch.float32) # type: ignore[attr-defined]
image_encoder = AutoModel.from_pretrained(DINOV2_MODEL_ID).to(
self.device,
dtype=enc_dtype, # type: ignore[attr-defined]
)
self.register_modules(image_encoder=image_encoder) # type: ignore[attr-defined]
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
feature_extractor = AutoImageProcessor.from_pretrained(DINOV2_MODEL_ID)
self.register_modules(feature_extractor=feature_extractor) # type: ignore[attr-defined]
unet = (
getattr(self, self.unet_name) # type: ignore[attr-defined]
if not hasattr(self, "unet")
else self.unet # type: ignore[attr-defined]
)
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
@@ -0,0 +1,37 @@
"""Custom Stable Diffusion ControlNet Img2Img pipeline for CtrlRegen.
Extends ``StableDiffusionControlNetImg2ImgPipeline`` with the
``load_ctrlregen_ip_adapter`` method (via ``CustomIPAdapterMixin``)
that swaps in DINOv2-giant as the image encoder and loads the
CtrlRegen semantic-control adapter weights.
No ``encode_image`` override is needed — the CtrlRegen checkpoint
creates an ``IPAdapterPlusImageProjection`` which tells diffusers to
call ``encode_image`` with ``output_hidden_states=True``. The
default implementation then uses ``hidden_states[-2]`` from DINOv2,
which is exactly what the projection was trained on.
Attribution:
Adapted from https://github.com/yepengliu/CtrlRegen (Apache-2.0).
"""
from __future__ import annotations
from diffusers import StableDiffusionControlNetImg2ImgPipeline
from remove_ai_watermarks.noai.ctrlregen.ip_adapter import CustomIPAdapterMixin
class CustomCtrlRegenPipeline(
StableDiffusionControlNetImg2ImgPipeline,
CustomIPAdapterMixin,
):
"""SD ControlNet Img2Img pipeline with DINOv2 IP-Adapter support.
MRO mirrors the original CtrlRegen repository: the base diffusers
pipeline comes first so all standard methods are resolved from it,
while ``CustomIPAdapterMixin`` only adds the
``load_ctrlregen_ip_adapter`` method.
"""
pass
@@ -0,0 +1,175 @@
"""Tile-based processing for large images in the CtrlRegen pipeline.
Extracted from ``ctrlregen.engine`` to keep the engine focused on
single-image inference and model orchestration.
"""
from __future__ import annotations
import math
import time
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
from PIL import Image
def tile_positions(total: int, tile: int, overlap: int) -> list[int]:
"""Compute evenly-spaced tile start positions covering *total* pixels."""
if total <= tile:
return [0]
n = max(2, math.ceil((total - overlap) / (tile - overlap)))
stride = (total - tile) / (n - 1)
return [round(i * stride) for i in range(n)]
def make_blend_weight(h: int, w: int, overlap: int) -> np.ndarray:
"""2-D weight mask: 1.0 in center, cosine ramp in overlap margins."""
wy = np.ones(h, dtype=np.float64)
wx = np.ones(w, dtype=np.float64)
if overlap > 0:
ramp = 0.5 - 0.5 * np.cos(np.linspace(0, np.pi, overlap))
wy[:overlap] = np.minimum(wy[:overlap], ramp)
wy[-overlap:] = np.minimum(wy[-overlap:], ramp[::-1])
wx[:overlap] = np.minimum(wx[:overlap], ramp)
wx[-overlap:] = np.minimum(wx[-overlap:], ramp[::-1])
return np.outer(wy, wx)
def resize_center_crop(image: Image.Image, size: int = 512) -> Image.Image:
"""Resize shortest edge to *size*, then center-crop to a square.
Matches the ``transforms.Resize(512) + CenterCrop(512)`` pipeline
used in the original CtrlRegen repository.
"""
w, h = image.size
short = min(w, h)
scale = size / short
new_w, new_h = round(w * scale), round(h * scale)
image = image.resize((new_w, new_h), Image.BILINEAR)
left = (new_w - size) // 2
top = (new_h - size) // 2
return image.crop((left, top, left + size, top + size))
def run_tiled(
pipeline: Any,
canny_detector: Any,
image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
seed: int | None,
*,
tile_size: int,
tile_overlap: int,
quality_prompt: str,
negative_prompt: str,
canny_low: int,
canny_high: int,
device: str,
set_progress: Callable[[str], None],
ip_adapter_image: Image.Image | None = None,
) -> Image.Image:
"""Split a large image into overlapping tiles, process each, blend."""
w, h = image.size
xs = tile_positions(w, tile_size, tile_overlap)
ys = tile_positions(h, tile_size, tile_overlap)
n_tiles = len(xs) * len(ys)
grid = f"{len(xs)}x{len(ys)}"
effective_steps = max(1, int(num_inference_steps * strength))
set_progress(f"Tiling {w}x{h}px → {n_tiles} tiles ({grid} grid, {tile_size}px, overlap {tile_overlap}px)")
canvas = np.zeros((h, w, 3), dtype=np.float64)
weight_sum = np.zeros((h, w), dtype=np.float64)
blend_w = make_blend_weight(tile_size, tile_size, tile_overlap)
t0 = time.monotonic()
bar_len = 20
tile_idx = 0
for ty in ys:
for tx in xs:
tile_idx += 1
prefix = f"[Tile {tile_idx}/{n_tiles}]"
tile = image.crop((tx, ty, tx + tile_size, ty + tile_size))
set_progress(f"{prefix} Extracting canny edges...")
control = canny_detector(
tile,
low_threshold=canny_low,
high_threshold=canny_high,
)
gen = None
if seed is not None:
gen = torch.Generator(device=device).manual_seed(seed + tile_idx)
tile_t0 = time.monotonic()
def _make_cb(
_prefix: str = prefix,
_t0: float = tile_t0,
_es: int = effective_steps,
) -> Callable:
def _cb(step: int, timestep: int, latents: Any) -> None: # noqa: ARG001
elapsed = time.monotonic() - _t0
cur = step + 1
per = elapsed / max(1, cur)
rem = per * max(0, _es - cur)
filled = int(bar_len * cur / max(1, _es))
bar = "" * filled + "" * (bar_len - filled)
set_progress(f"{_prefix} [{bar}] {cur}/{_es} | {elapsed:.0f}s, ~{rem:.0f}s left")
return _cb
sem_image = ip_adapter_image if ip_adapter_image is not None else tile
try:
result = pipeline(
prompt=quality_prompt,
negative_prompt=negative_prompt,
image=[tile],
control_image=[control],
controlnet_conditioning_scale=1.0,
ip_adapter_image=[sem_image],
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=gen,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback=_make_cb(),
callback_steps=1,
)
except TypeError:
result = pipeline(
prompt=quality_prompt,
negative_prompt=negative_prompt,
image=[tile],
control_image=[control],
controlnet_conditioning_scale=1.0,
ip_adapter_image=[sem_image],
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=gen,
)
proc_arr = np.array(result.images[0], dtype=np.float64)
th, tw = proc_arr.shape[:2]
mask = blend_w[:th, :tw]
canvas[ty : ty + th, tx : tx + tw] += proc_arr * mask[..., None]
weight_sum[ty : ty + th, tx : tx + tw] += mask
tile_time = time.monotonic() - tile_t0
total_elapsed = time.monotonic() - t0
set_progress(f"{prefix} Done ({tile_time:.0f}s) · Total: {total_elapsed:.0f}s")
set_progress(f"Blending {n_tiles} tiles → {w}x{h}px...")
canvas /= np.maximum(weight_sum[..., None], 1e-8)
return Image.fromarray(np.clip(canvas, 0, 255).astype(np.uint8))
+153
View File
@@ -0,0 +1,153 @@
"""Read-only metadata extraction from PNG and JPEG images.
Provides functions to pull all metadata, AI-only metadata, or a
human-readable summary without modifying the source file.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import piexif
from PIL import Image
from remove_ai_watermarks.noai.c2pa import extract_c2pa_chunk, extract_c2pa_info, has_c2pa_metadata
from remove_ai_watermarks.noai.constants import AI_KEYWORDS, AI_METADATA_KEYS, PNG_METADATA_KEYS
def extract_metadata(source_path: Path) -> dict[str, Any]:
"""
Extract all metadata from a PNG or JPG file.
Args:
source_path: Path to the source image file.
Returns:
Dictionary containing all extracted metadata.
"""
metadata: dict[str, Any] = {}
with Image.open(source_path) as img:
# Extract EXIF data
if "exif" in img.info:
try:
exif_dict = piexif.load(img.info["exif"])
metadata["exif"] = exif_dict
except Exception:
metadata["exif_raw"] = img.info["exif"]
# Extract standard PNG metadata
for key in PNG_METADATA_KEYS:
if key in img.info:
metadata[key] = img.info[key]
# Extract all other metadata including AI-specific
for key, value in img.info.items():
if key not in metadata and key not in ["exif"]:
metadata[key] = value
# Extract DPI and gamma if present
if "dpi" in img.info:
metadata["dpi"] = img.info["dpi"]
if "gamma" in img.info:
metadata["gamma"] = img.info["gamma"]
# Check for C2PA metadata
if has_c2pa_metadata(source_path):
metadata["c2pa"] = extract_c2pa_info(source_path)
c2pa_chunk = extract_c2pa_chunk(source_path)
if c2pa_chunk:
metadata["c2pa_chunk"] = c2pa_chunk
return metadata
def extract_ai_metadata(source_path: Path) -> dict[str, Any]:
"""
Extract only AI-generated metadata from a PNG or JPG file.
Args:
source_path: Path to the source image file.
Returns:
Dictionary containing only AI-related metadata.
"""
ai_metadata: dict[str, Any] = {}
with Image.open(source_path) as img:
for key in AI_METADATA_KEYS:
if key in img.info:
ai_metadata[key] = img.info[key]
for key, value in img.info.items():
key_lower = key.lower()
if key not in ai_metadata:
if any(kw in key_lower for kw in AI_KEYWORDS):
ai_metadata[key] = value
# Check for C2PA metadata
if has_c2pa_metadata(source_path):
ai_metadata["c2pa"] = extract_c2pa_info(source_path)
c2pa_chunk = extract_c2pa_chunk(source_path)
if c2pa_chunk:
ai_metadata["c2pa_chunk"] = c2pa_chunk
return ai_metadata
def has_ai_metadata(image_path: Path) -> bool:
"""
Check if an image contains AI-generated metadata.
Args:
image_path: Path to the image file.
Returns:
True if AI metadata is detected, False otherwise.
"""
with Image.open(image_path) as img:
for key in AI_METADATA_KEYS:
if key in img.info:
return True
if has_c2pa_metadata(image_path):
return True
return False
def get_ai_metadata_summary(source_path: Path) -> str:
"""
Get a human-readable summary of AI metadata.
Args:
source_path: Path to the source image file.
Returns:
Formatted string with AI metadata summary.
"""
ai_meta = extract_ai_metadata(source_path)
if not ai_meta:
return "No AI metadata found."
lines = ["AI Image Metadata:"]
lines.append("-" * 40)
for key, value in ai_meta.items():
if key == "c2pa_chunk":
continue
elif key == "c2pa" and isinstance(value, dict):
lines.append("C2PA Metadata:")
for ck, cv in value.items():
lines.append(f" {ck}: {cv}")
elif isinstance(value, str) and len(value) > 100:
value = value[:100] + "..."
lines.append(f"{key}: {value}")
elif isinstance(value, bytes):
lines.append(f"{key}: <binary data ({len(value)} bytes)>")
else:
lines.append(f"{key}: {value}")
return "\n".join(lines)
@@ -0,0 +1,152 @@
"""Img2img pipeline execution with progress monitoring and MPS fallback.
Extracted from ``watermark_remover.py`` to keep the ``WatermarkRemover``
class focused on orchestration.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
from PIL import Image
from remove_ai_watermarks.noai.progress import is_mps_error, make_pipeline_progress
logger = logging.getLogger(__name__)
def run_img2img(
pipeline: Any,
image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
generator: Any,
device: str,
set_progress: Callable[[str], None],
) -> Image.Image:
"""Execute img2img with live progress and return the generated image."""
w, h = image.size
effective_steps = max(1, int(num_inference_steps * strength))
step_cb, first_step, done_ev, start_updater = make_pipeline_progress(
effective_steps,
device,
set_progress,
)
start_updater()
try:
result = _call_pipeline(
pipeline,
image,
strength,
num_inference_steps,
guidance_scale,
generator,
step_cb,
)
done_ev.set()
return result.images[0]
except TypeError:
first_step.set()
result = _call_pipeline(
pipeline,
image,
strength,
num_inference_steps,
guidance_scale,
generator,
None,
)
done_ev.set()
return result.images[0]
finally:
first_step.set()
done_ev.set()
def run_img2img_with_mps_fallback(
load_pipeline: Callable[[], Any],
image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
generator: Any,
device: str,
set_progress: Callable[[str], None],
*,
reload_on_cpu: Callable[[], Any],
) -> tuple[Image.Image, str]:
"""Run img2img; on MPS error, fall back to CPU.
Returns:
(result_image, final_device) — device may change to ``"cpu"`` on fallback.
"""
pipeline = load_pipeline()
try:
img = run_img2img(
pipeline,
image,
strength,
num_inference_steps,
guidance_scale,
generator,
device,
set_progress,
)
return img, device
except RuntimeError as error:
if device == "mps" and is_mps_error(error):
logger.warning("MPS error detected: %s. Falling back to CPU.", error)
set_progress("MPS error! Clearing cache and retrying on CPU...")
_try_clear_mps_cache()
pipeline = reload_on_cpu()
img = run_img2img(
pipeline,
image,
strength,
num_inference_steps,
guidance_scale,
None,
"cpu",
set_progress,
)
return img, "cpu"
raise
def _call_pipeline(
pipeline: Any,
image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
generator: Any,
step_callback: Any,
) -> Any:
kwargs: dict[str, Any] = {
"prompt": "",
"image": image,
"strength": strength,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": generator,
}
if step_callback is not None:
kwargs["callback"] = step_callback
kwargs["callback_steps"] = 1
return pipeline(**kwargs)
def _try_clear_mps_cache() -> None:
try:
import torch
if hasattr(torch, "mps"):
torch.mps.empty_cache() # type: ignore[attr-defined]
except Exception:
pass
+333
View File
@@ -0,0 +1,333 @@
"""Terminal progress animation and library output suppression.
This module provides two main capabilities for the CLI:
1. ``run_with_progress`` — a styled two-line terminal animation that
displays a bouncing highlight bar, a braille spinner, elapsed time,
and a live operation message while a background task executes.
2. ``silence_library_output`` — a wrapper that suppresses noisy log
output produced by third-party ML libraries (transformers, diffusers,
huggingface_hub, tqdm) so the user only sees our own progress messages.
"""
from __future__ import annotations
import contextlib
import io
import os
import sys
import threading
import time
import warnings
from collections.abc import Callable
from typing import Any
# ── ANSI color constants ────────────────────────────────────────────
_CYAN = "\033[36m"
_YELLOW = "\033[33m"
_GREEN = "\033[32m"
_DIM = "\033[2m"
_BOLD = "\033[1m"
_RESET = "\033[0m"
# Bar geometry
_BAR_WIDTH = 32
_HIGHLIGHT_WIDTH = 5
def _no_color() -> bool:
"""Respect the NO_COLOR convention (https://no-color.org/)."""
return bool(os.environ.get("NO_COLOR"))
def _truncate(text: str, max_len: int = 72) -> str:
"""Shorten a string with an ellipsis if it exceeds *max_len*."""
return text if len(text) <= max_len else text[: max_len - 1] + ""
def _build_bar(step: int) -> str:
"""Build a flowing highlight bar that bounces across the width.
The highlight segment (5 chars wide) travels left→right→left
continuously, giving the user a visual "working" signal.
"""
cycle = _BAR_WIDTH * 2 - 2
pos = step % cycle
if pos >= _BAR_WIDTH:
pos = cycle - pos
hl_start = max(0, pos - _HIGHLIGHT_WIDTH // 2)
hl_end = min(_BAR_WIDTH, pos + _HIGHLIGHT_WIDTH // 2 + 1)
before = "" * hl_start
highlight = "" * (hl_end - hl_start)
after = "" * (_BAR_WIDTH - hl_end)
if _no_color():
return before + highlight + after
return f"{_DIM}{before}{_RESET}{_BOLD}{_YELLOW}{highlight}{_RESET}{_DIM}{after}{_RESET}"
def run_with_progress(
task: Callable[[], Any],
progress_state: dict[str, str] | None = None,
) -> Any:
"""Execute *task* in a background thread while showing a progress animation.
The animation renders two lines to ``sys.__stderr__``:
- **Line 1**: braille spinner + bouncing bar + elapsed seconds
- **Line 2**: current operation message from *progress_state*
When the task finishes, a green "Completed" line replaces the animation.
Args:
task: A zero-argument callable to run in the background.
progress_state: Mutable dict whose ``"message"`` key is read
by the animation loop to display the current operation.
Returns:
Whatever *task* returns.
Raises:
Any exception raised by *task* is re-raised after the animation
is cleaned up.
"""
done = threading.Event()
output_holder: dict[str, Any] = {"result": None, "error": None}
def worker() -> None:
try:
output_holder["result"] = task()
except Exception as error: # pragma: no cover passthrough
output_holder["error"] = error
finally:
done.set()
thread = threading.Thread(target=worker, daemon=True)
thread.start()
spinner_frames = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
idx = 0
start_time = time.time()
no_color = _no_color()
def _get_operation() -> str:
if isinstance(progress_state, dict):
return progress_state.get("message", "Processing...")
return "Processing..."
# ── Animation loop ──────────────────────────────────────────────
while not done.is_set():
spinner = spinner_frames[idx % len(spinner_frames)]
elapsed = int(time.time() - start_time)
bar_str = _build_bar(idx)
operation = _truncate(_get_operation())
if no_color:
line1 = f" {spinner} Processing {bar_str} {elapsed:>3}s"
line2 = f" ╰─ {operation}"
else:
line1 = f" {_CYAN}{spinner}{_RESET} Processing {bar_str} {_BOLD}{_YELLOW}{elapsed:>3}s{_RESET}"
line2 = f" {_DIM}╰─ {operation}{_RESET}"
print(
f"\r\033[2K{line1}\n\033[2K{line2}\033[1A\r",
end="",
flush=True,
file=sys.__stderr__,
)
time.sleep(0.08)
idx += 1
# ── Final "done" frame ──────────────────────────────────────────
thread.join()
total = int(time.time() - start_time)
final_operation = _truncate(_get_operation())
done_bar = "" * _BAR_WIDTH
if no_color:
final_line1 = f" ✓ Completed {done_bar} {total:>3}s"
final_line2 = f" ╰─ {final_operation}"
else:
final_line1 = (
f" {_GREEN}{_BOLD}{_RESET} {_GREEN}Completed{_RESET} "
f"{_GREEN}{done_bar}{_RESET} {_BOLD}{_GREEN}{total:>3}s{_RESET}"
)
final_line2 = f" {_DIM}╰─ {final_operation}{_RESET}"
print(
f"\r\033[2K{final_line1}\n\033[2K{final_line2}",
file=sys.__stderr__,
)
if output_holder["error"] is not None:
raise output_holder["error"]
return output_holder["result"]
def silence_library_output(
run_func: Callable[[], Any],
set_progress: Callable[[str], None] | None = None,
) -> Callable[[], Any]:
"""Return a wrapper that silences noisy ML library output.
The wrapper:
1. Disables HuggingFace Hub progress bars via env var.
2. Sets ``transformers``, ``diffusers``, and ``huggingface_hub``
loggers to *error* level.
3. Redirects ``stdout`` and ``stderr`` to ``io.StringIO`` sinks so
that stray ``tqdm`` bars and model-loading chatter are invisible.
4. Suppresses all Python warnings during the call.
Args:
run_func: The callable to execute silently.
set_progress: Optional callback to report phase changes.
Returns:
A zero-argument callable that, when invoked, runs *run_func*
inside the silent context.
"""
def wrapped() -> Any:
if set_progress:
set_progress("Configuring runtime and suppressing noisy logs...")
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
for _silence in (
lambda: __import__("transformers").logging.set_verbosity_error(),
lambda: _silence_diffusers(),
lambda: __import__("huggingface_hub").logging.set_verbosity_error(),
):
try:
_silence()
except Exception:
pass
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with contextlib.redirect_stdout(io.StringIO()):
with contextlib.redirect_stderr(io.StringIO()):
if set_progress:
set_progress("Executing watermark removal pipeline...")
return run_func()
return wrapped
def _silence_diffusers() -> None:
"""Silence diffusers logging and progress bars."""
from diffusers.utils import logging as diffusers_logging
diffusers_logging.set_verbosity_error()
if hasattr(diffusers_logging, "disable_progress_bar"):
diffusers_logging.disable_progress_bar()
# ── Shared pipeline progress helpers ─────────────────────────────────
_DEFAULT_PRE_PHASES: list[tuple[int, str]] = [
(0, "Encoding image with VAE encoder"),
(3, "Mapping pixel data → latent space"),
(7, "Injecting noise into latent representation"),
(12, "Building denoiser schedule"),
(18, "Starting reverse diffusion sampler"),
(30, "Running first denoising iteration"),
(50, "Still processing — this can take a while"),
(90, "Pipeline running — may take a few minutes"),
]
_DEFAULT_POST_PHASES: list[tuple[int, str]] = [
(0, "Denoising complete · Running VAE decoder"),
(2, "Decoding latent channels → RGB color space"),
(5, "Reconstructing pixel grid from latents"),
(10, "Applying color space conversion and normalization"),
(18, "Finalizing pixel output"),
(30, "Still decoding — large images take longer"),
(60, "Almost done — large images take longer to decode"),
]
def make_pipeline_progress(
effective_steps: int,
device: str,
set_progress: Callable[[str], None],
*,
bar_len: int = 20,
label: str = "Denoising",
pre_phases: list[tuple[int, str]] | None = None,
post_phases: list[tuple[int, str]] | None = None,
) -> tuple[Callable, threading.Event, threading.Event, Callable[[], threading.Thread]]:
"""Create step callback and background updater for a diffusion pipeline.
Returns:
(step_callback, first_step_event, pipeline_done_event, start_updater)
where ``start_updater()`` launches and returns the background thread.
"""
pre = pre_phases or [(s, f"{m} on {device}") for s, m in _DEFAULT_PRE_PHASES]
post = post_phases or [(s, f"{m} on {device}") for s, m in _DEFAULT_POST_PHASES]
t0_holder: list[float] = [time.monotonic()]
first_step = threading.Event()
pipeline_done = threading.Event()
last_cb_time: list[float] = [t0_holder[0]]
def _background_updater() -> None:
idx = 0
while not first_step.is_set():
elapsed = time.monotonic() - t0_holder[0]
while idx < len(pre) - 1 and elapsed >= pre[idx + 1][0]:
idx += 1
set_progress(pre[idx][1])
first_step.wait(timeout=0.4)
idx = 0
post_start: float | None = None
while not pipeline_done.is_set():
since_cb = time.monotonic() - last_cb_time[0]
if since_cb >= 1.5:
if post_start is None:
post_start = time.monotonic()
elapsed = time.monotonic() - post_start
while idx < len(post) - 1 and elapsed >= post[idx + 1][0]:
idx += 1
set_progress(post[idx][1])
else:
post_start = None
idx = 0
pipeline_done.wait(timeout=0.4)
def step_callback(step: int, timestep: int, latents: Any) -> None: # noqa: ARG001
first_step.set()
last_cb_time[0] = time.monotonic()
elapsed = time.monotonic() - t0_holder[0]
current = step + 1
per_step = elapsed / max(1, current)
remaining = per_step * max(0, effective_steps - current)
filled = int(bar_len * current / max(1, effective_steps))
bar = "" * filled + "" * (bar_len - filled)
set_progress(
f"{label} [{bar}] {current}/{effective_steps} | {elapsed:.0f}s elapsed, ~{remaining:.0f}s left | {device}"
)
def start_updater() -> threading.Thread:
t0_holder[0] = time.monotonic()
last_cb_time[0] = t0_holder[0]
first_step.clear()
pipeline_done.clear()
t = threading.Thread(target=_background_updater, daemon=True)
t.start()
return t
return step_callback, first_step, pipeline_done, start_updater
# ── MPS fallback helper ──────────────────────────────────────────────
def is_mps_error(error: Exception) -> bool:
"""Check whether an exception is an MPS-related runtime error."""
return "mps" in str(error).lower()
+40
View File
@@ -0,0 +1,40 @@
"""Low-level utility helpers used across the metadata pipeline.
Kept deliberately small — only format detection lives here so that
higher-level modules can import without circular dependencies.
"""
from __future__ import annotations
from pathlib import Path
from remove_ai_watermarks.noai.constants import SUPPORTED_FORMATS
def is_supported_format(file_path: Path) -> bool:
"""
Check if the file format is supported.
Args:
file_path: Path to the image file.
Returns:
True if the format is supported, False otherwise.
"""
return file_path.suffix.lower() in SUPPORTED_FORMATS
def get_image_format(file_path: Path) -> str:
"""
Get the image format from file path.
Args:
file_path: Path to the image file.
Returns:
Format string (PNG, JPEG, etc.).
"""
suffix = file_path.suffix.lower()
if suffix in {".jpg", ".jpeg"}:
return "JPEG"
return "PNG"
@@ -0,0 +1,51 @@
"""Watermark removal model profiles, strength presets, and profile detection.
Pure configuration and lookup functions with no ML dependencies.
"""
from __future__ import annotations
DEFAULT_MODEL_ID = "Lykon/dreamshaper-8"
CTRLREGEN_MODEL_ID = "yepengliu/ctrlregen"
LOW_STRENGTH = 0.04
MEDIUM_STRENGTH = 0.35
HIGH_STRENGTH = 0.7
_HIGH_PERTURBATION = ("stegasamp", "stegastamp", "treering", "ringid")
_LOW_PERTURBATION = ("stablesignature", "dwtectsvd", "rivagan", "ssl", "hidden")
def get_model_id_for_profile(profile: str) -> str:
"""Map CLI model profile names to concrete Hugging Face model IDs."""
normalized = profile.strip().lower()
if normalized == "default":
return DEFAULT_MODEL_ID
if normalized == "ctrlregen":
return CTRLREGEN_MODEL_ID
raise ValueError(f"Unknown model profile '{profile}'. Use one of: default, ctrlregen.")
def detect_model_profile(model_id: str) -> str:
"""Infer model profile from model identifier."""
if "ctrlregen" in model_id.lower():
return "ctrlregen"
return "default"
def get_recommended_strength(watermark_type: str) -> float:
"""Get recommended strength for different watermark types.
Args:
watermark_type: Type of watermark. One of: 'low', 'medium', 'high',
or specific names like 'stegastamp', 'treering', etc.
Returns:
Recommended strength value.
"""
wt = watermark_type.lower()
if any(name in wt for name in _HIGH_PERTURBATION):
return HIGH_STRENGTH
if any(name in wt for name in _LOW_PERTURBATION):
return LOW_STRENGTH
return MEDIUM_STRENGTH
@@ -0,0 +1,635 @@
"""Watermark removal using diffusion model regeneration attack.
Based on the paper "Image Watermarks Are Removable Using Controllable
Regeneration from Clean Noise" (ICLR 2025).
This module implements a simple regeneration attack that:
1. Encodes the watermarked image to latent space
2. Adds noise via forward diffusion process
3. Denoises via reverse diffusion process
4. Decodes back to pixel space
"""
from __future__ import annotations
import logging
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any
from PIL import Image
from remove_ai_watermarks.noai.watermark_profiles import (
CTRLREGEN_MODEL_ID,
DEFAULT_MODEL_ID,
HIGH_STRENGTH,
LOW_STRENGTH,
MEDIUM_STRENGTH,
detect_model_profile,
)
logger = logging.getLogger(__name__)
# Check for optional dependencies
_HAS_TORCH = False
_HAS_DIFFUSERS = False
try:
import torch
_HAS_TORCH = True
except ImportError:
torch = None # type: ignore
try:
from diffusers import AutoPipelineForImage2Image as AutoImg2ImgPipeline
_HAS_DIFFUSERS = True
except ImportError:
AutoImg2ImgPipeline = None # type: ignore
def is_watermark_removal_available() -> bool:
"""Check if watermark removal dependencies are installed."""
return _HAS_TORCH and _HAS_DIFFUSERS
_CUDA_FIX_ENV_KEY = "NOAI_CUDA_FIXED"
def _auto_install(packages: list[str], index_url: str | None = None) -> bool:
"""Attempt to install missing packages via pip. Returns True on success."""
import subprocess
cmd = [sys.executable, "-m", "pip", "install", "-q", *packages]
if index_url:
cmd.extend(["--index-url", index_url])
try:
subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def _has_nvidia_gpu() -> bool:
"""Check if an NVIDIA GPU is present via nvidia-smi."""
import subprocess
try:
subprocess.check_call(
["nvidia-smi"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def _detect_cuda_index_url() -> str:
"""Detect the appropriate PyTorch CUDA index URL from nvidia-smi output."""
import subprocess
try:
out = subprocess.check_output(
["nvidia-smi"],
stderr=subprocess.DEVNULL,
text=True,
)
for line in out.splitlines():
if "CUDA Version" in line:
version_str = line.split("CUDA Version:")[-1].strip().rstrip("|").strip()
major, minor = version_str.split(".")[:2]
cuda_tag = f"cu{major}{minor}"
return f"https://download.pytorch.org/whl/{cuda_tag}"
except Exception:
pass
return "https://download.pytorch.org/whl/cu121"
def _reinstall_torch_cuda_and_restart() -> None:
"""Reinstall torch with CUDA support showing live progress, then restart."""
import re
import subprocess
from remove_ai_watermarks.noai.progress import run_with_progress
index_url = _detect_cuda_index_url()
progress_state: dict[str, str] = {"message": "NVIDIA GPU detected — installing CUDA-enabled PyTorch..."}
pct_re = re.compile(r"(\d+)%")
pkg_re = re.compile(r"(?:Collecting|Downloading|Installing)\s+(\S+)")
def _run_pip() -> bool:
cmd = [
sys.executable,
"-m",
"pip",
"install",
"--force-reinstall",
"torch",
"--index-url",
index_url,
]
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
for line in iter(proc.stdout.readline, ""): # type: ignore[union-attr]
stripped = line.strip()
if not stripped:
continue
pkg_m = pkg_re.search(stripped)
pct_m = pct_re.search(stripped)
if pct_m and pkg_m:
progress_state["message"] = f"Downloading {pkg_m.group(1)} ({pct_m.group(1)}%)"
elif pct_m:
progress_state["message"] = f"Downloading CUDA packages ({pct_m.group(1)}%)"
elif pkg_m:
action = "Installing" if stripped.startswith("Installing") else "Downloading"
progress_state["message"] = f"{action} {pkg_m.group(1)}"
elif "Successfully installed" in stripped:
progress_state["message"] = "CUDA-enabled PyTorch installed successfully"
proc.wait()
return proc.returncode == 0
try:
success = run_with_progress(_run_pip, progress_state)
except Exception:
success = False
if not success:
print(
f"\n Failed to install CUDA-enabled PyTorch.\n"
f" Install manually:\n"
f" pip install torch --index-url {index_url}\n",
file=sys.stderr,
)
return
os.environ[_CUDA_FIX_ENV_KEY] = "1"
restart_code = f"import sys; sys.argv = {sys.argv!r}; from remove_ai_watermarks.cli import main; sys.exit(main())"
os.execl(sys.executable, sys.executable, "-c", restart_code)
def _ensure_watermark_deps() -> None:
"""Auto-install and re-import missing watermark removal dependencies."""
global _HAS_TORCH, _HAS_DIFFUSERS, torch, AutoImg2ImgPipeline
missing_pkgs: list[str] = []
if not _HAS_TORCH:
missing_pkgs.append("torch")
if not _HAS_DIFFUSERS:
missing_pkgs.extend(["diffusers", "transformers", "accelerate"])
logger.info("Auto-installing missing dependencies: %s", missing_pkgs)
if not _auto_install(missing_pkgs):
raise ImportError(
f"Failed to auto-install missing dependencies: {', '.join(missing_pkgs)}. "
"Try manually: pip install --force-reinstall noai-watermark"
)
import torch as _torch
torch = _torch
_HAS_TORCH = True
from diffusers import AutoPipelineForImage2Image # noqa: N813
AutoImg2ImgPipeline = AutoPipelineForImage2Image # noqa: N806
_HAS_DIFFUSERS = True
def get_device() -> str:
"""Get the best available device for inference."""
if not _HAS_TORCH:
return "cpu"
if torch.cuda.is_available(): # type: ignore
try:
t = torch.tensor([1.0], device="cuda")
_ = t + t
del t
return "cuda"
except (AssertionError, RuntimeError):
pass
if _has_nvidia_gpu() and not os.environ.get(_CUDA_FIX_ENV_KEY):
_reinstall_torch_cuda_and_restart()
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
# Keep legacy name available for backwards compatibility
_detect_model_profile_from_id = detect_model_profile
class WatermarkRemover:
"""Remove watermarks from images using diffusion model regeneration.
Attributes:
model_id: HuggingFace model ID for the diffusion model.
device: Device to run inference on (cuda, mps, or cpu).
"""
DEFAULT_MODEL_ID = DEFAULT_MODEL_ID
CTRLREGEN_MODEL_ID = CTRLREGEN_MODEL_ID
LOW_STRENGTH = LOW_STRENGTH
MEDIUM_STRENGTH = MEDIUM_STRENGTH
HIGH_STRENGTH = HIGH_STRENGTH
def __init__(
self,
model_id: str | None = None,
device: str | None = None,
torch_dtype: Any = None,
progress_callback: Callable[[str], None] | None = None,
hf_token: str | None = None,
):
self.model_id = model_id or self.DEFAULT_MODEL_ID
self.model_profile = detect_model_profile(self.model_id)
if not is_watermark_removal_available():
_ensure_watermark_deps()
self.device = (device or get_device()).lower()
if self.device == "auto":
self.device = get_device()
if self.device not in {"cpu", "mps", "cuda"}:
raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda.")
if torch_dtype is None:
if self.device == "cpu" or self.device == "mps":
self.torch_dtype = torch.float32 # type: ignore
else:
self.torch_dtype = torch.float16 # type: ignore
else:
self.torch_dtype = torch_dtype
self._pipeline: AutoImg2ImgPipeline | None = None
self._ctrlregen_engine: Any = None
self._progress_callback = progress_callback
self.hf_token: str | None = hf_token or os.environ.get("HF_TOKEN")
def _set_progress(self, message: str) -> None:
"""Send a progress update through callback when available."""
if self._progress_callback is None:
return
try:
self._progress_callback(message)
except Exception:
pass
# ── Preload ──────────────────────────────────────────────────────
def preload(self) -> None:
"""Eagerly load the pipeline so download progress bars are visible."""
if self.model_profile == "ctrlregen":
self._run_ctrlregen_preload()
else:
self._load_pipeline()
def _run_ctrlregen_preload(self) -> None:
"""Ensure the CtrlRegen engine and all its models are loaded."""
from remove_ai_watermarks.noai.ctrlregen import is_ctrlregen_available
if not is_ctrlregen_available():
missing_pkgs = ["controlnet-aux", "color-matcher", "safetensors"]
logger.info("Auto-installing missing CtrlRegen dependencies: %s", missing_pkgs)
if not _auto_install(missing_pkgs):
raise ImportError(
f"Failed to auto-install missing dependencies: {', '.join(missing_pkgs)}. "
"Try manually: pip install --force-reinstall noai-watermark"
)
if self._ctrlregen_engine is None:
self._ctrlregen_engine = self._make_ctrlregen_engine()
self._ctrlregen_engine.load()
def _make_ctrlregen_engine(self) -> Any:
"""Create a new CtrlRegenEngine with current settings."""
from remove_ai_watermarks.noai.ctrlregen import CtrlRegenEngine
base_model = self.model_id if self.model_id != self.CTRLREGEN_MODEL_ID else None
return CtrlRegenEngine(
base_model_id=base_model,
device=self.device,
torch_dtype=self.torch_dtype,
hf_token=self.hf_token,
progress_callback=self._progress_callback,
)
# ── Pipeline loading ─────────────────────────────────────────────
def _load_pipeline(self) -> AutoImg2ImgPipeline:
"""Load the diffusion pipeline lazily."""
if self._pipeline is None:
logger.info("Loading model %s on %s...", self.model_id, self.device)
self._set_progress(f"Loading model weights: {self.model_id}")
load_kwargs: dict[str, Any] = {
"torch_dtype": self.torch_dtype,
"safety_checker": None,
"requires_safety_checker": False,
}
if self.hf_token:
load_kwargs["token"] = self.hf_token
self._pipeline = AutoImg2ImgPipeline.from_pretrained( # type: ignore
self.model_id,
**load_kwargs,
)
self._set_progress(f"Moving model to device: {self.device}")
try:
self._pipeline = self._pipeline.to(self.device) # type: ignore
except (RuntimeError, AssertionError) as exc:
if self.device == "cuda" and not os.environ.get(_CUDA_FIX_ENV_KEY):
self._set_progress("CUDA failed. Reinstalling torch with CUDA support...")
_reinstall_torch_cuda_and_restart()
raise RuntimeError(
f"Failed to move model to {self.device} ({exc}). "
"Install CUDA-enabled PyTorch manually:\n"
f" pip install torch --index-url {_detect_cuda_index_url()}"
) from exc
if hasattr(self._pipeline, "enable_xformers_memory_efficient_attention"):
try:
self._set_progress("Enabling memory optimizations...")
self._pipeline.enable_xformers_memory_efficient_attention() # type: ignore
except Exception:
pass
# Mac Float32 memory slicing
if self.device == "mps" and hasattr(self._pipeline, "enable_attention_slicing"):
try:
self._pipeline.enable_attention_slicing("max")
except Exception:
pass
logger.info("Model loaded successfully")
self._set_progress("Model initialized. Preparing input image...")
return self._pipeline # type: ignore
# ── Core removal ─────────────────────────────────────────────────
def remove_watermark(
self,
image_path: Path,
output_path: Path | None = None,
strength: float | None = None,
num_inference_steps: int = 50,
guidance_scale: float | None = None,
seed: int | None = None,
) -> Path:
"""Remove watermark from an image using regeneration attack.
Args:
image_path: Path to the watermarked image.
output_path: Path for the cleaned image. If None, modifies in place.
strength: Denoising strength (0.0-1.0).
num_inference_steps: Number of denoising steps.
guidance_scale: Classifier-free guidance scale.
seed: Random seed for reproducibility.
Returns:
Path to the cleaned image.
Raises:
FileNotFoundError: If input image doesn't exist.
ValueError: If strength is not in valid range.
"""
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
if output_path is None:
output_path = image_path
strength = strength or self.LOW_STRENGTH
if not 0.0 <= strength <= 1.0:
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}")
if guidance_scale is None:
guidance_scale = 2.0 if self.model_profile == "ctrlregen" else 7.5
self._set_progress("Loading and preprocessing input image...")
init_image = Image.open(image_path).convert("RGB")
w, h = init_image.size
self._set_progress(f"Image loaded: {w}x{h}px | Model: {self.model_id}")
generator = None
if seed is not None and _HAS_TORCH:
self._set_progress(f"Setting reproducible seed: {seed}")
generator = torch.Generator(device=self.device).manual_seed(seed) # type: ignore
effective_steps = max(1, int(num_inference_steps * strength))
self._set_progress(
f"Config: strength={strength}, steps={num_inference_steps} "
f"(~{effective_steps} effective), guidance={guidance_scale}, device={self.device}"
)
_total_start = time.monotonic()
if self.model_profile == "ctrlregen":
cleaned_image = self._run_ctrlregen(
init_image,
strength,
num_inference_steps,
guidance_scale,
generator,
)
else:
cleaned_image = self._run_img2img(
init_image,
strength,
num_inference_steps,
guidance_scale,
generator,
)
self._set_progress(f"Regeneration complete · Output: {w}x{h}px {cleaned_image.mode}")
output_path.parent.mkdir(parents=True, exist_ok=True)
fmt = output_path.suffix.lower()
if fmt in (".jpg", ".jpeg"):
self._set_progress(f"Encoding as JPEG → {output_path.name}...")
else:
self._set_progress(f"Encoding as PNG → {output_path.name}...")
cleaned_image.save(output_path)
if output_path.exists():
self._set_progress("Stripping AI metadata from output...")
try:
from remove_ai_watermarks.noai.cleaner import remove_ai_metadata
remove_ai_metadata(output_path, output_path, keep_standard=True)
except Exception:
logger.debug("AI metadata stripping skipped", exc_info=True)
total_time = time.monotonic() - _total_start
size_str = ""
try:
file_size = output_path.stat().st_size
if file_size < 1024 * 1024:
size_str = f" ({file_size / 1024:.0f}KB)"
else:
size_str = f" ({file_size / (1024 * 1024):.1f}MB)"
except OSError:
pass
logger.info("Cleaned image saved to %s", output_path)
self._set_progress(f"✓ Saved {output_path.name}{size_str} · {w}x{h}px · {total_time:.0f}s total")
return output_path
# ── Img2img runner ───────────────────────────────────────────────
def _run_img2img(
self,
init_image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
generator: Any,
) -> Image.Image:
"""Execute the img2img pipeline with progress and MPS fallback."""
from remove_ai_watermarks.noai.img2img_runner import run_img2img_with_mps_fallback
result_image, final_device = run_img2img_with_mps_fallback(
load_pipeline=self._load_pipeline,
image=init_image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
device=self.device,
set_progress=self._set_progress,
reload_on_cpu=self._reload_pipeline_on_cpu,
)
if final_device != self.device:
self.device = final_device
self.torch_dtype = torch.float32 # type: ignore[assignment]
return result_image
def _reload_pipeline_on_cpu(self) -> Any:
"""Reload pipeline on CPU after MPS failure."""
self.device = "cpu"
self.torch_dtype = torch.float32 # type: ignore[assignment]
self._pipeline = None
return self._load_pipeline()
# ── CtrlRegen runner ─────────────────────────────────────────────
def _run_ctrlregen(
self,
init_image: Image.Image,
strength: float,
num_inference_steps: int,
guidance_scale: float,
generator: Any,
) -> Image.Image:
"""Run CtrlRegen pipeline with MPS fallback."""
from remove_ai_watermarks.noai.ctrlregen import is_ctrlregen_available
from remove_ai_watermarks.noai.progress import is_mps_error
if not is_ctrlregen_available():
missing_pkgs = ["controlnet-aux", "color-matcher", "safetensors"]
logger.info("Auto-installing missing CtrlRegen dependencies: %s", missing_pkgs)
if not _auto_install(missing_pkgs):
raise ImportError(
f"Failed to auto-install missing dependencies: {', '.join(missing_pkgs)}. "
"Try manually: pip install --force-reinstall noai-watermark"
)
if self._ctrlregen_engine is None:
self._ctrlregen_engine = self._make_ctrlregen_engine()
seed = None
if generator is not None and hasattr(generator, "initial_seed"):
seed = generator.initial_seed()
try:
return self._ctrlregen_engine.run(
image=init_image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
)
except RuntimeError as error:
if self.device == "mps" and is_mps_error(error):
logger.warning("MPS out of memory during CtrlRegen. Falling back to CPU.")
self._set_progress("MPS out of memory! Retrying CtrlRegen on CPU...")
try:
if _HAS_TORCH and hasattr(torch, "mps"):
torch.mps.empty_cache() # type: ignore[attr-defined]
except Exception:
pass
self.device = "cpu"
self.torch_dtype = torch.float32 # type: ignore[assignment]
self._ctrlregen_engine = self._make_ctrlregen_engine()
return self._ctrlregen_engine.run(
image=init_image,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
)
raise
# ── Batch ────────────────────────────────────────────────────────
def remove_watermark_batch(
self,
input_dir: Path,
output_dir: Path,
strength: float | None = None,
num_inference_steps: int = 50,
extensions: tuple[str, ...] = (".png", ".jpg", ".jpeg", ".webp"),
) -> list[Path]:
"""Remove watermarks from all images in a directory."""
if not input_dir.exists():
raise FileNotFoundError(f"Input directory not found: {input_dir}")
output_dir.mkdir(parents=True, exist_ok=True)
cleaned_paths: list[Path] = []
for ext in extensions:
for image_path in input_dir.glob(f"*{ext}"):
output_path = output_dir / image_path.name
try:
result_path = self.remove_watermark(
image_path=image_path,
output_path=output_path,
strength=strength,
num_inference_steps=num_inference_steps,
)
cleaned_paths.append(result_path)
except Exception as e:
logger.error("Failed to process %s: %s", image_path, e)
return cleaned_paths
# ── Convenience function ─────────────────────────────────────────────
def remove_watermark(
image_path: Path,
output_path: Path | None = None,
strength: float = 0.04,
model_id: str | None = None,
device: str | None = None,
hf_token: str | None = None,
) -> Path:
"""Convenience function to remove watermark from an image."""
remover = WatermarkRemover(model_id=model_id, device=device, hf_token=hf_token)
return remover.remove_watermark(
image_path=image_path,
output_path=output_path,
strength=strength,
)
+1
View File
@@ -0,0 +1 @@
"""Tests for remove-ai-watermarks."""
+63
View File
@@ -0,0 +1,63 @@
"""Shared fixtures for remove-ai-watermarks test suite."""
from __future__ import annotations
from pathlib import Path
import cv2
import numpy as np
import pytest
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@pytest.fixture()
def tmp_image_path(tmp_path: Path) -> Path:
"""Create a minimal 200×200 test PNG image and return its path."""
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
path = tmp_path / "test_image.png"
cv2.imwrite(str(path), img)
return path
@pytest.fixture()
def tmp_large_image_path(tmp_path: Path) -> Path:
"""Create a 1200×1200 test PNG image (triggers large watermark branch)."""
img = np.random.randint(0, 255, (1200, 1200, 3), dtype=np.uint8)
path = tmp_path / "test_large.png"
cv2.imwrite(str(path), img)
return path
@pytest.fixture()
def tmp_jpeg_path(tmp_path: Path) -> Path:
"""Create a minimal JPEG test image."""
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
path = tmp_path / "test_image.jpg"
cv2.imwrite(str(path), img)
return path
@pytest.fixture()
def tmp_png_with_ai_metadata(tmp_path: Path) -> Path:
"""Create a PNG with AI-related metadata keys."""
img = Image.new("RGB", (64, 64), color=(128, 128, 128))
pnginfo = PngInfo()
pnginfo.add_text("parameters", "Steps: 20, Sampler: Euler, CFG scale: 7")
pnginfo.add_text("prompt", "a beautiful landscape")
pnginfo.add_text("Author", "Test Author")
path = tmp_path / "ai_metadata.png"
img.save(path, pnginfo=pnginfo)
return path
@pytest.fixture()
def tmp_clean_png(tmp_path: Path) -> Path:
"""Create a PNG with no AI metadata."""
img = Image.new("RGB", (64, 64), color=(200, 100, 50))
pnginfo = PngInfo()
pnginfo.add_text("Author", "Human Artist")
pnginfo.add_text("Title", "Test Artwork")
path = tmp_path / "clean.png"
img.save(path, pnginfo=pnginfo)
return path
+322
View File
@@ -0,0 +1,322 @@
"""Tests for the CLI entry point."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import cv2
import numpy as np
import pytest
from click.testing import CliRunner
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from remove_ai_watermarks.cli import main
@pytest.fixture()
def runner():
return CliRunner()
@pytest.fixture()
def sample_png(tmp_path: Path) -> Path:
"""Create a sample PNG for CLI testing."""
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
path = tmp_path / "input.png"
cv2.imwrite(str(path), img)
return path
def _make_batch_dir(tmp_path: Path, count: int = 3) -> Path:
"""Create a directory with test images for batch testing."""
input_dir = tmp_path / "input"
input_dir.mkdir()
for i in range(count):
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
cv2.imwrite(str(input_dir / f"img_{i}.png"), img)
return input_dir
def _make_batch_dir_with_metadata(tmp_path: Path, count: int = 3) -> Path:
"""Create a directory with PNG images containing AI metadata."""
input_dir = tmp_path / "input"
input_dir.mkdir()
for i in range(count):
img = Image.new("RGB", (64, 64), color=(100 + i, 150, 200))
pnginfo = PngInfo()
pnginfo.add_text("parameters", f"Steps: 20, Sampler: Euler, img_{i}")
pnginfo.add_text("prompt", "a test landscape")
img.save(input_dir / f"img_{i}.png", pnginfo=pnginfo)
return input_dir
def _mock_invisible_engine():
"""Create a mock InvisibleEngine that writes a copy of the input image."""
def _mock_remove_watermark(image_path, output_path=None, **kwargs):
out = output_path or image_path.with_stem(image_path.stem + "_clean")
out.parent.mkdir(parents=True, exist_ok=True)
img = Image.open(image_path)
img.save(out)
return out
mock_engine = MagicMock()
mock_engine.remove_watermark.side_effect = _mock_remove_watermark
mock_cls = MagicMock(return_value=mock_engine)
return mock_cls, mock_engine
class TestMainGroup:
"""Tests for the top-level CLI group."""
def test_help(self, runner):
result = runner.invoke(main, ["--help"])
assert result.exit_code == 0
assert "Remove visible and invisible" in result.output
def test_version(self, runner):
result = runner.invoke(main, ["--version"])
assert result.exit_code == 0
assert "0.1.0" in result.output
def test_no_command_shows_banner(self, runner):
result = runner.invoke(main, [])
assert result.exit_code == 0
assert "Remove-AI-Watermarks" in result.output
class TestVisibleCommand:
"""Tests for the 'visible' subcommand."""
def test_visible_help(self, runner):
result = runner.invoke(main, ["visible", "--help"])
assert result.exit_code == 0
assert "Gemini watermark" in result.output
def test_visible_basic(self, runner, sample_png, tmp_path):
output = tmp_path / "clean.png"
result = runner.invoke(
main,
["visible", str(sample_png), "-o", str(output), "--no-detect"],
)
assert result.exit_code == 0
assert output.exists()
assert "Saved" in result.output
def test_visible_default_output_name(self, runner, sample_png):
result = runner.invoke(main, ["visible", str(sample_png), "--no-detect"])
assert result.exit_code == 0
expected = sample_png.with_stem(sample_png.stem + "_clean")
assert expected.exists()
def test_visible_no_inpaint(self, runner, sample_png, tmp_path):
output = tmp_path / "clean.png"
result = runner.invoke(
main,
[
"visible",
str(sample_png),
"-o",
str(output),
"--no-inpaint",
"--no-detect",
],
)
assert result.exit_code == 0
assert output.exists()
def test_visible_no_detect(self, runner, sample_png, tmp_path):
output = tmp_path / "clean.png"
result = runner.invoke(
main,
["visible", str(sample_png), "-o", str(output), "--no-detect"],
)
assert result.exit_code == 0
def test_visible_nonexistent_file(self, runner):
result = runner.invoke(main, ["visible", "/nonexistent/file.png"])
assert result.exit_code != 0
class TestInvisibleCommand:
"""Tests for the 'invisible' subcommand."""
def test_invisible_help(self, runner):
result = runner.invoke(main, ["invisible", "--help"])
assert result.exit_code == 0
assert "invisible" in result.output.lower()
def test_invisible_basic(self, runner, sample_png, tmp_path):
mock_cls, mock_engine = _mock_invisible_engine()
output = tmp_path / "clean.png"
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
):
result = runner.invoke(
main,
["invisible", str(sample_png), "-o", str(output)],
)
assert result.exit_code == 0, result.output
assert output.exists()
mock_engine.remove_watermark.assert_called_once()
def test_invisible_default_output(self, runner, sample_png):
mock_cls, mock_engine = _mock_invisible_engine()
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
):
result = runner.invoke(main, ["invisible", str(sample_png)])
assert result.exit_code == 0, result.output
expected = sample_png.with_stem(sample_png.stem + "_clean")
assert expected.exists()
def test_invisible_nonexistent_file(self, runner):
result = runner.invoke(main, ["invisible", "/nonexistent/file.png"])
assert result.exit_code != 0
class TestAllCommand:
"""Tests for the 'all' subcommand (full pipeline)."""
def test_all_help(self, runner):
result = runner.invoke(main, ["all", "--help"])
assert result.exit_code == 0
assert "visible" in result.output.lower()
def test_all_basic(self, runner, sample_png, tmp_path):
mock_cls, mock_engine = _mock_invisible_engine()
output = tmp_path / "clean.png"
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
):
result = runner.invoke(
main,
["all", str(sample_png), "-o", str(output)],
)
assert result.exit_code == 0, result.output
assert output.exists()
def test_all_nonexistent_file(self, runner):
result = runner.invoke(main, ["all", "/nonexistent/file.png"])
assert result.exit_code != 0
class TestMetadataCommand:
"""Tests for the 'metadata' subcommand."""
def test_metadata_help(self, runner):
result = runner.invoke(main, ["metadata", "--help"])
assert result.exit_code == 0
def test_metadata_check_clean(self, runner, tmp_clean_png):
result = runner.invoke(main, ["metadata", str(tmp_clean_png), "--check"])
assert result.exit_code == 0
assert "No AI metadata" in result.output
def test_metadata_check_ai(self, runner, tmp_png_with_ai_metadata):
result = runner.invoke(main, ["metadata", str(tmp_png_with_ai_metadata), "--check"])
assert result.exit_code == 0
assert "AI metadata detected" in result.output
def test_metadata_remove(self, runner, tmp_png_with_ai_metadata, tmp_path):
output = tmp_path / "stripped.png"
result = runner.invoke(
main,
[
"metadata",
str(tmp_png_with_ai_metadata),
"--remove",
"-o",
str(output),
],
)
assert result.exit_code == 0
assert "stripped" in result.output
class TestBatchCommand:
"""Tests for the 'batch' subcommand."""
def test_batch_help(self, runner):
result = runner.invoke(main, ["batch", "--help"])
assert result.exit_code == 0
def test_batch_empty_dir(self, runner, tmp_path):
empty_dir = tmp_path / "empty"
empty_dir.mkdir()
result = runner.invoke(main, ["batch", str(empty_dir)])
assert result.exit_code == 0
assert "No supported images" in result.output
def test_batch_visible_mode(self, runner, tmp_path):
input_dir = _make_batch_dir(tmp_path)
output_dir = tmp_path / "output"
result = runner.invoke(
main,
["batch", str(input_dir), "-o", str(output_dir), "--mode", "visible"],
)
assert result.exit_code == 0
assert "3 processed" in result.output
assert output_dir.exists()
assert len(list(output_dir.glob("*.png"))) == 3
def test_batch_metadata_mode(self, runner, tmp_path):
input_dir = _make_batch_dir_with_metadata(tmp_path)
output_dir = tmp_path / "output"
result = runner.invoke(
main,
["batch", str(input_dir), "-o", str(output_dir), "--mode", "metadata"],
)
assert result.exit_code == 0
assert "3 processed" in result.output
assert output_dir.exists()
assert len(list(output_dir.glob("*.png"))) == 3
# Verify AI metadata was stripped
for out_img in output_dir.glob("*.png"):
with Image.open(out_img) as img:
assert "parameters" not in img.info
def test_batch_invisible_mode(self, runner, tmp_path):
input_dir = _make_batch_dir(tmp_path)
output_dir = tmp_path / "output"
mock_cls, mock_engine = _mock_invisible_engine()
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
), patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True), patch(
"remove_ai_watermarks.invisible_engine.is_available", return_value=True
):
result = runner.invoke(
main,
["batch", str(input_dir), "-o", str(output_dir), "--mode", "invisible"],
)
assert result.exit_code == 0, result.output
assert "3 processed" in result.output
def test_batch_all_mode(self, runner, tmp_path):
input_dir = _make_batch_dir(tmp_path)
output_dir = tmp_path / "output"
mock_cls, mock_engine = _mock_invisible_engine()
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
), patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True), patch(
"remove_ai_watermarks.invisible_engine.is_available", return_value=True
):
result = runner.invoke(
main,
["batch", str(input_dir), "-o", str(output_dir), "--mode", "all"],
)
assert result.exit_code == 0, result.output
assert "3 processed" in result.output
def test_batch_default_output_dir(self, runner, tmp_path):
input_dir = _make_batch_dir(tmp_path)
result = runner.invoke(
main,
["batch", str(input_dir), "--mode", "visible"],
)
assert result.exit_code == 0
expected_dir = tmp_path / "input_clean"
assert expected_dir.exists()
+63
View File
@@ -0,0 +1,63 @@
import numpy as np
from remove_ai_watermarks.face_protector import FaceProtector
def test_face_protector_initialization():
# Will fallback to Haar cascade if ultralytics is missing
fp = FaceProtector(use_yolo=False)
assert fp.use_yolo is False
assert fp.haar_cascade is not None
def test_face_protector_lifecycle():
fp = FaceProtector(use_yolo=False)
# Create dummy black image
img = np.zeros((200, 200, 3), dtype=np.uint8)
# Since it's a black image, haar cascade should find 0 faces
faces = fp.extract_faces(img)
assert isinstance(faces, list)
assert len(faces) == 0
# Restoring 0 faces should result in strictly equal image
restored = fp.restore_faces(img, faces)
assert np.array_equal(img, restored)
def test_face_protector_restore_bypass_on_size_mismatch():
fp = FaceProtector(use_yolo=False)
img_small = np.zeros((100, 100, 3), dtype=np.uint8)
# Manually mock a face that is OUT OF BOUNDS for img_small
mock_bbox = (80, 80, 130, 130)
mock_crop = np.ones((50, 50, 3), dtype=np.uint8) * 255
mock_faces = [(mock_bbox, mock_crop)]
# Attempt to restore onto an image too small for this box
restored = fp.restore_faces(img_small, mock_faces)
# Should safely skip restoring and not crash
assert np.array_equal(restored, img_small)
def test_face_protector_restore_blending():
fp = FaceProtector(use_yolo=False)
# Background is black
img_target = np.zeros((100, 100, 3), dtype=np.uint8)
# Face crop is white
mock_bbox = (25, 25, 75, 75)
mock_crop = np.ones((50, 50, 3), dtype=np.uint8) * 255
mock_faces = [(mock_bbox, mock_crop)]
restored = fp.restore_faces(img_target, mock_faces)
# The center of the face should be perfectly white (255)
assert restored[50, 50, 0] >= 254
# The corner of the target should remain perfectly black (0)
assert restored[0, 0, 0] == 0
# We should have a blending gradient between them due to the gaussian blur mask
# For example, around (30, 30) or similar
assert 0 <= restored[28, 28, 0] <= 255
+216
View File
@@ -0,0 +1,216 @@
"""Tests for the Gemini visible-watermark engine."""
from __future__ import annotations
import cv2
import numpy as np
import pytest
from remove_ai_watermarks.gemini_engine import (
DetectionResult,
GeminiEngine,
WatermarkPosition,
WatermarkSize,
_calculate_alpha_map,
get_watermark_config,
get_watermark_size,
)
# ── WatermarkSize / config helpers ──────────────────────────────────
class TestWatermarkConfig:
"""Tests for watermark size detection and position calculation."""
def test_small_image_gets_small_watermark(self):
assert get_watermark_size(800, 600) == WatermarkSize.SMALL
def test_large_image_gets_large_watermark(self):
assert get_watermark_size(1920, 1080) == WatermarkSize.LARGE
def test_boundary_image_stays_small(self):
"""Exactly 1024×1024 should be SMALL (rule: > 1024 for LARGE)."""
assert get_watermark_size(1024, 1024) == WatermarkSize.SMALL
def test_one_dimension_small(self):
"""Only one dimension > 1024 → still SMALL."""
assert get_watermark_size(2000, 500) == WatermarkSize.SMALL
def test_config_small_returns_correct_values(self):
config = get_watermark_config(800, 600)
assert config.margin_right == 32
assert config.margin_bottom == 32
assert config.logo_size == 48
def test_config_large_returns_correct_values(self):
config = get_watermark_config(1920, 1080)
assert config.margin_right == 64
assert config.margin_bottom == 64
assert config.logo_size == 96
def test_position_calculation(self):
pos = WatermarkPosition(margin_right=32, margin_bottom=32, logo_size=48)
x, y = pos.get_position(800, 600)
assert x == 800 - 32 - 48 # 720
assert y == 600 - 32 - 48 # 520
# ── Alpha map ───────────────────────────────────────────────────────
class TestAlphaMap:
"""Tests for alpha map calculation."""
def test_pure_black_gives_zero_alpha(self):
black = np.zeros((10, 10, 3), dtype=np.uint8)
alpha = _calculate_alpha_map(black)
assert alpha.shape == (10, 10)
np.testing.assert_array_equal(alpha, 0.0)
def test_pure_white_gives_one_alpha(self):
white = np.full((10, 10, 3), 255, dtype=np.uint8)
alpha = _calculate_alpha_map(white)
np.testing.assert_allclose(alpha, 1.0)
def test_grayscale_input(self):
gray = np.full((10, 10), 128, dtype=np.uint8)
alpha = _calculate_alpha_map(gray)
np.testing.assert_allclose(alpha, 128 / 255.0)
def test_max_channel_used(self):
"""Alpha should use max(R, G, B)."""
img = np.zeros((1, 1, 3), dtype=np.uint8)
img[0, 0] = [50, 200, 100] # BGR
alpha = _calculate_alpha_map(img)
assert pytest.approx(alpha[0, 0], rel=1e-3) == 200 / 255.0
# ── GeminiEngine ────────────────────────────────────────────────────
class TestGeminiEngine:
"""Tests for the GeminiEngine class."""
@pytest.fixture(autouse=True)
def _setup_engine(self):
self.engine = GeminiEngine()
def test_engine_loads_alpha_maps(self):
small = self.engine.get_alpha_map(WatermarkSize.SMALL)
large = self.engine.get_alpha_map(WatermarkSize.LARGE)
assert small.shape == (48, 48)
assert large.shape == (96, 96)
def test_remove_watermark_returns_same_shape(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.remove_watermark(image)
assert result.shape == image.shape
assert result.dtype == np.uint8
def test_remove_watermark_does_not_modify_input(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
original = image.copy()
self.engine.remove_watermark(image)
np.testing.assert_array_equal(image, original)
def test_remove_watermark_large_image(self, tmp_large_image_path):
image = cv2.imread(str(tmp_large_image_path), cv2.IMREAD_COLOR)
result = self.engine.remove_watermark(image)
assert result.shape == image.shape
def test_remove_watermark_custom_region(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.remove_watermark_custom(image, (10, 10, 48, 48))
assert result.shape == image.shape
def test_remove_watermark_custom_large_region(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.remove_watermark_custom(image, (10, 10, 96, 96))
assert result.shape == image.shape
def test_remove_watermark_custom_arbitrary_region(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.remove_watermark_custom(image, (5, 5, 60, 60))
assert result.shape == image.shape
def test_force_size(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.remove_watermark(image, force_size=WatermarkSize.LARGE)
assert result.shape == image.shape
# ── Detection ───────────────────────────────────────────────────────
class TestDetection:
"""Tests for watermark detection."""
@pytest.fixture(autouse=True)
def _setup_engine(self):
self.engine = GeminiEngine()
def test_detect_returns_result_object(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.detect_watermark(image)
assert isinstance(result, DetectionResult)
assert 0.0 <= result.confidence <= 1.0
def test_detect_empty_image_returns_no_detection(self):
empty = np.zeros((0, 0, 3), dtype=np.uint8)
result = self.engine.detect_watermark(empty)
assert not result.detected
assert result.confidence == 0.0
def test_detect_none_image_returns_no_detection(self):
result = self.engine.detect_watermark(None)
assert not result.detected
def test_detect_random_image_low_confidence(self, tmp_image_path):
"""Random noise should not look like a watermark."""
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.detect_watermark(image)
# Random image may or may not be detected; confidence should be meaningful
assert isinstance(result.spatial_score, float)
assert isinstance(result.gradient_score, float)
# ── Inpainting ──────────────────────────────────────────────────────
class TestInpainting:
"""Tests for residual inpainting."""
@pytest.fixture(autouse=True)
def _setup_engine(self):
self.engine = GeminiEngine()
def test_inpaint_ns(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), method="ns")
assert result.shape == image.shape
def test_inpaint_telea(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), method="telea")
assert result.shape == image.shape
def test_inpaint_gaussian(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), method="gaussian")
assert result.shape == image.shape
def test_inpaint_zero_strength(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), strength=0.0)
np.testing.assert_array_equal(result, image)
def test_inpaint_tiny_region_returns_unchanged(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
result = self.engine.inpaint_residual(image, (10, 10, 2, 2))
np.testing.assert_array_equal(result, image)
def test_inpaint_does_not_modify_input(self, tmp_image_path):
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
original = image.copy()
self.engine.inpaint_residual(image, (150, 150, 48, 48))
np.testing.assert_array_equal(image, original)
+52
View File
@@ -0,0 +1,52 @@
import numpy as np
from remove_ai_watermarks.humanizer import apply_analog_humanizer
def test_humanizer_does_not_modify_original_if_disabled():
img = np.zeros((100, 100, 3), dtype=np.uint8)
img[50, 50] = [100, 150, 200]
org_img = img.copy()
# grain=0, shift=0 means disabled essentially. But wait, apply_analog_humanizer currently applies chromatic shift even if grain=0.
result = apply_analog_humanizer(img, grain_intensity=0.0, chromatic_shift=0)
assert np.array_equal(result, org_img)
def test_chromatic_shift():
# Only green channel is centered, red/blue should shift.
img = np.zeros((5, 5, 3), dtype=np.uint8)
img[2, 2] = [255, 255, 255] # B, G, R
# shift=1
result = apply_analog_humanizer(img, grain_intensity=0.0, chromatic_shift=1)
# G (index 1) stays at [2,2]
assert result[2, 2, 1] == 255
# B (index 0) shifted right (+1 axis 1) -> [2, 3]
assert result[2, 3, 0] == 255
# R (index 2) shifted left (-1 axis 1) -> [2, 1]
assert result[2, 1, 2] == 255
def test_grain_intensity():
# Gray image
img = np.full((100, 100, 3), 128, dtype=np.uint8)
# Add strong noise
result = apply_analog_humanizer(img, grain_intensity=10.0, chromatic_shift=0)
# Image should no longer be purely 128
unique_vals = np.unique(result)
assert len(unique_vals) > 5
# Mean should roughly be 128
assert 126 < np.mean(result) < 130
def test_invalid_shape():
# Missing color channel
img = np.zeros((100, 100), dtype=np.uint8)
img[0, 0] = 50
result = apply_analog_humanizer(img)
assert np.array_equal(img, result)
+27
View File
@@ -0,0 +1,27 @@
"""Tests for the invisible watermark engine (unit tests, no GPU required)."""
from __future__ import annotations
from remove_ai_watermarks.invisible_engine import InvisibleEngine, is_available
class TestIsAvailable:
"""Tests for dependency checking."""
def test_returns_bool(self):
result = is_available()
assert isinstance(result, bool)
def test_available_when_torch_installed(self):
"""torch + diffusers should be installed in dev env."""
assert is_available() is True
class TestInvisibleEngineInit:
"""Tests for InvisibleEngine construction (no GPU required)."""
def test_default_model_id(self):
assert InvisibleEngine.DEFAULT_MODEL_ID == "Lykon/dreamshaper-8"
def test_ctrlregen_model_id(self):
assert InvisibleEngine.CTRLREGEN_MODEL_ID == "yepengliu/ctrlregen"
+150
View File
@@ -0,0 +1,150 @@
"""Tests for AI metadata detection and removal."""
from __future__ import annotations
from pathlib import Path
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from remove_ai_watermarks.metadata import (
_is_ai_key,
get_ai_metadata,
has_ai_metadata,
remove_ai_metadata,
)
# ── Key detection ───────────────────────────────────────────────────
class TestIsAiKey:
"""Tests for _is_ai_key helper."""
def test_exact_match_lowercase(self):
assert _is_ai_key("parameters")
def test_exact_match_mixed_case(self):
assert _is_ai_key("Parameters")
def test_keyword_substring(self):
assert _is_ai_key("stable_diffusion_model_v2")
def test_c2pa_detected(self):
assert _is_ai_key("c2pa_chunk")
def test_standard_key_not_flagged(self):
assert not _is_ai_key("Author")
def test_innocuous_key_not_flagged(self):
assert not _is_ai_key("Title")
def test_dpi_not_flagged(self):
assert not _is_ai_key("dpi")
# ── has_ai_metadata / get_ai_metadata ───────────────────────────────
class TestHasAiMetadata:
"""Tests for detecting AI metadata in images."""
def test_detects_ai_metadata(self, tmp_png_with_ai_metadata):
assert has_ai_metadata(tmp_png_with_ai_metadata)
def test_clean_image_no_ai(self, tmp_clean_png):
assert not has_ai_metadata(tmp_clean_png)
class TestGetAiMetadata:
"""Tests for extracting AI metadata."""
def test_extracts_parameters_key(self, tmp_png_with_ai_metadata):
meta = get_ai_metadata(tmp_png_with_ai_metadata)
assert "parameters" in meta
assert "Euler" in meta["parameters"]
def test_extracts_prompt_key(self, tmp_png_with_ai_metadata):
meta = get_ai_metadata(tmp_png_with_ai_metadata)
assert "prompt" in meta
def test_does_not_extract_author(self, tmp_png_with_ai_metadata):
meta = get_ai_metadata(tmp_png_with_ai_metadata)
assert "Author" not in meta
def test_clean_image_empty_dict(self, tmp_clean_png):
meta = get_ai_metadata(tmp_clean_png)
assert meta == {}
# ── remove_ai_metadata ──────────────────────────────────────────────
class TestRemoveAiMetadata:
"""Tests for stripping AI metadata."""
def test_removes_ai_keys(self, tmp_png_with_ai_metadata):
output = tmp_png_with_ai_metadata.parent / "cleaned.png"
remove_ai_metadata(tmp_png_with_ai_metadata, output)
with Image.open(output) as img:
assert "parameters" not in img.info
assert "prompt" not in img.info
def test_keeps_standard_metadata(self, tmp_png_with_ai_metadata):
output = tmp_png_with_ai_metadata.parent / "cleaned.png"
remove_ai_metadata(tmp_png_with_ai_metadata, output, keep_standard=True)
with Image.open(output) as img:
assert "Author" in img.info
assert img.info["Author"] == "Test Author"
def test_remove_all_metadata(self, tmp_png_with_ai_metadata):
output = tmp_png_with_ai_metadata.parent / "cleaned.png"
remove_ai_metadata(tmp_png_with_ai_metadata, output, keep_standard=False)
with Image.open(output) as img:
assert "Author" not in img.info
assert "parameters" not in img.info
def test_overwrite_in_place(self, tmp_path):
"""When output_path is None, should overwrite source."""
img = Image.new("RGB", (32, 32))
pnginfo = PngInfo()
pnginfo.add_text("parameters", "test data")
path = tmp_path / "inplace.png"
img.save(path, pnginfo=pnginfo)
result = remove_ai_metadata(path)
assert result == path
with Image.open(path) as cleaned:
assert "parameters" not in cleaned.info
def test_jpeg_output(self, tmp_path):
"""Test metadata removal for JPEG format."""
img = Image.new("RGB", (64, 64), color=(100, 150, 200))
pnginfo = PngInfo()
pnginfo.add_text("parameters", "test")
png_path = tmp_path / "source.png"
img.save(png_path, pnginfo=pnginfo)
jpg_path = tmp_path / "output.jpg"
result = remove_ai_metadata(png_path, jpg_path)
assert result == jpg_path
assert jpg_path.exists()
def test_creates_parent_directories(self, tmp_path):
img = Image.new("RGB", (32, 32))
pnginfo = PngInfo()
pnginfo.add_text("prompt", "test")
path = tmp_path / "source.png"
img.save(path, pnginfo=pnginfo)
output = tmp_path / "sub" / "dir" / "cleaned.png"
remove_ai_metadata(path, output)
assert output.exists()
def test_returns_path(self, tmp_clean_png):
output = tmp_clean_png.parent / "out.png"
result = remove_ai_metadata(tmp_clean_png, output)
assert isinstance(result, Path)
assert result == output
+130
View File
@@ -0,0 +1,130 @@
"""Tests for vendored noai submodules: constants, extractor, cleaner, c2pa."""
from __future__ import annotations
from remove_ai_watermarks.noai.c2pa import (
extract_c2pa_chunk,
extract_c2pa_info,
has_c2pa_metadata,
)
from remove_ai_watermarks.noai.cleaner import (
has_ai_content,
)
from remove_ai_watermarks.noai.cleaner import (
remove_ai_metadata as noai_remove_ai_metadata,
)
from remove_ai_watermarks.noai.constants import (
AI_KEYWORDS,
AI_METADATA_KEYS,
C2PA_CHUNK_TYPE,
PNG_SIGNATURE,
SUPPORTED_FORMATS,
)
from remove_ai_watermarks.noai.extractor import (
extract_ai_metadata,
extract_metadata,
get_ai_metadata_summary,
has_ai_metadata,
)
# ── Constants ───────────────────────────────────────────────────────
class TestConstants:
"""Verify constant integrity."""
def test_supported_formats_include_png(self):
assert ".png" in SUPPORTED_FORMATS
def test_supported_formats_include_jpg(self):
assert ".jpg" in SUPPORTED_FORMATS
def test_ai_metadata_keys_not_empty(self):
assert len(AI_METADATA_KEYS) > 0
def test_ai_keywords_not_empty(self):
assert len(AI_KEYWORDS) > 0
def test_png_signature_bytes(self):
assert PNG_SIGNATURE == b"\x89PNG\r\n\x1a\n"
def test_c2pa_chunk_type(self):
assert C2PA_CHUNK_TYPE == b"caBX"
# ── Extractor ───────────────────────────────────────────────────────
class TestExtractor:
"""Tests for noai.extractor functions."""
def test_extract_metadata_returns_dict(self, tmp_clean_png):
meta = extract_metadata(tmp_clean_png)
assert isinstance(meta, dict)
def test_extract_metadata_gets_standard_keys(self, tmp_clean_png):
meta = extract_metadata(tmp_clean_png)
assert "Author" in meta
def test_extract_ai_metadata_from_ai_image(self, tmp_png_with_ai_metadata):
meta = extract_ai_metadata(tmp_png_with_ai_metadata)
assert "parameters" in meta
def test_extract_ai_metadata_from_clean_image(self, tmp_clean_png):
meta = extract_ai_metadata(tmp_clean_png)
assert len(meta) == 0
def test_has_ai_metadata_detects(self, tmp_png_with_ai_metadata):
assert has_ai_metadata(tmp_png_with_ai_metadata)
def test_has_ai_metadata_clean(self, tmp_clean_png):
assert not has_ai_metadata(tmp_clean_png)
def test_summary_with_ai(self, tmp_png_with_ai_metadata):
summary = get_ai_metadata_summary(tmp_png_with_ai_metadata)
assert "AI Image Metadata" in summary
def test_summary_clean(self, tmp_clean_png):
summary = get_ai_metadata_summary(tmp_clean_png)
assert "No AI metadata" in summary
# ── Cleaner ─────────────────────────────────────────────────────────
class TestCleaner:
"""Tests for noai.cleaner functions."""
def test_remove_ai_metadata(self, tmp_png_with_ai_metadata, tmp_path):
output = tmp_path / "cleaned.png"
noai_remove_ai_metadata(tmp_png_with_ai_metadata, output)
assert output.exists()
# Verify AI metadata removed
meta = extract_ai_metadata(output)
assert "parameters" not in meta
def test_has_ai_content(self, tmp_png_with_ai_metadata):
assert has_ai_content(tmp_png_with_ai_metadata)
# ── C2PA ────────────────────────────────────────────────────────────
class TestC2PA:
"""Tests for C2PA detection on regular (non-C2PA) images."""
def test_no_c2pa_on_regular_png(self, tmp_clean_png):
assert not has_c2pa_metadata(tmp_clean_png)
def test_no_c2pa_on_jpeg(self, tmp_jpeg_path):
assert not has_c2pa_metadata(tmp_jpeg_path)
def test_extract_c2pa_none_on_regular(self, tmp_clean_png):
assert extract_c2pa_chunk(tmp_clean_png) is None
def test_extract_c2pa_info_empty(self, tmp_clean_png):
info = extract_c2pa_info(tmp_clean_png)
assert info == {}
def test_c2pa_returns_false_for_non_png(self, tmp_jpeg_path):
assert not has_c2pa_metadata(tmp_jpeg_path)
+165
View File
@@ -0,0 +1,165 @@
"""Tests for cross-platform and cross-device compatibility.
Verifies that device detection, MPS fallback, and platform-specific
code paths work correctly on CPU, MPS (macOS), and CUDA (Linux/Windows).
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from remove_ai_watermarks.noai.progress import is_mps_error
from remove_ai_watermarks.noai.utils import get_image_format, is_supported_format
from remove_ai_watermarks.noai.watermark_profiles import (
detect_model_profile,
get_model_id_for_profile,
get_recommended_strength,
)
from remove_ai_watermarks.noai.watermark_remover import get_device, is_watermark_removal_available
# ── Device detection ────────────────────────────────────────────────
class TestDeviceDetection:
"""Tests for get_device() across platforms."""
def test_returns_valid_device(self):
device = get_device()
assert device in ("cpu", "mps", "cuda")
def test_cpu_fallback_when_no_gpu(self):
"""On CI / machines without GPU, should fall back to cpu or mps."""
device = get_device()
# Just verify it doesn't crash and returns a valid string
assert isinstance(device, str)
@patch("remove_ai_watermarks.noai.watermark_remover._HAS_TORCH", False)
def test_no_torch_returns_cpu(self):
assert get_device() == "cpu"
class TestMpsErrorDetection:
"""Tests for MPS error detection helper."""
def test_detects_mps_error(self):
err = RuntimeError("MPS backend out of memory")
assert is_mps_error(err) is True
def test_non_mps_error(self):
err = RuntimeError("CUDA out of memory")
assert is_mps_error(err) is False
def test_generic_error(self):
err = RuntimeError("something went wrong")
assert is_mps_error(err) is False
# ── Model profiles ──────────────────────────────────────────────────
class TestModelProfiles:
"""Tests for watermark_profiles.py."""
def test_default_profile(self):
assert get_model_id_for_profile("default") == "Lykon/dreamshaper-8"
def test_ctrlregen_profile(self):
assert get_model_id_for_profile("ctrlregen") == "yepengliu/ctrlregen"
def test_unknown_profile_raises(self):
with pytest.raises(ValueError, match="Unknown model profile"):
get_model_id_for_profile("nonexistent")
def test_detect_default(self):
assert detect_model_profile("Lykon/dreamshaper-8") == "default"
def test_detect_ctrlregen(self):
assert detect_model_profile("yepengliu/ctrlregen") == "ctrlregen"
def test_recommended_strength_high(self):
assert get_recommended_strength("treering") == 0.7
def test_recommended_strength_low(self):
assert get_recommended_strength("stablesignature") == 0.04
def test_recommended_strength_medium(self):
assert get_recommended_strength("unknown_type") == 0.35
# ── Format utilities ────────────────────────────────────────────────
class TestFormatUtils:
"""Tests for utils.py format helpers."""
def test_supported_png(self, tmp_path):
assert is_supported_format(tmp_path / "test.png")
def test_supported_jpg(self, tmp_path):
assert is_supported_format(tmp_path / "test.jpg")
def test_supported_jpeg(self, tmp_path):
assert is_supported_format(tmp_path / "test.jpeg")
def test_supported_webp(self, tmp_path):
assert is_supported_format(tmp_path / "test.webp")
def test_unsupported_bmp(self, tmp_path):
assert not is_supported_format(tmp_path / "test.bmp")
def test_unsupported_gif(self, tmp_path):
assert not is_supported_format(tmp_path / "test.gif")
def test_get_format_png(self, tmp_path):
assert get_image_format(tmp_path / "x.png") == "PNG"
def test_get_format_jpg(self, tmp_path):
assert get_image_format(tmp_path / "x.jpg") == "JPEG"
def test_get_format_jpeg(self, tmp_path):
assert get_image_format(tmp_path / "x.jpeg") == "JPEG"
def test_get_format_webp_defaults_png(self, tmp_path):
# .webp falls through to PNG in current implementation
assert get_image_format(tmp_path / "x.webp") == "PNG"
# ── Availability checks ────────────────────────────────────────────
class TestAvailability:
"""Tests for dependency availability checks."""
def test_watermark_removal_available(self):
# In dev env with torch+diffusers installed
assert is_watermark_removal_available() is True
def test_invisible_is_available(self):
from remove_ai_watermarks.invisible_engine import is_available
assert is_available() is True
# ── Platform-specific path handling ─────────────────────────────────
class TestPlatformPaths:
"""Verify path handling works on current platform."""
def test_pathlib_works_for_assets(self):
from pathlib import Path
asset_dir = Path(__file__).parent.parent / "src" / "remove_ai_watermarks" / "assets"
assert (asset_dir / "gemini_bg_48.png").exists()
assert (asset_dir / "gemini_bg_96.png").exists()
def test_asset_loading_works(self):
"""Verify embedded assets load correctly (critical for packaging)."""
from remove_ai_watermarks.gemini_engine import GeminiEngine
engine = GeminiEngine()
# If we get here without error, asset loading works
assert engine._alpha_small.shape == (48, 48)
assert engine._alpha_large.shape == (96, 96)
Generated
+2779
View File
File diff suppressed because it is too large Load Diff