diff --git a/src/remove_ai_watermarks/photomaker_restore.py b/src/remove_ai_watermarks/photomaker_restore.py index a5ff379..d8c4bbd 100644 --- a/src/remove_ai_watermarks/photomaker_restore.py +++ b/src/remove_ai_watermarks/photomaker_restore.py @@ -85,6 +85,8 @@ _PHOTOMAKER_FACE_SIZE = 512 _pipeline: Any | None = None _pipeline_lock = threading.Lock() +_face_analyser: Any | None = None +_face_analyser_lock = threading.Lock() def is_available() -> bool: @@ -110,6 +112,29 @@ def _select_device() -> str: return "cpu" +def _get_face_analyser() -> Any: + """Return the InsightFace FaceAnalysis2 singleton (downloads model packs on first use). + + **This is the non-commercial step.** Instantiating ``FaceAnalysis2()`` triggers + InsightFace's auto-download of the antelopev2/buffalo_l model packs, which are + released under a research-only license. See the module docstring NON-COMMERCIAL + notice. PhotoMaker-V2 requires this for the ArcFace identity branch. + """ + global _face_analyser + if _face_analyser is not None: + return _face_analyser + with _face_analyser_lock: + if _face_analyser is None: + import torch + from photomaker import FaceAnalysis2 + + providers = ["CUDAExecutionProvider"] if torch.cuda.is_available() else ["CPUExecutionProvider"] + fa = FaceAnalysis2(providers=providers, allowed_modules=["detection", "recognition"]) + fa.prepare(ctx_id=0, det_size=(640, 640)) + _face_analyser = fa + return _face_analyser + + def _get_pipeline() -> Any: """Return the lazily-built PhotoMaker pipeline singleton (downloads weights on first use).""" global _pipeline @@ -280,6 +305,9 @@ def restore_faces_photomaker( return cleaned_bgr pipeline = _get_pipeline() + face_analyser = _get_face_analyser() # NON-COMMERCIAL: triggers InsightFace model packs + from photomaker import analyze_faces + generator = None if seed is not None: generator = torch.Generator(device=pipeline.device).manual_seed(seed) @@ -289,6 +317,14 @@ def restore_faces_photomaker( id_crop_bgr, square_box = _face_crop_square(original_bgr, box) if id_crop_bgr.size == 0: continue + # Get the ArcFace embedding for THIS face (V2's required ID branch). InsightFace + # expects BGR; analyze_faces returns a list, take the first detection. + faces = analyze_faces(face_analyser, id_crop_bgr) + if not faces: + logger.debug("photomaker_restore: InsightFace did not detect a face in the crop; skipping") + continue + id_embed = torch.from_numpy(faces[0]["embedding"]).unsqueeze(0).to(pipeline.device, dtype=pipeline.dtype) + id_crop_rgb = cv2.cvtColor(id_crop_bgr, cv2.COLOR_BGR2RGB) id_image_pil = Image.fromarray(id_crop_rgb) @@ -302,6 +338,7 @@ def restore_faces_photomaker( out = pipeline( prompt=_PHOTOMAKER_PROMPT, input_id_images=[id_image_pil], + id_embeds=id_embed, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, start_merge_step=style_strength,