fix(photomaker-v2): don't pre-unsqueeze id_embeds (the pipeline does it)

V2's pipeline forward at line 705 of upstream pipeline.py calls
`id_embeds.unsqueeze(0)` itself to add a batch dim, so callers pass a 2-D
(N_faces, 512) tensor and the pipeline turns it into 3-D. Upstream
inference_pmv2.py shows the canonical form: torch.stack([...]) of per-image
embeddings.

Our previous call .unsqueeze(0)'d on the way in, which the pipeline then
.unsqueeze(0)'d again, giving a (1, 1, 512) shape that the V2 id_encoder
consumed as garbage -- the resulting output was a training-time face collage
(verified visually 2026-06-04 against tatsunari + gemini_3 + the 9-face grid).

Fix: pass torch.stack([torch.from_numpy(embedding)]) -- shape (1, 512) -- so
the pipeline's internal unsqueeze gives the expected (1, 1, 512) inside the
forward. Don't pre-cast dtype either; the pipeline handles that internally.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
Victor Kuznetsov
2026-06-08 19:03:04 -07:00
parent 37817a610f
commit b1fed810fd
@@ -319,11 +319,17 @@ def restore_faces_photomaker(
continue
# Get the ArcFace embedding for THIS face (V2's required ID branch). InsightFace
# expects BGR; analyze_faces returns a list, take the first detection.
# Shape: upstream's inference_pmv2.py stacks per-image embeddings into a 2-D
# tensor (N_images, 512). The pipeline forward then calls `.unsqueeze(0)` ITSELF
# (line 705 of pipeline.py) to add a batch dim, so we must NOT pre-unsqueeze --
# giving `(1, 1, 512)` to the V2 forward made the id_encoder consume garbage and
# the pipeline output the training-time face collage (caught visually 2026-06-04).
# Dtype stays float32 here; the pipeline casts internally.
faces = analyze_faces(face_analyser, id_crop_bgr)
if not faces:
logger.debug("photomaker_restore: InsightFace did not detect a face in the crop; skipping")
continue
id_embed = torch.from_numpy(faces[0]["embedding"]).unsqueeze(0).to(pipeline.device, dtype=pipeline.dtype)
id_embeds = torch.stack([torch.from_numpy(faces[0]["embedding"])])
id_crop_rgb = cv2.cvtColor(id_crop_bgr, cv2.COLOR_BGR2RGB)
id_image_pil = Image.fromarray(id_crop_rgb)
@@ -338,7 +344,7 @@ def restore_faces_photomaker(
out = pipeline(
prompt=_PHOTOMAKER_PROMPT,
input_id_images=[id_image_pil],
id_embeds=id_embed,
id_embeds=id_embeds,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
start_merge_step=style_strength,