Merge pull request #70 from facefusion/cycle-loss

Cycle loss
This commit is contained in:
Harisreedhar
2025-04-15 14:17:28 +05:30
committed by GitHub
5 changed files with 34 additions and 3 deletions
+1
View File
@@ -77,6 +77,7 @@ num_filters = 16
```
[training.losses]
adversarial_weight = 1.0
cycle_weight = 1.0
feature_weight = 10.0
reconstruction_weight = 10.0
identity_weight = 20.0
+1
View File
@@ -37,6 +37,7 @@ num_filters =
[training.losses]
adversarial_weight =
cycle_weight =
feature_weight =
reconstruction_weight =
identity_weight =
+2 -2
View File
@@ -88,13 +88,13 @@ class AugmentTransform:
albumentations.OneOf(
[
albumentations.MotionBlur(p = 0.1),
albumentations.ZoomBlur(p = 0.1)
albumentations.ZoomBlur(max_factor = (1.0, 1.1), p = 0.1)
], p = 0.2),
albumentations.RandomBrightnessContrast(p = 0.7),
albumentations.ColorJitter(p = 0.2),
albumentations.RGBShift(p = 0.7),
albumentations.Illumination(p = 0.2),
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.7)
albumentations.Affine(translate_percent = (-0.03, 0.03), scale = (0.98, 1.02), rotate = (-2, 2), border_mode = 1, p = 0.3)
])
+21
View File
@@ -49,6 +49,27 @@ class AdversarialLoss(nn.Module):
return adversarial_loss, weighted_adversarial_loss
class CycleLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_cycle_weight = config_parser.getfloat('training.losses', 'cycle_weight')
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, cycle_tensor : Tensor, target_features : Tuple[Feature, ...], cycle_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_feature, output_feature in zip(target_features, cycle_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
feature_loss = torch.stack(temp_tensors).mean()
reconstruction_loss = self.l1_loss(target_tensor, cycle_tensor)
cycle_loss = (feature_loss + reconstruction_loss) * 0.5
weighted_feature_loss = cycle_loss * self.config_cycle_weight
return cycle_loss, weighted_feature_loss
class FeatureLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
+9 -1
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -45,6 +45,7 @@ class FaceSwapperTrainer(LightningModule):
self.discriminator = Discriminator(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.cycle_loss = CycleLoss(config_parser)
self.feature_loss = FeatureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
@@ -92,11 +93,15 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)
cycle_output_tensor, cycle_output_mask = self.generator(target_embedding, generator_output_tensor, generator_output_features)
cycle_output_features = self.generator.encode_features(cycle_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
cycle_loss, weighted_cycle_loss = self.cycle_loss(target_tensor, cycle_output_tensor, generator_target_features, cycle_output_features)
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
@@ -111,6 +116,7 @@ class FaceSwapperTrainer(LightningModule):
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
if do_update:
generator_optimizer.step()
generator_optimizer.zero_grad()
@@ -118,6 +124,7 @@ class FaceSwapperTrainer(LightningModule):
self.toggle_optimizer(discriminator_optimizer)
self.manual_backward(discriminator_loss)
if do_update:
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
@@ -129,6 +136,7 @@ class FaceSwapperTrainer(LightningModule):
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('cycle_loss', cycle_loss)
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)