mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Add more Loss types
This commit is contained in:
@@ -14,7 +14,7 @@ class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
|
||||
def forward(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Loss:
|
||||
positive_tensors = []
|
||||
negative_tensors = []
|
||||
|
||||
@@ -37,7 +37,7 @@ class AdversarialLoss(nn.Module):
|
||||
super().__init__()
|
||||
self.config_adversarial_weight = config_parser.getfloat('training.losses', 'adversarial_weight')
|
||||
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
|
||||
Reference in New Issue
Block a user