From 2ddcf52b6677c49a960b778128b69c95fce41757 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 27 Feb 2025 12:18:09 +0530 Subject: [PATCH] changes --- face_swapper/src/models/loss.py | 21 +++++++++------------ face_swapper/src/training.py | 6 +++--- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index a41d76e..803286a 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -70,32 +70,29 @@ class AttributeLoss(nn.Module): class ReconstructionLoss(nn.Module): - def __init__(self) -> None: + def __init__(self, embedder : nn.Module) -> None: super().__init__() + self.embedder = embedder self.mse_loss = nn.MSELoss() def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') - temp_tensors = [] - - for __source_tensor__, __target_tensor__ in zip(source_tensor, target_tensor): - temp_tensor = self.mse_loss(__source_tensor__, __target_tensor__) * torch.equal(__source_tensor__, __target_tensor__) - temp_tensors.append(temp_tensor) - - reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5 + source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) + same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 + reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) + reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5 data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() - reconstruction_loss = (reconstruction_loss + similarity) * 0.5 weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight return reconstruction_loss, weighted_reconstruction_loss class IdentityLoss(nn.Module): - def __init__(self) -> None: + def __init__(self, embedder : nn.Module) -> None: super().__init__() - embedder_path = CONFIG.get('training.model', 'embedder_path') - self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.embedder = embedder def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: identity_weight = CONFIG.getfloat('training.losses', 'identity_weight') diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index e016b80..44d635d 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -28,16 +28,16 @@ class FaceSwapperTrainer(lightning.LightningModule): super().__init__() embedder_path = CONFIG.get('training.model', 'embedder_path') + self.embedder = torch.jit.load(embedder_path, map_location='cpu') # type:ignore[no-untyped-call] self.generator = Generator() self.discriminator = Discriminator() self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss() self.attribute_loss = AttributeLoss() - self.reconstruction_loss = ReconstructionLoss() - self.identity_loss = IdentityLoss() + self.reconstruction_loss = ReconstructionLoss(self.embedder) + self.identity_loss = IdentityLoss(self.embedder) self.pose_loss = PoseLoss() self.gaze_loss = GazeLoss() - self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.automatic_optimization = False def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: