From 569df9e96d74e2c9adab9a18cb1b94ae54464a9c Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Mar 2025 15:33:31 +0100 Subject: [PATCH 1/3] Let the generator return target attributes --- face_swapper/src/models/generator.py | 9 +++------ face_swapper/src/training.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 8953354..6fd0097 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -16,13 +16,10 @@ class Generator(nn.Module): self.encoder.apply(init_weight) self.generator.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: - target_attributes = self.get_attributes(target_tensor) + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]: + target_attributes = self.encoder(target_tensor) output_tensor = self.generator(source_embedding, target_attributes) - return output_tensor - - def get_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]: - return self.encoder(input_tensor) + return output_tensor, target_attributes def init_weight(module : nn.Module) -> None: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 0baea47..29a5000 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -55,8 +55,8 @@ class FaceSwapperTrainer(LightningModule): def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]: with torch.no_grad(): - output_tensor = self.generator(source_embedding, target_tensor) - target_attribute = self.generator.get_attributes(target_tensor)[-1] + output_tensor, target_attributes = self.generator(source_embedding, target_tensor) + target_attribute = target_attributes[-1] mask_tensor = self.masker(target_tensor, target_attribute) return output_tensor, mask_tensor @@ -105,12 +105,11 @@ class FaceSwapperTrainer(LightningModule): self.toggle_optimizer(generator_optimizer) source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - target_attributes = self.generator.get_attributes(target_tensor) - generator_output_tensor = self.generator(source_embedding, target_tensor) - generator_output_attributes = self.generator.get_attributes(generator_output_tensor) + _, generator_target_attributes = self.generator(target_tensor) + generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) - attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes) + attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) @@ -126,7 +125,7 @@ class FaceSwapperTrainer(LightningModule): self.untoggle_optimizer(generator_optimizer) self.toggle_optimizer(masker_optimizer) - target_attribute = target_attributes[-1].detach() + target_attribute = generator_target_attributes[-1].detach() mask_tensor = self.masker(target_tensor, target_attribute) mask_loss = self.mask_loss(target_tensor, mask_tensor) @@ -168,7 +167,7 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - output_tensor = self.generator(source_embedding, target_tensor) + output_tensor, _ = self.generator(source_embedding, target_tensor) output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, prog_bar = True) From d212e2fe124af7d0974024bc2c354a287c5f4aaf Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Mar 2025 15:48:53 +0100 Subject: [PATCH 2/3] Fix import --- face_swapper/src/networks/masknet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index b80655e..0c43c4d 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -3,7 +3,7 @@ from configparser import ConfigParser import torch from torch import Tensor, nn -from face_swapper.src.types import Attribute +from ..types import Attribute class MaskNet(nn.Module): From 8b465fce038442a0fbf8e47926a1042ad97c4175 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Mar 2025 15:52:56 +0100 Subject: [PATCH 3/3] Fix generator call --- face_swapper/src/training.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 29a5000..68a0aad 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -105,11 +105,10 @@ class FaceSwapperTrainer(LightningModule): self.toggle_optimizer(generator_optimizer) source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - _, generator_target_attributes = self.generator(target_tensor) generator_output_tensor, generator_output_attributes = self.generator(source_embedding, target_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) - attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes) + attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_output_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) @@ -125,7 +124,7 @@ class FaceSwapperTrainer(LightningModule): self.untoggle_optimizer(generator_optimizer) self.toggle_optimizer(masker_optimizer) - target_attribute = generator_target_attributes[-1].detach() + target_attribute = generator_output_attributes[-1].detach() mask_tensor = self.masker(target_tensor, target_attribute) mask_loss = self.mask_loss(target_tensor, mask_tensor)