diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 803286a..f2936df 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -78,7 +78,7 @@ class ReconstructionLoss(nn.Module): def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - target_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) + target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5