This commit is contained in:
harisreedhar
2025-04-08 12:29:07 +05:30
parent 2e3c3517cb
commit 4f5ac00a7b
3 changed files with 13 additions and 9 deletions
+2 -1
View File
@@ -43,7 +43,8 @@ split_ratio = 0.9995
```
[training.model]
embedder_path = .models/blendface.pt
generator_embedder_path = .models/blendface.pt
loss_embedder_path = .models/adaface.pt
gazer_path = .models/gazer.pt
motion_extractor_path = .models/motion_extractor.pt
face_masker_path = .models/face_masker.pt
+2 -1
View File
@@ -11,7 +11,8 @@ num_workers =
split_ratio =
[training.model]
embedder_path =
generator_embedder_path =
loss_embedder_path =
gazer_path =
motion_extractor_path =
face_masker_path =
+9 -7
View File
@@ -28,14 +28,16 @@ CONFIG_PARSER.read('config.ini')
class FaceSwapperTrainer(LightningModule):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_embedder_path = config_parser.get('training.model', 'embedder_path')
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval()
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval()
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
@@ -44,8 +46,8 @@ class FaceSwapperTrainer(LightningModule):
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.feature_loss = FeautureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder)
self.identity_loss = IdentityLoss(config_parser, self.embedder)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
self.gaze_loss = GazeLoss(config_parser, self.gazer)
self.mask_loss = MaskLoss(config_parser, self.face_masker)
@@ -89,7 +91,7 @@ class FaceSwapperTrainer(LightningModule):
do_update = (batch_index + 1) % self.config_accumulate_size == 0
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)
@@ -138,9 +140,9 @@ class FaceSwapperTrainer(LightningModule):
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
output_tensor, _ = self.forward(source_embedding, target_tensor)
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
output_embedding = calc_embedding(self.generator_embedder, output_tensor, (0, 0, 0, 0))
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
return validation_score