mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
fix ssim
This commit is contained in:
@@ -105,8 +105,9 @@ class ReconstructionLoss(nn.Module):
|
||||
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
|
||||
reconstruction_loss = (reconstruction_loss * has_similar_identity).mean() * 0.5
|
||||
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = 2.0)
|
||||
visual_loss = (visual_loss * has_similar_identity).mean()
|
||||
|
||||
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * self.config_reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
Reference in New Issue
Block a user