This commit is contained in:
harisreedhar
2025-03-11 14:10:39 +05:30
committed by henryruhs
parent b6e131e4c1
commit 2322b6539f
+5 -4
View File
@@ -147,7 +147,7 @@ class FaceSwapperTrainer(LightningModule):
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor)
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
@@ -169,12 +169,13 @@ class FaceSwapperTrainer(LightningModule):
self.log('validation_score', validation_score, prog_bar = True)
return validation_score
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> None:
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
preview_limit = 8
preview_cells = []
mask_tensor = (mask_tensor.repeat(1, 3, 1, 1) - 0.5) * 2
for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]):
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)
for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]):
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor] , dim = 2)
preview_cells.append(preview_cell)
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)