diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 31fc078..4b2f4f6 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -1,7 +1,6 @@ import os from configparser import ConfigParser - import torch from .training import FaceSwapperTrainer diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 2643195..6f81b2d 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -4,7 +4,7 @@ import torch from torchvision import io from .helper import calc_embedding -from .models.generator import Generator +from .training import FaceSwapperTrainer CONFIG_PARSER = configparser.ConfigParser() CONFIG_PARSER.read('config.ini') @@ -17,10 +17,7 @@ def infer() -> None: config_target_path = CONFIG_PARSER.get('inferencing', 'target_path') config_output_path = CONFIG_PARSER.get('inferencing', 'output_path') - state_dict = torch.load(config_generator_path).get('state_dict').get('generator') - generator = Generator(CONFIG_PARSER) - generator.load_state_dict(state_dict) - generator.eval() + generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval() embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval() source_tensor = io.read_image(config_source_path) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 4f5a2f5..ab76154 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -19,8 +19,7 @@ class Generator(nn.Module): self.generator.apply(init_weight) self.masker.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: - target_features = self.encode_features(target_tensor) + def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]: output_tensor = self.generator(source_embedding, target_features) target_feature = target_features[-1] output_mask = self.masker(target_tensor, target_feature) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index c105652..b928679 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -53,7 +53,8 @@ class FaceSwapperTrainer(LightningModule): def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: with torch.no_grad(): - output_tensor, output_mask = self.generator(source_embedding, target_tensor) + generator_target_features = self.generator.encode_features(target_tensor) + output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features) return output_tensor, output_mask @@ -89,8 +90,8 @@ class FaceSwapperTrainer(LightningModule): generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor) generator_target_features = self.generator.encode_features(target_tensor) + generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features) generator_output_features = self.generator.encode_features(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) @@ -138,7 +139,7 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - output_tensor, _ = self.generator(source_embedding, target_tensor) + output_tensor, _ = self.forward(source_embedding, target_tensor) output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)