diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index ab880fb..64648dd 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -171,7 +171,7 @@ class FaceSwapperTrainer(LightningModule): preview_cells = [] overlay_tensor = overlay_mask(output_tensor, mask_tensor) - for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]): + for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]): preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2) preview_cells.append(preview_cell)