Introduce new ReconstructionLoss class

This commit is contained in:
henryruhs
2025-02-22 14:30:52 +01:00
parent 086d9eed87
commit 3b7d3b6688
-1
View File
@@ -147,7 +147,6 @@ class ReconstructionLoss(torch.nn.Module):
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tensor:
batch_size = CONFIG.getint('training.loader', 'batch_size')
loss_tensor = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1)
loss_tensor = torch.mean(loss_tensor, dim = 1) * 0.5