diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index a51fada..8ea5b3a 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -33,7 +33,6 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P crop_tensor[:, :, :, :padding[2]] = 0 crop_tensor[:, :, :, 112 - padding[3]:] = 0 - with torch.no_grad(): - embedding = embedder(crop_tensor) + embedding = embedder(crop_tensor) embedding = nn.functional.normalize(embedding, p = 2) return embedding diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 6d7a0e5..52f049f 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -19,11 +19,11 @@ class DiscriminatorLoss(nn.Module): negative_tensors = [] for discriminator_source_tensor in discriminator_source_tensors: - positive_tensor = torch.relu(discriminator_source_tensor + 1).mean(dim = [ 1, 2, 3 ]) + positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ]) positive_tensors.append(positive_tensor) for discriminator_output_tensor in discriminator_output_tensors: - negative_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ]) + negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ]) negative_tensors.append(negative_tensor) positive_loss = torch.stack(positive_tensors).mean() @@ -75,8 +75,10 @@ class ReconstructionLoss(nn.Module): self.mse_loss = nn.MSELoss() def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: - source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) + + with torch.no_grad(): + source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))