From 2322b6539f5d36aaf8aab19dfda813297c8dd61a Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 11 Mar 2025 14:10:39 +0530 Subject: [PATCH] changes --- face_swapper/src/training.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 11b95a1..f0b5cf7 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -147,7 +147,7 @@ class FaceSwapperTrainer(LightningModule): self.untoggle_optimizer(discriminator_optimizer) if self.global_step % self.config_preview_frequency == 0: - self.generate_preview(source_tensor, target_tensor, generator_output_tensor) + self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor) self.log('generator_loss', generator_loss, prog_bar = True) self.log('discriminator_loss', discriminator_loss, prog_bar = True) @@ -169,12 +169,13 @@ class FaceSwapperTrainer(LightningModule): self.log('validation_score', validation_score, prog_bar = True) return validation_score - def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> None: + def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None: preview_limit = 8 preview_cells = [] + mask_tensor = (mask_tensor.repeat(1, 3, 1, 1) - 0.5) * 2 - for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]): - preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2) + for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]): + preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor] , dim = 2) preview_cells.append(preview_cell) preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)