mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -147,7 +147,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % self.config_preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
@@ -169,12 +169,13 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.log('validation_score', validation_score, prog_bar = True)
|
||||
return validation_score
|
||||
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> None:
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
mask_tensor = (mask_tensor.repeat(1, 3, 1, 1) - 0.5) * 2
|
||||
|
||||
for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)
|
||||
for source_tensor, target_tensor, output_tensor, mask_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], mask_tensor[:preview_limit]):
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, mask_tensor] , dim = 2)
|
||||
preview_cells.append(preview_cell)
|
||||
|
||||
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
|
||||
|
||||
Reference in New Issue
Block a user