fix discriminator & restore grad for ID embedder

This commit is contained in:
harisreedhar
2025-03-08 15:11:55 +05:30
committed by henryruhs
parent 9ff30a0268
commit 4af22832db
2 changed files with 7 additions and 6 deletions
+1 -2
View File
@@ -33,7 +33,6 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0
with torch.no_grad():
embedding = embedder(crop_tensor)
embedding = embedder(crop_tensor)
embedding = nn.functional.normalize(embedding, p = 2)
return embedding
+6 -4
View File
@@ -19,11 +19,11 @@ class DiscriminatorLoss(nn.Module):
negative_tensors = []
for discriminator_source_tensor in discriminator_source_tensors:
positive_tensor = torch.relu(discriminator_source_tensor + 1).mean(dim = [ 1, 2, 3 ])
positive_tensor = torch.relu(1 - discriminator_source_tensor).mean(dim = [ 1, 2, 3 ])
positive_tensors.append(positive_tensor)
for discriminator_output_tensor in discriminator_output_tensors:
negative_tensor = torch.relu(1 - discriminator_output_tensor).mean(dim = [ 1, 2, 3 ])
negative_tensor = torch.relu(discriminator_output_tensor + 1).mean(dim = [ 1, 2, 3 ])
negative_tensors.append(negative_tensor)
positive_loss = torch.stack(positive_tensors).mean()
@@ -75,8 +75,10 @@ class ReconstructionLoss(nn.Module):
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
with torch.no_grad():
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))