mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Remove the condition from reconstruction loss
This commit is contained in:
@@ -79,12 +79,8 @@ class ReconstructionLoss(nn.Module):
|
||||
temp_tensors = []
|
||||
|
||||
for _source_tensor, _target_tensor in zip(source_tensor, target_tensor):
|
||||
temp_tensor = self.mse_loss(_source_tensor, _target_tensor)
|
||||
|
||||
if torch.equal(_source_tensor, _target_tensor):
|
||||
temp_tensors.append(temp_tensor)
|
||||
else:
|
||||
temp_tensors.append(temp_tensor * 0)
|
||||
temp_tensor = self.mse_loss(_source_tensor, _target_tensor) * torch.equal(_source_tensor, _target_tensor)
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
|
||||
Reference in New Issue
Block a user