From b1fed810fd81b3f0bbe685a167eaf65a9ae286b6 Mon Sep 17 00:00:00 2001 From: Victor Kuznetsov Date: Mon, 8 Jun 2026 19:03:04 -0700 Subject: [PATCH] fix(photomaker-v2): don't pre-unsqueeze id_embeds (the pipeline does it) V2's pipeline forward at line 705 of upstream pipeline.py calls `id_embeds.unsqueeze(0)` itself to add a batch dim, so callers pass a 2-D (N_faces, 512) tensor and the pipeline turns it into 3-D. Upstream inference_pmv2.py shows the canonical form: torch.stack([...]) of per-image embeddings. Our previous call .unsqueeze(0)'d on the way in, which the pipeline then .unsqueeze(0)'d again, giving a (1, 1, 512) shape that the V2 id_encoder consumed as garbage -- the resulting output was a training-time face collage (verified visually 2026-06-04 against tatsunari + gemini_3 + the 9-face grid). Fix: pass torch.stack([torch.from_numpy(embedding)]) -- shape (1, 512) -- so the pipeline's internal unsqueeze gives the expected (1, 1, 512) inside the forward. Don't pre-cast dtype either; the pipeline handles that internally. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/remove_ai_watermarks/photomaker_restore.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/remove_ai_watermarks/photomaker_restore.py b/src/remove_ai_watermarks/photomaker_restore.py index d8c4bbd..3a0d610 100644 --- a/src/remove_ai_watermarks/photomaker_restore.py +++ b/src/remove_ai_watermarks/photomaker_restore.py @@ -319,11 +319,17 @@ def restore_faces_photomaker( continue # Get the ArcFace embedding for THIS face (V2's required ID branch). InsightFace # expects BGR; analyze_faces returns a list, take the first detection. + # Shape: upstream's inference_pmv2.py stacks per-image embeddings into a 2-D + # tensor (N_images, 512). The pipeline forward then calls `.unsqueeze(0)` ITSELF + # (line 705 of pipeline.py) to add a batch dim, so we must NOT pre-unsqueeze -- + # giving `(1, 1, 512)` to the V2 forward made the id_encoder consume garbage and + # the pipeline output the training-time face collage (caught visually 2026-06-04). + # Dtype stays float32 here; the pipeline casts internally. faces = analyze_faces(face_analyser, id_crop_bgr) if not faces: logger.debug("photomaker_restore: InsightFace did not detect a face in the crop; skipping") continue - id_embed = torch.from_numpy(faces[0]["embedding"]).unsqueeze(0).to(pipeline.device, dtype=pipeline.dtype) + id_embeds = torch.stack([torch.from_numpy(faces[0]["embedding"])]) id_crop_rgb = cv2.cvtColor(id_crop_bgr, cv2.COLOR_BGR2RGB) id_image_pil = Image.fromarray(id_crop_rgb) @@ -338,7 +344,7 @@ def restore_faces_photomaker( out = pipeline( prompt=_PHOTOMAKER_PROMPT, input_id_images=[id_image_pil], - id_embeds=id_embed, + id_embeds=id_embeds, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, start_merge_step=style_strength,