Adjust some namings

This commit is contained in:
henryruhs
2025-03-01 16:06:38 +01:00
parent 0055c0c97f
commit a22adaf51f
+4 -4
View File
@@ -79,14 +79,14 @@ class ReconstructionLoss(nn.Module):
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
is_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
reconstruction_loss = (reconstruction_loss * is_similar_identity).mean() * 0.5
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
return reconstruction_loss, weighted_reconstruction_loss