This commit is contained in:
harisreedhar
2025-04-11 23:15:30 +05:30
parent 76fe5c351c
commit dc2b2dc982
4 changed files with 32 additions and 3 deletions
+1
View File
@@ -77,6 +77,7 @@ num_filters = 16
```
[training.losses]
adversarial_weight = 1.0
cycle_weight = 1.0
feature_weight = 10.0
reconstruction_weight = 10.0
identity_weight = 20.0
+1
View File
@@ -37,6 +37,7 @@ num_filters =
[training.losses]
adversarial_weight =
cycle_weight =
feature_weight =
reconstruction_weight =
identity_weight =
+22 -1
View File
@@ -49,7 +49,28 @@ class AdversarialLoss(nn.Module):
return adversarial_loss, weighted_adversarial_loss
class FeatureLoss(nn.Module):
class CycleLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, cycle_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
cycle_feature_loss = torch.stack(temp_tensors).mean()
cycle_l1_loss = self.l1_loss(target_tensor, cycle_tensor)
cycle_loss = (cycle_feature_loss + cycle_l1_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss
class FeautureLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
+8 -2
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -45,7 +45,8 @@ class FaceSwapperTrainer(LightningModule):
self.discriminator = Discriminator(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.feature_loss = FeatureLoss(config_parser)
self.cycle_loss = CycleLoss(config_parser)
self.feature_loss = FeautureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
@@ -92,11 +93,15 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
@@ -129,6 +134,7 @@ class FaceSwapperTrainer(LightningModule):
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('cycle_loss', cycle_loss)
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)