This commit is contained in:
harisreedhar
2025-02-27 13:13:20 +05:30
committed by henryruhs
parent 2ddcf52b66
commit 5d1b90ff19
+1 -1
View File
@@ -78,7 +78,7 @@ class ReconstructionLoss(nn.Module):
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
same_person = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * same_person).mean() * 0.5