Add project files, tests, and documentation for GitHub release
- CLI with visible, invisible, all, metadata, and batch commands - Gemini watermark removal via reverse alpha blending - Invisible watermark removal via diffusion regeneration (SynthID, TreeRing) - AI metadata stripping (EXIF, PNG text, C2PA) - Face protection (YOLO/Haar) and analog humanizer - 137 tests covering all CLI modes and core engines - Ruff and Pyright clean
@@ -0,0 +1,65 @@
|
||||
---
|
||||
trigger: always_on
|
||||
description: Project rules for remove-ai-watermarks
|
||||
---
|
||||
|
||||
# Project rules
|
||||
|
||||
## Role
|
||||
|
||||
You are a **principal Python engineer** working on `remove-ai-watermarks` — a CLI tool for removing visible and invisible AI watermarks from images.
|
||||
|
||||
## Python environment
|
||||
|
||||
- This project uses `uv` for Python package management.
|
||||
- Use `uv pip install` instead of `pip install`.
|
||||
- Use `uv run` to run Python scripts (e.g., `uv run python scripts/example.py`).
|
||||
- Project uses **src-layout**: source code is in `src/remove_ai_watermarks/`.
|
||||
|
||||
## Project structure
|
||||
|
||||
- `src/remove_ai_watermarks/` — main package (CLI, engines, metadata)
|
||||
- `src/remove_ai_watermarks/noai/` — vendored noai-watermark code (invisible watermark removal)
|
||||
- `src/remove_ai_watermarks/noai/ctrlregen/` — vendored CtrlRegen pipeline
|
||||
- `src/remove_ai_watermarks/assets/` — embedded watermark alpha maps (PNG)
|
||||
- `tests/` — pytest test suite
|
||||
- `data/samples/` — sample images for manual testing
|
||||
|
||||
## Testing
|
||||
|
||||
- Run tests: `uv run pytest`
|
||||
- Run with coverage: `uv run pytest --cov=remove_ai_watermarks --cov-report=term-missing`
|
||||
- Tests create temporary files via `tmp_path` — they do NOT depend on `data/samples/`.
|
||||
- ML pipeline modules (invisible_engine, ctrlregen, watermark_remover) require GPU and multi-GB model downloads — unit tests for these are limited to availability checks and constants.
|
||||
|
||||
## Git
|
||||
|
||||
- Do not create commits or push unless explicitly asked by the user.
|
||||
|
||||
## Code quality
|
||||
|
||||
- Always run `./maintain.sh` before committing to ensure code quality (ruff, pyright).
|
||||
- **CRITICAL**: Before using any method on a client or class, ALWAYS verify that the method exists by reading the class file first.
|
||||
- **NEVER** assume or invent methods. If the method doesn't exist, either use an existing alternative method or explicitly create the new method first.
|
||||
|
||||
## Documentation
|
||||
|
||||
- Update documentation (README.md) when you change functionality or add new features.
|
||||
- **DO NOT** create artifact documentation (walkthrough.md, verification.md) for bug fixes or small corrections.
|
||||
|
||||
## Language
|
||||
|
||||
- Use only English for all code, comments, docstrings, documentation, commit messages, and project artifacts.
|
||||
- Communicate with users in Russian, but all technical content must be in English.
|
||||
|
||||
## API integrations
|
||||
|
||||
- Do not assume or invent API request/response structures.
|
||||
- Always verify API payloads against official documentation before implementing.
|
||||
- Use Context7 MCP to retrieve up-to-date library documentation when working with external APIs or packages.
|
||||
|
||||
## Work completion
|
||||
|
||||
- You have **NO time limits**. Always complete the full task in one go.
|
||||
- Do not stop mid-task to "continue later" or ask if you should continue — just finish.
|
||||
- Do not split work into multiple commits unless the task is genuinely large.
|
||||
@@ -0,0 +1,3 @@
|
||||
# HuggingFace token (required for invisible watermark removal)
|
||||
# Get yours at: https://huggingface.co/settings/tokens
|
||||
HF_TOKEN=
|
||||
@@ -1,2 +1,9 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
|
||||
# Binary files
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
*.jpeg binary
|
||||
*.webp binary
|
||||
*.pt binary
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
# Dependencies
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Environment secrets
|
||||
.env
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Test results
|
||||
data/results/
|
||||
|
||||
# Reference materials
|
||||
_refs/
|
||||
|
||||
# Downloaded model weights
|
||||
yolov8n.pt
|
||||
.coverage
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 wiltodelta
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,147 @@
|
||||
# Remove-AI-Watermarks
|
||||
|
||||
Unified tool for removing **visible** and **invisible** AI watermarks from images.
|
||||
|
||||
## Features
|
||||
|
||||
- **Visible watermark removal** — Gemini sparkle logo via reverse alpha blending (fast, offline, deterministic)
|
||||
- **Invisible watermark removal** — SynthID, StableSignature, TreeRing via diffusion-based regeneration
|
||||
- **AI metadata stripping** — EXIF, PNG text chunks, C2PA provenance manifests
|
||||
- **Analog Humanizer** — film grain and chromatic aberration injection to bypass AI image classifiers
|
||||
- **Smart Face Protection** — automatic extraction and blending of human faces to prevent AI distortion
|
||||
- **High-Res Upscaler** — prevents resolution degradation during invisible watermark removal
|
||||
- **Batch processing** — process entire directories
|
||||
- **Detection** — three-stage NCC watermark detection with confidence scoring
|
||||
|
||||
## Examples
|
||||
|
||||
| Before (Watermarked) | After (Cleaned) |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
## Installation
|
||||
|
||||
### Recommended (macOS)
|
||||
|
||||
Install as an isolated CLI tool — no need to manage virtual environments:
|
||||
|
||||
```bash
|
||||
# Using pipx (brew install pipx)
|
||||
pipx install git+https://github.com/wiltodelta/remove-ai-watermarks.git
|
||||
|
||||
# Or using uv (brew install uv)
|
||||
uv tool install git+https://github.com/wiltodelta/remove-ai-watermarks.git
|
||||
```
|
||||
|
||||
### Install from repository (macOS)
|
||||
|
||||
**Prerequisites:** Python 3.10+ and `pip` (or [`uv`](https://docs.astral.sh/uv/)).
|
||||
|
||||
```bash
|
||||
# 1. Clone the repository
|
||||
git clone https://github.com/wiltodelta/remove-ai-watermarks.git
|
||||
cd remove-ai-watermarks
|
||||
|
||||
# 2. Install the package in editable mode
|
||||
pip install -e .
|
||||
|
||||
# Or, if you use uv:
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
After installation the `remove-ai-watermarks` command is available system-wide.
|
||||
|
||||
#### Invisible watermark removal (optional)
|
||||
|
||||
Invisible removal uses diffusion models and requires a **HuggingFace token** and a decent GPU (CUDA) or Apple Silicon (MPS).
|
||||
|
||||
```bash
|
||||
# 1. Create a free token at https://huggingface.co/settings/tokens
|
||||
# 2. Copy the example env file and paste your token
|
||||
cp .env.example .env
|
||||
# Edit .env and set HF_TOKEN=hf_your_token_here
|
||||
|
||||
# 3. On first run, the model (~2 GB) will be downloaded automatically.
|
||||
# On macOS with Apple Silicon, MPS acceleration is used by default.
|
||||
# On macOS without GPU, add --device cpu (inference will be slow).
|
||||
```
|
||||
|
||||
#### Developer setup
|
||||
|
||||
```bash
|
||||
# Install with dev dependencies (pytest, ruff, pyright)
|
||||
pip install -e ".[dev]"
|
||||
# Or with uv:
|
||||
uv pip install -e ".[dev]"
|
||||
|
||||
# Run tests
|
||||
pytest
|
||||
|
||||
# Run linters
|
||||
./maintain.sh
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### CLI
|
||||
|
||||
```bash
|
||||
# Remove visible Gemini watermark
|
||||
remove-ai-watermarks visible image.png -o clean.png
|
||||
|
||||
# Remove invisible watermarks (SynthID etc.) with optimal quality retention
|
||||
remove-ai-watermarks invisible image.png -o clean.png --humanize 4.0
|
||||
|
||||
# Strip AI metadata
|
||||
remove-ai-watermarks metadata image.png --check
|
||||
remove-ai-watermarks metadata image.png --remove
|
||||
|
||||
# Batch processing
|
||||
remove-ai-watermarks batch ./images/ --mode visible
|
||||
|
||||
# Full pipeline: visible + invisible + metadata
|
||||
remove-ai-watermarks all image.png -o clean.png
|
||||
```
|
||||
|
||||
### Python API
|
||||
|
||||
```python
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
||||
import cv2
|
||||
|
||||
engine = GeminiEngine()
|
||||
image = cv2.imread("watermarked.png")
|
||||
|
||||
# Detect
|
||||
result = engine.detect_watermark(image)
|
||||
print(f"Detected: {result.detected} (confidence: {result.confidence:.1%})")
|
||||
|
||||
# Remove
|
||||
clean = engine.remove_watermark(image)
|
||||
cv2.imwrite("clean.png", clean)
|
||||
```
|
||||
|
||||
### Metadata stripping
|
||||
|
||||
```python
|
||||
from remove_ai_watermarks.metadata import has_ai_metadata, remove_ai_metadata
|
||||
from pathlib import Path
|
||||
|
||||
if has_ai_metadata(Path("image.png")):
|
||||
remove_ai_metadata(Path("image.png"), Path("clean.png"))
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python ≥ 3.10
|
||||
- **Visible removal / metadata**: CPU only, no GPU required
|
||||
- **Invisible removal**: GPU recommended (CUDA or MPS), works on CPU (slow)
|
||||
|
||||
## Credits
|
||||
|
||||
- [noai-watermark](https://github.com/mertizci/noai-watermark) by mertizci — invisible watermark removal
|
||||
- [GeminiWatermarkTool](https://github.com/allenk/GeminiWatermarkTool) by Allen Kuo — visible watermark removal algorithm
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
After Width: | Height: | Size: 6.8 MiB |
|
After Width: | Height: | Size: 5.7 MiB |
|
After Width: | Height: | Size: 4.2 MiB |
|
After Width: | Height: | Size: 5.6 MiB |
|
After Width: | Height: | Size: 3.6 MiB |
|
After Width: | Height: | Size: 5.7 MiB |
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
uv sync --all-extras
|
||||
uv-outdated
|
||||
uv run uv-secure --ignore-unfixed
|
||||
uv run ruff check --fix
|
||||
uv run ruff format
|
||||
uv run pyright
|
||||
@@ -0,0 +1,73 @@
|
||||
[project]
|
||||
name = "remove-ai-watermarks"
|
||||
version = "0.1.0"
|
||||
description = "Unified tool for removing visible and invisible AI watermarks from images"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "MIT"}
|
||||
dependencies = [
|
||||
"pillow>=10.0.0",
|
||||
"piexif>=1.1.3",
|
||||
"numpy>=1.24.0",
|
||||
"opencv-python>=4.8.0",
|
||||
"click>=8.0.0",
|
||||
"rich>=13.0.0",
|
||||
"torch>=2.0.0",
|
||||
"diffusers>=0.25.0",
|
||||
"transformers>=4.35.0",
|
||||
"accelerate>=0.25.0",
|
||||
"controlnet-aux>=0.0.9",
|
||||
"color-matcher",
|
||||
"safetensors",
|
||||
"python-dotenv>=1.0.0",
|
||||
"ultralytics>=8.0.0",
|
||||
"requests>=2.33.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"ruff>=0.4.0",
|
||||
]
|
||||
all = ["remove-ai-watermarks[dev]"]
|
||||
|
||||
[project.scripts]
|
||||
remove-ai-watermarks = "remove_ai_watermarks.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/wiltodelta/remove-ai-watermarks"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/remove_ai_watermarks"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["src"]
|
||||
addopts = "-v --tb=short"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 120
|
||||
|
||||
[tool.pyright]
|
||||
exclude = ["tests"]
|
||||
reportOptionalMemberAccess = "warning"
|
||||
reportPrivateImportUsage = false
|
||||
reportInvalidTypeForm = false
|
||||
reportOptionalCall = false
|
||||
reportMissingImports = "warning"
|
||||
reportCallIssue = "warning"
|
||||
reportAttributeAccessIssue = "warning"
|
||||
reportArgumentType = "warning"
|
||||
reportPossiblyUnboundVariable = "warning"
|
||||
reportAssignmentType = "warning"
|
||||
reportOperatorIssue = "warning"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "N", "UP"]
|
||||
ignore = ["E501"]
|
||||
@@ -0,0 +1,3 @@
|
||||
"""Remove-AI-Watermarks: Unified tool for removing visible and invisible AI watermarks."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1 @@
|
||||
"""Embedded assets for visible watermark removal."""
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
|
After Width: | Height: | Size: 8.0 KiB |
@@ -0,0 +1,598 @@
|
||||
"""Unified CLI for remove-ai-watermarks.
|
||||
|
||||
Provides commands for:
|
||||
- Visible watermark removal (Gemini sparkle) — works offline, fast
|
||||
- Invisible watermark removal (SynthID etc.) — requires GPU/diffusion models
|
||||
- AI metadata stripping — lightweight, no ML deps needed
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
|
||||
from rich.table import Table
|
||||
|
||||
from remove_ai_watermarks import __version__
|
||||
|
||||
console = Console()
|
||||
|
||||
SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
|
||||
|
||||
def _setup_logging(verbose: bool) -> None:
|
||||
level = logging.DEBUG if verbose else logging.WARNING
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(name)s | %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
|
||||
def _banner() -> None:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold cyan]Remove-AI-Watermarks[/] [dim]v{__version__}[/]\n[dim]Visible & invisible watermark removal[/]",
|
||||
border_style="cyan",
|
||||
padding=(0, 2),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _validate_image(path: Path) -> Path:
|
||||
if not path.exists():
|
||||
console.print(f"[red]Error:[/] File not found: {path}")
|
||||
raise SystemExit(1)
|
||||
if path.suffix.lower() not in SUPPORTED_FORMATS:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/] {path.suffix} may not be supported (expected: {', '.join(SUPPORTED_FORMATS)})"
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
# ── Main group ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.version_option(__version__, prog_name="remove-ai-watermarks")
|
||||
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging.")
|
||||
@click.pass_context
|
||||
def main(ctx: click.Context, verbose: bool) -> None:
|
||||
"""Remove visible and invisible AI watermarks from images."""
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # Load .env (e.g. HF_TOKEN)
|
||||
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["verbose"] = verbose
|
||||
_setup_logging(verbose)
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
_banner()
|
||||
click.echo(ctx.get_help())
|
||||
|
||||
|
||||
# ── Visible (Gemini) watermark removal ───────────────────────────────
|
||||
|
||||
|
||||
@main.command("visible")
|
||||
@click.argument("source", type=click.Path(exists=True, path_type=Path))
|
||||
@click.option(
|
||||
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
|
||||
)
|
||||
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting cleanup after removal.")
|
||||
@click.option(
|
||||
"--inpaint-method", type=click.Choice(["ns", "telea", "gaussian"]), default="ns", help="Inpainting method."
|
||||
)
|
||||
@click.option("--inpaint-strength", type=float, default=0.85, help="Inpainting blend strength (0.0–1.0).")
|
||||
@click.option("--detect/--no-detect", default=True, help="Detect watermark before removal.")
|
||||
@click.option("--detect-threshold", type=float, default=0.25, help="Detection confidence threshold.")
|
||||
@click.option("--strip-metadata/--keep-metadata", default=True, help="Strip AI metadata from output.")
|
||||
@click.pass_context
|
||||
def cmd_visible(
|
||||
ctx: click.Context,
|
||||
source: Path,
|
||||
output: Path | None,
|
||||
inpaint: bool,
|
||||
inpaint_method: str,
|
||||
inpaint_strength: float,
|
||||
detect: bool,
|
||||
detect_threshold: float,
|
||||
strip_metadata: bool,
|
||||
) -> None:
|
||||
"""Remove visible Gemini watermark (sparkle logo) from an image.
|
||||
|
||||
Uses reverse alpha blending — fast, deterministic, offline.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
||||
|
||||
_banner()
|
||||
source = _validate_image(source)
|
||||
|
||||
if output is None:
|
||||
output = source.with_stem(source.stem + "_clean")
|
||||
|
||||
engine = GeminiEngine()
|
||||
|
||||
# Load image
|
||||
image = cv2.imread(str(source), cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
console.print(f"[red]Error:[/] Failed to read image: {source}")
|
||||
raise SystemExit(1)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}×{h})")
|
||||
|
||||
# Detection (we always detect softly, to find dynamic region for inpainting)
|
||||
with console.status("[cyan]Detecting watermark…[/]"):
|
||||
det = engine.detect_watermark(image)
|
||||
|
||||
if detect:
|
||||
if det.detected:
|
||||
console.print(
|
||||
f" [green]✓[/] Watermark detected "
|
||||
f"[dim](confidence: {det.confidence:.1%}, "
|
||||
f"spatial: {det.spatial_score:.3f}, "
|
||||
f"gradient: {det.gradient_score:.3f})[/]"
|
||||
)
|
||||
else:
|
||||
console.print(f" [yellow]⚠[/] Watermark not detected [dim](confidence: {det.confidence:.1%})[/]")
|
||||
if det.confidence < detect_threshold:
|
||||
console.print(" [dim]Skipping. Use --no-detect to force removal.[/]")
|
||||
raise SystemExit(0)
|
||||
|
||||
# Removal
|
||||
t0 = time.monotonic()
|
||||
with console.status("[cyan]Removing watermark…[/]"):
|
||||
result = engine.remove_watermark(image)
|
||||
|
||||
if inpaint:
|
||||
if det.confidence > 0.15:
|
||||
region = det.region
|
||||
else:
|
||||
from remove_ai_watermarks.gemini_engine import get_watermark_config
|
||||
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
region = (pos[0], pos[1], config.logo_size, config.logo_size)
|
||||
result = engine.inpaint_residual(
|
||||
result,
|
||||
region,
|
||||
strength=inpaint_strength,
|
||||
method=inpaint_method,
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
|
||||
# Save
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(output), result)
|
||||
|
||||
# Strip metadata
|
||||
if strip_metadata:
|
||||
try:
|
||||
from remove_ai_watermarks.metadata import remove_ai_metadata
|
||||
|
||||
remove_ai_metadata(output, output)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
size_kb = output.stat().st_size / 1024
|
||||
console.print(f" [green]✓[/] Saved: {output} [dim]({size_kb:.0f} KB, {elapsed:.2f}s)[/]")
|
||||
|
||||
|
||||
# ── Invisible watermark removal ─────────────────────────────────────
|
||||
|
||||
|
||||
@main.command("invisible")
|
||||
@click.argument("source", type=click.Path(exists=True, path_type=Path))
|
||||
@click.option(
|
||||
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
|
||||
)
|
||||
@click.option("--strength", type=float, default=0.02, help="Denoising strength (0.0–1.0). Default: 0.02.")
|
||||
@click.option("--steps", type=int, default=100, help="Number of denoising steps. Default: 100.")
|
||||
@click.option("--pipeline", type=click.Choice(["default", "ctrlregen"]), default="default", help="Pipeline profile.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option("--humanize", type=float, default=0.0, help="Humanization strength (0.0–1.0) for invisible removal.")
|
||||
@click.pass_context
|
||||
def cmd_invisible(
|
||||
ctx: click.Context,
|
||||
source: Path,
|
||||
output: Path | None,
|
||||
strength: float,
|
||||
steps: int,
|
||||
pipeline: str,
|
||||
device: str,
|
||||
seed: int | None,
|
||||
hf_token: str | None,
|
||||
humanize: float,
|
||||
) -> None:
|
||||
"""Remove invisible AI watermarks (SynthID, StableSignature, TreeRing).
|
||||
|
||||
Uses diffusion-based regeneration. Requires GPU for reasonable speed.
|
||||
"""
|
||||
from remove_ai_watermarks.invisible_engine import InvisibleEngine
|
||||
|
||||
source = _validate_image(source)
|
||||
if output is None:
|
||||
output = source.with_stem(source.stem + "_clean")
|
||||
|
||||
device_str = None if device == "auto" else device
|
||||
|
||||
def progress_cb(msg: str) -> None:
|
||||
console.print(f" [dim]{msg}[/]")
|
||||
|
||||
engine = InvisibleEngine(
|
||||
device=device_str,
|
||||
pipeline=pipeline,
|
||||
hf_token=hf_token,
|
||||
progress_callback=progress_cb,
|
||||
)
|
||||
|
||||
console.print(f" [dim]Input:[/] {source.name}")
|
||||
console.print(f" [dim]Pipeline:[/] {pipeline}")
|
||||
console.print(f" [dim]Strength:[/] {strength} Steps: {steps}")
|
||||
|
||||
t0 = time.monotonic()
|
||||
result_path = engine.remove_watermark(
|
||||
image_path=source,
|
||||
output_path=output,
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=None,
|
||||
seed=seed,
|
||||
humanize=humanize,
|
||||
)
|
||||
elapsed = time.monotonic() - t0
|
||||
|
||||
size_kb = result_path.stat().st_size / 1024
|
||||
console.print(f"\n [green]✓[/] Saved: {result_path} [dim]({size_kb:.0f} KB, {elapsed:.1f}s)[/]")
|
||||
|
||||
|
||||
# ── Metadata operations ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@main.command("metadata")
|
||||
@click.argument("source", type=click.Path(exists=True, path_type=Path))
|
||||
@click.option("--check", is_flag=True, help="Check for AI metadata (don't modify).")
|
||||
@click.option("--remove", is_flag=True, help="Remove AI metadata.")
|
||||
@click.option(
|
||||
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: overwrite source)."
|
||||
)
|
||||
@click.option("--keep-standard/--remove-all", default=True, help="Keep standard metadata (Author, Title, etc.).")
|
||||
@click.pass_context
|
||||
def cmd_metadata(
|
||||
ctx: click.Context,
|
||||
source: Path,
|
||||
check: bool,
|
||||
remove: bool,
|
||||
output: Path | None,
|
||||
keep_standard: bool,
|
||||
) -> None:
|
||||
"""Check or remove AI-generation metadata from images.
|
||||
|
||||
Strips EXIF AI tags, PNG text chunks, and C2PA provenance manifests.
|
||||
"""
|
||||
from remove_ai_watermarks.metadata import get_ai_metadata, has_ai_metadata, remove_ai_metadata
|
||||
|
||||
_banner()
|
||||
source = _validate_image(source)
|
||||
|
||||
if check or (not remove):
|
||||
has_ai = has_ai_metadata(source)
|
||||
if has_ai:
|
||||
console.print(f" [yellow]⚠[/] AI metadata detected in {source.name}:")
|
||||
meta = get_ai_metadata(source)
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Key", style="cyan")
|
||||
table.add_column("Value")
|
||||
for k, v in meta.items():
|
||||
table.add_row(k, str(v)[:80])
|
||||
console.print(table)
|
||||
else:
|
||||
console.print(f" [green]✓[/] No AI metadata found in {source.name}")
|
||||
|
||||
if not remove:
|
||||
return
|
||||
|
||||
# Remove
|
||||
out = remove_ai_metadata(source, output, keep_standard=keep_standard)
|
||||
console.print(f" [green]✓[/] AI metadata stripped → {out}")
|
||||
|
||||
|
||||
# ── Combined "all" mode ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@main.command("all")
|
||||
@click.argument("source", type=click.Path(exists=True, path_type=Path))
|
||||
@click.option(
|
||||
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
|
||||
)
|
||||
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting cleanup after visible removal.")
|
||||
@click.option(
|
||||
"--inpaint-method", type=click.Choice(["ns", "telea", "gaussian"]), default="ns", help="Inpainting method."
|
||||
)
|
||||
@click.option("--strength", type=float, default=0.04, help="Invisible watermark denoising strength (0.0–1.0).")
|
||||
@click.option("--steps", type=int, default=50, help="Number of denoising steps for invisible removal.")
|
||||
@click.option(
|
||||
"--pipeline",
|
||||
type=click.Choice(["default", "ctrlregen"]),
|
||||
default="default",
|
||||
help="Pipeline profile for invisible removal.",
|
||||
)
|
||||
@click.option("--model", type=str, default=None, help="HuggingFace model ID for invisible removal.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option("--humanize", type=float, default=0.0, help="Humanization strength (0.0–1.0) for invisible removal.")
|
||||
@click.pass_context
|
||||
def cmd_all(
|
||||
ctx: click.Context,
|
||||
source: Path,
|
||||
output: Path | None,
|
||||
inpaint: bool,
|
||||
inpaint_method: str,
|
||||
strength: float,
|
||||
steps: int,
|
||||
pipeline: str,
|
||||
model: str | None,
|
||||
device: str,
|
||||
seed: int | None,
|
||||
hf_token: str | None,
|
||||
humanize: float,
|
||||
) -> None:
|
||||
"""Remove ALL watermarks: visible + invisible + metadata.
|
||||
|
||||
Runs the full pipeline in order:
|
||||
1. Visible watermark removal (Gemini sparkle, reverse alpha blending)
|
||||
2. Invisible watermark removal (SynthID etc., diffusion regeneration)
|
||||
3. AI metadata stripping (EXIF, PNG text, C2PA)
|
||||
|
||||
If invisible watermark deps are not installed, skips step 2 with a warning.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine, get_watermark_config
|
||||
|
||||
_banner()
|
||||
source = _validate_image(source)
|
||||
|
||||
if output is None:
|
||||
output = source.with_stem(source.stem + "_clean")
|
||||
|
||||
t0 = time.monotonic()
|
||||
|
||||
# ── Step 1: Visible watermark ────────────────────────────────
|
||||
console.print("\n [bold cyan]① Visible watermark removal[/]")
|
||||
engine = GeminiEngine()
|
||||
image = cv2.imread(str(source), cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
console.print(f"[red]Error:[/] Failed to read image: {source}")
|
||||
raise SystemExit(1)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}×{h})")
|
||||
|
||||
with console.status("[cyan]Removing visible watermark…[/]"):
|
||||
det = engine.detect_watermark(image)
|
||||
if det.detected:
|
||||
result = engine.remove_watermark(image)
|
||||
if inpaint:
|
||||
if det.confidence > 0.15:
|
||||
region = det.region
|
||||
else:
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
region = (pos[0], pos[1], config.logo_size, config.logo_size)
|
||||
result = engine.inpaint_residual(result, region, method=inpaint_method)
|
||||
console.print(" [green]✓[/] Visible watermark removed")
|
||||
else:
|
||||
result = image.copy()
|
||||
console.print(" [dim]Skipped (no visible watermark detected)[/]")
|
||||
|
||||
# Save intermediate
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(output), result)
|
||||
|
||||
# ── Step 2: Invisible watermark ──────────────────────────────
|
||||
console.print("\n [bold cyan]② Invisible watermark removal[/]")
|
||||
from remove_ai_watermarks.invisible_engine import InvisibleEngine
|
||||
|
||||
device_str = None if device == "auto" else device
|
||||
|
||||
def progress_cb(msg: str) -> None:
|
||||
console.print(f" [dim]{msg}[/]")
|
||||
|
||||
inv_engine = InvisibleEngine(
|
||||
model_id=model,
|
||||
device=device_str,
|
||||
pipeline=pipeline,
|
||||
hf_token=hf_token,
|
||||
progress_callback=progress_cb,
|
||||
)
|
||||
|
||||
console.print(f" [dim]Strength:[/] {strength} Steps: {steps}")
|
||||
inv_engine.remove_watermark(
|
||||
image_path=output,
|
||||
output_path=output,
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
seed=seed,
|
||||
humanize=humanize,
|
||||
)
|
||||
console.print(" [green]✓[/] Invisible watermark removed")
|
||||
|
||||
# ── Step 3: Metadata ─────────────────────────────────────────
|
||||
console.print("\n [bold cyan]③ AI metadata stripping[/]")
|
||||
try:
|
||||
from remove_ai_watermarks.metadata import remove_ai_metadata
|
||||
|
||||
remove_ai_metadata(output, output)
|
||||
console.print(" [green]✓[/] AI metadata stripped")
|
||||
except Exception as e:
|
||||
console.print(f" [yellow]⚠[/] Metadata strip failed: {e}")
|
||||
|
||||
# ── Done ─────────────────────────────────────────────────────
|
||||
elapsed = time.monotonic() - t0
|
||||
size_kb = output.stat().st_size / 1024
|
||||
console.print(f"\n [bold green]✓ Done:[/] {output} [dim]({size_kb:.0f} KB, {elapsed:.1f}s total)[/]")
|
||||
|
||||
|
||||
# ── Batch command ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@main.command("batch")
|
||||
@click.argument("directory", type=click.Path(exists=True, file_okay=False, path_type=Path))
|
||||
@click.option(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=click.Path(path_type=Path),
|
||||
default=None,
|
||||
help="Output directory (default: <dir>_clean/).",
|
||||
)
|
||||
@click.option(
|
||||
"--mode", type=click.Choice(["visible", "invisible", "metadata", "all"]), default="visible", help="Processing mode."
|
||||
)
|
||||
@click.option("--strength", type=float, default=None, help="Denoising strength (invisible mode).")
|
||||
@click.option("--steps", type=int, default=50, help="Number of denoising steps (invisible mode).")
|
||||
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting (visible mode).")
|
||||
@click.option("--humanize", type=float, default=0.0, help="Humanization strength (0.0–1.0) for invisible removal.")
|
||||
@click.option("--pipeline", type=click.Choice(["default", "ctrlregen"]), default="default", help="Pipeline profile.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.pass_context
|
||||
def cmd_batch(
|
||||
ctx: click.Context,
|
||||
directory: Path,
|
||||
mode: str,
|
||||
output_dir: Path | None,
|
||||
strength: float | None,
|
||||
steps: int,
|
||||
pipeline: str,
|
||||
device: str,
|
||||
seed: int | None,
|
||||
hf_token: str | None,
|
||||
inpaint: bool,
|
||||
humanize: float,
|
||||
) -> None:
|
||||
"""Process all images in a directory."""
|
||||
_banner()
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = directory.parent / (directory.name + "_clean")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Assuming supported_formats is defined elsewhere
|
||||
supported_formats = [".jpg", ".jpeg", ".png", ".webp"]
|
||||
|
||||
images = sorted(p for p in directory.iterdir() if p.suffix.lower() in supported_formats)
|
||||
|
||||
if not images:
|
||||
console.print(f"[yellow]No supported images found in {directory}[/]")
|
||||
return
|
||||
|
||||
console.print(f" Found [bold]{len(images)}[/] images in {directory}")
|
||||
console.print(f" Output → {output_dir}")
|
||||
console.print(f" Mode: [cyan]{mode}[/]")
|
||||
|
||||
processed = 0
|
||||
errors = 0
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Processing…", total=len(images))
|
||||
|
||||
for img_path in images:
|
||||
out_path = output_dir / img_path.name
|
||||
progress.update(task, description=f"[cyan]{img_path.name}[/]")
|
||||
|
||||
try:
|
||||
if mode in ("visible", "all"):
|
||||
import cv2
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import (
|
||||
GeminiEngine,
|
||||
get_watermark_config,
|
||||
)
|
||||
|
||||
if "_vis_engine" not in ctx.obj:
|
||||
ctx.obj["_vis_engine"] = GeminiEngine()
|
||||
engine = ctx.obj["_vis_engine"]
|
||||
read_path = img_path
|
||||
if mode == "all" and out_path.exists():
|
||||
read_path = out_path
|
||||
image = cv2.imread(str(read_path), cv2.IMREAD_COLOR)
|
||||
if image is None:
|
||||
raise ValueError("Failed to read image")
|
||||
|
||||
det = engine.detect_watermark(image)
|
||||
if det.detected:
|
||||
result = engine.remove_watermark(image)
|
||||
if inpaint:
|
||||
if det.confidence > 0.15:
|
||||
region = det.region
|
||||
else:
|
||||
h, w = image.shape[:2]
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
region = (pos[0], pos[1], config.logo_size, config.logo_size)
|
||||
|
||||
result = engine.inpaint_residual(result, region)
|
||||
else:
|
||||
result = image.copy()
|
||||
|
||||
cv2.imwrite(str(out_path), result)
|
||||
|
||||
if mode in ("invisible", "all"):
|
||||
from remove_ai_watermarks.invisible_engine import (
|
||||
is_available as invisible_available,
|
||||
)
|
||||
|
||||
if invisible_available():
|
||||
from remove_ai_watermarks.invisible_engine import InvisibleEngine
|
||||
|
||||
engine_inv = InvisibleEngine()
|
||||
engine_inv.remove_watermark(
|
||||
img_path if mode == "invisible" else out_path,
|
||||
out_path,
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
seed=seed,
|
||||
humanize=humanize,
|
||||
)
|
||||
|
||||
if mode in ("metadata", "all"):
|
||||
from remove_ai_watermarks.metadata import remove_ai_metadata
|
||||
|
||||
remove_ai_metadata(img_path if mode == "metadata" else out_path, out_path)
|
||||
|
||||
processed += 1
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
if ctx.obj.get("verbose"):
|
||||
console.print(f" [red]✗[/] {img_path.name}: {e}")
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
console.print(f"\n [green]✓[/] {processed} processed" + (f" [red]✗[/] {errors} errors" if errors else ""))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
HAS_YOLO = True
|
||||
except ImportError:
|
||||
HAS_YOLO = False
|
||||
|
||||
logger = logging.getLogger("face_protector")
|
||||
|
||||
|
||||
class FaceProtector:
|
||||
"""
|
||||
Detects faces in an image and provides methods to seamlessly paste them back
|
||||
onto the an upscaled/processed image to preserve facial details that may have
|
||||
been destroyed by latent diffusion or other algorithms.
|
||||
"""
|
||||
|
||||
def __init__(self, use_yolo: bool = True, model_name: str = "yolov8n.pt"):
|
||||
self.use_yolo = use_yolo and HAS_YOLO
|
||||
self.detector = None
|
||||
self.haar_cascade = None
|
||||
|
||||
if self.use_yolo:
|
||||
logger.info(f"Loading YOLO model '{model_name}' for face protection...")
|
||||
self.detector = YOLO(model_name)
|
||||
else:
|
||||
if use_yolo and not HAS_YOLO:
|
||||
logger.warning(
|
||||
"ultralytics YOLO is not installed. Falling back to OpenCV Haar Cascades. Install ultralytics with `pip install ultralytics` for better face detection."
|
||||
)
|
||||
logger.info("Loading OpenCV Haar Cascade for face protection...")
|
||||
cascade_path = Path(cv2.__file__).parent / "data" / "haarcascade_frontalface_default.xml"
|
||||
if not cascade_path.exists():
|
||||
cascade_path = "haarcascade_frontalface_default.xml"
|
||||
self.haar_cascade = cv2.CascadeClassifier(str(cascade_path))
|
||||
|
||||
def detect_face_bboxes(self, image: np.ndarray) -> list[tuple[int, int, int, int]]:
|
||||
"""
|
||||
Detect faces and return bounding boxes as (x1, y1, x2, y2).
|
||||
"""
|
||||
if self.use_yolo and self.detector is not None:
|
||||
# For standard YOLOv8n, 'person' is class 0. We'll use person bounding boxes
|
||||
# as a proxy for faces/people to protect them. If using a specific face model, adjust classes.
|
||||
results = self.detector(image, verbose=False, classes=[0])
|
||||
bboxes = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
bboxes.append((int(x1), int(y1), int(x2), int(y2)))
|
||||
return bboxes
|
||||
|
||||
else:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
faces = self.haar_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
||||
bboxes = []
|
||||
for x, y, w, h in faces:
|
||||
# Add a 20% margin around the haar cascade face box
|
||||
margin_x = int(w * 0.2)
|
||||
margin_y = int(h * 0.2)
|
||||
x1 = max(0, x - margin_x)
|
||||
y1 = max(0, y - int(margin_y * 1.5)) # more margin on top for hair
|
||||
x2 = min(image.shape[1], x + w + margin_x)
|
||||
y2 = min(image.shape[0], y + h + margin_y)
|
||||
bboxes.append((x1, y1, x2, y2))
|
||||
return bboxes
|
||||
|
||||
def extract_faces(self, image: np.ndarray) -> list[tuple[tuple[int, int, int, int], np.ndarray]]:
|
||||
"""
|
||||
Extract faces from the image.
|
||||
Returns a list of (bbox, face_crop) tuples.
|
||||
"""
|
||||
bboxes = self.detect_face_bboxes(image)
|
||||
faces = []
|
||||
for bbox in bboxes:
|
||||
x1, y1, x2, y2 = bbox
|
||||
faces.append((bbox, image[y1:y2, x1:x2].copy()))
|
||||
return faces
|
||||
|
||||
def restore_faces(
|
||||
self, processed_image: np.ndarray, original_faces: list[tuple[tuple[int, int, int, int], np.ndarray]]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Paste original faces back onto the processed image using seamless cloning
|
||||
or soft blending so the edges don't show.
|
||||
"""
|
||||
if not original_faces:
|
||||
return processed_image
|
||||
|
||||
result = processed_image.copy()
|
||||
for (x1, y1, x2, y2), face_crop in original_faces:
|
||||
h, w = face_crop.shape[:2]
|
||||
|
||||
# If the processed image was resized, we'd need to resize face_crop, but
|
||||
# pipeline ensures the output from InvisibleEngine is the same size or we resize it back before this.
|
||||
if result.shape[:2] != processed_image.shape[:2]:
|
||||
continue # Safety bypass
|
||||
|
||||
try:
|
||||
# Create a soft alpha mask for the face crop to smoothly blend it
|
||||
mask = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
# Inner ellipse is pure white
|
||||
cv2.ellipse(mask, (w // 2, h // 2), (int(w * 0.4), int(h * 0.4)), 0, 0, 360, 1.0, -1)
|
||||
|
||||
# Blur the mask heavily for soft edges
|
||||
blur_size = max(w, h) // 4
|
||||
if blur_size % 2 == 0:
|
||||
blur_size += 1
|
||||
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
|
||||
mask = cv2.merge([mask, mask, mask])
|
||||
|
||||
# Blend
|
||||
target_roi = result[y1:y2, x1:x2].astype(np.float32)
|
||||
src_roi = face_crop.astype(np.float32)
|
||||
|
||||
blended = src_roi * mask + target_roi * (1.0 - mask)
|
||||
result[y1:y2, x1:x2] = blended.astype(np.uint8)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to restore face at {x1},{y1} to {x2},{y2}: {e}")
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,549 @@
|
||||
"""Gemini visible watermark removal engine.
|
||||
|
||||
Port of the GeminiWatermarkTool reverse-alpha-blending algorithm from C++ to Python.
|
||||
Original author: Allen Kuo (allenk) — https://github.com/allenk/GeminiWatermarkTool
|
||||
|
||||
The Gemini AI watermark is applied using alpha blending:
|
||||
watermarked = α × logo + (1 - α) × original
|
||||
|
||||
We reverse this to recover the original:
|
||||
original = (watermarked - α × logo) / (1 - α)
|
||||
|
||||
The alpha maps are derived from background captures of the Gemini watermark
|
||||
on pure-black backgrounds (48×48 for small images, 96×96 for large images).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatermarkSize(Enum):
|
||||
"""Watermark size mode based on image dimensions."""
|
||||
|
||||
SMALL = "small" # 48×48, for images ≤ 1024×1024
|
||||
LARGE = "large" # 96×96, for images > 1024×1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Result of watermark detection."""
|
||||
|
||||
detected: bool = False
|
||||
confidence: float = 0.0
|
||||
region: tuple[int, int, int, int] = (0, 0, 0, 0) # x, y, w, h
|
||||
size: WatermarkSize = WatermarkSize.SMALL
|
||||
|
||||
# stage scores
|
||||
spatial_score: float = 0.0
|
||||
gradient_score: float = 0.0
|
||||
variance_score: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatermarkPosition:
|
||||
"""Watermark position configuration."""
|
||||
|
||||
margin_right: int
|
||||
margin_bottom: int
|
||||
logo_size: int
|
||||
|
||||
def get_position(self, image_width: int, image_height: int) -> tuple[int, int]:
|
||||
"""Get top-left position for a given image size."""
|
||||
x = image_width - self.margin_right - self.logo_size
|
||||
y = image_height - self.margin_bottom - self.logo_size
|
||||
return (x, y)
|
||||
|
||||
|
||||
def get_watermark_config(width: int, height: int) -> WatermarkPosition:
|
||||
"""Get the appropriate watermark configuration based on image size.
|
||||
|
||||
Rules discovered from Gemini:
|
||||
- W > 1024 AND H > 1024: 96×96 logo at (W-64-96, H-64-96)
|
||||
- Otherwise: 48×48 logo at (W-32-48, H-32-48)
|
||||
"""
|
||||
if width > 1024 and height > 1024:
|
||||
return WatermarkPosition(margin_right=64, margin_bottom=64, logo_size=96)
|
||||
return WatermarkPosition(margin_right=32, margin_bottom=32, logo_size=48)
|
||||
|
||||
|
||||
def get_watermark_size(width: int, height: int) -> WatermarkSize:
|
||||
"""Determine watermark size mode from image dimensions."""
|
||||
if width > 1024 and height > 1024:
|
||||
return WatermarkSize.LARGE
|
||||
return WatermarkSize.SMALL
|
||||
|
||||
|
||||
def _calculate_alpha_map(bg_capture: NDArray) -> NDArray:
|
||||
"""Calculate alpha map from a background capture.
|
||||
|
||||
The alpha map represents how much the watermark affects each pixel.
|
||||
alpha = max(R, G, B) / 255.0
|
||||
"""
|
||||
if len(bg_capture.shape) == 2:
|
||||
gray = bg_capture.astype(np.float32)
|
||||
elif bg_capture.shape[2] >= 3:
|
||||
# Use max of channels for brightness
|
||||
gray = np.max(bg_capture[:, :, :3], axis=2).astype(np.float32)
|
||||
else:
|
||||
gray = bg_capture[:, :, 0].astype(np.float32)
|
||||
|
||||
return gray / 255.0
|
||||
|
||||
|
||||
def _load_embedded_asset(name: str) -> NDArray:
|
||||
"""Load an embedded PNG asset and decode it with OpenCV."""
|
||||
asset_path = Path(__file__).parent / "assets" / name
|
||||
if not asset_path.exists():
|
||||
raise FileNotFoundError(f"Embedded asset not found: {asset_path}")
|
||||
|
||||
data = asset_path.read_bytes()
|
||||
buf = np.frombuffer(data, dtype=np.uint8)
|
||||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
raise RuntimeError(f"Failed to decode embedded asset: {name}")
|
||||
return img
|
||||
|
||||
|
||||
class GeminiEngine:
|
||||
"""Engine for removing visible Gemini watermarks via reverse alpha blending.
|
||||
|
||||
This is a Python port of the GeminiWatermarkTool C++ engine.
|
||||
"""
|
||||
|
||||
def __init__(self, logo_value: float = 255.0) -> None:
|
||||
"""Initialize the engine with embedded alpha maps.
|
||||
|
||||
Args:
|
||||
logo_value: The logo brightness value (default 255.0 = white).
|
||||
"""
|
||||
self.logo_value = logo_value
|
||||
|
||||
# Load embedded background captures
|
||||
bg_small = _load_embedded_asset("gemini_bg_48.png")
|
||||
bg_large = _load_embedded_asset("gemini_bg_96.png")
|
||||
|
||||
# Ensure correct sizes
|
||||
if bg_small.shape[:2] != (48, 48):
|
||||
bg_small = cv2.resize(bg_small, (48, 48), interpolation=cv2.INTER_AREA)
|
||||
if bg_large.shape[:2] != (96, 96):
|
||||
bg_large = cv2.resize(bg_large, (96, 96), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# Calculate alpha maps
|
||||
self._alpha_small = _calculate_alpha_map(bg_small)
|
||||
self._alpha_large = _calculate_alpha_map(bg_large)
|
||||
|
||||
logger.debug(
|
||||
"Alpha maps loaded: small=%s, large=%s",
|
||||
self._alpha_small.shape,
|
||||
self._alpha_large.shape,
|
||||
)
|
||||
|
||||
def get_alpha_map(self, size: WatermarkSize) -> NDArray:
|
||||
"""Get the base alpha map for a specific standard size."""
|
||||
if size == WatermarkSize.SMALL:
|
||||
return self._alpha_small
|
||||
return self._alpha_large
|
||||
|
||||
def get_interpolated_alpha(self, size_px: int) -> NDArray:
|
||||
"""Create an interpolated alpha map dynamically scaled from the high-res 96x96 base."""
|
||||
source = self._alpha_large
|
||||
if size_px == source.shape[1]:
|
||||
return source.copy()
|
||||
|
||||
interp = cv2.INTER_LINEAR if size_px > source.shape[1] else cv2.INTER_AREA
|
||||
return cv2.resize(source, (size_px, size_px), interpolation=interp)
|
||||
|
||||
# ── Detection ────────────────────────────────────────────────────
|
||||
|
||||
def detect_watermark(
|
||||
self,
|
||||
image: NDArray,
|
||||
force_size: WatermarkSize | None = None,
|
||||
) -> DetectionResult:
|
||||
"""Detect Gemini watermark using multi-scale Snap Engine logic (ported from C++ vendor algorithm)."""
|
||||
result = DetectionResult()
|
||||
|
||||
if image is None or image.size == 0:
|
||||
return result
|
||||
|
||||
h, w = image.shape[:2]
|
||||
base_size = force_size or get_watermark_size(w, h)
|
||||
result.size = base_size
|
||||
|
||||
# Use large alpha template (96x96) as the high-quality source for downscaling
|
||||
source_alpha = self._alpha_large
|
||||
|
||||
# Dynamically search bottom-right corner (search up to 256x256 region)
|
||||
search_size = int(min(min(w, h), 256))
|
||||
sx1 = max(0, w - search_size)
|
||||
sy1 = max(0, h - search_size)
|
||||
|
||||
search_region = image[sy1:h, sx1:w]
|
||||
if len(search_region.shape) == 3 and search_region.shape[2] >= 3:
|
||||
gray_sr = cv2.cvtColor(search_region, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray_sr = search_region.copy()
|
||||
|
||||
gray_sr_f = gray_sr.astype(np.float32) / 255.0
|
||||
|
||||
# Phase 1 & 2: Multi-scale spatial NCC search
|
||||
best_scale = 0
|
||||
best_score = -1.0
|
||||
best_raw_ncc = -1.0
|
||||
best_loc = (0, 0)
|
||||
|
||||
# Search scales from 16 to 120 (covering aggressively downscaled or slightly upscaled logos)
|
||||
for scale in range(16, 120, 2):
|
||||
if scale > search_region.shape[0] or scale > search_region.shape[1]:
|
||||
continue
|
||||
|
||||
tmpl = cv2.resize(source_alpha, (scale, scale), interpolation=cv2.INTER_AREA)
|
||||
match_res = cv2.matchTemplate(gray_sr_f, tmpl, cv2.TM_CCOEFF_NORMED)
|
||||
_, max_val, _, max_loc = cv2.minMaxLoc(match_res)
|
||||
|
||||
# Size-adjusted score to overcome NCC bias toward tiny patches (mimics C++ weight)
|
||||
weight = min(1.0, (scale / 96.0) ** 0.5)
|
||||
adj_val = max_val * weight
|
||||
|
||||
if adj_val > best_score:
|
||||
best_score = adj_val
|
||||
best_scale = scale
|
||||
best_loc = max_loc
|
||||
best_raw_ncc = max_val
|
||||
|
||||
# Exact dynamic location & size
|
||||
pos_x = sx1 + best_loc[0]
|
||||
pos_y = sy1 + best_loc[1]
|
||||
result.region = (pos_x, pos_y, best_scale, best_scale)
|
||||
result.spatial_score = float(best_raw_ncc)
|
||||
|
||||
# Generate exact alpha map for matched size
|
||||
alpha_region = self.get_interpolated_alpha(best_scale)
|
||||
|
||||
# Extract exactly the matched region for Gradient & Variance analysis
|
||||
x1 = pos_x
|
||||
y1 = pos_y
|
||||
x2 = min(w, x1 + best_scale)
|
||||
y2 = min(h, y1 + best_scale)
|
||||
|
||||
region = image[y1:y2, x1:x2]
|
||||
if len(region.shape) == 3 and region.shape[2] >= 3:
|
||||
gray_region = cv2.cvtColor(region, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray_region = region.copy()
|
||||
|
||||
gray_f = gray_region.astype(np.float32) / 255.0
|
||||
|
||||
# Adjust alpha_region if clipped by image boundary (rare, but possible)
|
||||
ay1, ax1 = 0, 0
|
||||
alpha_region = alpha_region[ay1 : ay1 + (y2 - y1), ax1 : ax1 + (x2 - x1)]
|
||||
|
||||
if result.spatial_score < 0.25:
|
||||
result.confidence = float(max(0.0, result.spatial_score * 0.5))
|
||||
return result
|
||||
|
||||
# ── Stage 2: Gradient NCC ────────────────────────────────────
|
||||
img_gx = cv2.Sobel(gray_f, cv2.CV_32F, 1, 0, ksize=3)
|
||||
img_gy = cv2.Sobel(gray_f, cv2.CV_32F, 0, 1, ksize=3)
|
||||
img_gmag = cv2.magnitude(img_gx, img_gy)
|
||||
|
||||
alpha_gx = cv2.Sobel(alpha_region, cv2.CV_32F, 1, 0, ksize=3)
|
||||
alpha_gy = cv2.Sobel(alpha_region, cv2.CV_32F, 0, 1, ksize=3)
|
||||
alpha_gmag = cv2.magnitude(alpha_gx, alpha_gy)
|
||||
|
||||
grad_match = cv2.matchTemplate(img_gmag, alpha_gmag, cv2.TM_CCOEFF_NORMED)
|
||||
_, grad_score, _, _ = cv2.minMaxLoc(grad_match)
|
||||
result.gradient_score = float(grad_score)
|
||||
|
||||
# ── Stage 3: Variance Analysis ───────────────────────────────
|
||||
var_score = 0.0
|
||||
ref_h = min(y1, best_scale)
|
||||
|
||||
if ref_h > 8:
|
||||
ref_region = image[y1 - ref_h : y1, x1:x2]
|
||||
if len(ref_region.shape) == 3:
|
||||
gray_ref = cv2.cvtColor(ref_region, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray_ref = ref_region
|
||||
|
||||
_, s_wm = cv2.meanStdDev(gray_region)
|
||||
_, s_ref = cv2.meanStdDev(gray_ref)
|
||||
|
||||
if s_ref[0][0] > 5.0:
|
||||
var_score = max(0.0, min(1.0, 1.0 - (s_wm[0][0] / s_ref[0][0])))
|
||||
|
||||
result.variance_score = float(var_score)
|
||||
|
||||
# ── Fusion ───────────────────────────────────────────────────
|
||||
confidence = result.spatial_score * 0.50 + result.gradient_score * 0.30 + var_score * 0.20
|
||||
result.confidence = float(max(0.0, min(1.0, confidence)))
|
||||
result.detected = result.confidence >= 0.35
|
||||
|
||||
logger.debug(
|
||||
"Detection: spatial=%.3f, grad=%.3f, var=%.3f → conf=%.3f (%s)",
|
||||
result.spatial_score,
|
||||
result.gradient_score,
|
||||
var_score,
|
||||
result.confidence,
|
||||
"DETECTED" if result.detected else "not detected",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# ── Removal ──────────────────────────────────────────────────────
|
||||
|
||||
def remove_watermark(
|
||||
self,
|
||||
image: NDArray,
|
||||
force_size: WatermarkSize | None = None,
|
||||
) -> NDArray:
|
||||
"""Remove Gemini visible watermark from an image using reverse alpha blending.
|
||||
|
||||
Args:
|
||||
image: BGR image as numpy array (will NOT be modified in-place).
|
||||
force_size: Force a specific watermark size (auto-detect if None).
|
||||
|
||||
Returns:
|
||||
Cleaned BGR image as numpy array.
|
||||
"""
|
||||
result = image.copy()
|
||||
|
||||
# Handle alpha channel
|
||||
if result.shape[2] == 4:
|
||||
result = cv2.cvtColor(result, cv2.COLOR_BGRA2BGR)
|
||||
elif result.shape[2] == 1:
|
||||
result = cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
h, w = result.shape[:2]
|
||||
size = force_size or get_watermark_size(w, h)
|
||||
|
||||
# Detect dynamic position & size
|
||||
detection = self.detect_watermark(image, force_size=size)
|
||||
|
||||
# If the confidence is really high (>0.15), use the dynamically detected position and size.
|
||||
if detection.confidence > 0.15:
|
||||
pos = (detection.region[0], detection.region[1])
|
||||
alpha_map = self.get_interpolated_alpha(detection.region[2])
|
||||
logger.debug(
|
||||
"Using dynamic watermark position (%d, %d) at size %dx%d [conf=%.3f]",
|
||||
pos[0],
|
||||
pos[1],
|
||||
detection.region[2],
|
||||
detection.region[3],
|
||||
detection.confidence,
|
||||
)
|
||||
else:
|
||||
config = get_watermark_config(w, h)
|
||||
pos = config.get_position(w, h)
|
||||
alpha_map = self.get_alpha_map(size)
|
||||
logger.debug(
|
||||
"Dynamic search failed. Using fallback default position (%d, %d) with %dx%d alpha map.",
|
||||
pos[0],
|
||||
pos[1],
|
||||
alpha_map.shape[1],
|
||||
alpha_map.shape[0],
|
||||
)
|
||||
|
||||
self._reverse_alpha_blend(result, alpha_map, pos)
|
||||
return result
|
||||
|
||||
def remove_watermark_custom(
|
||||
self,
|
||||
image: NDArray,
|
||||
region: tuple[int, int, int, int],
|
||||
) -> NDArray:
|
||||
"""Remove watermark from a custom region with interpolated alpha map.
|
||||
|
||||
Args:
|
||||
image: BGR image (will NOT be modified in-place).
|
||||
region: (x, y, width, height) of the watermark region.
|
||||
|
||||
Returns:
|
||||
Cleaned BGR image.
|
||||
"""
|
||||
result = image.copy()
|
||||
x, y, rw, rh = region
|
||||
|
||||
# Check standard sizes
|
||||
if rw == 48 and rh == 48:
|
||||
self._reverse_alpha_blend(result, self._alpha_small, (x, y))
|
||||
return result
|
||||
if rw == 96 and rh == 96:
|
||||
self._reverse_alpha_blend(result, self._alpha_large, (x, y))
|
||||
return result
|
||||
|
||||
# Interpolate alpha map for custom size
|
||||
interp = cv2.INTER_LINEAR if rw > 96 else cv2.INTER_AREA
|
||||
alpha = cv2.resize(self._alpha_large, (rw, rh), interpolation=interp)
|
||||
self._reverse_alpha_blend(result, alpha, (x, y))
|
||||
return result
|
||||
|
||||
def _reverse_alpha_blend(
|
||||
self,
|
||||
image: NDArray,
|
||||
alpha_map: NDArray,
|
||||
position: tuple[int, int],
|
||||
) -> None:
|
||||
"""Apply reverse alpha blending in-place.
|
||||
|
||||
Formula: original = (watermarked - α × logo) / (1 - α)
|
||||
"""
|
||||
x, y = position
|
||||
ah, aw = alpha_map.shape[:2]
|
||||
ih, iw = image.shape[:2]
|
||||
|
||||
# Clip to bounds
|
||||
x1 = max(0, x)
|
||||
y1 = max(0, y)
|
||||
x2 = min(iw, x + aw)
|
||||
y2 = min(ih, y + ah)
|
||||
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
return
|
||||
|
||||
# Get ROIs
|
||||
ax1, ay1 = x1 - x, y1 - y
|
||||
alpha_roi = alpha_map[ay1 : ay1 + (y2 - y1), ax1 : ax1 + (x2 - x1)]
|
||||
image_roi = image[y1:y2, x1:x2].astype(np.float32)
|
||||
|
||||
alpha_threshold = 0.002
|
||||
max_alpha = 0.99
|
||||
|
||||
# Vectorized reverse alpha blending
|
||||
alpha = alpha_roi.copy()
|
||||
mask = alpha >= alpha_threshold
|
||||
alpha = np.clip(alpha, 0.0, max_alpha)
|
||||
one_minus_alpha = 1.0 - alpha
|
||||
|
||||
# Expand alpha for 3-channel broadcast
|
||||
alpha_3d = alpha[:, :, np.newaxis]
|
||||
one_minus_3d = one_minus_alpha[:, :, np.newaxis]
|
||||
mask_3d = mask[:, :, np.newaxis]
|
||||
|
||||
# original = (watermarked - alpha * logo) / (1 - alpha)
|
||||
restored = (image_roi - alpha_3d * self.logo_value) / one_minus_3d
|
||||
restored = np.clip(restored, 0.0, 255.0)
|
||||
|
||||
# Apply only where alpha is significant
|
||||
image_roi = np.where(mask_3d, restored, image_roi)
|
||||
image[y1:y2, x1:x2] = image_roi.astype(np.uint8)
|
||||
|
||||
# ── Inpainting cleanup ───────────────────────────────────────────
|
||||
|
||||
def inpaint_residual(
|
||||
self,
|
||||
image: NDArray,
|
||||
region: tuple[int, int, int, int],
|
||||
strength: float = 0.85,
|
||||
method: Literal["gaussian", "telea", "ns"] = "ns",
|
||||
inpaint_radius: int = 10,
|
||||
padding: int = 32,
|
||||
) -> NDArray:
|
||||
"""Apply inpaint cleanup on residual artifacts after reverse alpha blend.
|
||||
|
||||
Uses a sparse mask derived from alpha map gradient to repair only
|
||||
the sparkle-edge pixels where interpolation broke the math.
|
||||
|
||||
Args:
|
||||
image: BGR image (will NOT be modified in-place).
|
||||
region: (x, y, w, h) of the watermark region.
|
||||
strength: Blend strength (0.0 = keep original, 1.0 = fully inpainted).
|
||||
method: Inpaint method ("gaussian", "telea", or "ns").
|
||||
inpaint_radius: Radius for cv2.inpaint.
|
||||
padding: Context padding around region in pixels.
|
||||
|
||||
Returns:
|
||||
Cleaned BGR image.
|
||||
"""
|
||||
result = image.copy()
|
||||
x, y, rw, rh = region
|
||||
|
||||
if rw < 4 or rh < 4:
|
||||
return result
|
||||
|
||||
strength = max(0.0, min(1.0, strength))
|
||||
if strength < 0.001:
|
||||
return result
|
||||
|
||||
# Padded region
|
||||
px1 = max(0, x - padding)
|
||||
py1 = max(0, y - padding)
|
||||
px2 = min(image.shape[1], x + rw + padding)
|
||||
py2 = min(image.shape[0], y + rh + padding)
|
||||
|
||||
if (px2 - px1) < 8 or (py2 - py1) < 8:
|
||||
return result
|
||||
|
||||
# Inner rect relative to padded
|
||||
ix1 = x - px1
|
||||
iy1 = y - py1
|
||||
|
||||
# Get alpha map (interpolated if needed)
|
||||
source_alpha = self._alpha_large
|
||||
interp = cv2.INTER_LINEAR if rw > source_alpha.shape[1] else cv2.INTER_AREA
|
||||
alpha_resized = cv2.resize(source_alpha, (rw, rh), interpolation=interp)
|
||||
|
||||
# Compute gradient mask from alpha
|
||||
grad_x = cv2.Sobel(alpha_resized, cv2.CV_32F, 1, 0, ksize=3)
|
||||
grad_y = cv2.Sobel(alpha_resized, cv2.CV_32F, 0, 1, ksize=3)
|
||||
grad_mag = cv2.magnitude(grad_x, grad_y)
|
||||
|
||||
grad_min, grad_max = grad_mag.min(), grad_mag.max()
|
||||
if grad_max <= grad_min:
|
||||
return result
|
||||
|
||||
# Normalize and apply gamma correction
|
||||
grad_norm = (grad_mag - grad_min) / (grad_max - grad_min)
|
||||
grad_weight = np.sqrt(grad_norm)
|
||||
|
||||
# Dilate the mask
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
grad_weight = cv2.dilate(grad_weight, kernel)
|
||||
|
||||
if method == "gaussian":
|
||||
# Soft blend with Gaussian blur
|
||||
padded_roi = result[py1:py2, px1:px2].copy()
|
||||
blurred = cv2.GaussianBlur(padded_roi, (0, 0), sigmaX=2.0)
|
||||
|
||||
# Create weight mask on padded area (only inner region has weights)
|
||||
weight_full = np.zeros((py2 - py1, px2 - px1), dtype=np.float32)
|
||||
weight_full[iy1 : iy1 + rh, ix1 : ix1 + rw] = grad_weight * strength
|
||||
|
||||
weight_3d = weight_full[:, :, np.newaxis]
|
||||
blended = padded_roi.astype(np.float32) * (1 - weight_3d) + blurred.astype(np.float32) * weight_3d
|
||||
result[py1:py2, px1:px2] = blended.astype(np.uint8)
|
||||
else:
|
||||
# OpenCV inpainting (TELEA or NS)
|
||||
inpaint_flag = cv2.INPAINT_TELEA if method == "telea" else cv2.INPAINT_NS
|
||||
|
||||
# Create binary mask from gradient weight
|
||||
binary_mask = (grad_weight * 255).astype(np.uint8)
|
||||
_, binary_mask = cv2.threshold(binary_mask, 30, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Expand mask to padded region
|
||||
mask_full = np.zeros((py2 - py1, px2 - px1), dtype=np.uint8)
|
||||
mask_full[iy1 : iy1 + rh, ix1 : ix1 + rw] = binary_mask
|
||||
|
||||
padded_roi = result[py1:py2, px1:px2].copy()
|
||||
inpainted = cv2.inpaint(padded_roi, mask_full, inpaint_radius, inpaint_flag)
|
||||
|
||||
# Blend with strength
|
||||
weight_full = np.zeros((py2 - py1, px2 - px1), dtype=np.float32)
|
||||
weight_full[iy1 : iy1 + rh, ix1 : ix1 + rw] = grad_weight * strength
|
||||
weight_3d = weight_full[:, :, np.newaxis]
|
||||
|
||||
blended = padded_roi.astype(np.float32) * (1 - weight_3d) + inpainted.astype(np.float32) * weight_3d
|
||||
result[py1:py2, px1:px2] = blended.astype(np.uint8)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,45 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
def apply_analog_humanizer(image: NDArray, grain_intensity: float = 4.0, chromatic_shift: int = 1) -> NDArray:
|
||||
"""
|
||||
Apply Analog Humanizer (film grain and chromatic aberration) to an image.
|
||||
This simulates analog film imperfections to defeat digital AI perfection classifiers.
|
||||
|
||||
Ported from NeuralBleach.
|
||||
|
||||
Args:
|
||||
image: BGR image as numpy array (uint8).
|
||||
grain_intensity: Standard deviation of the Gaussian noise (film grain).
|
||||
chromatic_shift: Number of pixels to shift the red/blue color channels.
|
||||
|
||||
Returns:
|
||||
Humanized BGR image.
|
||||
"""
|
||||
# Ensure image is BGR
|
||||
if len(image.shape) != 3 or image.shape[2] != 3:
|
||||
return image.copy()
|
||||
|
||||
# Split channels (OpenCV uses BGR)
|
||||
# B = 0, G = 1, R = 2
|
||||
b, g, r = cv2.split(image)
|
||||
|
||||
# 1. Chromatic Aberration
|
||||
# Shift R channel left, B channel right
|
||||
if chromatic_shift > 0:
|
||||
r = np.roll(r, -chromatic_shift, axis=1)
|
||||
b = np.roll(b, chromatic_shift, axis=1)
|
||||
|
||||
merged = cv2.merge((b, g, r))
|
||||
|
||||
# 2. Film Grain (Gaussian Noise)
|
||||
if grain_intensity > 0:
|
||||
img_f = merged.astype(np.float32)
|
||||
noise = np.random.normal(0, grain_intensity, img_f.shape).astype(np.float32)
|
||||
humanized = np.clip(img_f + noise, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
humanized = merged
|
||||
|
||||
return humanized
|
||||
@@ -0,0 +1,253 @@
|
||||
"""Invisible watermark removal engine.
|
||||
|
||||
Wraps the vendored noai-watermark code for removing invisible AI watermarks
|
||||
(SynthID, StableSignature, TreeRing) via diffusion-based regeneration.
|
||||
|
||||
This module requires the 'invisible' extra dependencies:
|
||||
uv pip install 'remove-ai-watermarks[invisible]'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
# Suppress verbose deprecation warnings from diffusers/transformers/huggingface_hub
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="diffusers")
|
||||
warnings.filterwarnings("ignore", module="transformers")
|
||||
|
||||
# Suppress HuggingFace internal logging
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["DIFFUSERS_VERBOSITY"] = "error"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
"""Check if invisible watermark removal dependencies are installed."""
|
||||
try:
|
||||
import diffusers # noqa: F401
|
||||
import torch # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class InvisibleEngine:
|
||||
"""Remove invisible AI watermarks using diffusion model regeneration.
|
||||
|
||||
Based on noai-watermark by mertizci:
|
||||
https://github.com/mertizci/noai-watermark
|
||||
|
||||
The approach encodes the image into latent space, injects controlled noise
|
||||
to break watermark patterns, and reconstructs via reverse diffusion.
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_ID = "Lykon/dreamshaper-8"
|
||||
CTRLREGEN_MODEL_ID = "yepengliu/ctrlregen"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str | None = None,
|
||||
device: str | None = None,
|
||||
pipeline: str = "default",
|
||||
hf_token: str | None = None,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the invisible watermark removal engine.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID. None = use default for pipeline.
|
||||
device: Device for inference (auto/cpu/mps/cuda). None = auto.
|
||||
pipeline: Pipeline profile ("default" or "ctrlregen").
|
||||
hf_token: HuggingFace API token.
|
||||
progress_callback: Optional callback for progress messages.
|
||||
"""
|
||||
|
||||
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover
|
||||
|
||||
effective_model = model_id
|
||||
if pipeline == "ctrlregen" and model_id is None:
|
||||
effective_model = self.CTRLREGEN_MODEL_ID
|
||||
elif model_id is None:
|
||||
effective_model = self.DEFAULT_MODEL_ID
|
||||
|
||||
self._remover = WatermarkRemover(
|
||||
model_id=effective_model,
|
||||
device=device,
|
||||
progress_callback=progress_callback,
|
||||
hf_token=hf_token,
|
||||
)
|
||||
self._progress_callback = progress_callback
|
||||
|
||||
def preload(self) -> None:
|
||||
"""Eagerly load the pipeline so download progress is visible."""
|
||||
self._remover.preload()
|
||||
|
||||
def remove_watermark(
|
||||
self,
|
||||
image_path: Path,
|
||||
output_path: Path | None = None,
|
||||
strength: float | None = None,
|
||||
num_inference_steps: int = 100,
|
||||
guidance_scale: float | None = None,
|
||||
seed: int | None = None,
|
||||
humanize: float = 0.0,
|
||||
protect_faces: bool = True,
|
||||
) -> Path:
|
||||
"""Remove invisible watermark from an image.
|
||||
|
||||
Args:
|
||||
image_path: Path to the watermarked image.
|
||||
output_path: Output path (None = overwrite source).
|
||||
strength: Denoising strength (0.0–1.0). Default 0.04.
|
||||
steps: Number of denoising steps.
|
||||
guidance_scale: Classifier-free guidance scale.
|
||||
seed: Random seed for reproducibility.
|
||||
humanize: Intensity of Analog Humanizer film grain (0 = off).
|
||||
protect_faces: Boolean to extract and restore faces intact.
|
||||
|
||||
Returns:
|
||||
Path to the cleaned image.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
max_dimension = 768
|
||||
image = Image.open(image_path)
|
||||
image = ImageOps.exif_transpose(image)
|
||||
orig_size = image.size # (width, height)
|
||||
_tmp_path = None
|
||||
|
||||
if max(image.width, image.height) > max_dimension:
|
||||
ratio = max_dimension / max(image.width, image.height)
|
||||
new_size = (int(image.width * ratio), int(image.height * ratio))
|
||||
if self._progress_callback:
|
||||
self._progress_callback(
|
||||
f"Auto-downscaling {image.width}x{image.height} "
|
||||
f"to {new_size[0]}x{new_size[1]} to prevent Memory Error..."
|
||||
)
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Save to a temp file instead of overwriting the original
|
||||
_tmp_fd, _tmp_str = tempfile.mkstemp(suffix=image_path.suffix)
|
||||
_tmp_path = Path(_tmp_str)
|
||||
image.save(_tmp_path)
|
||||
import os as _os
|
||||
|
||||
_os.close(_tmp_fd)
|
||||
image_path = _tmp_path
|
||||
else:
|
||||
# We must save the transposed image back to a tmp file if it was rotated
|
||||
# otherwise WatermarkRemover will reload it without EXIF rotation!
|
||||
_tmp_fd, _tmp_str = tempfile.mkstemp(suffix=image_path.suffix)
|
||||
_tmp_path = Path(_tmp_str)
|
||||
image.save(_tmp_path)
|
||||
import os as _os
|
||||
|
||||
_os.close(_tmp_fd)
|
||||
image_path = _tmp_path
|
||||
|
||||
try:
|
||||
# Optional: Face protection (Phase 1 - Extraction)
|
||||
original_faces = []
|
||||
if protect_faces:
|
||||
try:
|
||||
import cv2
|
||||
|
||||
from remove_ai_watermarks.face_protector import FaceProtector
|
||||
|
||||
if self._progress_callback:
|
||||
self._progress_callback("Detecting and extracting faces (protect-faces)...")
|
||||
# Convert PIL to CV2 BGR
|
||||
import numpy as np
|
||||
|
||||
cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
protector = FaceProtector(use_yolo=True)
|
||||
original_faces = protector.extract_faces(cv_img)
|
||||
if self._progress_callback:
|
||||
self._progress_callback(f"Extracted {len(original_faces)} face(s) for protection.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract faces: {e}")
|
||||
|
||||
out_path = self._remover.remove_watermark(
|
||||
image_path=image_path,
|
||||
output_path=output_path,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Optional: Face restoration & Humanizer (Phase 2 - Post-processing)
|
||||
if protect_faces or humanize > 0.0:
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
out_cv = cv2.imread(str(out_path), cv2.IMREAD_COLOR)
|
||||
|
||||
if protect_faces and original_faces:
|
||||
if self._progress_callback:
|
||||
self._progress_callback("Restoring protected faces with soft blending...")
|
||||
from remove_ai_watermarks.face_protector import FaceProtector
|
||||
|
||||
protector = FaceProtector(use_yolo=True)
|
||||
out_cv = protector.restore_faces(out_cv, original_faces)
|
||||
|
||||
if humanize > 0.0:
|
||||
if self._progress_callback:
|
||||
self._progress_callback(f"Applying Analog Humanizer (grain: {humanize})...")
|
||||
from remove_ai_watermarks.humanizer import apply_analog_humanizer
|
||||
|
||||
out_cv = apply_analog_humanizer(out_cv, grain_intensity=humanize, chromatic_shift=1)
|
||||
|
||||
# Restore original resolution
|
||||
if (out_cv.shape[1], out_cv.shape[0]) != orig_size:
|
||||
if self._progress_callback:
|
||||
self._progress_callback(
|
||||
f"Upscaling result back to original resolution {orig_size[0]}x{orig_size[1]}..."
|
||||
)
|
||||
# Using INTER_LANCZOS4 for high-quality upscaling back to original
|
||||
out_cv = cv2.resize(out_cv, orig_size, interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
cv2.imwrite(str(out_path), out_cv)
|
||||
|
||||
else:
|
||||
# Even if no protect_faces or humanize, we must restore original size if needed
|
||||
import cv2
|
||||
|
||||
out_cv = cv2.imread(str(out_path), cv2.IMREAD_COLOR)
|
||||
if out_cv is not None and (out_cv.shape[1], out_cv.shape[0]) != orig_size:
|
||||
if self._progress_callback:
|
||||
self._progress_callback(
|
||||
f"Upscaling result back to original resolution {orig_size[0]}x{orig_size[1]}..."
|
||||
)
|
||||
out_cv = cv2.resize(out_cv, orig_size, interpolation=cv2.INTER_LANCZOS4)
|
||||
cv2.imwrite(str(out_path), out_cv)
|
||||
|
||||
return out_path
|
||||
finally:
|
||||
if _tmp_path is not None and _tmp_path.exists():
|
||||
_tmp_path.unlink()
|
||||
|
||||
def remove_watermark_batch(
|
||||
self,
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
strength: float = 0.04,
|
||||
steps: int = 50,
|
||||
) -> list[Path]:
|
||||
"""Remove invisible watermarks from all images in a directory."""
|
||||
return self._remover.remove_watermark_batch(
|
||||
input_dir=input_dir,
|
||||
output_dir=output_dir,
|
||||
strength=strength,
|
||||
num_inference_steps=steps,
|
||||
)
|
||||
@@ -0,0 +1,214 @@
|
||||
"""AI metadata detection and removal.
|
||||
|
||||
Wraps the noai-watermark metadata handling for stripping AI-generation
|
||||
metadata (EXIF, PNG text chunks, C2PA provenance) from images.
|
||||
|
||||
For metadata-only operations, the heavy ML dependencies are NOT required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Known AI metadata keys ──────────────────────────────────────────
|
||||
|
||||
AI_METADATA_KEYS: frozenset[str] = frozenset(
|
||||
k.lower()
|
||||
for k in [
|
||||
"parameters",
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"workflow",
|
||||
"comfyui",
|
||||
"sd-metadata",
|
||||
"invokeai_metadata",
|
||||
"generation_data",
|
||||
"ai_metadata",
|
||||
"dream",
|
||||
"sd:prompt",
|
||||
"sd:negative_prompt",
|
||||
"sd:seed",
|
||||
"sd:steps",
|
||||
"sd:sampler",
|
||||
"sd:cfg_scale",
|
||||
"sd:model_hash",
|
||||
"c2pa",
|
||||
"c2pa_chunk",
|
||||
"Software",
|
||||
]
|
||||
)
|
||||
|
||||
AI_KEYWORDS: tuple[str, ...] = (
|
||||
"stable_diffusion",
|
||||
"comfyui",
|
||||
"automatic1111",
|
||||
"invokeai",
|
||||
"midjourney",
|
||||
"dall-e",
|
||||
"dalle",
|
||||
"imagen",
|
||||
"synthid",
|
||||
"google_ai",
|
||||
"openai",
|
||||
"c2pa",
|
||||
)
|
||||
|
||||
STANDARD_METADATA_KEYS: frozenset[str] = frozenset(
|
||||
[
|
||||
"Author",
|
||||
"Title",
|
||||
"Description",
|
||||
"Copyright",
|
||||
"Creation Time",
|
||||
"Software",
|
||||
"Comment",
|
||||
"Disclaimer",
|
||||
"Source",
|
||||
"Warning",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _is_ai_key(key: str) -> bool:
|
||||
"""Check if a metadata key is AI-related."""
|
||||
key_lower = key.lower()
|
||||
if key_lower in AI_METADATA_KEYS:
|
||||
return True
|
||||
return any(kw in key_lower for kw in AI_KEYWORDS)
|
||||
|
||||
|
||||
def has_ai_metadata(image_path: Path) -> bool:
|
||||
"""Check if an image contains AI-generation metadata.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image.
|
||||
|
||||
Returns:
|
||||
True if AI metadata is detected.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
for key in img.info:
|
||||
if _is_ai_key(key):
|
||||
return True
|
||||
|
||||
# Check C2PA
|
||||
try:
|
||||
from c2pa import has_c2pa_metadata
|
||||
|
||||
if has_c2pa_metadata(image_path):
|
||||
return True
|
||||
except ImportError:
|
||||
# Try simple binary scan
|
||||
data = image_path.read_bytes()
|
||||
if b"c2pa" in data.lower() or b"C2PA" in data:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_ai_metadata(image_path: Path) -> dict[str, str]:
|
||||
"""Extract AI-related metadata from an image.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image.
|
||||
|
||||
Returns:
|
||||
Dictionary of AI metadata key-value pairs.
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
result: dict[str, str] = {}
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
for key, value in img.info.items():
|
||||
if _is_ai_key(key):
|
||||
if isinstance(value, bytes):
|
||||
result[key] = f"<binary {len(value)} bytes>"
|
||||
elif isinstance(value, str) and len(value) > 200:
|
||||
result[key] = value[:200] + "…"
|
||||
else:
|
||||
result[key] = str(value)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_ai_metadata(
|
||||
source_path: Path,
|
||||
output_path: Path | None = None,
|
||||
keep_standard: bool = True,
|
||||
) -> Path:
|
||||
"""Remove AI-generation metadata from an image.
|
||||
|
||||
Strips EXIF AI tags, PNG text chunks, and C2PA provenance manifests
|
||||
while optionally preserving standard metadata (Author, Title, etc.).
|
||||
|
||||
Args:
|
||||
source_path: Path to the source image.
|
||||
output_path: Output path (None = overwrite source).
|
||||
keep_standard: If True, preserve standard metadata fields.
|
||||
|
||||
Returns:
|
||||
Path to the cleaned image.
|
||||
"""
|
||||
import piexif
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
if output_path is None:
|
||||
output_path = source_path
|
||||
|
||||
# Read image and filter metadata
|
||||
with Image.open(source_path) as img:
|
||||
img = img.copy()
|
||||
fmt = output_path.suffix.lower()
|
||||
|
||||
save_kwargs: dict = {}
|
||||
if fmt in (".jpg", ".jpeg"):
|
||||
save_kwargs["format"] = "JPEG"
|
||||
if img.mode in ("RGBA", "P"):
|
||||
img = img.convert("RGB")
|
||||
else:
|
||||
save_kwargs["format"] = "PNG"
|
||||
|
||||
# Collect non-AI metadata
|
||||
kept_meta: dict[str, str] = {}
|
||||
exif_data = None
|
||||
|
||||
for key, value in img.info.items():
|
||||
if _is_ai_key(key):
|
||||
continue
|
||||
if key == "exif":
|
||||
try:
|
||||
exif_data = piexif.load(value)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
if key in ("dpi", "gamma"):
|
||||
save_kwargs[key] = value
|
||||
continue
|
||||
if keep_standard and key in STANDARD_METADATA_KEYS:
|
||||
kept_meta[key] = str(value) if not isinstance(value, str) else value
|
||||
|
||||
# Apply cleaned metadata
|
||||
if save_kwargs["format"] == "PNG" and kept_meta:
|
||||
pnginfo = PngInfo()
|
||||
for k, v in kept_meta.items():
|
||||
pnginfo.add_text(k, v)
|
||||
save_kwargs["pnginfo"] = pnginfo
|
||||
|
||||
if exif_data and save_kwargs["format"] == "JPEG":
|
||||
try:
|
||||
save_kwargs["exif"] = piexif.dump(exif_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(output_path, **save_kwargs)
|
||||
|
||||
logger.info("Stripped AI metadata → %s", output_path)
|
||||
return output_path
|
||||
@@ -0,0 +1,9 @@
|
||||
"""Vendored noai-watermark code for invisible watermark removal.
|
||||
|
||||
Original: https://github.com/mertizci/noai-watermark (MIT License)
|
||||
"""
|
||||
|
||||
from remove_ai_watermarks.noai.cleaner import remove_ai_metadata
|
||||
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover, remove_watermark
|
||||
|
||||
__all__ = ["WatermarkRemover", "remove_watermark", "remove_ai_metadata"]
|
||||
@@ -0,0 +1,297 @@
|
||||
"""C2PA (Coalition for Content Provenance and Authenticity) metadata handling.
|
||||
|
||||
C2PA metadata is embedded in PNG files as a JUMBF container chunk
|
||||
(``caBX``). This module can detect, extract, and re-inject those
|
||||
chunks. Supported issuers:
|
||||
|
||||
- Google Imagen
|
||||
- Adobe Firefly
|
||||
- Microsoft Designer
|
||||
- OpenAI (ChatGPT, GPT-4o, Sora, DALL-E)
|
||||
- Truepic (signing authority)
|
||||
|
||||
The parser uses byte-level scanning — it does not validate JUMBF/CBOR
|
||||
structure but reliably identifies known signatures, issuers, tools,
|
||||
and actions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from remove_ai_watermarks.noai.constants import (
|
||||
C2PA_ACTIONS,
|
||||
C2PA_AI_TOOLS,
|
||||
C2PA_CHUNK_TYPE,
|
||||
C2PA_ISSUERS,
|
||||
C2PA_SIGNATURES,
|
||||
PNG_SIGNATURE,
|
||||
)
|
||||
|
||||
|
||||
def has_c2pa_metadata(image_path: Path) -> bool:
|
||||
"""
|
||||
Check if an image contains C2PA metadata.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
True if C2PA metadata is detected, False otherwise.
|
||||
"""
|
||||
image_path = Path(image_path)
|
||||
|
||||
if image_path.suffix.lower() != ".png":
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
signature = f.read(8)
|
||||
if signature != PNG_SIGNATURE:
|
||||
return False
|
||||
|
||||
while True:
|
||||
chunk_header = f.read(8)
|
||||
if len(chunk_header) < 8:
|
||||
break
|
||||
|
||||
length = struct.unpack(">I", chunk_header[:4])[0]
|
||||
chunk_type = chunk_header[4:8]
|
||||
|
||||
if chunk_type == C2PA_CHUNK_TYPE:
|
||||
chunk_data = f.read(length)
|
||||
# Check for any C2PA signature
|
||||
for sig in C2PA_SIGNATURES:
|
||||
if sig in chunk_data:
|
||||
return True
|
||||
# Also check if chunk_data itself contains C2PA-like patterns
|
||||
if b"jumb" in chunk_data.lower() or b"c2pa" in chunk_data.lower():
|
||||
return True
|
||||
f.read(4)
|
||||
else:
|
||||
f.read(length + 4)
|
||||
|
||||
if chunk_type == b"IEND":
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def extract_c2pa_info(image_path: Path) -> dict[str, Any]:
|
||||
"""
|
||||
Extract basic C2PA metadata information from an image.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
Dictionary containing C2PA metadata info.
|
||||
"""
|
||||
c2pa_info: dict[str, Any] = {}
|
||||
|
||||
if not has_c2pa_metadata(image_path):
|
||||
return c2pa_info
|
||||
|
||||
c2pa_info["has_c2pa"] = True
|
||||
c2pa_info["type"] = "C2PA (Coalition for Content Provenance and Authenticity)"
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
signature = f.read(8)
|
||||
if signature != PNG_SIGNATURE:
|
||||
return c2pa_info
|
||||
|
||||
while True:
|
||||
chunk_header = f.read(8)
|
||||
if len(chunk_header) < 8:
|
||||
break
|
||||
|
||||
length = struct.unpack(">I", chunk_header[:4])[0]
|
||||
chunk_type = chunk_header[4:8]
|
||||
|
||||
if chunk_type == C2PA_CHUNK_TYPE:
|
||||
chunk_data = f.read(length)
|
||||
_parse_c2pa_chunk(chunk_data, c2pa_info)
|
||||
f.read(4)
|
||||
else:
|
||||
f.read(length + 4)
|
||||
|
||||
if chunk_type == b"IEND":
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return c2pa_info
|
||||
|
||||
|
||||
def _parse_c2pa_chunk(chunk_data: bytes, c2pa_info: dict[str, Any]) -> None:
|
||||
"""Parse C2PA chunk data and populate info dictionary."""
|
||||
# Find issuers
|
||||
issuers = []
|
||||
for sig, name in C2PA_ISSUERS.items():
|
||||
if sig in chunk_data:
|
||||
issuers.append(name)
|
||||
if issuers:
|
||||
c2pa_info["issuer"] = ", ".join(set(issuers))
|
||||
|
||||
# Find AI tools
|
||||
ai_tools = []
|
||||
for sig, name in C2PA_AI_TOOLS.items():
|
||||
if sig in chunk_data:
|
||||
ai_tools.append(name)
|
||||
if ai_tools:
|
||||
c2pa_info["ai_tool"] = ", ".join(set(ai_tools))
|
||||
|
||||
# Extract software agent (multiple patterns)
|
||||
patterns = [
|
||||
rb"softwareAgent.*?dname([^\x00]+?)(?:q|l|m|n)",
|
||||
rb"software_agent[^\x00]*?([A-Za-z0-9_\-\.]+)",
|
||||
rb"Software[^\x00]*?([A-Za-z0-9_\-\. ]+)",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, chunk_data, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
agent = match.group(1).decode("utf-8", errors="ignore").strip()
|
||||
if agent and len(agent) < 100:
|
||||
c2pa_info["software_agent"] = agent
|
||||
break
|
||||
|
||||
# Extract claim generator (multiple patterns)
|
||||
claim_patterns = [
|
||||
rb"claim_generator[^\x00]*?([A-Za-z0-9_\-\.\/\:]+)",
|
||||
rb"claimGenerator[^\x00]*?([A-Za-z0-9_\-\.\/\:]+)",
|
||||
rb"dname([^\x00]{3,50})(?:q|l|m|n|i)",
|
||||
]
|
||||
for pattern in claim_patterns:
|
||||
match = re.search(pattern, chunk_data, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
gen_name = match.group(1).decode("utf-8", errors="ignore").strip()
|
||||
# Filter out common false positives
|
||||
if gen_name and len(gen_name) < 100 and not gen_name.startswith(("\\x", "\\\\x")):
|
||||
c2pa_info["claim_generator"] = gen_name
|
||||
break
|
||||
|
||||
# Find actions
|
||||
actions = []
|
||||
for sig, name in C2PA_ACTIONS.items():
|
||||
if sig in chunk_data:
|
||||
actions.append(name)
|
||||
if actions:
|
||||
c2pa_info["actions"] = ", ".join(actions)
|
||||
|
||||
# Find timestamps
|
||||
timestamp_matches = re.findall(rb"(\d{14}Z)", chunk_data)
|
||||
if timestamp_matches:
|
||||
c2pa_info["timestamp"] = timestamp_matches[0].decode("utf-8")
|
||||
if len(timestamp_matches) > 1:
|
||||
c2pa_info["timestamps"] = [t.decode("utf-8") for t in timestamp_matches[:3]]
|
||||
|
||||
# Find digital source type
|
||||
if b"trainedAlgorithmicMedia" in chunk_data:
|
||||
c2pa_info["source_type"] = "trainedAlgorithmicMedia (AI-generated)"
|
||||
elif b"algorithmicMedia" in chunk_data:
|
||||
c2pa_info["source_type"] = "algorithmicMedia"
|
||||
elif b"compositeWithTrainedAlgorithmicMedia" in chunk_data:
|
||||
c2pa_info["source_type"] = "compositeWithTrainedAlgorithmicMedia (AI-enhanced)"
|
||||
|
||||
|
||||
def extract_c2pa_chunk(image_path: Path) -> bytes | None:
|
||||
"""
|
||||
Extract the raw C2PA JUMBF chunk from a PNG file.
|
||||
|
||||
Args:
|
||||
image_path: Path to the source PNG file.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the C2PA chunk or None.
|
||||
"""
|
||||
if image_path.suffix.lower() != ".png":
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
signature = f.read(8)
|
||||
if signature != PNG_SIGNATURE:
|
||||
return None
|
||||
|
||||
while True:
|
||||
chunk_header = f.read(8)
|
||||
if len(chunk_header) < 8:
|
||||
break
|
||||
|
||||
length = struct.unpack(">I", chunk_header[:4])[0]
|
||||
chunk_type = chunk_header[4:8]
|
||||
|
||||
if chunk_type == C2PA_CHUNK_TYPE:
|
||||
chunk_data = f.read(length)
|
||||
crc = f.read(4)
|
||||
|
||||
# Check for any C2PA signature
|
||||
for sig in C2PA_SIGNATURES:
|
||||
if sig in chunk_data:
|
||||
return chunk_header + chunk_data + crc
|
||||
|
||||
# Also check lowercase variants
|
||||
if b"jumb" in chunk_data.lower() or b"c2pa" in chunk_data.lower():
|
||||
return chunk_header + chunk_data + crc
|
||||
else:
|
||||
f.read(length + 4)
|
||||
|
||||
if chunk_type == b"IEND":
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def inject_c2pa_chunk(target_path: Path, output_path: Path, c2pa_chunk: bytes) -> None:
|
||||
"""
|
||||
Inject a C2PA JUMBF chunk into a PNG file.
|
||||
|
||||
Args:
|
||||
target_path: Path to the target PNG file.
|
||||
output_path: Path where the output file will be saved.
|
||||
c2pa_chunk: Raw bytes of the C2PA chunk to inject.
|
||||
|
||||
Raises:
|
||||
ValueError: If not PNG files.
|
||||
"""
|
||||
if target_path.suffix.lower() != ".png" or output_path.suffix.lower() != ".png":
|
||||
raise ValueError("C2PA chunk injection is only supported for PNG files")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(target_path, "rb") as f_in:
|
||||
with open(output_path, "wb") as f_out:
|
||||
f_out.write(f_in.read(8))
|
||||
|
||||
c2pa_injected = False
|
||||
while True:
|
||||
chunk_header = f_in.read(8)
|
||||
if len(chunk_header) < 8:
|
||||
break
|
||||
|
||||
length = struct.unpack(">I", chunk_header[:4])[0]
|
||||
chunk_type = chunk_header[4:8]
|
||||
chunk_data = f_in.read(length)
|
||||
crc = f_in.read(4)
|
||||
|
||||
if chunk_type == b"IDAT" and not c2pa_injected:
|
||||
f_out.write(c2pa_chunk)
|
||||
c2pa_injected = True
|
||||
|
||||
if chunk_type == C2PA_CHUNK_TYPE:
|
||||
continue
|
||||
|
||||
f_out.write(chunk_header)
|
||||
f_out.write(chunk_data)
|
||||
f_out.write(crc)
|
||||
|
||||
if chunk_type == b"IEND":
|
||||
break
|
||||
@@ -0,0 +1,181 @@
|
||||
"""AI metadata cleaning and removal.
|
||||
|
||||
Provides functions to identify and strip AI-generation metadata from
|
||||
PNG and JPEG images while optionally preserving standard fields like
|
||||
Author, Title, and Copyright.
|
||||
|
||||
The removal pipeline:
|
||||
1. Opens the image and iterates over all metadata keys.
|
||||
2. Classifies each key as AI-related (using ``constants.AI_METADATA_KEYS``
|
||||
and ``constants.AI_KEYWORDS``) or standard.
|
||||
3. Rebuilds the image with only the desired metadata retained.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import piexif
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from remove_ai_watermarks.noai.constants import AI_KEYWORDS, AI_METADATA_KEYS, PNG_METADATA_KEYS
|
||||
from remove_ai_watermarks.noai.utils import get_image_format
|
||||
|
||||
# Pre-compute a lowercase set for O(1) key lookups.
|
||||
_AI_KEYS_LOWER: frozenset[str] = frozenset(k.lower() for k in AI_METADATA_KEYS)
|
||||
|
||||
|
||||
def remove_ai_metadata(
|
||||
source_path: Path,
|
||||
output_path: Path | None = None,
|
||||
keep_standard: bool = True,
|
||||
) -> Path:
|
||||
"""
|
||||
Remove all AI-generated metadata from an image.
|
||||
|
||||
Removes:
|
||||
- AI parameters (Stable Diffusion, ComfyUI, etc.)
|
||||
- C2PA metadata (Google Imagen, OpenAI, etc.)
|
||||
- Any metadata with AI-related keywords
|
||||
|
||||
Args:
|
||||
source_path: Path to the source image file.
|
||||
output_path: Optional output path. If not provided, modifies source in place.
|
||||
keep_standard: If True, keeps standard metadata (Author, Title, etc.).
|
||||
|
||||
Returns:
|
||||
Path to the output file with AI metadata removed.
|
||||
"""
|
||||
if output_path is None:
|
||||
output_path = source_path
|
||||
|
||||
cleaned_metadata = _extract_non_ai_metadata(source_path, keep_standard)
|
||||
|
||||
with Image.open(source_path) as img:
|
||||
img = img.copy()
|
||||
|
||||
output_format = get_image_format(output_path)
|
||||
save_kwargs: dict[str, Any] = {"format": output_format}
|
||||
|
||||
# Handle EXIF data (keep it, just don't include AI-related fields)
|
||||
if "exif" in cleaned_metadata:
|
||||
save_kwargs["exif"] = cleaned_metadata["exif"]
|
||||
|
||||
if output_format == "PNG":
|
||||
save_kwargs = _prepare_clean_png_kwargs(save_kwargs, cleaned_metadata)
|
||||
elif output_format == "JPEG":
|
||||
save_kwargs = _prepare_clean_jpeg_kwargs(save_kwargs, cleaned_metadata)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_format == "JPEG" and img.mode in ("RGBA", "P"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
img.save(output_path, **save_kwargs)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def _extract_non_ai_metadata(source_path: Path, keep_standard: bool) -> dict[str, Any]:
|
||||
"""Extract metadata excluding AI-related fields."""
|
||||
cleaned_metadata: dict[str, Any] = {}
|
||||
|
||||
with Image.open(source_path) as img:
|
||||
# Handle EXIF data
|
||||
if "exif" in img.info:
|
||||
try:
|
||||
exif_dict = piexif.load(img.info["exif"])
|
||||
cleaned_metadata["exif"] = exif_dict
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract non-AI metadata
|
||||
for key, value in img.info.items():
|
||||
if _is_ai_metadata_key(key):
|
||||
continue
|
||||
|
||||
if keep_standard and key in PNG_METADATA_KEYS:
|
||||
cleaned_metadata[key] = value
|
||||
elif not keep_standard:
|
||||
# Remove standard metadata while still preserving non-standard fields.
|
||||
if key not in ["exif", "dpi", "gamma"] and key not in PNG_METADATA_KEYS:
|
||||
cleaned_metadata[key] = value
|
||||
|
||||
# Keep DPI and gamma
|
||||
if "dpi" in img.info:
|
||||
cleaned_metadata["dpi"] = img.info["dpi"]
|
||||
if "gamma" in img.info:
|
||||
cleaned_metadata["gamma"] = img.info["gamma"]
|
||||
|
||||
return cleaned_metadata
|
||||
|
||||
|
||||
def _is_ai_metadata_key(key: str) -> bool:
|
||||
"""Return True if *key* is an AI-generation metadata field.
|
||||
|
||||
Detection uses two layers:
|
||||
1. Exact match against the canonical ``AI_METADATA_KEYS`` list.
|
||||
2. Substring match against ``AI_KEYWORDS`` (covers partial hits
|
||||
like ``"stable_diffusion_model"``).
|
||||
"""
|
||||
key_lower = key.lower()
|
||||
if key_lower in _AI_KEYS_LOWER:
|
||||
return True
|
||||
return any(kw in key_lower for kw in AI_KEYWORDS)
|
||||
|
||||
|
||||
def _prepare_clean_png_kwargs(save_kwargs: dict[str, Any], metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare save kwargs for clean PNG."""
|
||||
pnginfo = {}
|
||||
exclude_keys = ["exif", "exif_raw", "dpi", "gamma"]
|
||||
|
||||
for key, value in metadata.items():
|
||||
if key not in exclude_keys:
|
||||
pnginfo[key] = value
|
||||
|
||||
if pnginfo:
|
||||
pnginfo_obj = PngInfo()
|
||||
for key, value in pnginfo.items():
|
||||
if isinstance(value, str):
|
||||
pnginfo_obj.add_text(key, value)
|
||||
elif isinstance(value, bytes):
|
||||
pnginfo_obj.add_text(key, value.decode("utf-8", errors="replace"))
|
||||
save_kwargs["pnginfo"] = pnginfo_obj
|
||||
|
||||
if "dpi" in metadata:
|
||||
save_kwargs["dpi"] = metadata["dpi"]
|
||||
|
||||
return save_kwargs
|
||||
|
||||
|
||||
def _prepare_clean_jpeg_kwargs(save_kwargs: dict[str, Any], metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Prepare save kwargs for clean JPEG."""
|
||||
exif_dict = metadata.get("exif", {"0th": {}, "Exif": {}, "1st": {}, "GPS": {}, "Interop": {}})
|
||||
|
||||
try:
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if "dpi" in metadata:
|
||||
save_kwargs["dpi"] = metadata["dpi"]
|
||||
|
||||
return save_kwargs
|
||||
|
||||
|
||||
def has_ai_content(image_path: Path) -> bool:
|
||||
"""
|
||||
Check if an image has any AI-generated content or metadata.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
True if the image contains AI metadata.
|
||||
"""
|
||||
from remove_ai_watermarks.noai.extractor import has_ai_metadata
|
||||
|
||||
return has_ai_metadata(image_path)
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Shared constants for AI metadata detection, C2PA parsing, and format support.
|
||||
|
||||
All modules reference these constants rather than hard-coding values,
|
||||
so adding a new AI tool or metadata key requires updating only this file.
|
||||
"""
|
||||
|
||||
# Supported image formats
|
||||
SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp"}
|
||||
|
||||
# AI-generated image metadata keys (Stable Diffusion, ComfyUI, Midjourney, etc.)
|
||||
AI_METADATA_KEYS = [
|
||||
"parameters", # Stable Diffusion WebUI (AUTOMATIC1111, Vladmandic)
|
||||
"postprocessing", # SD WebUI post-processing info
|
||||
"extras", # SD WebUI extras
|
||||
"workflow", # ComfyUI workflow JSON
|
||||
"prompt", # Some AI tools
|
||||
"Dream", # DreamStudio
|
||||
"SD:mode", # Stability AI
|
||||
"StableDiffusionVersion", # SD version info
|
||||
"generation_time", # Generation time info
|
||||
"Model", # Model name
|
||||
"Model hash", # Model hash
|
||||
"Seed", # Seed value
|
||||
]
|
||||
|
||||
# Standard PNG metadata keys
|
||||
PNG_METADATA_KEYS = [
|
||||
"Author",
|
||||
"Title",
|
||||
"Description",
|
||||
"Copyright",
|
||||
"Creation Time",
|
||||
"Software",
|
||||
"Disclaimer",
|
||||
"Warning",
|
||||
"Source",
|
||||
"Comment",
|
||||
]
|
||||
|
||||
# AI-related keywords for detection
|
||||
AI_KEYWORDS = [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"sampler",
|
||||
"cfg_scale",
|
||||
"lora",
|
||||
"diffusion",
|
||||
"comfy",
|
||||
"midjourney",
|
||||
"dall-e",
|
||||
"dalle",
|
||||
"imagen",
|
||||
"firefly",
|
||||
"c2pa",
|
||||
"chatgpt",
|
||||
"gpt-4",
|
||||
"sora",
|
||||
"openai",
|
||||
"truepic",
|
||||
"stable_diffusion",
|
||||
"invokeai",
|
||||
]
|
||||
|
||||
# C2PA (Coalition for Content Provenance and Authenticity) constants
|
||||
# Used by Google Imagen, Adobe Firefly, Microsoft Designer, OpenAI, etc.
|
||||
C2PA_CHUNK_TYPE = b"caBX" # JUMBF container chunk type for C2PA
|
||||
C2PA_SIGNATURES = [
|
||||
b"c2pa",
|
||||
b"C2PA",
|
||||
b"jumb",
|
||||
b"jumd",
|
||||
b"JUMBF",
|
||||
b"jumbf",
|
||||
b"cbor",
|
||||
b"contentcreds",
|
||||
b"digid",
|
||||
b"assertions",
|
||||
b"manifest",
|
||||
]
|
||||
|
||||
# C2PA known issuers
|
||||
C2PA_ISSUERS = {
|
||||
b"Google": "Google LLC",
|
||||
b"Adobe": "Adobe",
|
||||
b"Microsoft": "Microsoft",
|
||||
b"OpenAI": "OpenAI",
|
||||
b"Truepic": "Truepic",
|
||||
}
|
||||
|
||||
# C2PA known AI tools
|
||||
C2PA_AI_TOOLS = {
|
||||
b"GPT-4o": "GPT-4o",
|
||||
b"ChatGPT": "ChatGPT",
|
||||
b"Sora": "Sora",
|
||||
b"DALL-E": "DALL-E",
|
||||
b"DALL": "DALL-E",
|
||||
b"Imagen": "Imagen",
|
||||
b"Firefly": "Firefly",
|
||||
}
|
||||
|
||||
# C2PA action types
|
||||
C2PA_ACTIONS = {
|
||||
b"c2pa.created": "created",
|
||||
b"c2pa.converted": "converted",
|
||||
b"c2pa.edited": "edited",
|
||||
b"c2pa.filtered": "filtered",
|
||||
b"c2pa.cropped": "cropped",
|
||||
b"c2pa.resized": "resized",
|
||||
}
|
||||
|
||||
# PNG signature
|
||||
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
||||
@@ -0,0 +1,18 @@
|
||||
"""CtrlRegen watermark removal via controllable regeneration.
|
||||
|
||||
Implements the pipeline from "Image Watermarks Are Removable Using
|
||||
Controllable Regeneration from Clean Noise" (ICLR 2025) by Liu et al.
|
||||
|
||||
This sub-package uses a ControlNet for spatial guidance (canny edges)
|
||||
and a DINOv2-based IP Adapter for semantic guidance to regenerate
|
||||
watermarked images from partially noised latents.
|
||||
|
||||
Attribution:
|
||||
Based on https://github.com/yepengliu/CtrlRegen (Apache-2.0).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from remove_ai_watermarks.noai.ctrlregen.engine import CtrlRegenEngine, is_ctrlregen_available
|
||||
|
||||
__all__ = ["CtrlRegenEngine", "is_ctrlregen_available"]
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Color matching post-processing for CtrlRegen output.
|
||||
|
||||
After diffusion-based regeneration, the output image may have slight
|
||||
color shifts. This module uses histogram-based color transfer to
|
||||
align the regenerated image's color distribution back to the original.
|
||||
|
||||
Attribution:
|
||||
Adapted from https://github.com/yepengliu/CtrlRegen (Apache-2.0).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
from color_matcher import ColorMatcher
|
||||
from color_matcher.normalizer import Normalizer
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def color_match(reference: Image.Image, source: Image.Image) -> Image.Image:
|
||||
"""Transfer the color distribution of *reference* onto *source*.
|
||||
|
||||
Uses a two-pass histogram matching approach (``hm-mkl-hm``) that
|
||||
preserves fine-grained color relationships while correcting global
|
||||
shifts introduced by the regeneration pipeline.
|
||||
|
||||
Args:
|
||||
reference: The original (watermarked) image whose colors should
|
||||
be preserved.
|
||||
source: The regenerated image whose colors will be adjusted.
|
||||
|
||||
Returns:
|
||||
A new PIL Image with the structure of *source* but the color
|
||||
palette of *reference*.
|
||||
"""
|
||||
cm = ColorMatcher()
|
||||
ref_np = Normalizer(np.asarray(reference)).type_norm()
|
||||
src_np = Normalizer(np.asarray(source)).type_norm()
|
||||
result = cm.transfer(src=src_np, ref=ref_np, method="hm-mkl-hm")
|
||||
result = Normalizer(result).uint8_norm()
|
||||
return Image.fromarray(result)
|
||||
@@ -0,0 +1,365 @@
|
||||
"""CtrlRegen engine — orchestrates the full watermark removal pipeline.
|
||||
|
||||
Loads the base SD 1.5 model with a ControlNet (spatial control from
|
||||
canny edges) and a DINOv2-based IP Adapter (semantic control), then
|
||||
runs controllable regeneration with optional color matching.
|
||||
|
||||
Attribution:
|
||||
Based on https://github.com/yepengliu/CtrlRegen (Apache-2.0).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from remove_ai_watermarks.noai.progress import make_pipeline_progress
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability checks — these imports are optional.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HAS_CONTROLNET_AUX = False
|
||||
_HAS_COLOR_MATCHER = False
|
||||
_HAS_DIFFUSERS = False
|
||||
|
||||
try:
|
||||
from ctrlregen.pipeline import CustomCtrlRegenPipeline
|
||||
from diffusers import AutoencoderKL, ControlNetModel, UniPCMultistepScheduler
|
||||
|
||||
_HAS_DIFFUSERS = True
|
||||
except ImportError:
|
||||
AutoencoderKL = None # type: ignore[assignment,misc]
|
||||
ControlNetModel = None # type: ignore[assignment,misc]
|
||||
UniPCMultistepScheduler = None # type: ignore[assignment,misc]
|
||||
CustomCtrlRegenPipeline = None # type: ignore[assignment,misc]
|
||||
|
||||
try:
|
||||
from controlnet_aux import CannyDetector
|
||||
|
||||
_HAS_CONTROLNET_AUX = True
|
||||
except ImportError:
|
||||
CannyDetector = None # type: ignore[assignment,misc]
|
||||
|
||||
try:
|
||||
from ctrlregen.color import color_match
|
||||
|
||||
_HAS_COLOR_MATCHER = True
|
||||
except ImportError:
|
||||
color_match = None # type: ignore[assignment]
|
||||
|
||||
CTRLREGEN_HF_REPO = "yepengliu/ctrlregen"
|
||||
SPATIAL_SUBFOLDER = "spatialnet_ckp/spatial_control_ckp_14000"
|
||||
SEMANTIC_SUBFOLDER = "semanticnet_ckp/models"
|
||||
SEMANTIC_WEIGHT_NAME = "semantic_control_ckp_435000.bin"
|
||||
|
||||
DEFAULT_BASE_MODEL = "SG161222/Realistic_Vision_V4.0_noVAE"
|
||||
CUSTOM_VAE_ID = "stabilityai/sd-vae-ft-mse"
|
||||
|
||||
PROCESS_SIZE = 512
|
||||
DEFAULT_GUIDANCE_SCALE = 2.0
|
||||
QUALITY_PROMPT = "best quality, high quality"
|
||||
NEGATIVE_PROMPT = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
CANNY_LOW_THRESHOLD = 100
|
||||
CANNY_HIGH_THRESHOLD = 150
|
||||
|
||||
TILE_SIZE = 512
|
||||
TILE_OVERLAP = 192
|
||||
|
||||
|
||||
def is_ctrlregen_available() -> bool:
|
||||
"""Return True when all CtrlRegen-specific dependencies are installed."""
|
||||
return _HAS_DIFFUSERS and _HAS_CONTROLNET_AUX and _HAS_COLOR_MATCHER
|
||||
|
||||
|
||||
class CtrlRegenEngine:
|
||||
"""End-to-end CtrlRegen watermark removal engine.
|
||||
|
||||
Handles model loading, canny edge extraction, controlled denoising,
|
||||
and color-matched post-processing in a single ``run()`` call.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model_id: str | None = None,
|
||||
device: str = "cpu",
|
||||
torch_dtype: torch.dtype | None = None,
|
||||
hf_token: str | None = None,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
if not is_ctrlregen_available():
|
||||
missing: list[str] = []
|
||||
if not _HAS_DIFFUSERS:
|
||||
missing.extend(["diffusers", "transformers", "accelerate"])
|
||||
if not _HAS_CONTROLNET_AUX:
|
||||
missing.append("controlnet-aux")
|
||||
if not _HAS_COLOR_MATCHER:
|
||||
missing.append("color-matcher")
|
||||
logger.info("Auto-installing missing dependencies: %s", missing)
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", *missing],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
raise ImportError(
|
||||
"Failed to auto-install missing dependencies: "
|
||||
+ ", ".join(missing)
|
||||
+ ". Try manually: pip install --force-reinstall noai-watermark"
|
||||
)
|
||||
|
||||
self.base_model_id = base_model_id or DEFAULT_BASE_MODEL
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype or (torch.float32 if device in ("cpu", "mps") else torch.float16)
|
||||
self.hf_token: str | None = hf_token or os.environ.get("HF_TOKEN")
|
||||
self._progress_callback = progress_callback
|
||||
self._pipeline: CustomCtrlRegenPipeline | None = None # type: ignore[assignment]
|
||||
self._canny_detector: CannyDetector | None = None # type: ignore[assignment]
|
||||
|
||||
def _set_progress(self, message: str) -> None:
|
||||
if self._progress_callback is None:
|
||||
return
|
||||
try:
|
||||
self._progress_callback(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load(self) -> None:
|
||||
"""Download and assemble the full CtrlRegen pipeline."""
|
||||
if self._pipeline is not None:
|
||||
return
|
||||
|
||||
token_kwargs: dict[str, Any] = {}
|
||||
if self.hf_token:
|
||||
token_kwargs["token"] = self.hf_token
|
||||
|
||||
self._set_progress(f"Loading CtrlRegen spatial ControlNet from {CTRLREGEN_HF_REPO}...")
|
||||
logger.info("Loading ControlNet from %s/%s", CTRLREGEN_HF_REPO, SPATIAL_SUBFOLDER)
|
||||
controlnet = [
|
||||
ControlNetModel.from_pretrained(
|
||||
CTRLREGEN_HF_REPO,
|
||||
subfolder=SPATIAL_SUBFOLDER,
|
||||
torch_dtype=self.torch_dtype,
|
||||
**token_kwargs,
|
||||
)
|
||||
]
|
||||
|
||||
self._set_progress(f"Loading SD base model ({self.base_model_id}) for CtrlRegen pipeline...")
|
||||
logger.info("Loading base pipeline from %s", self.base_model_id)
|
||||
pipe = CustomCtrlRegenPipeline.from_pretrained(
|
||||
self.base_model_id,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=self.torch_dtype,
|
||||
safety_checker=None,
|
||||
requires_safety_checker=False,
|
||||
**token_kwargs,
|
||||
)
|
||||
|
||||
self._set_progress(f"Loading CtrlRegen semantic IP-Adapter + DINOv2 from {CTRLREGEN_HF_REPO}...")
|
||||
logger.info("Loading IP-Adapter from %s/%s", CTRLREGEN_HF_REPO, SEMANTIC_SUBFOLDER)
|
||||
pipe.load_ctrlregen_ip_adapter(
|
||||
CTRLREGEN_HF_REPO,
|
||||
subfolder=SEMANTIC_SUBFOLDER,
|
||||
weight_name=SEMANTIC_WEIGHT_NAME,
|
||||
**token_kwargs,
|
||||
)
|
||||
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
pipe.image_encoder = AutoModel.from_pretrained("facebook/dinov2-giant").to(self.device, dtype=self.torch_dtype)
|
||||
pipe.feature_extractor = AutoImageProcessor.from_pretrained("facebook/dinov2-giant")
|
||||
|
||||
self._set_progress(f"Loading custom VAE ({CUSTOM_VAE_ID})...")
|
||||
logger.info("Loading VAE from %s", CUSTOM_VAE_ID)
|
||||
pipe.vae = AutoencoderKL.from_pretrained(
|
||||
CUSTOM_VAE_ID,
|
||||
torch_dtype=self.torch_dtype,
|
||||
**token_kwargs,
|
||||
)
|
||||
|
||||
self._set_progress("Configuring UniPC scheduler...")
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_ip_adapter_scale(1.0)
|
||||
|
||||
self._set_progress(f"Moving CtrlRegen pipeline to {self.device}...")
|
||||
pipe = pipe.to(self.device)
|
||||
|
||||
if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
|
||||
try:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._pipeline = pipe
|
||||
self._canny_detector = CannyDetector()
|
||||
self._set_progress("CtrlRegen pipeline ready.")
|
||||
logger.info("CtrlRegen pipeline loaded on %s", self.device)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inference — public entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def run(
|
||||
self,
|
||||
image: Image.Image,
|
||||
strength: float = 0.5,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||
seed: int | None = None,
|
||||
) -> Image.Image:
|
||||
"""Run CtrlRegen watermark removal on a single image.
|
||||
|
||||
Images that fit within ``TILE_SIZE`` (512) are processed as a
|
||||
single pass. Larger images are split into overlapping tiles.
|
||||
"""
|
||||
self.load()
|
||||
assert self._pipeline is not None
|
||||
assert self._canny_detector is not None
|
||||
|
||||
orig_w, orig_h = image.size
|
||||
orig_image = image
|
||||
t0 = time.monotonic()
|
||||
|
||||
needs_tiling = orig_w > TILE_SIZE or orig_h > TILE_SIZE
|
||||
|
||||
if needs_tiling:
|
||||
from ctrlregen.tiling import resize_center_crop, run_tiled
|
||||
|
||||
aligned_w = orig_w // 8 * 8
|
||||
aligned_h = orig_h // 8 * 8
|
||||
if aligned_w != orig_w or aligned_h != orig_h:
|
||||
image = image.resize((aligned_w, aligned_h), Image.LANCZOS)
|
||||
regen_image = run_tiled(
|
||||
pipeline=self._pipeline,
|
||||
canny_detector=self._canny_detector,
|
||||
image=image,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
tile_size=TILE_SIZE,
|
||||
tile_overlap=TILE_OVERLAP,
|
||||
quality_prompt=QUALITY_PROMPT,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
canny_low=CANNY_LOW_THRESHOLD,
|
||||
canny_high=CANNY_HIGH_THRESHOLD,
|
||||
device=self.device,
|
||||
set_progress=self._set_progress,
|
||||
ip_adapter_image=orig_image,
|
||||
)
|
||||
else:
|
||||
from ctrlregen.tiling import resize_center_crop
|
||||
|
||||
proc_image = resize_center_crop(image, PROCESS_SIZE)
|
||||
self._set_progress(f"Preprocessed {orig_w}x{orig_h}px → {proc_image.size[0]}x{proc_image.size[1]}px")
|
||||
regen_image = self._run_single(
|
||||
proc_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
)
|
||||
|
||||
if regen_image.size != (orig_w, orig_h):
|
||||
self._set_progress(f"Resizing {regen_image.size[0]}x{regen_image.size[1]}px → {orig_w}x{orig_h}px...")
|
||||
regen_image = regen_image.resize((orig_w, orig_h), Image.LANCZOS)
|
||||
|
||||
self._set_progress(f"Applying color matching at {orig_w}x{orig_h}px...")
|
||||
output = color_match(reference=orig_image, source=regen_image)
|
||||
|
||||
self._set_progress(f"✓ CtrlRegen done · {orig_w}x{orig_h}px · {time.monotonic() - t0:.0f}s total")
|
||||
return output
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Single-image path (image <= 512x512)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_single(
|
||||
self,
|
||||
image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int | None,
|
||||
) -> Image.Image:
|
||||
"""Process a single 512x512 image through the CtrlRegen pipeline."""
|
||||
w, h = image.size
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
|
||||
self._set_progress(
|
||||
f"Extracting canny edges ({w}x{h}px, thresholds {CANNY_LOW_THRESHOLD}/{CANNY_HIGH_THRESHOLD})..."
|
||||
)
|
||||
control_image = self._canny_detector(
|
||||
image,
|
||||
low_threshold=CANNY_LOW_THRESHOLD,
|
||||
high_threshold=CANNY_HIGH_THRESHOLD,
|
||||
)
|
||||
|
||||
generator = torch.manual_seed(seed if seed is not None else 0)
|
||||
|
||||
self._set_progress(
|
||||
f"Config: strength={strength}, steps={num_inference_steps} "
|
||||
f"(~{effective_steps} effective), guidance={guidance_scale}"
|
||||
)
|
||||
|
||||
step_cb, first_step, pipeline_done, start_updater = make_pipeline_progress(
|
||||
effective_steps,
|
||||
self.device,
|
||||
self._set_progress,
|
||||
label="CtrlRegen denoising",
|
||||
)
|
||||
start_updater()
|
||||
|
||||
try:
|
||||
result = self._pipeline(
|
||||
prompt=QUALITY_PROMPT,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
image=[image],
|
||||
control_image=[control_image],
|
||||
controlnet_conditioning_scale=1.0,
|
||||
ip_adapter_image=[image],
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback=step_cb,
|
||||
callback_steps=1,
|
||||
)
|
||||
except TypeError:
|
||||
first_step.set()
|
||||
result = self._pipeline(
|
||||
prompt=QUALITY_PROMPT,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
image=[image],
|
||||
control_image=[control_image],
|
||||
controlnet_conditioning_scale=1.0,
|
||||
ip_adapter_image=[image],
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
)
|
||||
finally:
|
||||
first_step.set()
|
||||
pipeline_done.set()
|
||||
|
||||
return result.images[0]
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Custom IP-Adapter mixin using DINOv2 as the image encoder.
|
||||
|
||||
The standard diffusers ``IPAdapterMixin`` uses a CLIP image encoder.
|
||||
CtrlRegen replaces it with ``facebook/dinov2-giant`` for richer
|
||||
semantic features. This mixin provides ``load_ctrlregen_ip_adapter``
|
||||
which handles the custom weight format and encoder swap.
|
||||
|
||||
Attribution:
|
||||
Adapted from https://github.com/yepengliu/CtrlRegen (Apache-2.0).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from diffusers.utils import (
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
logging as diffusers_logging,
|
||||
)
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_diffusers_logger = diffusers_logging.get_logger(__name__)
|
||||
|
||||
DINOV2_MODEL_ID = "facebook/dinov2-giant"
|
||||
|
||||
|
||||
class CustomIPAdapterMixin:
|
||||
"""Mixin that adds ``load_ctrlregen_ip_adapter`` to a diffusers pipeline."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ctrlregen_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: str | list[str] | dict[str, torch.Tensor],
|
||||
subfolder: str | list[str],
|
||||
weight_name: str | list[str],
|
||||
image_encoder_folder: str | None = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Load CtrlRegen IP-Adapter weights and DINOv2 image encoder.
|
||||
|
||||
Parameters mirror ``IPAdapterMixin.load_ip_adapter`` but the
|
||||
image encoder is always ``facebook/dinov2-giant`` regardless of
|
||||
the ``image_encoder_folder`` value in the checkpoint.
|
||||
"""
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
_diffusers_logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because "
|
||||
"`accelerate` was not found. Defaulting to "
|
||||
"`low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError("Low memory initialization requires torch >= 1.9.0.")
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dicts: list[dict] = []
|
||||
for path_or_dict, wn, sf in zip(pretrained_model_name_or_path_or_dict, weight_name, subfolder):
|
||||
if not isinstance(path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
path_or_dict,
|
||||
weights_name=wn,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=sf,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if wn.endswith(".safetensors"):
|
||||
state_dict: dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# Always use DINOv2-giant as the image encoder.
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
logger.info("Loading DINOv2-giant image encoder for CtrlRegen")
|
||||
enc_dtype = getattr(self, "dtype", torch.float32) # type: ignore[attr-defined]
|
||||
image_encoder = AutoModel.from_pretrained(DINOV2_MODEL_ID).to(
|
||||
self.device,
|
||||
dtype=enc_dtype, # type: ignore[attr-defined]
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder) # type: ignore[attr-defined]
|
||||
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
feature_extractor = AutoImageProcessor.from_pretrained(DINOV2_MODEL_ID)
|
||||
self.register_modules(feature_extractor=feature_extractor) # type: ignore[attr-defined]
|
||||
|
||||
unet = (
|
||||
getattr(self, self.unet_name) # type: ignore[attr-defined]
|
||||
if not hasattr(self, "unet")
|
||||
else self.unet # type: ignore[attr-defined]
|
||||
)
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Custom Stable Diffusion ControlNet Img2Img pipeline for CtrlRegen.
|
||||
|
||||
Extends ``StableDiffusionControlNetImg2ImgPipeline`` with the
|
||||
``load_ctrlregen_ip_adapter`` method (via ``CustomIPAdapterMixin``)
|
||||
that swaps in DINOv2-giant as the image encoder and loads the
|
||||
CtrlRegen semantic-control adapter weights.
|
||||
|
||||
No ``encode_image`` override is needed — the CtrlRegen checkpoint
|
||||
creates an ``IPAdapterPlusImageProjection`` which tells diffusers to
|
||||
call ``encode_image`` with ``output_hidden_states=True``. The
|
||||
default implementation then uses ``hidden_states[-2]`` from DINOv2,
|
||||
which is exactly what the projection was trained on.
|
||||
|
||||
Attribution:
|
||||
Adapted from https://github.com/yepengliu/CtrlRegen (Apache-2.0).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from diffusers import StableDiffusionControlNetImg2ImgPipeline
|
||||
|
||||
from remove_ai_watermarks.noai.ctrlregen.ip_adapter import CustomIPAdapterMixin
|
||||
|
||||
|
||||
class CustomCtrlRegenPipeline(
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
CustomIPAdapterMixin,
|
||||
):
|
||||
"""SD ControlNet Img2Img pipeline with DINOv2 IP-Adapter support.
|
||||
|
||||
MRO mirrors the original CtrlRegen repository: the base diffusers
|
||||
pipeline comes first so all standard methods are resolved from it,
|
||||
while ``CustomIPAdapterMixin`` only adds the
|
||||
``load_ctrlregen_ip_adapter`` method.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,175 @@
|
||||
"""Tile-based processing for large images in the CtrlRegen pipeline.
|
||||
|
||||
Extracted from ``ctrlregen.engine`` to keep the engine focused on
|
||||
single-image inference and model orchestration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def tile_positions(total: int, tile: int, overlap: int) -> list[int]:
|
||||
"""Compute evenly-spaced tile start positions covering *total* pixels."""
|
||||
if total <= tile:
|
||||
return [0]
|
||||
n = max(2, math.ceil((total - overlap) / (tile - overlap)))
|
||||
stride = (total - tile) / (n - 1)
|
||||
return [round(i * stride) for i in range(n)]
|
||||
|
||||
|
||||
def make_blend_weight(h: int, w: int, overlap: int) -> np.ndarray:
|
||||
"""2-D weight mask: 1.0 in center, cosine ramp in overlap margins."""
|
||||
wy = np.ones(h, dtype=np.float64)
|
||||
wx = np.ones(w, dtype=np.float64)
|
||||
if overlap > 0:
|
||||
ramp = 0.5 - 0.5 * np.cos(np.linspace(0, np.pi, overlap))
|
||||
wy[:overlap] = np.minimum(wy[:overlap], ramp)
|
||||
wy[-overlap:] = np.minimum(wy[-overlap:], ramp[::-1])
|
||||
wx[:overlap] = np.minimum(wx[:overlap], ramp)
|
||||
wx[-overlap:] = np.minimum(wx[-overlap:], ramp[::-1])
|
||||
return np.outer(wy, wx)
|
||||
|
||||
|
||||
def resize_center_crop(image: Image.Image, size: int = 512) -> Image.Image:
|
||||
"""Resize shortest edge to *size*, then center-crop to a square.
|
||||
|
||||
Matches the ``transforms.Resize(512) + CenterCrop(512)`` pipeline
|
||||
used in the original CtrlRegen repository.
|
||||
"""
|
||||
w, h = image.size
|
||||
short = min(w, h)
|
||||
scale = size / short
|
||||
new_w, new_h = round(w * scale), round(h * scale)
|
||||
image = image.resize((new_w, new_h), Image.BILINEAR)
|
||||
left = (new_w - size) // 2
|
||||
top = (new_h - size) // 2
|
||||
return image.crop((left, top, left + size, top + size))
|
||||
|
||||
|
||||
def run_tiled(
|
||||
pipeline: Any,
|
||||
canny_detector: Any,
|
||||
image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int | None,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int,
|
||||
quality_prompt: str,
|
||||
negative_prompt: str,
|
||||
canny_low: int,
|
||||
canny_high: int,
|
||||
device: str,
|
||||
set_progress: Callable[[str], None],
|
||||
ip_adapter_image: Image.Image | None = None,
|
||||
) -> Image.Image:
|
||||
"""Split a large image into overlapping tiles, process each, blend."""
|
||||
w, h = image.size
|
||||
xs = tile_positions(w, tile_size, tile_overlap)
|
||||
ys = tile_positions(h, tile_size, tile_overlap)
|
||||
n_tiles = len(xs) * len(ys)
|
||||
grid = f"{len(xs)}x{len(ys)}"
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
|
||||
set_progress(f"Tiling {w}x{h}px → {n_tiles} tiles ({grid} grid, {tile_size}px, overlap {tile_overlap}px)")
|
||||
|
||||
canvas = np.zeros((h, w, 3), dtype=np.float64)
|
||||
weight_sum = np.zeros((h, w), dtype=np.float64)
|
||||
blend_w = make_blend_weight(tile_size, tile_size, tile_overlap)
|
||||
|
||||
t0 = time.monotonic()
|
||||
bar_len = 20
|
||||
tile_idx = 0
|
||||
|
||||
for ty in ys:
|
||||
for tx in xs:
|
||||
tile_idx += 1
|
||||
prefix = f"[Tile {tile_idx}/{n_tiles}]"
|
||||
|
||||
tile = image.crop((tx, ty, tx + tile_size, ty + tile_size))
|
||||
|
||||
set_progress(f"{prefix} Extracting canny edges...")
|
||||
control = canny_detector(
|
||||
tile,
|
||||
low_threshold=canny_low,
|
||||
high_threshold=canny_high,
|
||||
)
|
||||
|
||||
gen = None
|
||||
if seed is not None:
|
||||
gen = torch.Generator(device=device).manual_seed(seed + tile_idx)
|
||||
|
||||
tile_t0 = time.monotonic()
|
||||
|
||||
def _make_cb(
|
||||
_prefix: str = prefix,
|
||||
_t0: float = tile_t0,
|
||||
_es: int = effective_steps,
|
||||
) -> Callable:
|
||||
def _cb(step: int, timestep: int, latents: Any) -> None: # noqa: ARG001
|
||||
elapsed = time.monotonic() - _t0
|
||||
cur = step + 1
|
||||
per = elapsed / max(1, cur)
|
||||
rem = per * max(0, _es - cur)
|
||||
filled = int(bar_len * cur / max(1, _es))
|
||||
bar = "█" * filled + "░" * (bar_len - filled)
|
||||
set_progress(f"{_prefix} [{bar}] {cur}/{_es} | {elapsed:.0f}s, ~{rem:.0f}s left")
|
||||
|
||||
return _cb
|
||||
|
||||
sem_image = ip_adapter_image if ip_adapter_image is not None else tile
|
||||
|
||||
try:
|
||||
result = pipeline(
|
||||
prompt=quality_prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=[tile],
|
||||
control_image=[control],
|
||||
controlnet_conditioning_scale=1.0,
|
||||
ip_adapter_image=[sem_image],
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=gen,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback=_make_cb(),
|
||||
callback_steps=1,
|
||||
)
|
||||
except TypeError:
|
||||
result = pipeline(
|
||||
prompt=quality_prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=[tile],
|
||||
control_image=[control],
|
||||
controlnet_conditioning_scale=1.0,
|
||||
ip_adapter_image=[sem_image],
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
proc_arr = np.array(result.images[0], dtype=np.float64)
|
||||
th, tw = proc_arr.shape[:2]
|
||||
mask = blend_w[:th, :tw]
|
||||
canvas[ty : ty + th, tx : tx + tw] += proc_arr * mask[..., None]
|
||||
weight_sum[ty : ty + th, tx : tx + tw] += mask
|
||||
|
||||
tile_time = time.monotonic() - tile_t0
|
||||
total_elapsed = time.monotonic() - t0
|
||||
set_progress(f"{prefix} Done ({tile_time:.0f}s) · Total: {total_elapsed:.0f}s")
|
||||
|
||||
set_progress(f"Blending {n_tiles} tiles → {w}x{h}px...")
|
||||
canvas /= np.maximum(weight_sum[..., None], 1e-8)
|
||||
return Image.fromarray(np.clip(canvas, 0, 255).astype(np.uint8))
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Read-only metadata extraction from PNG and JPEG images.
|
||||
|
||||
Provides functions to pull all metadata, AI-only metadata, or a
|
||||
human-readable summary without modifying the source file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import piexif
|
||||
from PIL import Image
|
||||
|
||||
from remove_ai_watermarks.noai.c2pa import extract_c2pa_chunk, extract_c2pa_info, has_c2pa_metadata
|
||||
from remove_ai_watermarks.noai.constants import AI_KEYWORDS, AI_METADATA_KEYS, PNG_METADATA_KEYS
|
||||
|
||||
|
||||
def extract_metadata(source_path: Path) -> dict[str, Any]:
|
||||
"""
|
||||
Extract all metadata from a PNG or JPG file.
|
||||
|
||||
Args:
|
||||
source_path: Path to the source image file.
|
||||
|
||||
Returns:
|
||||
Dictionary containing all extracted metadata.
|
||||
"""
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
with Image.open(source_path) as img:
|
||||
# Extract EXIF data
|
||||
if "exif" in img.info:
|
||||
try:
|
||||
exif_dict = piexif.load(img.info["exif"])
|
||||
metadata["exif"] = exif_dict
|
||||
except Exception:
|
||||
metadata["exif_raw"] = img.info["exif"]
|
||||
|
||||
# Extract standard PNG metadata
|
||||
for key in PNG_METADATA_KEYS:
|
||||
if key in img.info:
|
||||
metadata[key] = img.info[key]
|
||||
|
||||
# Extract all other metadata including AI-specific
|
||||
for key, value in img.info.items():
|
||||
if key not in metadata and key not in ["exif"]:
|
||||
metadata[key] = value
|
||||
|
||||
# Extract DPI and gamma if present
|
||||
if "dpi" in img.info:
|
||||
metadata["dpi"] = img.info["dpi"]
|
||||
if "gamma" in img.info:
|
||||
metadata["gamma"] = img.info["gamma"]
|
||||
|
||||
# Check for C2PA metadata
|
||||
if has_c2pa_metadata(source_path):
|
||||
metadata["c2pa"] = extract_c2pa_info(source_path)
|
||||
c2pa_chunk = extract_c2pa_chunk(source_path)
|
||||
if c2pa_chunk:
|
||||
metadata["c2pa_chunk"] = c2pa_chunk
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_ai_metadata(source_path: Path) -> dict[str, Any]:
|
||||
"""
|
||||
Extract only AI-generated metadata from a PNG or JPG file.
|
||||
|
||||
Args:
|
||||
source_path: Path to the source image file.
|
||||
|
||||
Returns:
|
||||
Dictionary containing only AI-related metadata.
|
||||
"""
|
||||
ai_metadata: dict[str, Any] = {}
|
||||
|
||||
with Image.open(source_path) as img:
|
||||
for key in AI_METADATA_KEYS:
|
||||
if key in img.info:
|
||||
ai_metadata[key] = img.info[key]
|
||||
|
||||
for key, value in img.info.items():
|
||||
key_lower = key.lower()
|
||||
if key not in ai_metadata:
|
||||
if any(kw in key_lower for kw in AI_KEYWORDS):
|
||||
ai_metadata[key] = value
|
||||
|
||||
# Check for C2PA metadata
|
||||
if has_c2pa_metadata(source_path):
|
||||
ai_metadata["c2pa"] = extract_c2pa_info(source_path)
|
||||
c2pa_chunk = extract_c2pa_chunk(source_path)
|
||||
if c2pa_chunk:
|
||||
ai_metadata["c2pa_chunk"] = c2pa_chunk
|
||||
|
||||
return ai_metadata
|
||||
|
||||
|
||||
def has_ai_metadata(image_path: Path) -> bool:
|
||||
"""
|
||||
Check if an image contains AI-generated metadata.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
True if AI metadata is detected, False otherwise.
|
||||
"""
|
||||
with Image.open(image_path) as img:
|
||||
for key in AI_METADATA_KEYS:
|
||||
if key in img.info:
|
||||
return True
|
||||
|
||||
if has_c2pa_metadata(image_path):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_ai_metadata_summary(source_path: Path) -> str:
|
||||
"""
|
||||
Get a human-readable summary of AI metadata.
|
||||
|
||||
Args:
|
||||
source_path: Path to the source image file.
|
||||
|
||||
Returns:
|
||||
Formatted string with AI metadata summary.
|
||||
"""
|
||||
ai_meta = extract_ai_metadata(source_path)
|
||||
|
||||
if not ai_meta:
|
||||
return "No AI metadata found."
|
||||
|
||||
lines = ["AI Image Metadata:"]
|
||||
lines.append("-" * 40)
|
||||
|
||||
for key, value in ai_meta.items():
|
||||
if key == "c2pa_chunk":
|
||||
continue
|
||||
elif key == "c2pa" and isinstance(value, dict):
|
||||
lines.append("C2PA Metadata:")
|
||||
for ck, cv in value.items():
|
||||
lines.append(f" {ck}: {cv}")
|
||||
elif isinstance(value, str) and len(value) > 100:
|
||||
value = value[:100] + "..."
|
||||
lines.append(f"{key}: {value}")
|
||||
elif isinstance(value, bytes):
|
||||
lines.append(f"{key}: <binary data ({len(value)} bytes)>")
|
||||
else:
|
||||
lines.append(f"{key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Img2img pipeline execution with progress monitoring and MPS fallback.
|
||||
|
||||
Extracted from ``watermark_remover.py`` to keep the ``WatermarkRemover``
|
||||
class focused on orchestration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from remove_ai_watermarks.noai.progress import is_mps_error, make_pipeline_progress
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_img2img(
|
||||
pipeline: Any,
|
||||
image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
generator: Any,
|
||||
device: str,
|
||||
set_progress: Callable[[str], None],
|
||||
) -> Image.Image:
|
||||
"""Execute img2img with live progress and return the generated image."""
|
||||
w, h = image.size
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
|
||||
step_cb, first_step, done_ev, start_updater = make_pipeline_progress(
|
||||
effective_steps,
|
||||
device,
|
||||
set_progress,
|
||||
)
|
||||
start_updater()
|
||||
|
||||
try:
|
||||
result = _call_pipeline(
|
||||
pipeline,
|
||||
image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
step_cb,
|
||||
)
|
||||
done_ev.set()
|
||||
return result.images[0]
|
||||
except TypeError:
|
||||
first_step.set()
|
||||
result = _call_pipeline(
|
||||
pipeline,
|
||||
image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
None,
|
||||
)
|
||||
done_ev.set()
|
||||
return result.images[0]
|
||||
finally:
|
||||
first_step.set()
|
||||
done_ev.set()
|
||||
|
||||
|
||||
def run_img2img_with_mps_fallback(
|
||||
load_pipeline: Callable[[], Any],
|
||||
image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
generator: Any,
|
||||
device: str,
|
||||
set_progress: Callable[[str], None],
|
||||
*,
|
||||
reload_on_cpu: Callable[[], Any],
|
||||
) -> tuple[Image.Image, str]:
|
||||
"""Run img2img; on MPS error, fall back to CPU.
|
||||
|
||||
Returns:
|
||||
(result_image, final_device) — device may change to ``"cpu"`` on fallback.
|
||||
"""
|
||||
pipeline = load_pipeline()
|
||||
|
||||
try:
|
||||
img = run_img2img(
|
||||
pipeline,
|
||||
image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
device,
|
||||
set_progress,
|
||||
)
|
||||
return img, device
|
||||
except RuntimeError as error:
|
||||
if device == "mps" and is_mps_error(error):
|
||||
logger.warning("MPS error detected: %s. Falling back to CPU.", error)
|
||||
set_progress("MPS error! Clearing cache and retrying on CPU...")
|
||||
_try_clear_mps_cache()
|
||||
pipeline = reload_on_cpu()
|
||||
img = run_img2img(
|
||||
pipeline,
|
||||
image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
None,
|
||||
"cpu",
|
||||
set_progress,
|
||||
)
|
||||
return img, "cpu"
|
||||
raise
|
||||
|
||||
|
||||
def _call_pipeline(
|
||||
pipeline: Any,
|
||||
image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
generator: Any,
|
||||
step_callback: Any,
|
||||
) -> Any:
|
||||
kwargs: dict[str, Any] = {
|
||||
"prompt": "",
|
||||
"image": image,
|
||||
"strength": strength,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"guidance_scale": guidance_scale,
|
||||
"generator": generator,
|
||||
}
|
||||
if step_callback is not None:
|
||||
kwargs["callback"] = step_callback
|
||||
kwargs["callback_steps"] = 1
|
||||
return pipeline(**kwargs)
|
||||
|
||||
|
||||
def _try_clear_mps_cache() -> None:
|
||||
try:
|
||||
import torch
|
||||
|
||||
if hasattr(torch, "mps"):
|
||||
torch.mps.empty_cache() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
@@ -0,0 +1,333 @@
|
||||
"""Terminal progress animation and library output suppression.
|
||||
|
||||
This module provides two main capabilities for the CLI:
|
||||
|
||||
1. ``run_with_progress`` — a styled two-line terminal animation that
|
||||
displays a bouncing highlight bar, a braille spinner, elapsed time,
|
||||
and a live operation message while a background task executes.
|
||||
|
||||
2. ``silence_library_output`` — a wrapper that suppresses noisy log
|
||||
output produced by third-party ML libraries (transformers, diffusers,
|
||||
huggingface_hub, tqdm) so the user only sees our own progress messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
# ── ANSI color constants ────────────────────────────────────────────
|
||||
_CYAN = "\033[36m"
|
||||
_YELLOW = "\033[33m"
|
||||
_GREEN = "\033[32m"
|
||||
_DIM = "\033[2m"
|
||||
_BOLD = "\033[1m"
|
||||
_RESET = "\033[0m"
|
||||
|
||||
# Bar geometry
|
||||
_BAR_WIDTH = 32
|
||||
_HIGHLIGHT_WIDTH = 5
|
||||
|
||||
|
||||
def _no_color() -> bool:
|
||||
"""Respect the NO_COLOR convention (https://no-color.org/)."""
|
||||
return bool(os.environ.get("NO_COLOR"))
|
||||
|
||||
|
||||
def _truncate(text: str, max_len: int = 72) -> str:
|
||||
"""Shorten a string with an ellipsis if it exceeds *max_len*."""
|
||||
return text if len(text) <= max_len else text[: max_len - 1] + "…"
|
||||
|
||||
|
||||
def _build_bar(step: int) -> str:
|
||||
"""Build a flowing highlight bar that bounces across the width.
|
||||
|
||||
The highlight segment (5 chars wide) travels left→right→left
|
||||
continuously, giving the user a visual "working" signal.
|
||||
"""
|
||||
cycle = _BAR_WIDTH * 2 - 2
|
||||
pos = step % cycle
|
||||
if pos >= _BAR_WIDTH:
|
||||
pos = cycle - pos
|
||||
|
||||
hl_start = max(0, pos - _HIGHLIGHT_WIDTH // 2)
|
||||
hl_end = min(_BAR_WIDTH, pos + _HIGHLIGHT_WIDTH // 2 + 1)
|
||||
before = "━" * hl_start
|
||||
highlight = "━" * (hl_end - hl_start)
|
||||
after = "━" * (_BAR_WIDTH - hl_end)
|
||||
|
||||
if _no_color():
|
||||
return before + highlight + after
|
||||
return f"{_DIM}{before}{_RESET}{_BOLD}{_YELLOW}{highlight}{_RESET}{_DIM}{after}{_RESET}"
|
||||
|
||||
|
||||
def run_with_progress(
|
||||
task: Callable[[], Any],
|
||||
progress_state: dict[str, str] | None = None,
|
||||
) -> Any:
|
||||
"""Execute *task* in a background thread while showing a progress animation.
|
||||
|
||||
The animation renders two lines to ``sys.__stderr__``:
|
||||
|
||||
- **Line 1**: braille spinner + bouncing bar + elapsed seconds
|
||||
- **Line 2**: current operation message from *progress_state*
|
||||
|
||||
When the task finishes, a green "Completed" line replaces the animation.
|
||||
|
||||
Args:
|
||||
task: A zero-argument callable to run in the background.
|
||||
progress_state: Mutable dict whose ``"message"`` key is read
|
||||
by the animation loop to display the current operation.
|
||||
|
||||
Returns:
|
||||
Whatever *task* returns.
|
||||
|
||||
Raises:
|
||||
Any exception raised by *task* is re-raised after the animation
|
||||
is cleaned up.
|
||||
"""
|
||||
done = threading.Event()
|
||||
output_holder: dict[str, Any] = {"result": None, "error": None}
|
||||
|
||||
def worker() -> None:
|
||||
try:
|
||||
output_holder["result"] = task()
|
||||
except Exception as error: # pragma: no cover – passthrough
|
||||
output_holder["error"] = error
|
||||
finally:
|
||||
done.set()
|
||||
|
||||
thread = threading.Thread(target=worker, daemon=True)
|
||||
thread.start()
|
||||
|
||||
spinner_frames = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
|
||||
idx = 0
|
||||
start_time = time.time()
|
||||
no_color = _no_color()
|
||||
|
||||
def _get_operation() -> str:
|
||||
if isinstance(progress_state, dict):
|
||||
return progress_state.get("message", "Processing...")
|
||||
return "Processing..."
|
||||
|
||||
# ── Animation loop ──────────────────────────────────────────────
|
||||
while not done.is_set():
|
||||
spinner = spinner_frames[idx % len(spinner_frames)]
|
||||
elapsed = int(time.time() - start_time)
|
||||
bar_str = _build_bar(idx)
|
||||
operation = _truncate(_get_operation())
|
||||
|
||||
if no_color:
|
||||
line1 = f" {spinner} Processing {bar_str} {elapsed:>3}s"
|
||||
line2 = f" ╰─ {operation}"
|
||||
else:
|
||||
line1 = f" {_CYAN}{spinner}{_RESET} Processing {bar_str} {_BOLD}{_YELLOW}{elapsed:>3}s{_RESET}"
|
||||
line2 = f" {_DIM}╰─ {operation}{_RESET}"
|
||||
|
||||
print(
|
||||
f"\r\033[2K{line1}\n\033[2K{line2}\033[1A\r",
|
||||
end="",
|
||||
flush=True,
|
||||
file=sys.__stderr__,
|
||||
)
|
||||
time.sleep(0.08)
|
||||
idx += 1
|
||||
|
||||
# ── Final "done" frame ──────────────────────────────────────────
|
||||
thread.join()
|
||||
total = int(time.time() - start_time)
|
||||
final_operation = _truncate(_get_operation())
|
||||
done_bar = "━" * _BAR_WIDTH
|
||||
|
||||
if no_color:
|
||||
final_line1 = f" ✓ Completed {done_bar} {total:>3}s"
|
||||
final_line2 = f" ╰─ {final_operation}"
|
||||
else:
|
||||
final_line1 = (
|
||||
f" {_GREEN}{_BOLD}✓{_RESET} {_GREEN}Completed{_RESET} "
|
||||
f"{_GREEN}{done_bar}{_RESET} {_BOLD}{_GREEN}{total:>3}s{_RESET}"
|
||||
)
|
||||
final_line2 = f" {_DIM}╰─ {final_operation}{_RESET}"
|
||||
|
||||
print(
|
||||
f"\r\033[2K{final_line1}\n\033[2K{final_line2}",
|
||||
file=sys.__stderr__,
|
||||
)
|
||||
|
||||
if output_holder["error"] is not None:
|
||||
raise output_holder["error"]
|
||||
|
||||
return output_holder["result"]
|
||||
|
||||
|
||||
def silence_library_output(
|
||||
run_func: Callable[[], Any],
|
||||
set_progress: Callable[[str], None] | None = None,
|
||||
) -> Callable[[], Any]:
|
||||
"""Return a wrapper that silences noisy ML library output.
|
||||
|
||||
The wrapper:
|
||||
|
||||
1. Disables HuggingFace Hub progress bars via env var.
|
||||
2. Sets ``transformers``, ``diffusers``, and ``huggingface_hub``
|
||||
loggers to *error* level.
|
||||
3. Redirects ``stdout`` and ``stderr`` to ``io.StringIO`` sinks so
|
||||
that stray ``tqdm`` bars and model-loading chatter are invisible.
|
||||
4. Suppresses all Python warnings during the call.
|
||||
|
||||
Args:
|
||||
run_func: The callable to execute silently.
|
||||
set_progress: Optional callback to report phase changes.
|
||||
|
||||
Returns:
|
||||
A zero-argument callable that, when invoked, runs *run_func*
|
||||
inside the silent context.
|
||||
"""
|
||||
|
||||
def wrapped() -> Any:
|
||||
if set_progress:
|
||||
set_progress("Configuring runtime and suppressing noisy logs...")
|
||||
|
||||
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
||||
|
||||
for _silence in (
|
||||
lambda: __import__("transformers").logging.set_verbosity_error(),
|
||||
lambda: _silence_diffusers(),
|
||||
lambda: __import__("huggingface_hub").logging.set_verbosity_error(),
|
||||
):
|
||||
try:
|
||||
_silence()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
with contextlib.redirect_stdout(io.StringIO()):
|
||||
with contextlib.redirect_stderr(io.StringIO()):
|
||||
if set_progress:
|
||||
set_progress("Executing watermark removal pipeline...")
|
||||
return run_func()
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _silence_diffusers() -> None:
|
||||
"""Silence diffusers logging and progress bars."""
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
diffusers_logging.set_verbosity_error()
|
||||
if hasattr(diffusers_logging, "disable_progress_bar"):
|
||||
diffusers_logging.disable_progress_bar()
|
||||
|
||||
|
||||
# ── Shared pipeline progress helpers ─────────────────────────────────
|
||||
|
||||
_DEFAULT_PRE_PHASES: list[tuple[int, str]] = [
|
||||
(0, "Encoding image with VAE encoder"),
|
||||
(3, "Mapping pixel data → latent space"),
|
||||
(7, "Injecting noise into latent representation"),
|
||||
(12, "Building denoiser schedule"),
|
||||
(18, "Starting reverse diffusion sampler"),
|
||||
(30, "Running first denoising iteration"),
|
||||
(50, "Still processing — this can take a while"),
|
||||
(90, "Pipeline running — may take a few minutes"),
|
||||
]
|
||||
|
||||
_DEFAULT_POST_PHASES: list[tuple[int, str]] = [
|
||||
(0, "Denoising complete · Running VAE decoder"),
|
||||
(2, "Decoding latent channels → RGB color space"),
|
||||
(5, "Reconstructing pixel grid from latents"),
|
||||
(10, "Applying color space conversion and normalization"),
|
||||
(18, "Finalizing pixel output"),
|
||||
(30, "Still decoding — large images take longer"),
|
||||
(60, "Almost done — large images take longer to decode"),
|
||||
]
|
||||
|
||||
|
||||
def make_pipeline_progress(
|
||||
effective_steps: int,
|
||||
device: str,
|
||||
set_progress: Callable[[str], None],
|
||||
*,
|
||||
bar_len: int = 20,
|
||||
label: str = "Denoising",
|
||||
pre_phases: list[tuple[int, str]] | None = None,
|
||||
post_phases: list[tuple[int, str]] | None = None,
|
||||
) -> tuple[Callable, threading.Event, threading.Event, Callable[[], threading.Thread]]:
|
||||
"""Create step callback and background updater for a diffusion pipeline.
|
||||
|
||||
Returns:
|
||||
(step_callback, first_step_event, pipeline_done_event, start_updater)
|
||||
where ``start_updater()`` launches and returns the background thread.
|
||||
"""
|
||||
pre = pre_phases or [(s, f"{m} on {device}") for s, m in _DEFAULT_PRE_PHASES]
|
||||
post = post_phases or [(s, f"{m} on {device}") for s, m in _DEFAULT_POST_PHASES]
|
||||
|
||||
t0_holder: list[float] = [time.monotonic()]
|
||||
first_step = threading.Event()
|
||||
pipeline_done = threading.Event()
|
||||
last_cb_time: list[float] = [t0_holder[0]]
|
||||
|
||||
def _background_updater() -> None:
|
||||
idx = 0
|
||||
while not first_step.is_set():
|
||||
elapsed = time.monotonic() - t0_holder[0]
|
||||
while idx < len(pre) - 1 and elapsed >= pre[idx + 1][0]:
|
||||
idx += 1
|
||||
set_progress(pre[idx][1])
|
||||
first_step.wait(timeout=0.4)
|
||||
|
||||
idx = 0
|
||||
post_start: float | None = None
|
||||
while not pipeline_done.is_set():
|
||||
since_cb = time.monotonic() - last_cb_time[0]
|
||||
if since_cb >= 1.5:
|
||||
if post_start is None:
|
||||
post_start = time.monotonic()
|
||||
elapsed = time.monotonic() - post_start
|
||||
while idx < len(post) - 1 and elapsed >= post[idx + 1][0]:
|
||||
idx += 1
|
||||
set_progress(post[idx][1])
|
||||
else:
|
||||
post_start = None
|
||||
idx = 0
|
||||
pipeline_done.wait(timeout=0.4)
|
||||
|
||||
def step_callback(step: int, timestep: int, latents: Any) -> None: # noqa: ARG001
|
||||
first_step.set()
|
||||
last_cb_time[0] = time.monotonic()
|
||||
elapsed = time.monotonic() - t0_holder[0]
|
||||
current = step + 1
|
||||
per_step = elapsed / max(1, current)
|
||||
remaining = per_step * max(0, effective_steps - current)
|
||||
filled = int(bar_len * current / max(1, effective_steps))
|
||||
bar = "█" * filled + "░" * (bar_len - filled)
|
||||
set_progress(
|
||||
f"{label} [{bar}] {current}/{effective_steps} | {elapsed:.0f}s elapsed, ~{remaining:.0f}s left | {device}"
|
||||
)
|
||||
|
||||
def start_updater() -> threading.Thread:
|
||||
t0_holder[0] = time.monotonic()
|
||||
last_cb_time[0] = t0_holder[0]
|
||||
first_step.clear()
|
||||
pipeline_done.clear()
|
||||
t = threading.Thread(target=_background_updater, daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
return step_callback, first_step, pipeline_done, start_updater
|
||||
|
||||
|
||||
# ── MPS fallback helper ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def is_mps_error(error: Exception) -> bool:
|
||||
"""Check whether an exception is an MPS-related runtime error."""
|
||||
return "mps" in str(error).lower()
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Low-level utility helpers used across the metadata pipeline.
|
||||
|
||||
Kept deliberately small — only format detection lives here so that
|
||||
higher-level modules can import without circular dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from remove_ai_watermarks.noai.constants import SUPPORTED_FORMATS
|
||||
|
||||
|
||||
def is_supported_format(file_path: Path) -> bool:
|
||||
"""
|
||||
Check if the file format is supported.
|
||||
|
||||
Args:
|
||||
file_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
True if the format is supported, False otherwise.
|
||||
"""
|
||||
return file_path.suffix.lower() in SUPPORTED_FORMATS
|
||||
|
||||
|
||||
def get_image_format(file_path: Path) -> str:
|
||||
"""
|
||||
Get the image format from file path.
|
||||
|
||||
Args:
|
||||
file_path: Path to the image file.
|
||||
|
||||
Returns:
|
||||
Format string (PNG, JPEG, etc.).
|
||||
"""
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix in {".jpg", ".jpeg"}:
|
||||
return "JPEG"
|
||||
return "PNG"
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Watermark removal model profiles, strength presets, and profile detection.
|
||||
|
||||
Pure configuration and lookup functions with no ML dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
DEFAULT_MODEL_ID = "Lykon/dreamshaper-8"
|
||||
CTRLREGEN_MODEL_ID = "yepengliu/ctrlregen"
|
||||
|
||||
LOW_STRENGTH = 0.04
|
||||
MEDIUM_STRENGTH = 0.35
|
||||
HIGH_STRENGTH = 0.7
|
||||
|
||||
_HIGH_PERTURBATION = ("stegasamp", "stegastamp", "treering", "ringid")
|
||||
_LOW_PERTURBATION = ("stablesignature", "dwtectsvd", "rivagan", "ssl", "hidden")
|
||||
|
||||
|
||||
def get_model_id_for_profile(profile: str) -> str:
|
||||
"""Map CLI model profile names to concrete Hugging Face model IDs."""
|
||||
normalized = profile.strip().lower()
|
||||
if normalized == "default":
|
||||
return DEFAULT_MODEL_ID
|
||||
if normalized == "ctrlregen":
|
||||
return CTRLREGEN_MODEL_ID
|
||||
raise ValueError(f"Unknown model profile '{profile}'. Use one of: default, ctrlregen.")
|
||||
|
||||
|
||||
def detect_model_profile(model_id: str) -> str:
|
||||
"""Infer model profile from model identifier."""
|
||||
if "ctrlregen" in model_id.lower():
|
||||
return "ctrlregen"
|
||||
return "default"
|
||||
|
||||
|
||||
def get_recommended_strength(watermark_type: str) -> float:
|
||||
"""Get recommended strength for different watermark types.
|
||||
|
||||
Args:
|
||||
watermark_type: Type of watermark. One of: 'low', 'medium', 'high',
|
||||
or specific names like 'stegastamp', 'treering', etc.
|
||||
|
||||
Returns:
|
||||
Recommended strength value.
|
||||
"""
|
||||
wt = watermark_type.lower()
|
||||
if any(name in wt for name in _HIGH_PERTURBATION):
|
||||
return HIGH_STRENGTH
|
||||
if any(name in wt for name in _LOW_PERTURBATION):
|
||||
return LOW_STRENGTH
|
||||
return MEDIUM_STRENGTH
|
||||
@@ -0,0 +1,635 @@
|
||||
"""Watermark removal using diffusion model regeneration attack.
|
||||
|
||||
Based on the paper "Image Watermarks Are Removable Using Controllable
|
||||
Regeneration from Clean Noise" (ICLR 2025).
|
||||
|
||||
This module implements a simple regeneration attack that:
|
||||
1. Encodes the watermarked image to latent space
|
||||
2. Adds noise via forward diffusion process
|
||||
3. Denoises via reverse diffusion process
|
||||
4. Decodes back to pixel space
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from remove_ai_watermarks.noai.watermark_profiles import (
|
||||
CTRLREGEN_MODEL_ID,
|
||||
DEFAULT_MODEL_ID,
|
||||
HIGH_STRENGTH,
|
||||
LOW_STRENGTH,
|
||||
MEDIUM_STRENGTH,
|
||||
detect_model_profile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check for optional dependencies
|
||||
_HAS_TORCH = False
|
||||
_HAS_DIFFUSERS = False
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
_HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None # type: ignore
|
||||
|
||||
try:
|
||||
from diffusers import AutoPipelineForImage2Image as AutoImg2ImgPipeline
|
||||
|
||||
_HAS_DIFFUSERS = True
|
||||
except ImportError:
|
||||
AutoImg2ImgPipeline = None # type: ignore
|
||||
|
||||
|
||||
def is_watermark_removal_available() -> bool:
|
||||
"""Check if watermark removal dependencies are installed."""
|
||||
return _HAS_TORCH and _HAS_DIFFUSERS
|
||||
|
||||
|
||||
_CUDA_FIX_ENV_KEY = "NOAI_CUDA_FIXED"
|
||||
|
||||
|
||||
def _auto_install(packages: list[str], index_url: str | None = None) -> bool:
|
||||
"""Attempt to install missing packages via pip. Returns True on success."""
|
||||
import subprocess
|
||||
|
||||
cmd = [sys.executable, "-m", "pip", "install", "-q", *packages]
|
||||
if index_url:
|
||||
cmd.extend(["--index-url", index_url])
|
||||
try:
|
||||
subprocess.check_call(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def _has_nvidia_gpu() -> bool:
|
||||
"""Check if an NVIDIA GPU is present via nvidia-smi."""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
subprocess.check_call(
|
||||
["nvidia-smi"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def _detect_cuda_index_url() -> str:
|
||||
"""Detect the appropriate PyTorch CUDA index URL from nvidia-smi output."""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["nvidia-smi"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
for line in out.splitlines():
|
||||
if "CUDA Version" in line:
|
||||
version_str = line.split("CUDA Version:")[-1].strip().rstrip("|").strip()
|
||||
major, minor = version_str.split(".")[:2]
|
||||
cuda_tag = f"cu{major}{minor}"
|
||||
return f"https://download.pytorch.org/whl/{cuda_tag}"
|
||||
except Exception:
|
||||
pass
|
||||
return "https://download.pytorch.org/whl/cu121"
|
||||
|
||||
|
||||
def _reinstall_torch_cuda_and_restart() -> None:
|
||||
"""Reinstall torch with CUDA support showing live progress, then restart."""
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from remove_ai_watermarks.noai.progress import run_with_progress
|
||||
|
||||
index_url = _detect_cuda_index_url()
|
||||
progress_state: dict[str, str] = {"message": "NVIDIA GPU detected — installing CUDA-enabled PyTorch..."}
|
||||
|
||||
pct_re = re.compile(r"(\d+)%")
|
||||
pkg_re = re.compile(r"(?:Collecting|Downloading|Installing)\s+(\S+)")
|
||||
|
||||
def _run_pip() -> bool:
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
"install",
|
||||
"--force-reinstall",
|
||||
"torch",
|
||||
"--index-url",
|
||||
index_url,
|
||||
]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
for line in iter(proc.stdout.readline, ""): # type: ignore[union-attr]
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
pkg_m = pkg_re.search(stripped)
|
||||
pct_m = pct_re.search(stripped)
|
||||
if pct_m and pkg_m:
|
||||
progress_state["message"] = f"Downloading {pkg_m.group(1)} ({pct_m.group(1)}%)"
|
||||
elif pct_m:
|
||||
progress_state["message"] = f"Downloading CUDA packages ({pct_m.group(1)}%)"
|
||||
elif pkg_m:
|
||||
action = "Installing" if stripped.startswith("Installing") else "Downloading"
|
||||
progress_state["message"] = f"{action} {pkg_m.group(1)}"
|
||||
elif "Successfully installed" in stripped:
|
||||
progress_state["message"] = "CUDA-enabled PyTorch installed successfully"
|
||||
proc.wait()
|
||||
return proc.returncode == 0
|
||||
|
||||
try:
|
||||
success = run_with_progress(_run_pip, progress_state)
|
||||
except Exception:
|
||||
success = False
|
||||
|
||||
if not success:
|
||||
print(
|
||||
f"\n Failed to install CUDA-enabled PyTorch.\n"
|
||||
f" Install manually:\n"
|
||||
f" pip install torch --index-url {index_url}\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return
|
||||
|
||||
os.environ[_CUDA_FIX_ENV_KEY] = "1"
|
||||
restart_code = f"import sys; sys.argv = {sys.argv!r}; from remove_ai_watermarks.cli import main; sys.exit(main())"
|
||||
os.execl(sys.executable, sys.executable, "-c", restart_code)
|
||||
|
||||
|
||||
def _ensure_watermark_deps() -> None:
|
||||
"""Auto-install and re-import missing watermark removal dependencies."""
|
||||
global _HAS_TORCH, _HAS_DIFFUSERS, torch, AutoImg2ImgPipeline
|
||||
missing_pkgs: list[str] = []
|
||||
if not _HAS_TORCH:
|
||||
missing_pkgs.append("torch")
|
||||
if not _HAS_DIFFUSERS:
|
||||
missing_pkgs.extend(["diffusers", "transformers", "accelerate"])
|
||||
logger.info("Auto-installing missing dependencies: %s", missing_pkgs)
|
||||
if not _auto_install(missing_pkgs):
|
||||
raise ImportError(
|
||||
f"Failed to auto-install missing dependencies: {', '.join(missing_pkgs)}. "
|
||||
"Try manually: pip install --force-reinstall noai-watermark"
|
||||
)
|
||||
import torch as _torch
|
||||
|
||||
torch = _torch
|
||||
_HAS_TORCH = True
|
||||
from diffusers import AutoPipelineForImage2Image # noqa: N813
|
||||
|
||||
AutoImg2ImgPipeline = AutoPipelineForImage2Image # noqa: N806
|
||||
_HAS_DIFFUSERS = True
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
"""Get the best available device for inference."""
|
||||
if not _HAS_TORCH:
|
||||
return "cpu"
|
||||
if torch.cuda.is_available(): # type: ignore
|
||||
try:
|
||||
t = torch.tensor([1.0], device="cuda")
|
||||
_ = t + t
|
||||
del t
|
||||
return "cuda"
|
||||
except (AssertionError, RuntimeError):
|
||||
pass
|
||||
if _has_nvidia_gpu() and not os.environ.get(_CUDA_FIX_ENV_KEY):
|
||||
_reinstall_torch_cuda_and_restart()
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
# Keep legacy name available for backwards compatibility
|
||||
_detect_model_profile_from_id = detect_model_profile
|
||||
|
||||
|
||||
class WatermarkRemover:
|
||||
"""Remove watermarks from images using diffusion model regeneration.
|
||||
|
||||
Attributes:
|
||||
model_id: HuggingFace model ID for the diffusion model.
|
||||
device: Device to run inference on (cuda, mps, or cpu).
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_ID = DEFAULT_MODEL_ID
|
||||
CTRLREGEN_MODEL_ID = CTRLREGEN_MODEL_ID
|
||||
LOW_STRENGTH = LOW_STRENGTH
|
||||
MEDIUM_STRENGTH = MEDIUM_STRENGTH
|
||||
HIGH_STRENGTH = HIGH_STRENGTH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str | None = None,
|
||||
device: str | None = None,
|
||||
torch_dtype: Any = None,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
hf_token: str | None = None,
|
||||
):
|
||||
self.model_id = model_id or self.DEFAULT_MODEL_ID
|
||||
self.model_profile = detect_model_profile(self.model_id)
|
||||
|
||||
if not is_watermark_removal_available():
|
||||
_ensure_watermark_deps()
|
||||
self.device = (device or get_device()).lower()
|
||||
if self.device == "auto":
|
||||
self.device = get_device()
|
||||
if self.device not in {"cpu", "mps", "cuda"}:
|
||||
raise ValueError(f"Unsupported device '{device}'. Use one of: auto, cpu, mps, cuda.")
|
||||
if torch_dtype is None:
|
||||
if self.device == "cpu" or self.device == "mps":
|
||||
self.torch_dtype = torch.float32 # type: ignore
|
||||
else:
|
||||
self.torch_dtype = torch.float16 # type: ignore
|
||||
else:
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
self._pipeline: AutoImg2ImgPipeline | None = None
|
||||
self._ctrlregen_engine: Any = None
|
||||
self._progress_callback = progress_callback
|
||||
self.hf_token: str | None = hf_token or os.environ.get("HF_TOKEN")
|
||||
|
||||
def _set_progress(self, message: str) -> None:
|
||||
"""Send a progress update through callback when available."""
|
||||
if self._progress_callback is None:
|
||||
return
|
||||
try:
|
||||
self._progress_callback(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Preload ──────────────────────────────────────────────────────
|
||||
|
||||
def preload(self) -> None:
|
||||
"""Eagerly load the pipeline so download progress bars are visible."""
|
||||
if self.model_profile == "ctrlregen":
|
||||
self._run_ctrlregen_preload()
|
||||
else:
|
||||
self._load_pipeline()
|
||||
|
||||
def _run_ctrlregen_preload(self) -> None:
|
||||
"""Ensure the CtrlRegen engine and all its models are loaded."""
|
||||
from remove_ai_watermarks.noai.ctrlregen import is_ctrlregen_available
|
||||
|
||||
if not is_ctrlregen_available():
|
||||
missing_pkgs = ["controlnet-aux", "color-matcher", "safetensors"]
|
||||
logger.info("Auto-installing missing CtrlRegen dependencies: %s", missing_pkgs)
|
||||
if not _auto_install(missing_pkgs):
|
||||
raise ImportError(
|
||||
f"Failed to auto-install missing dependencies: {', '.join(missing_pkgs)}. "
|
||||
"Try manually: pip install --force-reinstall noai-watermark"
|
||||
)
|
||||
if self._ctrlregen_engine is None:
|
||||
self._ctrlregen_engine = self._make_ctrlregen_engine()
|
||||
self._ctrlregen_engine.load()
|
||||
|
||||
def _make_ctrlregen_engine(self) -> Any:
|
||||
"""Create a new CtrlRegenEngine with current settings."""
|
||||
from remove_ai_watermarks.noai.ctrlregen import CtrlRegenEngine
|
||||
|
||||
base_model = self.model_id if self.model_id != self.CTRLREGEN_MODEL_ID else None
|
||||
return CtrlRegenEngine(
|
||||
base_model_id=base_model,
|
||||
device=self.device,
|
||||
torch_dtype=self.torch_dtype,
|
||||
hf_token=self.hf_token,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
|
||||
# ── Pipeline loading ─────────────────────────────────────────────
|
||||
|
||||
def _load_pipeline(self) -> AutoImg2ImgPipeline:
|
||||
"""Load the diffusion pipeline lazily."""
|
||||
if self._pipeline is None:
|
||||
logger.info("Loading model %s on %s...", self.model_id, self.device)
|
||||
self._set_progress(f"Loading model weights: {self.model_id}")
|
||||
|
||||
load_kwargs: dict[str, Any] = {
|
||||
"torch_dtype": self.torch_dtype,
|
||||
"safety_checker": None,
|
||||
"requires_safety_checker": False,
|
||||
}
|
||||
if self.hf_token:
|
||||
load_kwargs["token"] = self.hf_token
|
||||
|
||||
self._pipeline = AutoImg2ImgPipeline.from_pretrained( # type: ignore
|
||||
self.model_id,
|
||||
**load_kwargs,
|
||||
)
|
||||
|
||||
self._set_progress(f"Moving model to device: {self.device}")
|
||||
try:
|
||||
self._pipeline = self._pipeline.to(self.device) # type: ignore
|
||||
except (RuntimeError, AssertionError) as exc:
|
||||
if self.device == "cuda" and not os.environ.get(_CUDA_FIX_ENV_KEY):
|
||||
self._set_progress("CUDA failed. Reinstalling torch with CUDA support...")
|
||||
_reinstall_torch_cuda_and_restart()
|
||||
raise RuntimeError(
|
||||
f"Failed to move model to {self.device} ({exc}). "
|
||||
"Install CUDA-enabled PyTorch manually:\n"
|
||||
f" pip install torch --index-url {_detect_cuda_index_url()}"
|
||||
) from exc
|
||||
|
||||
if hasattr(self._pipeline, "enable_xformers_memory_efficient_attention"):
|
||||
try:
|
||||
self._set_progress("Enabling memory optimizations...")
|
||||
self._pipeline.enable_xformers_memory_efficient_attention() # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Mac Float32 memory slicing
|
||||
if self.device == "mps" and hasattr(self._pipeline, "enable_attention_slicing"):
|
||||
try:
|
||||
self._pipeline.enable_attention_slicing("max")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
self._set_progress("Model initialized. Preparing input image...")
|
||||
|
||||
return self._pipeline # type: ignore
|
||||
|
||||
# ── Core removal ─────────────────────────────────────────────────
|
||||
|
||||
def remove_watermark(
|
||||
self,
|
||||
image_path: Path,
|
||||
output_path: Path | None = None,
|
||||
strength: float | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float | None = None,
|
||||
seed: int | None = None,
|
||||
) -> Path:
|
||||
"""Remove watermark from an image using regeneration attack.
|
||||
|
||||
Args:
|
||||
image_path: Path to the watermarked image.
|
||||
output_path: Path for the cleaned image. If None, modifies in place.
|
||||
strength: Denoising strength (0.0-1.0).
|
||||
num_inference_steps: Number of denoising steps.
|
||||
guidance_scale: Classifier-free guidance scale.
|
||||
seed: Random seed for reproducibility.
|
||||
|
||||
Returns:
|
||||
Path to the cleaned image.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If input image doesn't exist.
|
||||
ValueError: If strength is not in valid range.
|
||||
"""
|
||||
if not image_path.exists():
|
||||
raise FileNotFoundError(f"Image not found: {image_path}")
|
||||
|
||||
if output_path is None:
|
||||
output_path = image_path
|
||||
|
||||
strength = strength or self.LOW_STRENGTH
|
||||
|
||||
if not 0.0 <= strength <= 1.0:
|
||||
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}")
|
||||
|
||||
if guidance_scale is None:
|
||||
guidance_scale = 2.0 if self.model_profile == "ctrlregen" else 7.5
|
||||
|
||||
self._set_progress("Loading and preprocessing input image...")
|
||||
init_image = Image.open(image_path).convert("RGB")
|
||||
w, h = init_image.size
|
||||
self._set_progress(f"Image loaded: {w}x{h}px | Model: {self.model_id}")
|
||||
|
||||
generator = None
|
||||
if seed is not None and _HAS_TORCH:
|
||||
self._set_progress(f"Setting reproducible seed: {seed}")
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed) # type: ignore
|
||||
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
self._set_progress(
|
||||
f"Config: strength={strength}, steps={num_inference_steps} "
|
||||
f"(~{effective_steps} effective), guidance={guidance_scale}, device={self.device}"
|
||||
)
|
||||
|
||||
_total_start = time.monotonic()
|
||||
|
||||
if self.model_profile == "ctrlregen":
|
||||
cleaned_image = self._run_ctrlregen(
|
||||
init_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
cleaned_image = self._run_img2img(
|
||||
init_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
generator,
|
||||
)
|
||||
|
||||
self._set_progress(f"Regeneration complete · Output: {w}x{h}px {cleaned_image.mode}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fmt = output_path.suffix.lower()
|
||||
if fmt in (".jpg", ".jpeg"):
|
||||
self._set_progress(f"Encoding as JPEG → {output_path.name}...")
|
||||
else:
|
||||
self._set_progress(f"Encoding as PNG → {output_path.name}...")
|
||||
cleaned_image.save(output_path)
|
||||
|
||||
if output_path.exists():
|
||||
self._set_progress("Stripping AI metadata from output...")
|
||||
try:
|
||||
from remove_ai_watermarks.noai.cleaner import remove_ai_metadata
|
||||
|
||||
remove_ai_metadata(output_path, output_path, keep_standard=True)
|
||||
except Exception:
|
||||
logger.debug("AI metadata stripping skipped", exc_info=True)
|
||||
|
||||
total_time = time.monotonic() - _total_start
|
||||
|
||||
size_str = ""
|
||||
try:
|
||||
file_size = output_path.stat().st_size
|
||||
if file_size < 1024 * 1024:
|
||||
size_str = f" ({file_size / 1024:.0f}KB)"
|
||||
else:
|
||||
size_str = f" ({file_size / (1024 * 1024):.1f}MB)"
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
logger.info("Cleaned image saved to %s", output_path)
|
||||
self._set_progress(f"✓ Saved {output_path.name}{size_str} · {w}x{h}px · {total_time:.0f}s total")
|
||||
|
||||
return output_path
|
||||
|
||||
# ── Img2img runner ───────────────────────────────────────────────
|
||||
|
||||
def _run_img2img(
|
||||
self,
|
||||
init_image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
generator: Any,
|
||||
) -> Image.Image:
|
||||
"""Execute the img2img pipeline with progress and MPS fallback."""
|
||||
from remove_ai_watermarks.noai.img2img_runner import run_img2img_with_mps_fallback
|
||||
|
||||
result_image, final_device = run_img2img_with_mps_fallback(
|
||||
load_pipeline=self._load_pipeline,
|
||||
image=init_image,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
set_progress=self._set_progress,
|
||||
reload_on_cpu=self._reload_pipeline_on_cpu,
|
||||
)
|
||||
|
||||
if final_device != self.device:
|
||||
self.device = final_device
|
||||
self.torch_dtype = torch.float32 # type: ignore[assignment]
|
||||
|
||||
return result_image
|
||||
|
||||
def _reload_pipeline_on_cpu(self) -> Any:
|
||||
"""Reload pipeline on CPU after MPS failure."""
|
||||
self.device = "cpu"
|
||||
self.torch_dtype = torch.float32 # type: ignore[assignment]
|
||||
self._pipeline = None
|
||||
return self._load_pipeline()
|
||||
|
||||
# ── CtrlRegen runner ─────────────────────────────────────────────
|
||||
|
||||
def _run_ctrlregen(
|
||||
self,
|
||||
init_image: Image.Image,
|
||||
strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
generator: Any,
|
||||
) -> Image.Image:
|
||||
"""Run CtrlRegen pipeline with MPS fallback."""
|
||||
from remove_ai_watermarks.noai.ctrlregen import is_ctrlregen_available
|
||||
from remove_ai_watermarks.noai.progress import is_mps_error
|
||||
|
||||
if not is_ctrlregen_available():
|
||||
missing_pkgs = ["controlnet-aux", "color-matcher", "safetensors"]
|
||||
logger.info("Auto-installing missing CtrlRegen dependencies: %s", missing_pkgs)
|
||||
if not _auto_install(missing_pkgs):
|
||||
raise ImportError(
|
||||
f"Failed to auto-install missing dependencies: {', '.join(missing_pkgs)}. "
|
||||
"Try manually: pip install --force-reinstall noai-watermark"
|
||||
)
|
||||
|
||||
if self._ctrlregen_engine is None:
|
||||
self._ctrlregen_engine = self._make_ctrlregen_engine()
|
||||
|
||||
seed = None
|
||||
if generator is not None and hasattr(generator, "initial_seed"):
|
||||
seed = generator.initial_seed()
|
||||
|
||||
try:
|
||||
return self._ctrlregen_engine.run(
|
||||
image=init_image,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
if self.device == "mps" and is_mps_error(error):
|
||||
logger.warning("MPS out of memory during CtrlRegen. Falling back to CPU.")
|
||||
self._set_progress("MPS out of memory! Retrying CtrlRegen on CPU...")
|
||||
try:
|
||||
if _HAS_TORCH and hasattr(torch, "mps"):
|
||||
torch.mps.empty_cache() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.device = "cpu"
|
||||
self.torch_dtype = torch.float32 # type: ignore[assignment]
|
||||
self._ctrlregen_engine = self._make_ctrlregen_engine()
|
||||
|
||||
return self._ctrlregen_engine.run(
|
||||
image=init_image,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
)
|
||||
raise
|
||||
|
||||
# ── Batch ────────────────────────────────────────────────────────
|
||||
|
||||
def remove_watermark_batch(
|
||||
self,
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
strength: float | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
extensions: tuple[str, ...] = (".png", ".jpg", ".jpeg", ".webp"),
|
||||
) -> list[Path]:
|
||||
"""Remove watermarks from all images in a directory."""
|
||||
if not input_dir.exists():
|
||||
raise FileNotFoundError(f"Input directory not found: {input_dir}")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
cleaned_paths: list[Path] = []
|
||||
|
||||
for ext in extensions:
|
||||
for image_path in input_dir.glob(f"*{ext}"):
|
||||
output_path = output_dir / image_path.name
|
||||
try:
|
||||
result_path = self.remove_watermark(
|
||||
image_path=image_path,
|
||||
output_path=output_path,
|
||||
strength=strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
)
|
||||
cleaned_paths.append(result_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to process %s: %s", image_path, e)
|
||||
|
||||
return cleaned_paths
|
||||
|
||||
|
||||
# ── Convenience function ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def remove_watermark(
|
||||
image_path: Path,
|
||||
output_path: Path | None = None,
|
||||
strength: float = 0.04,
|
||||
model_id: str | None = None,
|
||||
device: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
) -> Path:
|
||||
"""Convenience function to remove watermark from an image."""
|
||||
remover = WatermarkRemover(model_id=model_id, device=device, hf_token=hf_token)
|
||||
return remover.remove_watermark(
|
||||
image_path=image_path,
|
||||
output_path=output_path,
|
||||
strength=strength,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for remove-ai-watermarks."""
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Shared fixtures for remove-ai-watermarks test suite."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_image_path(tmp_path: Path) -> Path:
|
||||
"""Create a minimal 200×200 test PNG image and return its path."""
|
||||
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
|
||||
path = tmp_path / "test_image.png"
|
||||
cv2.imwrite(str(path), img)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_large_image_path(tmp_path: Path) -> Path:
|
||||
"""Create a 1200×1200 test PNG image (triggers large watermark branch)."""
|
||||
img = np.random.randint(0, 255, (1200, 1200, 3), dtype=np.uint8)
|
||||
path = tmp_path / "test_large.png"
|
||||
cv2.imwrite(str(path), img)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_jpeg_path(tmp_path: Path) -> Path:
|
||||
"""Create a minimal JPEG test image."""
|
||||
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
path = tmp_path / "test_image.jpg"
|
||||
cv2.imwrite(str(path), img)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_png_with_ai_metadata(tmp_path: Path) -> Path:
|
||||
"""Create a PNG with AI-related metadata keys."""
|
||||
img = Image.new("RGB", (64, 64), color=(128, 128, 128))
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("parameters", "Steps: 20, Sampler: Euler, CFG scale: 7")
|
||||
pnginfo.add_text("prompt", "a beautiful landscape")
|
||||
pnginfo.add_text("Author", "Test Author")
|
||||
path = tmp_path / "ai_metadata.png"
|
||||
img.save(path, pnginfo=pnginfo)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_clean_png(tmp_path: Path) -> Path:
|
||||
"""Create a PNG with no AI metadata."""
|
||||
img = Image.new("RGB", (64, 64), color=(200, 100, 50))
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("Author", "Human Artist")
|
||||
pnginfo.add_text("Title", "Test Artwork")
|
||||
path = tmp_path / "clean.png"
|
||||
img.save(path, pnginfo=pnginfo)
|
||||
return path
|
||||
@@ -0,0 +1,322 @@
|
||||
"""Tests for the CLI entry point."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from remove_ai_watermarks.cli import main
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_png(tmp_path: Path) -> Path:
|
||||
"""Create a sample PNG for CLI testing."""
|
||||
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
|
||||
path = tmp_path / "input.png"
|
||||
cv2.imwrite(str(path), img)
|
||||
return path
|
||||
|
||||
|
||||
def _make_batch_dir(tmp_path: Path, count: int = 3) -> Path:
|
||||
"""Create a directory with test images for batch testing."""
|
||||
input_dir = tmp_path / "input"
|
||||
input_dir.mkdir()
|
||||
for i in range(count):
|
||||
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
cv2.imwrite(str(input_dir / f"img_{i}.png"), img)
|
||||
return input_dir
|
||||
|
||||
|
||||
def _make_batch_dir_with_metadata(tmp_path: Path, count: int = 3) -> Path:
|
||||
"""Create a directory with PNG images containing AI metadata."""
|
||||
input_dir = tmp_path / "input"
|
||||
input_dir.mkdir()
|
||||
for i in range(count):
|
||||
img = Image.new("RGB", (64, 64), color=(100 + i, 150, 200))
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("parameters", f"Steps: 20, Sampler: Euler, img_{i}")
|
||||
pnginfo.add_text("prompt", "a test landscape")
|
||||
img.save(input_dir / f"img_{i}.png", pnginfo=pnginfo)
|
||||
return input_dir
|
||||
|
||||
|
||||
def _mock_invisible_engine():
|
||||
"""Create a mock InvisibleEngine that writes a copy of the input image."""
|
||||
|
||||
def _mock_remove_watermark(image_path, output_path=None, **kwargs):
|
||||
out = output_path or image_path.with_stem(image_path.stem + "_clean")
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
img = Image.open(image_path)
|
||||
img.save(out)
|
||||
return out
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.remove_watermark.side_effect = _mock_remove_watermark
|
||||
mock_cls = MagicMock(return_value=mock_engine)
|
||||
return mock_cls, mock_engine
|
||||
|
||||
|
||||
class TestMainGroup:
|
||||
"""Tests for the top-level CLI group."""
|
||||
|
||||
def test_help(self, runner):
|
||||
result = runner.invoke(main, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Remove visible and invisible" in result.output
|
||||
|
||||
def test_version(self, runner):
|
||||
result = runner.invoke(main, ["--version"])
|
||||
assert result.exit_code == 0
|
||||
assert "0.1.0" in result.output
|
||||
|
||||
def test_no_command_shows_banner(self, runner):
|
||||
result = runner.invoke(main, [])
|
||||
assert result.exit_code == 0
|
||||
assert "Remove-AI-Watermarks" in result.output
|
||||
|
||||
|
||||
class TestVisibleCommand:
|
||||
"""Tests for the 'visible' subcommand."""
|
||||
|
||||
def test_visible_help(self, runner):
|
||||
result = runner.invoke(main, ["visible", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Gemini watermark" in result.output
|
||||
|
||||
def test_visible_basic(self, runner, sample_png, tmp_path):
|
||||
output = tmp_path / "clean.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["visible", str(sample_png), "-o", str(output), "--no-detect"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert output.exists()
|
||||
assert "Saved" in result.output
|
||||
|
||||
def test_visible_default_output_name(self, runner, sample_png):
|
||||
result = runner.invoke(main, ["visible", str(sample_png), "--no-detect"])
|
||||
assert result.exit_code == 0
|
||||
expected = sample_png.with_stem(sample_png.stem + "_clean")
|
||||
assert expected.exists()
|
||||
|
||||
def test_visible_no_inpaint(self, runner, sample_png, tmp_path):
|
||||
output = tmp_path / "clean.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
[
|
||||
"visible",
|
||||
str(sample_png),
|
||||
"-o",
|
||||
str(output),
|
||||
"--no-inpaint",
|
||||
"--no-detect",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert output.exists()
|
||||
|
||||
def test_visible_no_detect(self, runner, sample_png, tmp_path):
|
||||
output = tmp_path / "clean.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["visible", str(sample_png), "-o", str(output), "--no-detect"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_visible_nonexistent_file(self, runner):
|
||||
result = runner.invoke(main, ["visible", "/nonexistent/file.png"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
class TestInvisibleCommand:
|
||||
"""Tests for the 'invisible' subcommand."""
|
||||
|
||||
def test_invisible_help(self, runner):
|
||||
result = runner.invoke(main, ["invisible", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "invisible" in result.output.lower()
|
||||
|
||||
def test_invisible_basic(self, runner, sample_png, tmp_path):
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
output = tmp_path / "clean.png"
|
||||
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
|
||||
):
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["invisible", str(sample_png), "-o", str(output)],
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert output.exists()
|
||||
mock_engine.remove_watermark.assert_called_once()
|
||||
|
||||
def test_invisible_default_output(self, runner, sample_png):
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
|
||||
):
|
||||
result = runner.invoke(main, ["invisible", str(sample_png)])
|
||||
assert result.exit_code == 0, result.output
|
||||
expected = sample_png.with_stem(sample_png.stem + "_clean")
|
||||
assert expected.exists()
|
||||
|
||||
def test_invisible_nonexistent_file(self, runner):
|
||||
result = runner.invoke(main, ["invisible", "/nonexistent/file.png"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
class TestAllCommand:
|
||||
"""Tests for the 'all' subcommand (full pipeline)."""
|
||||
|
||||
def test_all_help(self, runner):
|
||||
result = runner.invoke(main, ["all", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "visible" in result.output.lower()
|
||||
|
||||
def test_all_basic(self, runner, sample_png, tmp_path):
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
output = tmp_path / "clean.png"
|
||||
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
|
||||
):
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["all", str(sample_png), "-o", str(output)],
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert output.exists()
|
||||
|
||||
def test_all_nonexistent_file(self, runner):
|
||||
result = runner.invoke(main, ["all", "/nonexistent/file.png"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
|
||||
class TestMetadataCommand:
|
||||
"""Tests for the 'metadata' subcommand."""
|
||||
|
||||
def test_metadata_help(self, runner):
|
||||
result = runner.invoke(main, ["metadata", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_metadata_check_clean(self, runner, tmp_clean_png):
|
||||
result = runner.invoke(main, ["metadata", str(tmp_clean_png), "--check"])
|
||||
assert result.exit_code == 0
|
||||
assert "No AI metadata" in result.output
|
||||
|
||||
def test_metadata_check_ai(self, runner, tmp_png_with_ai_metadata):
|
||||
result = runner.invoke(main, ["metadata", str(tmp_png_with_ai_metadata), "--check"])
|
||||
assert result.exit_code == 0
|
||||
assert "AI metadata detected" in result.output
|
||||
|
||||
def test_metadata_remove(self, runner, tmp_png_with_ai_metadata, tmp_path):
|
||||
output = tmp_path / "stripped.png"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
[
|
||||
"metadata",
|
||||
str(tmp_png_with_ai_metadata),
|
||||
"--remove",
|
||||
"-o",
|
||||
str(output),
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "stripped" in result.output
|
||||
|
||||
|
||||
class TestBatchCommand:
|
||||
"""Tests for the 'batch' subcommand."""
|
||||
|
||||
def test_batch_help(self, runner):
|
||||
result = runner.invoke(main, ["batch", "--help"])
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_batch_empty_dir(self, runner, tmp_path):
|
||||
empty_dir = tmp_path / "empty"
|
||||
empty_dir.mkdir()
|
||||
result = runner.invoke(main, ["batch", str(empty_dir)])
|
||||
assert result.exit_code == 0
|
||||
assert "No supported images" in result.output
|
||||
|
||||
def test_batch_visible_mode(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
output_dir = tmp_path / "output"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["batch", str(input_dir), "-o", str(output_dir), "--mode", "visible"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "3 processed" in result.output
|
||||
assert output_dir.exists()
|
||||
assert len(list(output_dir.glob("*.png"))) == 3
|
||||
|
||||
def test_batch_metadata_mode(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir_with_metadata(tmp_path)
|
||||
output_dir = tmp_path / "output"
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["batch", str(input_dir), "-o", str(output_dir), "--mode", "metadata"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "3 processed" in result.output
|
||||
assert output_dir.exists()
|
||||
assert len(list(output_dir.glob("*.png"))) == 3
|
||||
# Verify AI metadata was stripped
|
||||
for out_img in output_dir.glob("*.png"):
|
||||
with Image.open(out_img) as img:
|
||||
assert "parameters" not in img.info
|
||||
|
||||
def test_batch_invisible_mode(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
output_dir = tmp_path / "output"
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
|
||||
), patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.is_available", return_value=True
|
||||
):
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["batch", str(input_dir), "-o", str(output_dir), "--mode", "invisible"],
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "3 processed" in result.output
|
||||
|
||||
def test_batch_all_mode(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
output_dir = tmp_path / "output"
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
with patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls
|
||||
), patch("remove_ai_watermarks.cli.invisible_available", return_value=True, create=True), patch(
|
||||
"remove_ai_watermarks.invisible_engine.is_available", return_value=True
|
||||
):
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["batch", str(input_dir), "-o", str(output_dir), "--mode", "all"],
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "3 processed" in result.output
|
||||
|
||||
def test_batch_default_output_dir(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
result = runner.invoke(
|
||||
main,
|
||||
["batch", str(input_dir), "--mode", "visible"],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
expected_dir = tmp_path / "input_clean"
|
||||
assert expected_dir.exists()
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
|
||||
from remove_ai_watermarks.face_protector import FaceProtector
|
||||
|
||||
|
||||
def test_face_protector_initialization():
|
||||
# Will fallback to Haar cascade if ultralytics is missing
|
||||
fp = FaceProtector(use_yolo=False)
|
||||
assert fp.use_yolo is False
|
||||
assert fp.haar_cascade is not None
|
||||
|
||||
|
||||
def test_face_protector_lifecycle():
|
||||
fp = FaceProtector(use_yolo=False)
|
||||
|
||||
# Create dummy black image
|
||||
img = np.zeros((200, 200, 3), dtype=np.uint8)
|
||||
|
||||
# Since it's a black image, haar cascade should find 0 faces
|
||||
faces = fp.extract_faces(img)
|
||||
assert isinstance(faces, list)
|
||||
assert len(faces) == 0
|
||||
|
||||
# Restoring 0 faces should result in strictly equal image
|
||||
restored = fp.restore_faces(img, faces)
|
||||
assert np.array_equal(img, restored)
|
||||
|
||||
|
||||
def test_face_protector_restore_bypass_on_size_mismatch():
|
||||
fp = FaceProtector(use_yolo=False)
|
||||
img_small = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# Manually mock a face that is OUT OF BOUNDS for img_small
|
||||
mock_bbox = (80, 80, 130, 130)
|
||||
mock_crop = np.ones((50, 50, 3), dtype=np.uint8) * 255
|
||||
mock_faces = [(mock_bbox, mock_crop)]
|
||||
|
||||
# Attempt to restore onto an image too small for this box
|
||||
restored = fp.restore_faces(img_small, mock_faces)
|
||||
|
||||
# Should safely skip restoring and not crash
|
||||
assert np.array_equal(restored, img_small)
|
||||
|
||||
|
||||
def test_face_protector_restore_blending():
|
||||
fp = FaceProtector(use_yolo=False)
|
||||
# Background is black
|
||||
img_target = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
# Face crop is white
|
||||
mock_bbox = (25, 25, 75, 75)
|
||||
mock_crop = np.ones((50, 50, 3), dtype=np.uint8) * 255
|
||||
mock_faces = [(mock_bbox, mock_crop)]
|
||||
|
||||
restored = fp.restore_faces(img_target, mock_faces)
|
||||
|
||||
# The center of the face should be perfectly white (255)
|
||||
assert restored[50, 50, 0] >= 254
|
||||
# The corner of the target should remain perfectly black (0)
|
||||
assert restored[0, 0, 0] == 0
|
||||
# We should have a blending gradient between them due to the gaussian blur mask
|
||||
# For example, around (30, 30) or similar
|
||||
assert 0 <= restored[28, 28, 0] <= 255
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Tests for the Gemini visible-watermark engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from remove_ai_watermarks.gemini_engine import (
|
||||
DetectionResult,
|
||||
GeminiEngine,
|
||||
WatermarkPosition,
|
||||
WatermarkSize,
|
||||
_calculate_alpha_map,
|
||||
get_watermark_config,
|
||||
get_watermark_size,
|
||||
)
|
||||
|
||||
# ── WatermarkSize / config helpers ──────────────────────────────────
|
||||
|
||||
|
||||
class TestWatermarkConfig:
|
||||
"""Tests for watermark size detection and position calculation."""
|
||||
|
||||
def test_small_image_gets_small_watermark(self):
|
||||
assert get_watermark_size(800, 600) == WatermarkSize.SMALL
|
||||
|
||||
def test_large_image_gets_large_watermark(self):
|
||||
assert get_watermark_size(1920, 1080) == WatermarkSize.LARGE
|
||||
|
||||
def test_boundary_image_stays_small(self):
|
||||
"""Exactly 1024×1024 should be SMALL (rule: > 1024 for LARGE)."""
|
||||
assert get_watermark_size(1024, 1024) == WatermarkSize.SMALL
|
||||
|
||||
def test_one_dimension_small(self):
|
||||
"""Only one dimension > 1024 → still SMALL."""
|
||||
assert get_watermark_size(2000, 500) == WatermarkSize.SMALL
|
||||
|
||||
def test_config_small_returns_correct_values(self):
|
||||
config = get_watermark_config(800, 600)
|
||||
assert config.margin_right == 32
|
||||
assert config.margin_bottom == 32
|
||||
assert config.logo_size == 48
|
||||
|
||||
def test_config_large_returns_correct_values(self):
|
||||
config = get_watermark_config(1920, 1080)
|
||||
assert config.margin_right == 64
|
||||
assert config.margin_bottom == 64
|
||||
assert config.logo_size == 96
|
||||
|
||||
def test_position_calculation(self):
|
||||
pos = WatermarkPosition(margin_right=32, margin_bottom=32, logo_size=48)
|
||||
x, y = pos.get_position(800, 600)
|
||||
assert x == 800 - 32 - 48 # 720
|
||||
assert y == 600 - 32 - 48 # 520
|
||||
|
||||
|
||||
# ── Alpha map ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAlphaMap:
|
||||
"""Tests for alpha map calculation."""
|
||||
|
||||
def test_pure_black_gives_zero_alpha(self):
|
||||
black = np.zeros((10, 10, 3), dtype=np.uint8)
|
||||
alpha = _calculate_alpha_map(black)
|
||||
assert alpha.shape == (10, 10)
|
||||
np.testing.assert_array_equal(alpha, 0.0)
|
||||
|
||||
def test_pure_white_gives_one_alpha(self):
|
||||
white = np.full((10, 10, 3), 255, dtype=np.uint8)
|
||||
alpha = _calculate_alpha_map(white)
|
||||
np.testing.assert_allclose(alpha, 1.0)
|
||||
|
||||
def test_grayscale_input(self):
|
||||
gray = np.full((10, 10), 128, dtype=np.uint8)
|
||||
alpha = _calculate_alpha_map(gray)
|
||||
np.testing.assert_allclose(alpha, 128 / 255.0)
|
||||
|
||||
def test_max_channel_used(self):
|
||||
"""Alpha should use max(R, G, B)."""
|
||||
img = np.zeros((1, 1, 3), dtype=np.uint8)
|
||||
img[0, 0] = [50, 200, 100] # BGR
|
||||
alpha = _calculate_alpha_map(img)
|
||||
assert pytest.approx(alpha[0, 0], rel=1e-3) == 200 / 255.0
|
||||
|
||||
|
||||
# ── GeminiEngine ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGeminiEngine:
|
||||
"""Tests for the GeminiEngine class."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_engine(self):
|
||||
self.engine = GeminiEngine()
|
||||
|
||||
def test_engine_loads_alpha_maps(self):
|
||||
small = self.engine.get_alpha_map(WatermarkSize.SMALL)
|
||||
large = self.engine.get_alpha_map(WatermarkSize.LARGE)
|
||||
assert small.shape == (48, 48)
|
||||
assert large.shape == (96, 96)
|
||||
|
||||
def test_remove_watermark_returns_same_shape(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.remove_watermark(image)
|
||||
assert result.shape == image.shape
|
||||
assert result.dtype == np.uint8
|
||||
|
||||
def test_remove_watermark_does_not_modify_input(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
original = image.copy()
|
||||
self.engine.remove_watermark(image)
|
||||
np.testing.assert_array_equal(image, original)
|
||||
|
||||
def test_remove_watermark_large_image(self, tmp_large_image_path):
|
||||
image = cv2.imread(str(tmp_large_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.remove_watermark(image)
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_remove_watermark_custom_region(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.remove_watermark_custom(image, (10, 10, 48, 48))
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_remove_watermark_custom_large_region(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.remove_watermark_custom(image, (10, 10, 96, 96))
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_remove_watermark_custom_arbitrary_region(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.remove_watermark_custom(image, (5, 5, 60, 60))
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_force_size(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.remove_watermark(image, force_size=WatermarkSize.LARGE)
|
||||
assert result.shape == image.shape
|
||||
|
||||
|
||||
# ── Detection ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDetection:
|
||||
"""Tests for watermark detection."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_engine(self):
|
||||
self.engine = GeminiEngine()
|
||||
|
||||
def test_detect_returns_result_object(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.detect_watermark(image)
|
||||
assert isinstance(result, DetectionResult)
|
||||
assert 0.0 <= result.confidence <= 1.0
|
||||
|
||||
def test_detect_empty_image_returns_no_detection(self):
|
||||
empty = np.zeros((0, 0, 3), dtype=np.uint8)
|
||||
result = self.engine.detect_watermark(empty)
|
||||
assert not result.detected
|
||||
assert result.confidence == 0.0
|
||||
|
||||
def test_detect_none_image_returns_no_detection(self):
|
||||
result = self.engine.detect_watermark(None)
|
||||
assert not result.detected
|
||||
|
||||
def test_detect_random_image_low_confidence(self, tmp_image_path):
|
||||
"""Random noise should not look like a watermark."""
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.detect_watermark(image)
|
||||
# Random image may or may not be detected; confidence should be meaningful
|
||||
assert isinstance(result.spatial_score, float)
|
||||
assert isinstance(result.gradient_score, float)
|
||||
|
||||
|
||||
# ── Inpainting ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestInpainting:
|
||||
"""Tests for residual inpainting."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_engine(self):
|
||||
self.engine = GeminiEngine()
|
||||
|
||||
def test_inpaint_ns(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), method="ns")
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_inpaint_telea(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), method="telea")
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_inpaint_gaussian(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), method="gaussian")
|
||||
assert result.shape == image.shape
|
||||
|
||||
def test_inpaint_zero_strength(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.inpaint_residual(image, (150, 150, 48, 48), strength=0.0)
|
||||
np.testing.assert_array_equal(result, image)
|
||||
|
||||
def test_inpaint_tiny_region_returns_unchanged(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
result = self.engine.inpaint_residual(image, (10, 10, 2, 2))
|
||||
np.testing.assert_array_equal(result, image)
|
||||
|
||||
def test_inpaint_does_not_modify_input(self, tmp_image_path):
|
||||
image = cv2.imread(str(tmp_image_path), cv2.IMREAD_COLOR)
|
||||
original = image.copy()
|
||||
self.engine.inpaint_residual(image, (150, 150, 48, 48))
|
||||
np.testing.assert_array_equal(image, original)
|
||||
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
|
||||
from remove_ai_watermarks.humanizer import apply_analog_humanizer
|
||||
|
||||
|
||||
def test_humanizer_does_not_modify_original_if_disabled():
|
||||
img = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
img[50, 50] = [100, 150, 200]
|
||||
org_img = img.copy()
|
||||
|
||||
# grain=0, shift=0 means disabled essentially. But wait, apply_analog_humanizer currently applies chromatic shift even if grain=0.
|
||||
result = apply_analog_humanizer(img, grain_intensity=0.0, chromatic_shift=0)
|
||||
assert np.array_equal(result, org_img)
|
||||
|
||||
|
||||
def test_chromatic_shift():
|
||||
# Only green channel is centered, red/blue should shift.
|
||||
img = np.zeros((5, 5, 3), dtype=np.uint8)
|
||||
img[2, 2] = [255, 255, 255] # B, G, R
|
||||
|
||||
# shift=1
|
||||
result = apply_analog_humanizer(img, grain_intensity=0.0, chromatic_shift=1)
|
||||
|
||||
# G (index 1) stays at [2,2]
|
||||
assert result[2, 2, 1] == 255
|
||||
# B (index 0) shifted right (+1 axis 1) -> [2, 3]
|
||||
assert result[2, 3, 0] == 255
|
||||
# R (index 2) shifted left (-1 axis 1) -> [2, 1]
|
||||
assert result[2, 1, 2] == 255
|
||||
|
||||
|
||||
def test_grain_intensity():
|
||||
# Gray image
|
||||
img = np.full((100, 100, 3), 128, dtype=np.uint8)
|
||||
|
||||
# Add strong noise
|
||||
result = apply_analog_humanizer(img, grain_intensity=10.0, chromatic_shift=0)
|
||||
|
||||
# Image should no longer be purely 128
|
||||
unique_vals = np.unique(result)
|
||||
assert len(unique_vals) > 5
|
||||
|
||||
# Mean should roughly be 128
|
||||
assert 126 < np.mean(result) < 130
|
||||
|
||||
|
||||
def test_invalid_shape():
|
||||
# Missing color channel
|
||||
img = np.zeros((100, 100), dtype=np.uint8)
|
||||
img[0, 0] = 50
|
||||
result = apply_analog_humanizer(img)
|
||||
assert np.array_equal(img, result)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Tests for the invisible watermark engine (unit tests, no GPU required)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from remove_ai_watermarks.invisible_engine import InvisibleEngine, is_available
|
||||
|
||||
|
||||
class TestIsAvailable:
|
||||
"""Tests for dependency checking."""
|
||||
|
||||
def test_returns_bool(self):
|
||||
result = is_available()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_available_when_torch_installed(self):
|
||||
"""torch + diffusers should be installed in dev env."""
|
||||
assert is_available() is True
|
||||
|
||||
|
||||
class TestInvisibleEngineInit:
|
||||
"""Tests for InvisibleEngine construction (no GPU required)."""
|
||||
|
||||
def test_default_model_id(self):
|
||||
assert InvisibleEngine.DEFAULT_MODEL_ID == "Lykon/dreamshaper-8"
|
||||
|
||||
def test_ctrlregen_model_id(self):
|
||||
assert InvisibleEngine.CTRLREGEN_MODEL_ID == "yepengliu/ctrlregen"
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Tests for AI metadata detection and removal."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
from remove_ai_watermarks.metadata import (
|
||||
_is_ai_key,
|
||||
get_ai_metadata,
|
||||
has_ai_metadata,
|
||||
remove_ai_metadata,
|
||||
)
|
||||
|
||||
# ── Key detection ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsAiKey:
|
||||
"""Tests for _is_ai_key helper."""
|
||||
|
||||
def test_exact_match_lowercase(self):
|
||||
assert _is_ai_key("parameters")
|
||||
|
||||
def test_exact_match_mixed_case(self):
|
||||
assert _is_ai_key("Parameters")
|
||||
|
||||
def test_keyword_substring(self):
|
||||
assert _is_ai_key("stable_diffusion_model_v2")
|
||||
|
||||
def test_c2pa_detected(self):
|
||||
assert _is_ai_key("c2pa_chunk")
|
||||
|
||||
def test_standard_key_not_flagged(self):
|
||||
assert not _is_ai_key("Author")
|
||||
|
||||
def test_innocuous_key_not_flagged(self):
|
||||
assert not _is_ai_key("Title")
|
||||
|
||||
def test_dpi_not_flagged(self):
|
||||
assert not _is_ai_key("dpi")
|
||||
|
||||
|
||||
# ── has_ai_metadata / get_ai_metadata ───────────────────────────────
|
||||
|
||||
|
||||
class TestHasAiMetadata:
|
||||
"""Tests for detecting AI metadata in images."""
|
||||
|
||||
def test_detects_ai_metadata(self, tmp_png_with_ai_metadata):
|
||||
assert has_ai_metadata(tmp_png_with_ai_metadata)
|
||||
|
||||
def test_clean_image_no_ai(self, tmp_clean_png):
|
||||
assert not has_ai_metadata(tmp_clean_png)
|
||||
|
||||
|
||||
class TestGetAiMetadata:
|
||||
"""Tests for extracting AI metadata."""
|
||||
|
||||
def test_extracts_parameters_key(self, tmp_png_with_ai_metadata):
|
||||
meta = get_ai_metadata(tmp_png_with_ai_metadata)
|
||||
assert "parameters" in meta
|
||||
assert "Euler" in meta["parameters"]
|
||||
|
||||
def test_extracts_prompt_key(self, tmp_png_with_ai_metadata):
|
||||
meta = get_ai_metadata(tmp_png_with_ai_metadata)
|
||||
assert "prompt" in meta
|
||||
|
||||
def test_does_not_extract_author(self, tmp_png_with_ai_metadata):
|
||||
meta = get_ai_metadata(tmp_png_with_ai_metadata)
|
||||
assert "Author" not in meta
|
||||
|
||||
def test_clean_image_empty_dict(self, tmp_clean_png):
|
||||
meta = get_ai_metadata(tmp_clean_png)
|
||||
assert meta == {}
|
||||
|
||||
|
||||
# ── remove_ai_metadata ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRemoveAiMetadata:
|
||||
"""Tests for stripping AI metadata."""
|
||||
|
||||
def test_removes_ai_keys(self, tmp_png_with_ai_metadata):
|
||||
output = tmp_png_with_ai_metadata.parent / "cleaned.png"
|
||||
remove_ai_metadata(tmp_png_with_ai_metadata, output)
|
||||
|
||||
with Image.open(output) as img:
|
||||
assert "parameters" not in img.info
|
||||
assert "prompt" not in img.info
|
||||
|
||||
def test_keeps_standard_metadata(self, tmp_png_with_ai_metadata):
|
||||
output = tmp_png_with_ai_metadata.parent / "cleaned.png"
|
||||
remove_ai_metadata(tmp_png_with_ai_metadata, output, keep_standard=True)
|
||||
|
||||
with Image.open(output) as img:
|
||||
assert "Author" in img.info
|
||||
assert img.info["Author"] == "Test Author"
|
||||
|
||||
def test_remove_all_metadata(self, tmp_png_with_ai_metadata):
|
||||
output = tmp_png_with_ai_metadata.parent / "cleaned.png"
|
||||
remove_ai_metadata(tmp_png_with_ai_metadata, output, keep_standard=False)
|
||||
with Image.open(output) as img:
|
||||
assert "Author" not in img.info
|
||||
assert "parameters" not in img.info
|
||||
|
||||
def test_overwrite_in_place(self, tmp_path):
|
||||
"""When output_path is None, should overwrite source."""
|
||||
img = Image.new("RGB", (32, 32))
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("parameters", "test data")
|
||||
path = tmp_path / "inplace.png"
|
||||
img.save(path, pnginfo=pnginfo)
|
||||
|
||||
result = remove_ai_metadata(path)
|
||||
assert result == path
|
||||
|
||||
with Image.open(path) as cleaned:
|
||||
assert "parameters" not in cleaned.info
|
||||
|
||||
def test_jpeg_output(self, tmp_path):
|
||||
"""Test metadata removal for JPEG format."""
|
||||
img = Image.new("RGB", (64, 64), color=(100, 150, 200))
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("parameters", "test")
|
||||
png_path = tmp_path / "source.png"
|
||||
img.save(png_path, pnginfo=pnginfo)
|
||||
|
||||
jpg_path = tmp_path / "output.jpg"
|
||||
result = remove_ai_metadata(png_path, jpg_path)
|
||||
assert result == jpg_path
|
||||
assert jpg_path.exists()
|
||||
|
||||
def test_creates_parent_directories(self, tmp_path):
|
||||
img = Image.new("RGB", (32, 32))
|
||||
pnginfo = PngInfo()
|
||||
pnginfo.add_text("prompt", "test")
|
||||
path = tmp_path / "source.png"
|
||||
img.save(path, pnginfo=pnginfo)
|
||||
|
||||
output = tmp_path / "sub" / "dir" / "cleaned.png"
|
||||
remove_ai_metadata(path, output)
|
||||
assert output.exists()
|
||||
|
||||
def test_returns_path(self, tmp_clean_png):
|
||||
output = tmp_clean_png.parent / "out.png"
|
||||
result = remove_ai_metadata(tmp_clean_png, output)
|
||||
assert isinstance(result, Path)
|
||||
assert result == output
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Tests for vendored noai submodules: constants, extractor, cleaner, c2pa."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from remove_ai_watermarks.noai.c2pa import (
|
||||
extract_c2pa_chunk,
|
||||
extract_c2pa_info,
|
||||
has_c2pa_metadata,
|
||||
)
|
||||
from remove_ai_watermarks.noai.cleaner import (
|
||||
has_ai_content,
|
||||
)
|
||||
from remove_ai_watermarks.noai.cleaner import (
|
||||
remove_ai_metadata as noai_remove_ai_metadata,
|
||||
)
|
||||
from remove_ai_watermarks.noai.constants import (
|
||||
AI_KEYWORDS,
|
||||
AI_METADATA_KEYS,
|
||||
C2PA_CHUNK_TYPE,
|
||||
PNG_SIGNATURE,
|
||||
SUPPORTED_FORMATS,
|
||||
)
|
||||
from remove_ai_watermarks.noai.extractor import (
|
||||
extract_ai_metadata,
|
||||
extract_metadata,
|
||||
get_ai_metadata_summary,
|
||||
has_ai_metadata,
|
||||
)
|
||||
|
||||
# ── Constants ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Verify constant integrity."""
|
||||
|
||||
def test_supported_formats_include_png(self):
|
||||
assert ".png" in SUPPORTED_FORMATS
|
||||
|
||||
def test_supported_formats_include_jpg(self):
|
||||
assert ".jpg" in SUPPORTED_FORMATS
|
||||
|
||||
def test_ai_metadata_keys_not_empty(self):
|
||||
assert len(AI_METADATA_KEYS) > 0
|
||||
|
||||
def test_ai_keywords_not_empty(self):
|
||||
assert len(AI_KEYWORDS) > 0
|
||||
|
||||
def test_png_signature_bytes(self):
|
||||
assert PNG_SIGNATURE == b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
def test_c2pa_chunk_type(self):
|
||||
assert C2PA_CHUNK_TYPE == b"caBX"
|
||||
|
||||
|
||||
# ── Extractor ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExtractor:
|
||||
"""Tests for noai.extractor functions."""
|
||||
|
||||
def test_extract_metadata_returns_dict(self, tmp_clean_png):
|
||||
meta = extract_metadata(tmp_clean_png)
|
||||
assert isinstance(meta, dict)
|
||||
|
||||
def test_extract_metadata_gets_standard_keys(self, tmp_clean_png):
|
||||
meta = extract_metadata(tmp_clean_png)
|
||||
assert "Author" in meta
|
||||
|
||||
def test_extract_ai_metadata_from_ai_image(self, tmp_png_with_ai_metadata):
|
||||
meta = extract_ai_metadata(tmp_png_with_ai_metadata)
|
||||
assert "parameters" in meta
|
||||
|
||||
def test_extract_ai_metadata_from_clean_image(self, tmp_clean_png):
|
||||
meta = extract_ai_metadata(tmp_clean_png)
|
||||
assert len(meta) == 0
|
||||
|
||||
def test_has_ai_metadata_detects(self, tmp_png_with_ai_metadata):
|
||||
assert has_ai_metadata(tmp_png_with_ai_metadata)
|
||||
|
||||
def test_has_ai_metadata_clean(self, tmp_clean_png):
|
||||
assert not has_ai_metadata(tmp_clean_png)
|
||||
|
||||
def test_summary_with_ai(self, tmp_png_with_ai_metadata):
|
||||
summary = get_ai_metadata_summary(tmp_png_with_ai_metadata)
|
||||
assert "AI Image Metadata" in summary
|
||||
|
||||
def test_summary_clean(self, tmp_clean_png):
|
||||
summary = get_ai_metadata_summary(tmp_clean_png)
|
||||
assert "No AI metadata" in summary
|
||||
|
||||
|
||||
# ── Cleaner ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCleaner:
|
||||
"""Tests for noai.cleaner functions."""
|
||||
|
||||
def test_remove_ai_metadata(self, tmp_png_with_ai_metadata, tmp_path):
|
||||
output = tmp_path / "cleaned.png"
|
||||
noai_remove_ai_metadata(tmp_png_with_ai_metadata, output)
|
||||
assert output.exists()
|
||||
# Verify AI metadata removed
|
||||
meta = extract_ai_metadata(output)
|
||||
assert "parameters" not in meta
|
||||
|
||||
def test_has_ai_content(self, tmp_png_with_ai_metadata):
|
||||
assert has_ai_content(tmp_png_with_ai_metadata)
|
||||
|
||||
|
||||
# ── C2PA ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestC2PA:
|
||||
"""Tests for C2PA detection on regular (non-C2PA) images."""
|
||||
|
||||
def test_no_c2pa_on_regular_png(self, tmp_clean_png):
|
||||
assert not has_c2pa_metadata(tmp_clean_png)
|
||||
|
||||
def test_no_c2pa_on_jpeg(self, tmp_jpeg_path):
|
||||
assert not has_c2pa_metadata(tmp_jpeg_path)
|
||||
|
||||
def test_extract_c2pa_none_on_regular(self, tmp_clean_png):
|
||||
assert extract_c2pa_chunk(tmp_clean_png) is None
|
||||
|
||||
def test_extract_c2pa_info_empty(self, tmp_clean_png):
|
||||
info = extract_c2pa_info(tmp_clean_png)
|
||||
assert info == {}
|
||||
|
||||
def test_c2pa_returns_false_for_non_png(self, tmp_jpeg_path):
|
||||
assert not has_c2pa_metadata(tmp_jpeg_path)
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for cross-platform and cross-device compatibility.
|
||||
|
||||
Verifies that device detection, MPS fallback, and platform-specific
|
||||
code paths work correctly on CPU, MPS (macOS), and CUDA (Linux/Windows).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from remove_ai_watermarks.noai.progress import is_mps_error
|
||||
from remove_ai_watermarks.noai.utils import get_image_format, is_supported_format
|
||||
from remove_ai_watermarks.noai.watermark_profiles import (
|
||||
detect_model_profile,
|
||||
get_model_id_for_profile,
|
||||
get_recommended_strength,
|
||||
)
|
||||
from remove_ai_watermarks.noai.watermark_remover import get_device, is_watermark_removal_available
|
||||
|
||||
# ── Device detection ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDeviceDetection:
|
||||
"""Tests for get_device() across platforms."""
|
||||
|
||||
def test_returns_valid_device(self):
|
||||
device = get_device()
|
||||
assert device in ("cpu", "mps", "cuda")
|
||||
|
||||
def test_cpu_fallback_when_no_gpu(self):
|
||||
"""On CI / machines without GPU, should fall back to cpu or mps."""
|
||||
device = get_device()
|
||||
# Just verify it doesn't crash and returns a valid string
|
||||
assert isinstance(device, str)
|
||||
|
||||
@patch("remove_ai_watermarks.noai.watermark_remover._HAS_TORCH", False)
|
||||
def test_no_torch_returns_cpu(self):
|
||||
assert get_device() == "cpu"
|
||||
|
||||
|
||||
class TestMpsErrorDetection:
|
||||
"""Tests for MPS error detection helper."""
|
||||
|
||||
def test_detects_mps_error(self):
|
||||
err = RuntimeError("MPS backend out of memory")
|
||||
assert is_mps_error(err) is True
|
||||
|
||||
def test_non_mps_error(self):
|
||||
err = RuntimeError("CUDA out of memory")
|
||||
assert is_mps_error(err) is False
|
||||
|
||||
def test_generic_error(self):
|
||||
err = RuntimeError("something went wrong")
|
||||
assert is_mps_error(err) is False
|
||||
|
||||
|
||||
# ── Model profiles ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModelProfiles:
|
||||
"""Tests for watermark_profiles.py."""
|
||||
|
||||
def test_default_profile(self):
|
||||
assert get_model_id_for_profile("default") == "Lykon/dreamshaper-8"
|
||||
|
||||
def test_ctrlregen_profile(self):
|
||||
assert get_model_id_for_profile("ctrlregen") == "yepengliu/ctrlregen"
|
||||
|
||||
def test_unknown_profile_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown model profile"):
|
||||
get_model_id_for_profile("nonexistent")
|
||||
|
||||
def test_detect_default(self):
|
||||
assert detect_model_profile("Lykon/dreamshaper-8") == "default"
|
||||
|
||||
def test_detect_ctrlregen(self):
|
||||
assert detect_model_profile("yepengliu/ctrlregen") == "ctrlregen"
|
||||
|
||||
def test_recommended_strength_high(self):
|
||||
assert get_recommended_strength("treering") == 0.7
|
||||
|
||||
def test_recommended_strength_low(self):
|
||||
assert get_recommended_strength("stablesignature") == 0.04
|
||||
|
||||
def test_recommended_strength_medium(self):
|
||||
assert get_recommended_strength("unknown_type") == 0.35
|
||||
|
||||
|
||||
# ── Format utilities ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFormatUtils:
|
||||
"""Tests for utils.py format helpers."""
|
||||
|
||||
def test_supported_png(self, tmp_path):
|
||||
assert is_supported_format(tmp_path / "test.png")
|
||||
|
||||
def test_supported_jpg(self, tmp_path):
|
||||
assert is_supported_format(tmp_path / "test.jpg")
|
||||
|
||||
def test_supported_jpeg(self, tmp_path):
|
||||
assert is_supported_format(tmp_path / "test.jpeg")
|
||||
|
||||
def test_supported_webp(self, tmp_path):
|
||||
assert is_supported_format(tmp_path / "test.webp")
|
||||
|
||||
def test_unsupported_bmp(self, tmp_path):
|
||||
assert not is_supported_format(tmp_path / "test.bmp")
|
||||
|
||||
def test_unsupported_gif(self, tmp_path):
|
||||
assert not is_supported_format(tmp_path / "test.gif")
|
||||
|
||||
def test_get_format_png(self, tmp_path):
|
||||
assert get_image_format(tmp_path / "x.png") == "PNG"
|
||||
|
||||
def test_get_format_jpg(self, tmp_path):
|
||||
assert get_image_format(tmp_path / "x.jpg") == "JPEG"
|
||||
|
||||
def test_get_format_jpeg(self, tmp_path):
|
||||
assert get_image_format(tmp_path / "x.jpeg") == "JPEG"
|
||||
|
||||
def test_get_format_webp_defaults_png(self, tmp_path):
|
||||
# .webp falls through to PNG in current implementation
|
||||
assert get_image_format(tmp_path / "x.webp") == "PNG"
|
||||
|
||||
|
||||
# ── Availability checks ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAvailability:
|
||||
"""Tests for dependency availability checks."""
|
||||
|
||||
def test_watermark_removal_available(self):
|
||||
# In dev env with torch+diffusers installed
|
||||
assert is_watermark_removal_available() is True
|
||||
|
||||
def test_invisible_is_available(self):
|
||||
from remove_ai_watermarks.invisible_engine import is_available
|
||||
|
||||
assert is_available() is True
|
||||
|
||||
|
||||
# ── Platform-specific path handling ─────────────────────────────────
|
||||
|
||||
|
||||
class TestPlatformPaths:
|
||||
"""Verify path handling works on current platform."""
|
||||
|
||||
def test_pathlib_works_for_assets(self):
|
||||
from pathlib import Path
|
||||
|
||||
asset_dir = Path(__file__).parent.parent / "src" / "remove_ai_watermarks" / "assets"
|
||||
assert (asset_dir / "gemini_bg_48.png").exists()
|
||||
assert (asset_dir / "gemini_bg_96.png").exists()
|
||||
|
||||
def test_asset_loading_works(self):
|
||||
"""Verify embedded assets load correctly (critical for packaging)."""
|
||||
from remove_ai_watermarks.gemini_engine import GeminiEngine
|
||||
|
||||
engine = GeminiEngine()
|
||||
# If we get here without error, asset loading works
|
||||
assert engine._alpha_small.shape == (48, 48)
|
||||
assert engine._alpha_large.shape == (96, 96)
|
||||