mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
fix discriminator & restore grad for ID embedder
This commit is contained in:
@@ -33,7 +33,6 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
|
||||
crop_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
|
||||
with torch.no_grad():
|
||||
embedding = embedder(crop_tensor)
|
||||
embedding = embedder(crop_tensor)
|
||||
embedding = nn.functional.normalize(embedding, p = 2)
|
||||
return embedding
|
||||
|
||||
@@ -19,11 +19,11 @@ class DiscriminatorLoss(nn.Module):
|
||||
negative_tensors = []
|
||||
|
||||
for discriminator_source_tensor in discriminator_source_tensors:
|
||||
positive_tensor = torch.relu(discriminator_source_tensor + 1).mean(dim = [ 1, 2, 3 ])
|
||||
positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ])
|
||||
positive_tensors.append(positive_tensor)
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
negative_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ])
|
||||
negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ])
|
||||
negative_tensors.append(negative_tensor)
|
||||
|
||||
positive_loss = torch.stack(positive_tensors).mean()
|
||||
@@ -75,8 +75,10 @@ class ReconstructionLoss(nn.Module):
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
|
||||
with torch.no_grad():
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
|
||||
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
|
||||
|
||||
Reference in New Issue
Block a user