mirror of
https://github.com/wiltodelta/remove-ai-watermarks.git
synced 2026-06-04 18:18:00 +02:00
b0aad476fb
The cli refactor dropped rich from dependencies, but four scripts still did `from rich.console import Console` / `rich.table import Table`. Their test modules import the scripts, so a clean `uv sync --frozen` (CI: core+dev, no rich) failed at collection with ModuleNotFoundError on macOS/Windows/Linux. Add a shared plain-text shim `scripts/_plain_console.py` (Console/Table via click.echo, markup stripped) and switch all four scripts to it. Verified: all four import with rich blocked, and tests/test_synthid_corpus.py + tests/test_synthid_pixel_probe.py pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
151 lines
5.7 KiB
Python
151 lines
5.7 KiB
Python
"""Multilingual recall benchmark for the text-protection detector.
|
|
|
|
Measures the core lever of text protection (`text_protector.TextProtector`): if
|
|
the PP-OCRv3 DB detector misses a text region, that text is NOT preserved during
|
|
the SDXL watermark-removal pass and gets deformed (issue #14). This renders short
|
|
text in several scripts at several font sizes on two canvas sizes, runs detection,
|
|
and reports the fraction of each known text bbox the detector covers.
|
|
|
|
Findings (2026-05-29):
|
|
- Detection is script-agnostic: DB segments text *regions*, not characters, so
|
|
Latin / Cyrillic / CJK / Hangul / Arabic / digits score identically. Language
|
|
was never the lever.
|
|
- The only lever is resolution. A fixed small detector input downscaled large
|
|
canvases so far that small text was missed. Detecting at the native long side
|
|
(capped, see ``text_protector._DET_MAX_LONG_SIDE``) lifts overall hit-rate
|
|
from 0.91 to 1.00 and the worst cell (~16 px text on a 2048 canvas) from
|
|
0.06 to 1.00.
|
|
|
|
This needs the detector model (downloaded on first use) and a font that covers
|
|
all the scripts (macOS "Arial Unicode"; on Linux install a Noto super-font).
|
|
No GPU. Run:
|
|
|
|
uv run python scripts/text_detection_benchmark.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sys
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from _plain_console import Console
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
from remove_ai_watermarks import text_protector as tp
|
|
|
|
log = logging.getLogger(__name__)
|
|
console = Console()
|
|
|
|
# A single font covering every tested script isolates "language" from "font".
|
|
_FONT_CANDIDATES = [
|
|
"/System/Library/Fonts/Supplemental/Arial Unicode.ttf",
|
|
"/Library/Fonts/Arial Unicode.ttf",
|
|
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
|
|
"/usr/share/fonts/truetype/unifont/unifont.ttf",
|
|
]
|
|
SCRIPTS = {
|
|
"Latin": "Generated by AI",
|
|
"Cyrillic": "Сгенерировано ИИ",
|
|
"CJK": "豆包AI生成内容",
|
|
"Hangul": "AI로 생성됨",
|
|
"Arabic": "أنشئ بالذكاء",
|
|
"Digits": "0123456789",
|
|
}
|
|
FONT_SIZES = [16, 24, 32, 48, 64]
|
|
CANVASES = [1024, 2048]
|
|
PLACEMENTS = [(0.08, 0.15), (0.30, 0.55), (0.10, 0.82)]
|
|
BG_COLORS = [(35, 40, 60), (210, 205, 200)]
|
|
|
|
|
|
def _find_font() -> str:
|
|
for path in _FONT_CANDIDATES:
|
|
if Path(path).exists():
|
|
return path
|
|
raise SystemExit(
|
|
"No multi-script font found. Install one (macOS ships 'Arial Unicode'; "
|
|
"on Linux: a Noto CJK/super font) and add its path to _FONT_CANDIDATES."
|
|
)
|
|
|
|
|
|
def _render(
|
|
font_path: str,
|
|
canvas: int,
|
|
text: str,
|
|
font_size: int,
|
|
place: tuple[float, float],
|
|
bg: tuple[int, int, int],
|
|
) -> tuple[Image.Image, tuple[int, int, int, int]]:
|
|
img = Image.new("RGB", (canvas, canvas), bg)
|
|
draw = ImageDraw.Draw(img)
|
|
font = ImageFont.truetype(font_path, font_size)
|
|
x, y = int(place[0] * canvas), int(place[1] * canvas)
|
|
fg = (245, 245, 245) if sum(bg) < 360 else (20, 20, 20)
|
|
draw.text((x, y), text, font=font, fill=fg)
|
|
return img, draw.textbbox((x, y), text, font=font)
|
|
|
|
|
|
def _coverage(boxes: list[Any], bbox: tuple[int, int, int, int], h: int, w: int) -> float | None:
|
|
gt = np.zeros((h, w), np.uint8)
|
|
cv2.rectangle(gt, (bbox[0], bbox[1]), (bbox[2], bbox[3]), 1, -1)
|
|
area = int(gt.sum())
|
|
if area == 0:
|
|
return None
|
|
det = np.zeros((h, w), np.uint8)
|
|
if boxes:
|
|
cv2.fillPoly(det, [np.asarray(b, np.int32) for b in boxes], 1)
|
|
return int((gt & det).sum()) / area
|
|
|
|
|
|
def _hitrate(values: list[float], thr: float = 0.5) -> float:
|
|
return sum(c >= thr for c in values) / len(values) if values else float("nan")
|
|
|
|
|
|
def main() -> int:
|
|
logging.basicConfig(level=logging.WARNING)
|
|
if not tp.is_available():
|
|
raise SystemExit("text detector unavailable (need opencv with cv2.dnn.TextDetectionModel_DB)")
|
|
font_path = _find_font()
|
|
detector = tp.TextProtector()
|
|
|
|
by_script_size: dict[tuple[str, int], list[float]] = defaultdict(list)
|
|
by_size_canvas: dict[tuple[int, int], list[float]] = defaultdict(list)
|
|
by_script: dict[str, list[float]] = defaultdict(list)
|
|
|
|
for canvas in CANVASES:
|
|
for script, text in SCRIPTS.items():
|
|
for font_size in FONT_SIZES:
|
|
for idx, place in enumerate(PLACEMENTS):
|
|
img, bbox = _render(font_path, canvas, text, font_size, place, BG_COLORS[idx % len(BG_COLORS)])
|
|
bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
|
cov = _coverage(detector.detect_text_boxes(bgr), bbox, canvas, canvas)
|
|
if cov is None:
|
|
continue
|
|
by_script_size[(script, font_size)].append(cov)
|
|
by_size_canvas[(font_size, canvas)].append(cov)
|
|
by_script[script].append(cov)
|
|
|
|
console.print("=== hit-rate (coverage>=0.5) by script x font-size ===")
|
|
console.print("script".ljust(10) + "".join(f"{fs:>7}" for fs in FONT_SIZES))
|
|
for script in SCRIPTS:
|
|
console.print(
|
|
script.ljust(10) + "".join(f"{_hitrate(by_script_size[(script, fs)]):>7.2f}" for fs in FONT_SIZES)
|
|
)
|
|
|
|
console.print("\n=== hit-rate by font-size x canvas (the downscale effect) ===")
|
|
console.print("size".ljust(8) + "".join(f"{c:>8}" for c in CANVASES))
|
|
for fs in FONT_SIZES:
|
|
console.print(str(fs).ljust(8) + "".join(f"{_hitrate(by_size_canvas[(fs, c)]):>8.2f}" for c in CANVASES))
|
|
|
|
overall = _hitrate([c for vals in by_script.values() for c in vals])
|
|
console.print(f"\nOVERALL hit-rate: {overall:.2f} (detector max long side = {tp._DET_MAX_LONG_SIDE})")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|