From 4d8433f54a316129aefe8e83f04e6e2652314ca8 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Fri, 21 Feb 2025 20:12:48 +0530 Subject: [PATCH] changes --- face_swapper/src/helper.py | 5 ++--- face_swapper/src/models/loss.py | 7 ++++--- face_swapper/src/training.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 8770e97..3811118 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,7 +1,7 @@ import numpy import torch -from torch import Tensor, nn from pytorch_msssim import ssim +from torch import Tensor, nn from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor @@ -42,8 +42,7 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 crop_vision_tensor[:, :, :, :padding[2]] = 0 crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 - with torch.no_grad(): - source_embedding = id_embedder(crop_vision_tensor) + source_embedding = id_embedder(crop_vision_tensor) source_embedding = nn.functional.normalize(source_embedding, p = 2) return source_embedding diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 0d5c49d..9ef34b9 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -95,14 +95,15 @@ class FaceSwapperLoss: return loss_attribute def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) - target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0)) + with torch.no_grad(): + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0)) face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5 loss_reconstructions = [] for index, face_similarity in enumerate(face_similarities): if face_similarity.item() > 0.9: - loss_mse = self.mse_loss(swap_tensor[index], target_tensor[index]) + loss_mse = self.mse_loss(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0)) loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0)) loss_reconstruction = (loss_mse + loss_ssim) * 0.5 loss_reconstructions.append(loss_reconstruction) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index f47093d..ecc70b6 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -94,14 +94,14 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined] -def create_loaders(dataset : Dataset[Any]): # type:ignore[type-arg] +def create_loaders(dataset : Dataset[Any]) -> Tuple[TorchDataLoader[Any], TorchDataLoader[Any]]: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') training_dataset, validate_dataset = split_dataset(dataset) training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True) - return training_loader, validation_loader # type:ignore[return-value] + return training_loader, validation_loader def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]: