mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-05-20 11:54:40 +02:00
refactor: enforce strict linting and type checking across codebase
- Expand ruff rules (B, S, SIM, RET, COM, C4, G, PT, PIE, T20, DTZ, ICN, TCH, RUF, ANN) - Switch pyright to strict mode with relaxed test environment - Replace try-except-pass with contextlib.suppress throughout - Move type-only imports into TYPE_CHECKING blocks - Replace ambiguous Unicode chars (en dash, multiplication sign, Greek alpha) with ASCII - Move color-matcher from base deps to [gpu], remove unused requests dep - Add pyright to dev deps, update dependabot to uv ecosystem - Fix hardcoded version in test_version, unused unpacked vars in tests - Update maintain.sh, CLAUDE.md, .gitignore, .claude/settings.json - Remove obsolete .agents/rules/project.md - Upgrade all dependencies (Pygments vulnerability fix) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,65 +0,0 @@
|
||||
---
|
||||
trigger: always_on
|
||||
description: Project rules for remove-ai-watermarks
|
||||
---
|
||||
|
||||
# Project rules
|
||||
|
||||
## Role
|
||||
|
||||
You are a **principal Python engineer** working on `remove-ai-watermarks` — a CLI tool for removing visible and invisible AI watermarks from images.
|
||||
|
||||
## Python environment
|
||||
|
||||
- This project uses `uv` for Python package management.
|
||||
- Use `uv pip install` instead of `pip install`.
|
||||
- Use `uv run` to run Python scripts (e.g., `uv run python scripts/example.py`).
|
||||
- Project uses **src-layout**: source code is in `src/remove_ai_watermarks/`.
|
||||
|
||||
## Project structure
|
||||
|
||||
- `src/remove_ai_watermarks/` — main package (CLI, engines, metadata)
|
||||
- `src/remove_ai_watermarks/noai/` — vendored noai-watermark code (invisible watermark removal)
|
||||
- `src/remove_ai_watermarks/noai/ctrlregen/` — vendored CtrlRegen pipeline
|
||||
- `src/remove_ai_watermarks/assets/` — embedded watermark alpha maps (PNG)
|
||||
- `tests/` — pytest test suite
|
||||
- `data/samples/` — sample images for manual testing
|
||||
|
||||
## Testing
|
||||
|
||||
- Run tests: `uv run pytest`
|
||||
- Run with coverage: `uv run pytest --cov=remove_ai_watermarks --cov-report=term-missing`
|
||||
- Tests create temporary files via `tmp_path` — they do NOT depend on `data/samples/`.
|
||||
- ML pipeline modules (invisible_engine, ctrlregen, watermark_remover) require GPU and multi-GB model downloads — unit tests for these are limited to availability checks and constants.
|
||||
|
||||
## Git
|
||||
|
||||
- Do not create commits or push unless explicitly asked by the user.
|
||||
|
||||
## Code quality
|
||||
|
||||
- Always run `./maintain.sh` before committing to ensure code quality (ruff, pyright).
|
||||
- **CRITICAL**: Before using any method on a client or class, ALWAYS verify that the method exists by reading the class file first.
|
||||
- **NEVER** assume or invent methods. If the method doesn't exist, either use an existing alternative method or explicitly create the new method first.
|
||||
|
||||
## Documentation
|
||||
|
||||
- Update documentation (README.md) when you change functionality or add new features.
|
||||
- **DO NOT** create artifact documentation (walkthrough.md, verification.md) for bug fixes or small corrections.
|
||||
|
||||
## Language
|
||||
|
||||
- Use only English for all code, comments, docstrings, documentation, commit messages, and project artifacts.
|
||||
- Communicate with users in Russian, but all technical content must be in English.
|
||||
|
||||
## API integrations
|
||||
|
||||
- Do not assume or invent API request/response structures.
|
||||
- Always verify API payloads against official documentation before implementing.
|
||||
- Use Context7 MCP to retrieve up-to-date library documentation when working with external APIs or packages.
|
||||
|
||||
## Work completion
|
||||
|
||||
- You have **NO time limits**. Always complete the full task in one go.
|
||||
- Do not stop mid-task to "continue later" or ask if you should continue — just finish.
|
||||
- Do not split work into multiple commits unless the task is genuinely large.
|
||||
@@ -6,6 +6,7 @@
|
||||
]
|
||||
},
|
||||
"enabledPlugins": {
|
||||
"pyright-lsp@claude-plugins-official": true,
|
||||
"context7@claude-plugins-official": true,
|
||||
"code-simplifier@claude-plugins-official": true,
|
||||
"claude-md-management@claude-plugins-official": true
|
||||
|
||||
+15
-7
@@ -1,21 +1,29 @@
|
||||
version: 2
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "deps"
|
||||
open-pull-requests-limit: 10
|
||||
labels:
|
||||
- "dependencies"
|
||||
open-pull-requests-limit: 10
|
||||
- "python"
|
||||
groups:
|
||||
minor-and-patch:
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "ci"
|
||||
open-pull-requests-limit: 5
|
||||
labels:
|
||||
- "ci"
|
||||
- "dependencies"
|
||||
- "github-actions"
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
@@ -27,3 +27,6 @@ _refs/
|
||||
# Downloaded model weights
|
||||
yolov8n.pt
|
||||
.coverage
|
||||
|
||||
# Claude Code local settings
|
||||
.claude/settings.local.json
|
||||
|
||||
@@ -35,11 +35,27 @@ uv run pyright # type check
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- Python 3.10+, ruff line-length 120, type hints everywhere
|
||||
- Python 3.10+, ruff line-length 120, pyright strict mode, type hints everywhere
|
||||
- GPU/ML modules (invisible_engine, ctrlregen, watermark_remover) are optional — guard imports with `is_available()` checks
|
||||
- Tests for ML modules are limited to availability checks (require multi-GB downloads)
|
||||
- Always run `./maintain.sh` before committing
|
||||
- Use `uv` for all package operations, never raw `pip`
|
||||
- `_refs/` directory is excluded from all checks — contains third-party reference code
|
||||
|
||||
## Release Process
|
||||
|
||||
To create a new release, run:
|
||||
|
||||
```bash
|
||||
./release.sh <version> # e.g. ./release.sh 0.4.0
|
||||
```
|
||||
|
||||
The script will: validate version format (X.Y.Z) → run all checks (ruff, pytest, pyright) → update version in `pyproject.toml` and `src/remove_ai_watermarks/__init__.py` → commit → create git tag `vX.Y.Z` → push. GitHub Actions will then automatically create a GitHub Release with build artifacts.
|
||||
|
||||
When asked to make a release, always use `./release.sh`.
|
||||
|
||||
## Pre-commit Hook
|
||||
|
||||
Git hooks live in `.githooks/`. Run `./maintain.sh` once to activate them (sets `core.hooksPath`). The pre-commit hook runs ruff check, ruff format --check, and pytest.
|
||||
|
||||
## Language
|
||||
|
||||
|
||||
+6
-1
@@ -1,9 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Ensure git hooks are active
|
||||
git config core.hooksPath .githooks
|
||||
|
||||
uv sync --all-extras
|
||||
uv-outdated
|
||||
uv run uv-outdated
|
||||
uv run uv-secure --ignore-unfixed
|
||||
uv run ruff check --fix
|
||||
uv run ruff format
|
||||
uv run pytest
|
||||
uv run pyright
|
||||
|
||||
+35
-18
@@ -12,9 +12,7 @@ dependencies = [
|
||||
"opencv-python-headless>=4.8.0",
|
||||
"click>=8.0.0",
|
||||
"rich>=13.0.0",
|
||||
"color-matcher",
|
||||
"python-dotenv>=1.0.0",
|
||||
"requests>=2.33.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -26,11 +24,13 @@ gpu = [
|
||||
"controlnet-aux>=0.0.9",
|
||||
"safetensors",
|
||||
"ultralytics>=8.0.0",
|
||||
"color-matcher>=0.5.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"ruff>=0.4.0",
|
||||
"pyright>=1.1.0",
|
||||
]
|
||||
all = ["remove-ai-watermarks[gpu,dev]"]
|
||||
|
||||
@@ -55,21 +55,38 @@ addopts = "-v --tb=short"
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 120
|
||||
|
||||
[tool.pyright]
|
||||
exclude = ["tests"]
|
||||
reportOptionalMemberAccess = "warning"
|
||||
reportPrivateImportUsage = false
|
||||
reportInvalidTypeForm = false
|
||||
reportOptionalCall = false
|
||||
reportMissingImports = "warning"
|
||||
reportCallIssue = "warning"
|
||||
reportAttributeAccessIssue = "warning"
|
||||
reportArgumentType = "warning"
|
||||
reportPossiblyUnboundVariable = "warning"
|
||||
reportAssignmentType = "warning"
|
||||
reportOperatorIssue = "warning"
|
||||
exclude = ["_refs"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "N", "UP"]
|
||||
ignore = ["E501"]
|
||||
select = ["E", "F", "B", "I", "S", "UP", "SIM", "RET", "COM", "C4", "G", "PT", "PIE", "T20", "DTZ", "ICN", "TCH", "RUF", "ANN"]
|
||||
ignore = [
|
||||
"COM812", # missing trailing comma (conflicts with ruff formatter)
|
||||
"ANN401", # typing.Any — sometimes unavoidable with third-party libs
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*.py" = ["ANN", "S101", "S105", "S106", "S108"]
|
||||
"src/remove_ai_watermarks/noai/watermark_remover.py" = ["S603", "S606", "S607", "T201"] # subprocess calls for auto-install/CUDA fix
|
||||
"src/remove_ai_watermarks/noai/c2pa.py" = ["S110"] # try-except-pass for corrupt file handling
|
||||
"src/remove_ai_watermarks/noai/ctrlregen/engine.py" = ["S101", "S603"] # assert for loaded state, subprocess for auto-install
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.10"
|
||||
typeCheckingMode = "strict"
|
||||
exclude = ["_refs"]
|
||||
|
||||
[[tool.pyright.executionEnvironments]]
|
||||
root = "tests"
|
||||
extraPaths = ["."]
|
||||
reportAttributeAccessIssue = false
|
||||
reportOptionalSubscript = false
|
||||
reportOptionalMemberAccess = false
|
||||
reportArgumentType = false
|
||||
reportUnknownMemberType = false
|
||||
reportUnknownArgumentType = false
|
||||
reportUnknownVariableType = false
|
||||
reportMissingTypeArgument = false
|
||||
|
||||
@@ -89,7 +89,7 @@ def main(ctx: click.Context, verbose: bool) -> None:
|
||||
@click.option(
|
||||
"--inpaint-method", type=click.Choice(["ns", "telea", "gaussian"]), default="ns", help="Inpainting method."
|
||||
)
|
||||
@click.option("--inpaint-strength", type=float, default=0.85, help="Inpainting blend strength (0.0–1.0).")
|
||||
@click.option("--inpaint-strength", type=float, default=0.85, help="Inpainting blend strength (0.0-1.0).")
|
||||
@click.option("--detect/--no-detect", default=True, help="Detect watermark before removal.")
|
||||
@click.option("--detect-threshold", type=float, default=0.25, help="Detection confidence threshold.")
|
||||
@click.option("--strip-metadata/--keep-metadata", default=True, help="Strip AI metadata from output.")
|
||||
@@ -128,7 +128,7 @@ def cmd_visible(
|
||||
raise SystemExit(1)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}×{h})")
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}x{h})")
|
||||
|
||||
# Detection (we always detect softly, to find dynamic region for inpainting)
|
||||
with console.status("[cyan]Detecting watermark…[/]"):
|
||||
@@ -197,14 +197,14 @@ def cmd_visible(
|
||||
@click.option(
|
||||
"-o", "--output", type=click.Path(path_type=Path), default=None, help="Output path (default: <source>_clean.<ext>)."
|
||||
)
|
||||
@click.option("--strength", type=float, default=0.02, help="Denoising strength (0.0–1.0). Default: 0.02.")
|
||||
@click.option("--strength", type=float, default=0.02, help="Denoising strength (0.0-1.0). Default: 0.02.")
|
||||
@click.option("--steps", type=int, default=100, help="Number of denoising steps. Default: 100.")
|
||||
@click.option("--pipeline", type=click.Choice(["default", "ctrlregen"]), default="default", help="Pipeline profile.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
"--humanize", type=float, default=0.0, help="Analog Humanizer film grain intensity (0 = off, typical: 2.0–6.0)."
|
||||
"--humanize", type=float, default=0.0, help="Analog Humanizer film grain intensity (0 = off, typical: 2.0-6.0)."
|
||||
)
|
||||
@click.pass_context
|
||||
def cmd_invisible(
|
||||
@@ -334,7 +334,7 @@ def cmd_metadata(
|
||||
@click.option(
|
||||
"--inpaint-method", type=click.Choice(["ns", "telea", "gaussian"]), default="ns", help="Inpainting method."
|
||||
)
|
||||
@click.option("--strength", type=float, default=0.02, help="Invisible watermark denoising strength (0.0–1.0).")
|
||||
@click.option("--strength", type=float, default=0.02, help="Invisible watermark denoising strength (0.0-1.0).")
|
||||
@click.option("--steps", type=int, default=100, help="Number of denoising steps for invisible removal.")
|
||||
@click.option(
|
||||
"--pipeline",
|
||||
@@ -347,7 +347,7 @@ def cmd_metadata(
|
||||
@click.option("--seed", type=int, default=None, help="Random seed for reproducibility.")
|
||||
@click.option("--hf-token", type=str, default=None, help="HuggingFace API token.")
|
||||
@click.option(
|
||||
"--humanize", type=float, default=0.0, help="Analog Humanizer film grain intensity (0 = off, typical: 2.0–6.0)."
|
||||
"--humanize", type=float, default=0.0, help="Analog Humanizer film grain intensity (0 = off, typical: 2.0-6.0)."
|
||||
)
|
||||
@click.pass_context
|
||||
def cmd_all(
|
||||
@@ -406,7 +406,7 @@ def cmd_all(
|
||||
raise SystemExit(1)
|
||||
|
||||
h, w = image.shape[:2]
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}×{h})")
|
||||
console.print(f" [dim]Input:[/] {source.name} ({w}x{h})")
|
||||
|
||||
with console.status("[cyan]Removing visible watermark…[/]"):
|
||||
det = engine.detect_watermark(image)
|
||||
@@ -598,7 +598,7 @@ def _process_batch_image(
|
||||
@click.option("--steps", type=int, default=50, help="Number of denoising steps (invisible mode).")
|
||||
@click.option("--inpaint/--no-inpaint", default=True, help="Apply inpainting (visible mode).")
|
||||
@click.option(
|
||||
"--humanize", type=float, default=0.0, help="Analog Humanizer film grain intensity (0 = off, typical: 2.0–6.0)."
|
||||
"--humanize", type=float, default=0.0, help="Analog Humanizer film grain intensity (0 = off, typical: 2.0-6.0)."
|
||||
)
|
||||
@click.option("--pipeline", type=click.Choice(["default", "ctrlregen"]), default="default", help="Pipeline profile.")
|
||||
@click.option("--device", type=click.Choice(["auto", "cpu", "mps", "cuda"]), default="auto", help="Inference device.")
|
||||
|
||||
@@ -23,7 +23,7 @@ class FaceProtector:
|
||||
been destroyed by latent diffusion or other algorithms.
|
||||
"""
|
||||
|
||||
def __init__(self, use_yolo: bool = True, model_name: str = "yolov8n.pt"):
|
||||
def __init__(self, use_yolo: bool = True, model_name: str = "yolov8n.pt") -> None:
|
||||
self.use_yolo = use_yolo and HAS_YOLO
|
||||
self.detector = None
|
||||
self.haar_cascade = None
|
||||
@@ -62,20 +62,19 @@ class FaceProtector:
|
||||
bboxes.append((int(x1), int(y1), int(x2), int(y2)))
|
||||
return bboxes
|
||||
|
||||
else:
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
faces = self.haar_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
||||
bboxes = []
|
||||
for x, y, w, h in faces:
|
||||
# Add a 20% margin around the haar cascade face box
|
||||
margin_x = int(w * 0.2)
|
||||
margin_y = int(h * 0.2)
|
||||
x1 = max(0, x - margin_x)
|
||||
y1 = max(0, y - int(margin_y * 1.5)) # more margin on top for hair
|
||||
x2 = min(image.shape[1], x + w + margin_x)
|
||||
y2 = min(image.shape[0], y + h + margin_y)
|
||||
bboxes.append((x1, y1, x2, y2))
|
||||
return bboxes
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
faces = self.haar_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
|
||||
bboxes = []
|
||||
for x, y, w, h in faces:
|
||||
# Add a 20% margin around the haar cascade face box
|
||||
margin_x = int(w * 0.2)
|
||||
margin_y = int(h * 0.2)
|
||||
x1 = max(0, x - margin_x)
|
||||
y1 = max(0, y - int(margin_y * 1.5)) # more margin on top for hair
|
||||
x2 = min(image.shape[1], x + w + margin_x)
|
||||
y2 = min(image.shape[0], y + h + margin_y)
|
||||
bboxes.append((x1, y1, x2, y2))
|
||||
return bboxes
|
||||
|
||||
@staticmethod
|
||||
def _fix_ssl_certs() -> None:
|
||||
|
||||
@@ -4,13 +4,13 @@ Port of the GeminiWatermarkTool reverse-alpha-blending algorithm from C++ to Pyt
|
||||
Original author: Allen Kuo (allenk) — https://github.com/allenk/GeminiWatermarkTool
|
||||
|
||||
The Gemini AI watermark is applied using alpha blending:
|
||||
watermarked = α × logo + (1 - α) × original
|
||||
watermarked = a * logo + (1 - a) * original
|
||||
|
||||
We reverse this to recover the original:
|
||||
original = (watermarked - α × logo) / (1 - α)
|
||||
original = (watermarked - a * logo) / (1 - a)
|
||||
|
||||
The alpha maps are derived from background captures of the Gemini watermark
|
||||
on pure-black backgrounds (48×48 for small images, 96×96 for large images).
|
||||
on pure-black backgrounds (48x48 for small images, 96x96 for large images).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -19,11 +19,13 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,8 +33,8 @@ logger = logging.getLogger(__name__)
|
||||
class WatermarkSize(Enum):
|
||||
"""Watermark size mode based on image dimensions."""
|
||||
|
||||
SMALL = "small" # 48×48, for images ≤ 1024×1024
|
||||
LARGE = "large" # 96×96, for images > 1024×1024
|
||||
SMALL = "small" # 48x48, for images <= 1024x1024
|
||||
LARGE = "large" # 96x96, for images > 1024x1024
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -69,8 +71,8 @@ def get_watermark_config(width: int, height: int) -> WatermarkPosition:
|
||||
"""Get the appropriate watermark configuration based on image size.
|
||||
|
||||
Rules discovered from Gemini:
|
||||
- W > 1024 AND H > 1024: 96×96 logo at (W-64-96, H-64-96)
|
||||
- Otherwise: 48×48 logo at (W-32-48, H-32-48)
|
||||
- W > 1024 AND H > 1024: 96x96 logo at (W-64-96, H-64-96)
|
||||
- Otherwise: 48x48 logo at (W-32-48, H-32-48)
|
||||
"""
|
||||
if width > 1024 and height > 1024:
|
||||
return WatermarkPosition(margin_right=64, margin_bottom=64, logo_size=96)
|
||||
@@ -272,10 +274,7 @@ class GeminiEngine:
|
||||
|
||||
if ref_h > 8:
|
||||
ref_region = image[y1 - ref_h : y1, x1:x2]
|
||||
if len(ref_region.shape) == 3:
|
||||
gray_ref = cv2.cvtColor(ref_region, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
gray_ref = ref_region
|
||||
gray_ref = cv2.cvtColor(ref_region, cv2.COLOR_BGR2GRAY) if len(ref_region.shape) == 3 else ref_region
|
||||
|
||||
_, s_wm = cv2.meanStdDev(gray_region)
|
||||
_, s_ref = cv2.meanStdDev(gray_ref)
|
||||
@@ -397,7 +396,7 @@ class GeminiEngine:
|
||||
) -> None:
|
||||
"""Apply reverse alpha blending in-place.
|
||||
|
||||
Formula: original = (watermarked - α × logo) / (1 - α)
|
||||
Formula: original = (watermarked - a * logo) / (1 - a)
|
||||
"""
|
||||
x, y = position
|
||||
ah, aw = alpha_map.shape[:2]
|
||||
|
||||
@@ -12,8 +12,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
# Suppress verbose deprecation warnings from diffusers/transformers/huggingface_hub
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
@@ -106,7 +109,7 @@ class InvisibleEngine:
|
||||
Args:
|
||||
image_path: Path to the watermarked image.
|
||||
output_path: Output path (None = overwrite source).
|
||||
strength: Denoising strength (0.0–1.0). Default 0.04.
|
||||
strength: Denoising strength (0.0-1.0). Default 0.04.
|
||||
steps: Number of denoising steps.
|
||||
guidance_scale: Classifier-free guidance scale.
|
||||
seed: Random seed for reproducibility.
|
||||
|
||||
@@ -8,8 +8,12 @@ For metadata-only operations, the heavy ML dependencies are NOT required.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -184,10 +188,8 @@ def remove_ai_metadata(
|
||||
if _is_ai_key(key):
|
||||
continue
|
||||
if key == "exif":
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
exif_data = piexif.load(value)
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
if key in ("dpi", "gamma"):
|
||||
save_kwargs[key] = value
|
||||
@@ -203,10 +205,8 @@ def remove_ai_metadata(
|
||||
save_kwargs["pnginfo"] = pnginfo
|
||||
|
||||
if exif_data and save_kwargs["format"] == "JPEG":
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
save_kwargs["exif"] = piexif.dump(exif_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(output_path, **save_kwargs)
|
||||
|
||||
@@ -6,4 +6,4 @@ Original: https://github.com/mertizci/noai-watermark (MIT License)
|
||||
from remove_ai_watermarks.noai.cleaner import remove_ai_metadata
|
||||
from remove_ai_watermarks.noai.watermark_remover import WatermarkRemover, remove_watermark
|
||||
|
||||
__all__ = ["WatermarkRemover", "remove_watermark", "remove_ai_metadata"]
|
||||
__all__ = ["WatermarkRemover", "remove_ai_metadata", "remove_watermark"]
|
||||
|
||||
@@ -267,31 +267,30 @@ def inject_c2pa_chunk(target_path: Path, output_path: Path, c2pa_chunk: bytes) -
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(target_path, "rb") as f_in:
|
||||
with open(output_path, "wb") as f_out:
|
||||
f_out.write(f_in.read(8))
|
||||
with open(target_path, "rb") as f_in, open(output_path, "wb") as f_out:
|
||||
f_out.write(f_in.read(8))
|
||||
|
||||
c2pa_injected = False
|
||||
while True:
|
||||
chunk_header = f_in.read(8)
|
||||
if len(chunk_header) < 8:
|
||||
break
|
||||
c2pa_injected = False
|
||||
while True:
|
||||
chunk_header = f_in.read(8)
|
||||
if len(chunk_header) < 8:
|
||||
break
|
||||
|
||||
length = struct.unpack(">I", chunk_header[:4])[0]
|
||||
chunk_type = chunk_header[4:8]
|
||||
chunk_data = f_in.read(length)
|
||||
crc = f_in.read(4)
|
||||
length = struct.unpack(">I", chunk_header[:4])[0]
|
||||
chunk_type = chunk_header[4:8]
|
||||
chunk_data = f_in.read(length)
|
||||
crc = f_in.read(4)
|
||||
|
||||
if chunk_type == b"IDAT" and not c2pa_injected:
|
||||
f_out.write(c2pa_chunk)
|
||||
c2pa_injected = True
|
||||
if chunk_type == b"IDAT" and not c2pa_injected:
|
||||
f_out.write(c2pa_chunk)
|
||||
c2pa_injected = True
|
||||
|
||||
if chunk_type == C2PA_CHUNK_TYPE:
|
||||
continue
|
||||
if chunk_type == C2PA_CHUNK_TYPE:
|
||||
continue
|
||||
|
||||
f_out.write(chunk_header)
|
||||
f_out.write(chunk_data)
|
||||
f_out.write(crc)
|
||||
f_out.write(chunk_header)
|
||||
f_out.write(chunk_data)
|
||||
f_out.write(crc)
|
||||
|
||||
if chunk_type == b"IEND":
|
||||
break
|
||||
if chunk_type == b"IEND":
|
||||
break
|
||||
|
||||
@@ -13,8 +13,11 @@ The removal pipeline:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
import contextlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
import piexif
|
||||
from PIL import Image
|
||||
@@ -85,23 +88,19 @@ def _extract_non_ai_metadata(source_path: Path, keep_standard: bool) -> dict[str
|
||||
with Image.open(source_path) as img:
|
||||
# Handle EXIF data
|
||||
if "exif" in img.info:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
exif_dict = piexif.load(img.info["exif"])
|
||||
cleaned_metadata["exif"] = exif_dict
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Extract non-AI metadata
|
||||
for key, value in img.info.items():
|
||||
if _is_ai_metadata_key(key):
|
||||
continue
|
||||
|
||||
if keep_standard and key in PNG_METADATA_KEYS:
|
||||
is_standard = keep_standard and key in PNG_METADATA_KEYS
|
||||
is_nonstandard = not keep_standard and key not in ["exif", "dpi", "gamma"] and key not in PNG_METADATA_KEYS
|
||||
if is_standard or is_nonstandard:
|
||||
cleaned_metadata[key] = value
|
||||
elif not keep_standard:
|
||||
# Remove standard metadata while still preserving non-standard fields.
|
||||
if key not in ["exif", "dpi", "gamma"] and key not in PNG_METADATA_KEYS:
|
||||
cleaned_metadata[key] = value
|
||||
|
||||
# Keep DPI and gamma
|
||||
if "dpi" in img.info:
|
||||
@@ -154,11 +153,9 @@ def _prepare_clean_jpeg_kwargs(save_kwargs: dict[str, Any], metadata: dict[str,
|
||||
"""Prepare save kwargs for clean JPEG."""
|
||||
exif_dict = metadata.get("exif", {"0th": {}, "Exif": {}, "1st": {}, "GPS": {}, "Interop": {}})
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if "dpi" in metadata:
|
||||
save_kwargs["dpi"] = metadata["dpi"]
|
||||
|
||||
@@ -10,12 +10,15 @@ Attribution:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -114,12 +117,12 @@ class CtrlRegenEngine:
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
raise ImportError(
|
||||
"Failed to auto-install missing dependencies: "
|
||||
+ ", ".join(missing)
|
||||
+ ". Try manually: pip install --force-reinstall noai-watermark"
|
||||
)
|
||||
) from exc
|
||||
|
||||
self.base_model_id = base_model_id or DEFAULT_BASE_MODEL
|
||||
self.device = device
|
||||
@@ -132,10 +135,8 @@ class CtrlRegenEngine:
|
||||
def _set_progress(self, message: str) -> None:
|
||||
if self._progress_callback is None:
|
||||
return
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self._progress_callback(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Loading
|
||||
@@ -202,10 +203,8 @@ class CtrlRegenEngine:
|
||||
pipe = pipe.to(self.device)
|
||||
|
||||
if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._pipeline = pipe
|
||||
self._canny_detector = CannyDetector()
|
||||
|
||||
@@ -12,6 +12,7 @@ Attribution:
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
@@ -43,7 +44,7 @@ class CustomIPAdapterMixin:
|
||||
subfolder: str | list[str],
|
||||
weight_name: str | list[str],
|
||||
image_encoder_folder: str | None = "image_encoder",
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Load CtrlRegen IP-Adapter weights and DINOv2 image encoder.
|
||||
|
||||
@@ -93,7 +94,7 @@ class CustomIPAdapterMixin:
|
||||
}
|
||||
|
||||
state_dicts: list[dict] = []
|
||||
for path_or_dict, wn, sf in zip(pretrained_model_name_or_path_or_dict, weight_name, subfolder):
|
||||
for path_or_dict, wn, sf in zip(pretrained_model_name_or_path_or_dict, weight_name, subfolder, strict=False):
|
||||
if not isinstance(path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
path_or_dict,
|
||||
@@ -110,7 +111,7 @@ class CustomIPAdapterMixin:
|
||||
if wn.endswith(".safetensors"):
|
||||
state_dict: dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
@@ -126,15 +127,15 @@ class CustomIPAdapterMixin:
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# Always use DINOv2-giant as the image encoder.
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
logger.info("Loading DINOv2-giant image encoder for CtrlRegen")
|
||||
enc_dtype = getattr(self, "dtype", torch.float32) # type: ignore[attr-defined]
|
||||
image_encoder = AutoModel.from_pretrained(DINOV2_MODEL_ID).to(
|
||||
self.device,
|
||||
dtype=enc_dtype, # type: ignore[attr-defined]
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder) # type: ignore[attr-defined]
|
||||
has_encoder_attr = hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None
|
||||
if has_encoder_attr and image_encoder_folder is not None:
|
||||
logger.info("Loading DINOv2-giant image encoder for CtrlRegen")
|
||||
enc_dtype = getattr(self, "dtype", torch.float32) # type: ignore[attr-defined]
|
||||
image_encoder = AutoModel.from_pretrained(DINOV2_MODEL_ID).to(
|
||||
self.device,
|
||||
dtype=enc_dtype, # type: ignore[attr-defined]
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder) # type: ignore[attr-defined]
|
||||
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
feature_extractor = AutoImageProcessor.from_pretrained(DINOV2_MODEL_ID)
|
||||
|
||||
@@ -33,5 +33,3 @@ class CustomCtrlRegenPipeline(
|
||||
while ``CustomIPAdapterMixin`` only adds the
|
||||
``load_ctrlregen_ip_adapter`` method.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -8,8 +8,10 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -116,7 +118,7 @@ def run_tiled(
|
||||
_t0: float = tile_t0,
|
||||
_es: int = effective_steps,
|
||||
) -> Callable:
|
||||
def _cb(step: int, timestep: int, latents: Any) -> None: # noqa: ARG001
|
||||
def _cb(step: int, timestep: int, latents: Any) -> None:
|
||||
elapsed = time.monotonic() - _t0
|
||||
cur = step + 1
|
||||
per = elapsed / max(1, cur)
|
||||
|
||||
@@ -6,8 +6,10 @@ human-readable summary without modifying the source file.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
import piexif
|
||||
from PIL import Image
|
||||
@@ -82,9 +84,8 @@ def extract_ai_metadata(source_path: Path) -> dict[str, Any]:
|
||||
|
||||
for key, value in img.info.items():
|
||||
key_lower = key.lower()
|
||||
if key not in ai_metadata:
|
||||
if any(kw in key_lower for kw in AI_KEYWORDS):
|
||||
ai_metadata[key] = value
|
||||
if key not in ai_metadata and any(kw in key_lower for kw in AI_KEYWORDS):
|
||||
ai_metadata[key] = value
|
||||
|
||||
# Check for C2PA metadata
|
||||
if has_c2pa_metadata(source_path):
|
||||
@@ -111,10 +112,7 @@ def has_ai_metadata(image_path: Path) -> bool:
|
||||
if key in img.info:
|
||||
return True
|
||||
|
||||
if has_c2pa_metadata(image_path):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(has_c2pa_metadata(image_path))
|
||||
|
||||
|
||||
def get_ai_metadata_summary(source_path: Path) -> str:
|
||||
@@ -138,7 +136,7 @@ def get_ai_metadata_summary(source_path: Path) -> str:
|
||||
for key, value in ai_meta.items():
|
||||
if key == "c2pa_chunk":
|
||||
continue
|
||||
elif key == "c2pa" and isinstance(value, dict):
|
||||
if key == "c2pa" and isinstance(value, dict):
|
||||
lines.append("C2PA Metadata:")
|
||||
for ck, cv in value.items():
|
||||
lines.append(f" {ck}: {cv}")
|
||||
|
||||
@@ -6,11 +6,14 @@ class focused on orchestration.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from PIL import Image
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from remove_ai_watermarks.noai.progress import is_mps_error, make_pipeline_progress
|
||||
|
||||
@@ -28,7 +31,6 @@ def run_img2img(
|
||||
set_progress: Callable[[str], None],
|
||||
) -> Image.Image:
|
||||
"""Execute img2img with live progress and return the generated image."""
|
||||
w, h = image.size
|
||||
effective_steps = max(1, int(num_inference_steps * strength))
|
||||
|
||||
step_cb, first_step, done_ev, start_updater = make_pipeline_progress(
|
||||
@@ -143,10 +145,8 @@ def _call_pipeline(
|
||||
|
||||
|
||||
def _try_clear_mps_cache() -> None:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
import torch
|
||||
|
||||
if hasattr(torch, "mps"):
|
||||
torch.mps.empty_cache() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -20,8 +20,10 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
# ── ANSI color constants ────────────────────────────────────────────
|
||||
_CYAN = "\033[36m"
|
||||
@@ -99,7 +101,7 @@ def run_with_progress(
|
||||
def worker() -> None:
|
||||
try:
|
||||
output_holder["result"] = task()
|
||||
except Exception as error: # pragma: no cover – passthrough
|
||||
except Exception as error: # pragma: no cover - passthrough
|
||||
output_holder["error"] = error
|
||||
finally:
|
||||
done.set()
|
||||
@@ -202,18 +204,15 @@ def silence_library_output(
|
||||
lambda: _silence_diffusers(),
|
||||
lambda: __import__("huggingface_hub").logging.set_verbosity_error(),
|
||||
):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
_silence()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
with contextlib.redirect_stdout(io.StringIO()):
|
||||
with contextlib.redirect_stderr(io.StringIO()):
|
||||
if set_progress:
|
||||
set_progress("Executing watermark removal pipeline...")
|
||||
return run_func()
|
||||
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
||||
if set_progress:
|
||||
set_progress("Executing watermark removal pipeline...")
|
||||
return run_func()
|
||||
|
||||
return wrapped
|
||||
|
||||
@@ -300,7 +299,7 @@ def make_pipeline_progress(
|
||||
idx = 0
|
||||
pipeline_done.wait(timeout=0.4)
|
||||
|
||||
def step_callback(step: int, timestep: int, latents: Any) -> None: # noqa: ARG001
|
||||
def step_callback(step: int, timestep: int, latents: Any) -> None:
|
||||
first_step.set()
|
||||
last_cb_time[0] = time.monotonic()
|
||||
elapsed = time.monotonic() - t0_holder[0]
|
||||
|
||||
@@ -6,7 +6,10 @@ higher-level modules can import without circular dependencies.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from remove_ai_watermarks.noai.constants import SUPPORTED_FORMATS
|
||||
|
||||
|
||||
@@ -12,13 +12,16 @@ This module implements a simple regeneration attack that:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -105,7 +108,7 @@ def _detect_cuda_index_url() -> str:
|
||||
major, minor = version_str.split(".")[:2]
|
||||
cuda_tag = f"cu{major}{minor}"
|
||||
return f"https://download.pytorch.org/whl/{cuda_tag}"
|
||||
except Exception:
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
return "https://download.pytorch.org/whl/cu121"
|
||||
|
||||
@@ -195,9 +198,9 @@ def _ensure_watermark_deps() -> None:
|
||||
|
||||
torch = _torch
|
||||
_HAS_TORCH = True
|
||||
from diffusers import AutoPipelineForImage2Image # noqa: N813
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
AutoImg2ImgPipeline = AutoPipelineForImage2Image # noqa: N806
|
||||
AutoImg2ImgPipeline = AutoPipelineForImage2Image
|
||||
_HAS_DIFFUSERS = True
|
||||
|
||||
|
||||
@@ -245,7 +248,7 @@ class WatermarkRemover:
|
||||
torch_dtype: Any = None,
|
||||
progress_callback: Callable[[str], None] | None = None,
|
||||
hf_token: str | None = None,
|
||||
):
|
||||
) -> None:
|
||||
self.model_id = model_id or self.DEFAULT_MODEL_ID
|
||||
self.model_profile = detect_model_profile(self.model_id)
|
||||
|
||||
@@ -273,10 +276,8 @@ class WatermarkRemover:
|
||||
"""Send a progress update through callback when available."""
|
||||
if self._progress_callback is None:
|
||||
return
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self._progress_callback(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Preload ──────────────────────────────────────────────────────
|
||||
|
||||
@@ -351,18 +352,14 @@ class WatermarkRemover:
|
||||
) from exc
|
||||
|
||||
if hasattr(self._pipeline, "enable_xformers_memory_efficient_attention"):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self._set_progress("Enabling memory optimizations...")
|
||||
self._pipeline.enable_xformers_memory_efficient_attention() # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Mac Float32 memory slicing
|
||||
if self.device == "mps" and hasattr(self._pipeline, "enable_attention_slicing"):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
self._pipeline.enable_attention_slicing("max")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Model loaded successfully")
|
||||
self._set_progress("Model initialized. Preparing input image...")
|
||||
@@ -562,11 +559,9 @@ class WatermarkRemover:
|
||||
if self.device == "mps" and is_mps_error(error):
|
||||
logger.warning("MPS out of memory during CtrlRegen. Falling back to CPU.")
|
||||
self._set_progress("MPS out of memory! Retrying CtrlRegen on CPU...")
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
if _HAS_TORCH and hasattr(torch, "mps"):
|
||||
torch.mps.empty_cache() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.device = "cpu"
|
||||
self.torch_dtype = torch.float32 # type: ignore[assignment]
|
||||
|
||||
+11
-8
@@ -2,7 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -11,25 +14,25 @@ from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def tmp_image_path(tmp_path: Path) -> Path:
|
||||
"""Create a minimal 200×200 test PNG image and return its path."""
|
||||
"""Create a minimal 200x200 test PNG image and return its path."""
|
||||
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
|
||||
path = tmp_path / "test_image.png"
|
||||
cv2.imwrite(str(path), img)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def tmp_large_image_path(tmp_path: Path) -> Path:
|
||||
"""Create a 1200×1200 test PNG image (triggers large watermark branch)."""
|
||||
"""Create a 1200x1200 test PNG image (triggers large watermark branch)."""
|
||||
img = np.random.randint(0, 255, (1200, 1200, 3), dtype=np.uint8)
|
||||
path = tmp_path / "test_large.png"
|
||||
cv2.imwrite(str(path), img)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def tmp_jpeg_path(tmp_path: Path) -> Path:
|
||||
"""Create a minimal JPEG test image."""
|
||||
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
@@ -38,7 +41,7 @@ def tmp_jpeg_path(tmp_path: Path) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def tmp_png_with_ai_metadata(tmp_path: Path) -> Path:
|
||||
"""Create a PNG with AI-related metadata keys."""
|
||||
img = Image.new("RGB", (64, 64), color=(128, 128, 128))
|
||||
@@ -51,7 +54,7 @@ def tmp_png_with_ai_metadata(tmp_path: Path) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def tmp_clean_png(tmp_path: Path) -> Path:
|
||||
"""Create a PNG with no AI metadata."""
|
||||
img = Image.new("RGB", (64, 64), color=(200, 100, 50))
|
||||
|
||||
+12
-8
@@ -2,9 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -15,12 +18,12 @@ from PIL.PngImagePlugin import PngInfo
|
||||
from remove_ai_watermarks.cli import main
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture
|
||||
def sample_png(tmp_path: Path) -> Path:
|
||||
"""Create a sample PNG for CLI testing."""
|
||||
img = np.random.randint(0, 255, (200, 200, 3), dtype=np.uint8)
|
||||
@@ -79,7 +82,8 @@ class TestMainGroup:
|
||||
def test_version(self, runner):
|
||||
result = runner.invoke(main, ["--version"])
|
||||
assert result.exit_code == 0
|
||||
assert "0.3.1" in result.output
|
||||
assert "remove-ai-watermarks" in result.output
|
||||
assert "version" in result.output
|
||||
|
||||
def test_no_command_shows_banner(self, runner):
|
||||
result = runner.invoke(main, [])
|
||||
@@ -164,7 +168,7 @@ class TestInvisibleCommand:
|
||||
mock_engine.remove_watermark.assert_called_once()
|
||||
|
||||
def test_invisible_default_output(self, runner, sample_png):
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
mock_cls, _mock_engine = _mock_invisible_engine()
|
||||
with (
|
||||
patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls),
|
||||
@@ -188,7 +192,7 @@ class TestAllCommand:
|
||||
assert "visible" in result.output.lower()
|
||||
|
||||
def test_all_basic(self, runner, sample_png, tmp_path):
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
mock_cls, _mock_engine = _mock_invisible_engine()
|
||||
output = tmp_path / "clean.png"
|
||||
with (
|
||||
patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True),
|
||||
@@ -284,7 +288,7 @@ class TestBatchCommand:
|
||||
def test_batch_invisible_mode(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
output_dir = tmp_path / "output"
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
mock_cls, _mock_engine = _mock_invisible_engine()
|
||||
with (
|
||||
patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls),
|
||||
@@ -301,7 +305,7 @@ class TestBatchCommand:
|
||||
def test_batch_all_mode(self, runner, tmp_path):
|
||||
input_dir = _make_batch_dir(tmp_path)
|
||||
output_dir = tmp_path / "output"
|
||||
mock_cls, mock_engine = _mock_invisible_engine()
|
||||
mock_cls, _mock_engine = _mock_invisible_engine()
|
||||
with (
|
||||
patch("remove_ai_watermarks.cli.InvisibleEngine", mock_cls, create=True),
|
||||
patch("remove_ai_watermarks.invisible_engine.InvisibleEngine", mock_cls),
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestWatermarkConfig:
|
||||
assert get_watermark_size(1920, 1080) == WatermarkSize.LARGE
|
||||
|
||||
def test_boundary_image_stays_small(self):
|
||||
"""Exactly 1024×1024 should be SMALL (rule: > 1024 for LARGE)."""
|
||||
"""Exactly 1024x1024 should be SMALL (rule: > 1024 for LARGE)."""
|
||||
assert get_watermark_size(1024, 1024) == WatermarkSize.SMALL
|
||||
|
||||
def test_one_dimension_small(self):
|
||||
|
||||
@@ -8,7 +8,7 @@ def test_humanizer_does_not_modify_original_if_disabled():
|
||||
img[50, 50] = [100, 150, 200]
|
||||
org_img = img.copy()
|
||||
|
||||
# grain=0, shift=0 means disabled essentially. But wait, apply_analog_humanizer currently applies chromatic shift even if grain=0.
|
||||
# grain=0, shift=0 means disabled — result should match original.
|
||||
result = apply_analog_humanizer(img, grain_intensity=0.0, chromatic_shift=0)
|
||||
assert np.array_equal(result, org_img)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user