Improve generate_preview

This commit is contained in:
henryruhs
2025-02-12 13:07:14 +01:00
parent 11c038cb81
commit b42d2b06e7
+9 -8
View File
@@ -69,14 +69,15 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
return generator_losses.get('loss_generator')
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
max_preview = 8
source_tensors = source_tensor[:max_preview]
target_tensors = target_tensor[:max_preview]
swap_tensors = swap_tensor[:max_preview]
rows = [ torch.cat([ source_tensor, target_tensor, swap_tensor ], dim = 2) for source_tensor, target_tensor, swap_tensor in zip(source_tensors, target_tensors, swap_tensors) ]
grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True)
self.logger.experiment.add_image('preview', grid, self.global_step)
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
preview_limit = 8
preview_items = []
for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]):
preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2))
preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True)
self.logger.experiment.add_image('preview', preview_grid, self.global_step)
def create_trainer() -> Trainer: