Add more Loss types

This commit is contained in:
henryruhs
2025-03-12 08:42:38 +01:00
parent 564cc7b127
commit 738e00d59e
+2 -2
View File
@@ -14,7 +14,7 @@ class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss:
positive_tensors = []
negative_tensors = []
@@ -37,7 +37,7 @@ class AdversarialLoss(nn.Module):
super().__init__()
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]:
temp_tensors = []
for discriminator_output_tensor in discriminator_output_tensors: