mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix typo
This commit is contained in:
@@ -49,28 +49,7 @@ class AdversarialLoss(nn.Module):
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class CycleLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_feature, output_feature in zip(target_features, cycle_features):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
cycle_feature_loss = torch.stack(temp_tensors).mean()
|
||||
cycle_l1_loss = self.l1_loss(target_tensor, cycle_tensor)
|
||||
cycle_loss = (cycle_feature_loss + cycle_l1_loss) * 0.5
|
||||
weighted_feature_loss = cycle_loss * self.config_cycle_weight
|
||||
return cycle_loss, weighted_feature_loss
|
||||
|
||||
|
||||
class FeautureLoss(nn.Module):
|
||||
class FeatureLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
@@ -45,8 +45,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.cycle_loss = CycleLoss(config_parser)
|
||||
self.feature_loss = FeautureLoss(config_parser)
|
||||
self.feature_loss = FeatureLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
|
||||
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
|
||||
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
|
||||
@@ -93,15 +92,11 @@ class FaceSwapperTrainer(LightningModule):
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
generator_output_features = self.generator.encode_features(generator_output_tensor)
|
||||
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
|
||||
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
|
||||
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
@@ -116,7 +111,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
self.manual_backward(generator_loss)
|
||||
|
||||
if do_update:
|
||||
generator_optimizer.step()
|
||||
generator_optimizer.zero_grad()
|
||||
@@ -124,7 +118,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
self.toggle_optimizer(discriminator_optimizer)
|
||||
self.manual_backward(discriminator_loss)
|
||||
|
||||
if do_update:
|
||||
discriminator_optimizer.step()
|
||||
discriminator_optimizer.zero_grad()
|
||||
@@ -136,7 +129,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss', adversarial_loss)
|
||||
self.log('cycle_loss', cycle_loss)
|
||||
self.log('feature_loss', feature_loss)
|
||||
self.log('reconstruction_loss', reconstruction_loss)
|
||||
self.log('identity_loss', identity_loss)
|
||||
|
||||
Reference in New Issue
Block a user