From c85c755e00777d113d100e29437b0d39d180f40e Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Mar 2025 18:34:23 +0530 Subject: [PATCH 1/4] changes --- face_swapper/src/models/generator.py | 5 ++--- face_swapper/src/training.py | 7 ++++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 4f5a2f5..81c81d9 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -19,13 +19,12 @@ class Generator(nn.Module): self.generator.apply(init_weight) self.masker.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: - target_features = self.encode_features(target_tensor) + def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask, ]: output_tensor = self.generator(source_embedding, target_features) target_feature = target_features[-1] output_mask = self.masker(target_tensor, target_feature) output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask) - return output_tensor, output_mask + return output_tensor, output_mask, def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]: return self.encoder(input_tensor) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index c105652..19502cd 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -53,7 +53,8 @@ class FaceSwapperTrainer(LightningModule): def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: with torch.no_grad(): - output_tensor, output_mask = self.generator(source_embedding, target_tensor) + generator_target_features = self.generator.encode_features(target_tensor) + output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features) return output_tensor, output_mask @@ -89,8 +90,8 @@ class FaceSwapperTrainer(LightningModule): generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor) generator_target_features = self.generator.encode_features(target_tensor) + generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features) generator_output_features = self.generator.encode_features(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) @@ -138,7 +139,7 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - output_tensor, _ = self.generator(source_embedding, target_tensor) + output_tensor, _ = self(source_embedding, target_tensor) output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, sync_dist = True, prog_bar = True) From 9ede8a2a7def5b021f8f71f6362587418110a823 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Mar 2025 18:38:48 +0530 Subject: [PATCH 2/4] changes --- face_swapper/src/models/generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 81c81d9..ab76154 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -19,12 +19,12 @@ class Generator(nn.Module): self.generator.apply(init_weight) self.masker.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask, ]: + def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]: output_tensor = self.generator(source_embedding, target_features) target_feature = target_features[-1] output_mask = self.masker(target_tensor, target_feature) output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask) - return output_tensor, output_mask, + return output_tensor, output_mask def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]: return self.encoder(input_tensor) From 602e890af2ceed43fe59b6319878240ee2b86cca Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Mar 2025 18:45:35 +0530 Subject: [PATCH 3/4] changes --- face_swapper/src/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 19502cd..b928679 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -139,7 +139,7 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - output_tensor, _ = self(source_embedding, target_tensor) + output_tensor, _ = self.forward(source_embedding, target_tensor) output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, sync_dist = True, prog_bar = True) From 99a8527e247bbdd5042fc7ecbe03133ab31a532b Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Mar 2025 18:48:45 +0530 Subject: [PATCH 4/4] changes --- face_swapper/src/exporting.py | 1 - face_swapper/src/inferencing.py | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 31fc078..4b2f4f6 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -1,7 +1,6 @@ import os from configparser import ConfigParser - import torch from .training import FaceSwapperTrainer diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 2643195..6f81b2d 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -4,7 +4,7 @@ import torch from torchvision import io from .helper import calc_embedding -from .models.generator import Generator +from .training import FaceSwapperTrainer CONFIG_PARSER = configparser.ConfigParser() CONFIG_PARSER.read('config.ini') @@ -17,10 +17,7 @@ def infer() -> None: config_target_path = CONFIG_PARSER.get('inferencing', 'target_path') config_output_path = CONFIG_PARSER.get('inferencing', 'output_path') - state_dict = torch.load(config_generator_path).get('state_dict').get('generator') - generator = Generator(CONFIG_PARSER) - generator.load_state_dict(state_dict) - generator.eval() + generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval() source_tensor = io.read_image(config_source_path)