This commit is contained in:
henryruhs
2025-04-13 10:17:52 +02:00
committed by harisreedhar
parent b7a6f00e8b
commit d9fe667ced
2 changed files with 3 additions and 32 deletions
+1 -22
View File
@@ -49,28 +49,7 @@ class AdversarialLoss(nn.Module):
return adversarial_loss, weighted_adversarial_loss
class CycleLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, cycle_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
cycle_feature_loss = torch.stack(temp_tensors).mean()
cycle_l1_loss = self.l1_loss(target_tensor, cycle_tensor)
cycle_loss = (cycle_feature_loss + cycle_l1_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss
class FeautureLoss(nn.Module):
class FeatureLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
+2 -10
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -45,8 +45,7 @@ class FaceSwapperTrainer(LightningModule):
self.discriminator = Discriminator(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.cycle_loss = CycleLoss(config_parser)
self.feature_loss = FeautureLoss(config_parser)
self.feature_loss = FeatureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
@@ -93,15 +92,11 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
@@ -116,7 +111,6 @@ class FaceSwapperTrainer(LightningModule):
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
if do_update:
generator_optimizer.step()
generator_optimizer.zero_grad()
@@ -124,7 +118,6 @@ class FaceSwapperTrainer(LightningModule):
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
@@ -136,7 +129,6 @@ class FaceSwapperTrainer(LightningModule):
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('cycle_loss', cycle_loss)
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)