diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index f0b5cf7..7040725 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -175,7 +175,7 @@ class FaceSwapperTrainer(LightningModule): mask_tensor = (mask_tensor.repeat(1, 3, 1, 1) - 0.5) * 2 for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]): - preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor] , dim = 2) + preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor ], dim = 2) preview_cells.append(preview_cell) preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)