diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 8520a61..e340924 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -14,7 +14,7 @@ class DiscriminatorLoss(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor: + def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss: positive_tensors = [] negative_tensors = [] @@ -37,7 +37,7 @@ class AdversarialLoss(nn.Module): super().__init__() self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight') - def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]: + def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]: temp_tensors = [] for discriminator_output_tensor in discriminator_output_tensors: