diff --git a/face_swapper/README.md b/face_swapper/README.md index 99cfe2b..98b8375 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -43,7 +43,8 @@ split_ratio = 0.9995 ``` [training.model] -embedder_path = .models/blendface.pt +generator_embedder_path = .models/blendface.pt +loss_embedder_path = .models/adaface.pt gazer_path = .models/gazer.pt motion_extractor_path = .models/motion_extractor.pt face_masker_path = .models/face_masker.pt diff --git a/face_swapper/config.ini b/face_swapper/config.ini index dade0ca..36ca46d 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -11,7 +11,8 @@ num_workers = split_ratio = [training.model] -embedder_path = +generator_embedder_path = +loss_embedder_path = gazer_path = motion_extractor_path = face_masker_path = diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 026f2a8..84a16d4 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -28,14 +28,16 @@ CONFIG_PARSER.read('config.ini') class FaceSwapperTrainer(LightningModule): def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.config_embedder_path = config_parser.get('training.model', 'embedder_path') + self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path') + self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path') self.config_gazer_path = config_parser.get('training.model', 'gazer_path') self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path') self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') - self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval() + self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() + self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval() self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval() @@ -44,8 +46,8 @@ class FaceSwapperTrainer(LightningModule): self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss(config_parser) self.feature_loss = FeautureLoss(config_parser) - self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder) - self.identity_loss = IdentityLoss(config_parser, self.embedder) + self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder) + self.identity_loss = IdentityLoss(config_parser, self.loss_embedder) self.motion_loss = MotionLoss(config_parser, self.motion_extractor) self.gaze_loss = GazeLoss(config_parser, self.gazer) self.mask_loss = MaskLoss(config_parser, self.face_masker) @@ -89,7 +91,7 @@ class FaceSwapperTrainer(LightningModule): do_update = (batch_index + 1) % self.config_accumulate_size == 0 generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] - source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) generator_target_features = self.generator.encode_features(target_tensor) generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features) generator_output_features = self.generator.encode_features(generator_output_tensor) @@ -138,9 +140,9 @@ class FaceSwapperTrainer(LightningModule): def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch - source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) + source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) output_tensor, _ = self.forward(source_embedding, target_tensor) - output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) + output_embedding = calc_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0)) validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation_score', validation_score, sync_dist = True, prog_bar = True) return validation_score