From ad675ae633c5db73eed81b67fa2f4e8d8989f80b Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sun, 16 Mar 2025 12:23:08 +0100 Subject: [PATCH] Join MaskNet to guide generator --- face_swapper/README.md | 1 + face_swapper/config.ini | 1 + face_swapper/src/helper.py | 1 + face_swapper/src/inferencing.py | 2 +- face_swapper/src/models/generator.py | 11 +++++--- face_swapper/src/models/loss.py | 6 +++-- face_swapper/src/training.py | 40 ++++++---------------------- 7 files changed, 24 insertions(+), 38 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index e95d853..e80e484 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -82,6 +82,7 @@ identity_weight = 20.0 gaze_weight = 0.05 pose_weight = 0.05 expression_weight = 0.05 +mask_weight = 0.5 ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 113a0e1..29443e6 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -42,6 +42,7 @@ identity_weight = gaze_weight = pose_weight = expression_weight = +mask_weight = [training.trainer] accumulate_size = diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index b1377d5..695d541 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -39,6 +39,7 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor: + input_mask = input_mask.mean(dim = 1, keepdim = True) overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device) overlay_tensor[:, 2, :, :] = 1 input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8) diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 5182e00..2643195 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -26,5 +26,5 @@ def infer() -> None: source_tensor = io.read_image(config_source_path) target_tensor = io.read_image(config_target_path) source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0)) - output_tensor = generator(source_embedding, target_tensor)[0] + output_tensor, _ = generator(source_embedding, target_tensor) io.write_jpeg(output_tensor, config_output_path) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 2d9183d..73d4fea 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -5,7 +5,8 @@ from torch import Tensor, nn from ..networks.aad import AAD from ..networks.unet import UNet -from ..types import Embedding, Feature +from ..networks.masknet import MaskNet +from ..types import Embedding, Feature, Mask class Generator(nn.Module): @@ -13,13 +14,17 @@ class Generator(nn.Module): super().__init__() self.encoder = UNet(config_parser) self.generator = AAD(config_parser) + self.masker = MaskNet(config_parser) self.encoder.apply(init_weight) self.generator.apply(init_weight) + self.masker.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Feature, ...]]: + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: target_features = self.encode_features(target_tensor) output_tensor = self.generator(source_embedding, target_features) - return output_tensor, target_features + target_feature = target_features[-1] + output_mask = self.masker(target_tensor, target_feature) + return output_tensor, output_mask def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]: return self.encoder(input_tensor) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index ef7a296..2b1a36e 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -182,16 +182,18 @@ class GazeLoss(nn.Module): class MaskLoss(nn.Module): def __init__(self, config_parser : ConfigParser, face_parser : FaceParserModule) -> None: super().__init__() + self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.face_parser = face_parser self.mse_loss = nn.MSELoss() - def forward(self, target_tensor : Tensor, output_mask : Mask) -> Loss: + def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]: target_mask = self.calc_mask(target_tensor) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size) mask_loss = self.mse_loss(target_mask, output_mask) - return mask_loss + weighted_mask_loss = mask_loss * self.config_mask_weight + return mask_loss, weighted_mask_loss def calc_mask(self, target_tensor : Tensor) -> Tensor: target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear') diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 7ab2936..a1ef480 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -17,7 +17,6 @@ from .helper import calc_embedding, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss -from .networks.masknet import MaskNet from .types import Batch, Embedding, Mask, OptimizerSet warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -42,7 +41,6 @@ class FaceSwapperTrainer(LightningModule): self.face_parser = torch.jit.load(self.config_face_parser_path, map_location ='cpu').eval() self.generator = Generator(config_parser) self.discriminator = Discriminator(config_parser) - self.masker = MaskNet(config_parser) self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss(config_parser) self.feature_loss = FeautureLoss(config_parser) @@ -55,19 +53,15 @@ class FaceSwapperTrainer(LightningModule): def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: with torch.no_grad(): - output_tensor, target_features = self.generator(source_embedding, target_tensor) - target_feature = target_features[-1] - output_mask = self.masker(target_tensor, target_feature) + output_tensor, output_mask = self.generator(source_embedding, target_tensor) return output_tensor, output_mask - def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]: + def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - masker_optimizer = torch.optim.AdamW(self.masker.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) - masker_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(masker_optimizer, T_0 = 300, T_mult = 2) generator_config =\ { @@ -87,24 +81,16 @@ class FaceSwapperTrainer(LightningModule): 'interval': 'step' } } - masker_config =\ - { - 'optimizer': masker_optimizer, - 'lr_scheduler': - { - 'scheduler': masker_scheduler, - 'interval': 'step' - } - } - return generator_config, discriminator_config, masker_config + return generator_config, discriminator_config def training_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch do_update = (batch_index + 1) % self.config_accumulate_size == 0 - generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined] + generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - generator_output_tensor, generator_target_features = self.generator(source_embedding, target_tensor) + generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor) + generator_target_features = self.generator.encode_features(target_tensor) generator_output_features = self.generator.encode_features(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) @@ -113,16 +99,13 @@ class FaceSwapperTrainer(LightningModule): identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) - generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask) + generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) - generator_output_feature = generator_output_features[-1] - generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_feature.detach()) - mask_loss = self.mask_loss(target_tensor, generator_output_mask) - self.toggle_optimizer(generator_optimizer) self.manual_backward(generator_loss) if do_update: @@ -137,13 +120,6 @@ class FaceSwapperTrainer(LightningModule): discriminator_optimizer.zero_grad() self.untoggle_optimizer(discriminator_optimizer) - self.toggle_optimizer(masker_optimizer) - self.manual_backward(mask_loss) - if do_update: - masker_optimizer.step() - masker_optimizer.zero_grad() - self.untoggle_optimizer(masker_optimizer) - if self.global_step % self.config_preview_frequency == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)