From b42d2b06e7fb12d556db5578f797066d2a327a63 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Feb 2025 13:07:14 +0100 Subject: [PATCH] Improve generate_preview --- face_swapper/src/training.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 381cfff..c936772 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -69,14 +69,15 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): self.log('loss_reconstruction', generator_losses.get('loss_reconstruction')) return generator_losses.get('loss_generator') - def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: - max_preview = 8 - source_tensors = source_tensor[:max_preview] - target_tensors = target_tensor[:max_preview] - swap_tensors = swap_tensor[:max_preview] - rows = [ torch.cat([ source_tensor, target_tensor, swap_tensor ], dim = 2) for source_tensor, target_tensor, swap_tensor in zip(source_tensors, target_tensors, swap_tensors) ] - grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True) - self.logger.experiment.add_image('preview', grid, self.global_step) + def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None: + preview_limit = 8 + preview_items = [] + + for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]): + preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)) + + preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True) + self.logger.experiment.add_image('preview', preview_grid, self.global_step) def create_trainer() -> Trainer: