diff --git a/src/remove_ai_watermarks/cli.py b/src/remove_ai_watermarks/cli.py index 222106b..648f56a 100644 --- a/src/remove_ai_watermarks/cli.py +++ b/src/remove_ai_watermarks/cli.py @@ -181,8 +181,9 @@ def cmd_visible( from remove_ai_watermarks.metadata import remove_ai_metadata remove_ai_metadata(output, output) - except Exception: - pass + except Exception as e: + if ctx.obj.get("verbose"): + console.print(f" [yellow]⚠[/] Failed to strip metadata: {e}") size_kb = output.stat().st_size / 1024 console.print(f" [green]✓[/] Saved: {output} [dim]({size_kb:.0f} KB, {elapsed:.2f}s)[/]") @@ -493,6 +494,94 @@ def cmd_all( # ── Batch command ──────────────────────────────────────────────────── +def _process_batch_image( + ctx: click.Context, + img_path: Path, + out_path: Path, + mode: str, + inpaint: bool, + strength: float | None, + steps: int, + pipeline: str, + device: str, + seed: int | None, + hf_token: str | None, + humanize: float, +) -> None: + """Process a single image for batch mode. + + Applies the requested watermark removal steps (visible, invisible, + metadata) to *img_path* and writes the result to *out_path*. + + Raises: + ValueError: If the image cannot be opened. + """ + if mode in ("visible", "all"): + import cv2 + + from remove_ai_watermarks.gemini_engine import ( + GeminiEngine, + get_watermark_config, + ) + + if "_vis_engine" not in ctx.obj: + ctx.obj["_vis_engine"] = GeminiEngine() + engine = ctx.obj["_vis_engine"] + read_path = img_path + if mode == "all" and out_path.exists(): + read_path = out_path + image = cv2.imread(str(read_path), cv2.IMREAD_COLOR) + if image is None: + raise ValueError("Failed to read image") + + det = engine.detect_watermark(image) + if det.detected: + result = engine.remove_watermark(image) + if inpaint: + if det.confidence > 0.15: + region = det.region + else: + h, w = image.shape[:2] + config = get_watermark_config(w, h) + pos = config.get_position(w, h) + region = (pos[0], pos[1], config.logo_size, config.logo_size) + + result = engine.inpaint_residual(result, region) + else: + result = image.copy() + + cv2.imwrite(str(out_path), result) + + if mode in ("invisible", "all"): + from remove_ai_watermarks.invisible_engine import ( + is_available as invisible_available, + ) + + if invisible_available(): + from remove_ai_watermarks.invisible_engine import InvisibleEngine + + if "_inv_engine" not in ctx.obj: + ctx.obj["_inv_engine"] = InvisibleEngine( + device=None if device == "auto" else device, + pipeline=pipeline, + hf_token=hf_token, + ) + engine_inv = ctx.obj["_inv_engine"] + engine_inv.remove_watermark( + img_path if mode == "invisible" else out_path, + out_path, + strength=strength, + num_inference_steps=steps, + seed=seed, + humanize=humanize, + ) + + if mode in ("metadata", "all"): + from remove_ai_watermarks.metadata import remove_ai_metadata + + remove_ai_metadata(img_path if mode == "metadata" else out_path, out_path) + + @main.command("batch") @click.argument("directory", type=click.Path(exists=True, file_okay=False, path_type=Path)) @click.option( @@ -537,10 +626,7 @@ def cmd_batch( output_dir = directory.parent / (directory.name + "_clean") output_dir.mkdir(parents=True, exist_ok=True) - # Assuming supported_formats is defined elsewhere - supported_formats = [".jpg", ".jpeg", ".png", ".webp"] - - images = sorted(p for p in directory.iterdir() if p.suffix.lower() in supported_formats) + images = sorted(p for p in directory.iterdir() if p.suffix.lower() in SUPPORTED_FORMATS) if not images: console.print(f"[yellow]No supported images found in {directory}[/]") @@ -568,71 +654,20 @@ def cmd_batch( progress.update(task, description=f"[cyan]{img_path.name}[/]") try: - if mode in ("visible", "all"): - import cv2 - - from remove_ai_watermarks.gemini_engine import ( - GeminiEngine, - get_watermark_config, - ) - - if "_vis_engine" not in ctx.obj: - ctx.obj["_vis_engine"] = GeminiEngine() - engine = ctx.obj["_vis_engine"] - read_path = img_path - if mode == "all" and out_path.exists(): - read_path = out_path - image = cv2.imread(str(read_path), cv2.IMREAD_COLOR) - if image is None: - raise ValueError("Failed to read image") - - det = engine.detect_watermark(image) - if det.detected: - result = engine.remove_watermark(image) - if inpaint: - if det.confidence > 0.15: - region = det.region - else: - h, w = image.shape[:2] - config = get_watermark_config(w, h) - pos = config.get_position(w, h) - region = (pos[0], pos[1], config.logo_size, config.logo_size) - - result = engine.inpaint_residual(result, region) - else: - result = image.copy() - - cv2.imwrite(str(out_path), result) - - if mode in ("invisible", "all"): - from remove_ai_watermarks.invisible_engine import ( - is_available as invisible_available, - ) - - if invisible_available(): - from remove_ai_watermarks.invisible_engine import InvisibleEngine - - if "_inv_engine" not in ctx.obj: - ctx.obj["_inv_engine"] = InvisibleEngine( - device=None if device == "auto" else device, - pipeline=pipeline, - hf_token=hf_token, - ) - engine_inv = ctx.obj["_inv_engine"] - engine_inv.remove_watermark( - img_path if mode == "invisible" else out_path, - out_path, - strength=strength, - num_inference_steps=steps, - seed=seed, - humanize=humanize, - ) - - if mode in ("metadata", "all"): - from remove_ai_watermarks.metadata import remove_ai_metadata - - remove_ai_metadata(img_path if mode == "metadata" else out_path, out_path) - + _process_batch_image( + ctx=ctx, + img_path=img_path, + out_path=out_path, + mode=mode, + inpaint=inpaint, + strength=strength, + steps=steps, + pipeline=pipeline, + device=device, + seed=seed, + hf_token=hf_token, + humanize=humanize, + ) processed += 1 except Exception as e: diff --git a/src/remove_ai_watermarks/metadata.py b/src/remove_ai_watermarks/metadata.py index 64f57ba..c4635d1 100644 --- a/src/remove_ai_watermarks/metadata.py +++ b/src/remove_ai_watermarks/metadata.py @@ -103,8 +103,9 @@ def has_ai_metadata(image_path: Path) -> bool: if has_c2pa_metadata(image_path): return True except ImportError: - # Try simple binary scan - data = image_path.read_bytes() + # Try simple binary scan (read only first 512KB to avoid OOM on huge files) + with open(image_path, "rb") as f: + data = f.read(512 * 1024) if b"c2pa" in data.lower() or b"C2PA" in data: return True