mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new ReconstructionLoss class
This commit is contained in:
@@ -147,7 +147,6 @@ class ReconstructionLoss(torch.nn.Module):
|
||||
|
||||
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tensor:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
|
||||
loss_tensor = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
|
||||
loss_tensor = torch.mean(loss_tensor, dim = 1) * 0.5
|
||||
|
||||
|
||||
Reference in New Issue
Block a user