mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
reconstruction-loss fix
This commit is contained in:
@@ -72,21 +72,22 @@ class AttributeLoss(nn.Module):
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
reconstruction_loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
|
||||
reconstruction_loss = torch.mean(reconstruction_loss, dim = 1) * 0.5
|
||||
temp_tensors = []
|
||||
|
||||
if torch.equal(source_tensor, target_tensor):
|
||||
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4)
|
||||
else:
|
||||
reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4)
|
||||
for _source_tensor, _target_tensor in zip(source_tensor, target_tensor):
|
||||
temp_tensor = self.mse_loss(_source_tensor, _target_tensor)
|
||||
|
||||
if torch.equal(_source_tensor, _target_tensor):
|
||||
temp_tensors.append(temp_tensor)
|
||||
else:
|
||||
temp_tensors.append(temp_tensor * 0)
|
||||
reconstruction_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
|
||||
reconstruction_loss = (reconstruction_loss + similarity) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
Reference in New Issue
Block a user