From 738e00d59e9ae876031c9505ce5eea9aa0adf465 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Mar 2025 08:42:38 +0100 Subject: [PATCH] Add more Loss types --- face_swapper/src/models/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 8520a61..e340924 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -14,7 +14,7 @@ class DiscriminatorLoss(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor: + def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss: positive_tensors = [] negative_tensors = [] @@ -37,7 +37,7 @@ class AdversarialLoss(nn.Module): super().__init__() self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight') - def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]: + def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]: temp_tensors = [] for discriminator_output_tensor in discriminator_output_tensors: