From a3ac4d5ddd219307f268fd0445f2310a221298c4 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 10 Mar 2025 13:43:52 +0530 Subject: [PATCH] changes --- face_swapper/README.md | 1 - face_swapper/config.ini | 1 - face_swapper/src/models/generator.py | 9 ++----- face_swapper/src/models/loss.py | 6 ++--- face_swapper/src/training.py | 35 ++++++++++++++++++++++------ 5 files changed, 32 insertions(+), 20 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index 3abe315..be3f497 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -81,7 +81,6 @@ identity_weight = 20.0 gaze_weight = 0.0 pose_weight = 0.0 expression_weight = 0.0 -mask_weight = 1.0 ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 088b6be..022c809 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -41,7 +41,6 @@ identity_weight = gaze_weight = pose_weight = expression_weight = -mask_weight = [training.trainer] learning_rate = diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 615dcd1..d02be0b 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -1,10 +1,8 @@ from configparser import ConfigParser -from typing import Tuple from torch import Tensor, nn from ..networks.aad import AAD -from ..networks.masknet import MaskNet from ..networks.unet import UNet from ..types import Attributes, Embedding @@ -14,16 +12,13 @@ class Generator(nn.Module): super().__init__() self.encoder = UNet(config_parser) self.generator = AAD(config_parser) - self.masker = MaskNet(config_parser) self.encoder.apply(init_weight) self.generator.apply(init_weight) - self.masker.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: target_attributes = self.get_attributes(target_tensor) output_tensor = self.generator(source_embedding, target_attributes) - mask_tensor = self.masker(target_tensor, target_attributes[-1]) - return output_tensor, mask_tensor + return output_tensor def get_attributes(self, input_tensor : Tensor) -> Attributes: return self.encoder(input_tensor) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index aec44e3..7e8f871 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -180,18 +180,16 @@ class GazeLoss(nn.Module): class MaskLoss(nn.Module): def __init__(self, config_parser : ConfigParser, parser : ParserModule) -> None: super().__init__() - self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.parser = parser self.mse_loss = nn.MSELoss() - def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor: target_mask = self.calc_mask(target_tensor) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size) mask_loss = self.mse_loss(target_mask, mask_tensor) - weighted_mask_loss = mask_loss * self.config_mask_weight - return mask_loss, weighted_mask_loss + return mask_loss def calc_mask(self, target_tensor : Tensor) -> Tensor: target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear') diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index ada30c0..6ba9604 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -17,6 +17,7 @@ from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss +from .networks.masknet import MaskNet from .types import Batch, Embedding, OptimizerSet warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -40,6 +41,7 @@ class FaceSwapperTrainer(LightningModule): self.parser = torch.jit.load(self.config_parser_path, map_location = 'cpu').eval() self.generator = Generator(config_parser) self.discriminator = Discriminator(config_parser) + self.masker = MaskNet(config_parser) self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss(config_parser) self.attribute_loss = AttributeLoss(config_parser) @@ -54,11 +56,13 @@ class FaceSwapperTrainer(LightningModule): output_tensor, mask_tensor = self.generator(source_embedding, target_tensor) return output_tensor, mask_tensor - def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: + def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]: generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + masker_optimizer = torch.optim.AdamW(self.masker.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) + masker_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(masker_optimizer, T_0 = 300, T_mult = 2) generator_config =\ { @@ -78,14 +82,23 @@ class FaceSwapperTrainer(LightningModule): 'interval': 'step' } } - return generator_config, discriminator_config + masker_config =\ + { + 'optimizer': masker_optimizer, + 'lr_scheduler': + { + 'scheduler': masker_scheduler, + 'interval': 'step' + } + } + return generator_config, discriminator_config, masker_config def training_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch - generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_attributes = self.generator.get_attributes(target_tensor) - generator_output_tensor, generator_mask_tensor = self.generator(source_embedding, target_tensor) + generator_output_tensor = self.generator(source_embedding, target_tensor) generator_output_attributes = self.generator.get_attributes(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) @@ -96,14 +109,22 @@ class FaceSwapperTrainer(LightningModule): identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) - mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_mask_tensor) - generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss + generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss generator_optimizer.zero_grad() self.manual_backward(generator_loss) generator_optimizer.step() self.untoggle_optimizer(generator_optimizer) + self.toggle_optimizer(masker_optimizer) + mask_tensor = self.masker(target_tensor, target_attributes[-1].detach()) + mask_loss = self.mask_loss(target_tensor, mask_tensor) + + masker_optimizer.zero_grad() + self.manual_backward(mask_loss) + masker_optimizer.step() + self.untoggle_optimizer(masker_optimizer) + self.toggle_optimizer(discriminator_optimizer) discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) @@ -131,7 +152,7 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - output_tensor, mask_tensor = self.generator(source_embedding, target_tensor) + output_tensor = self.generator(source_embedding, target_tensor) output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, prog_bar = True)