mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Improve generate_preview
This commit is contained in:
@@ -69,14 +69,15 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
|
||||
max_preview = 8
|
||||
source_tensors = source_tensor[:max_preview]
|
||||
target_tensors = target_tensor[:max_preview]
|
||||
swap_tensors = swap_tensor[:max_preview]
|
||||
rows = [ torch.cat([ source_tensor, target_tensor, swap_tensor ], dim = 2) for source_tensor, target_tensor, swap_tensor in zip(source_tensors, target_tensors, swap_tensors) ]
|
||||
grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image('preview', grid, self.global_step)
|
||||
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_items = []
|
||||
|
||||
for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]):
|
||||
preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2))
|
||||
|
||||
preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step)
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
|
||||
Reference in New Issue
Block a user