This commit is contained in:
harisreedhar
2025-04-23 22:00:42 +05:30
parent d990ce4575
commit d44ac98e38
5 changed files with 3 additions and 58 deletions
-3
View File
@@ -46,7 +46,6 @@ split_ratio = 0.9995
generator_embedder_path = .models/blendface.pt
loss_embedder_path = .models/adaface.pt
gazer_path = .models/gazer.pt
motion_extractor_path = .models/motion_extractor.pt
face_masker_path = .models/face_masker.pt
```
@@ -82,8 +81,6 @@ feature_weight = 10.0
reconstruction_weight = 10.0
identity_weight = 20.0
gaze_weight = 0.05
pose_weight = 0.05
expression_weight = 0.05
mask_weight = 5.0
```
-3
View File
@@ -14,7 +14,6 @@ split_ratio =
generator_embedder_path =
loss_embedder_path =
gazer_path =
motion_extractor_path =
face_masker_path =
[training.model.generator]
@@ -42,8 +41,6 @@ feature_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
pose_weight =
expression_weight =
mask_weight =
[training.trainer]
+1 -43
View File
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask, MotionExtractorModule
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
class DiscriminatorLoss(nn.Module):
@@ -126,48 +126,6 @@ class IdentityLoss(nn.Module):
return identity_loss, weighted_identity_loss
class MotionLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
super().__init__()
self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight')
self.config_expression_weight = config_parser.getfloat('training.losses', 'expression_weight')
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss, Loss, Loss]:
target_poses, target_expression = self.detect_motions(target_tensor)
output_poses, output_expression = self.detect_motions(output_tensor)
pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses)
expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression)
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_pose, output_pose in zip(target_poses, output_poses):
temp_tensor = self.mse_loss(target_pose, output_pose)
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * self.config_pose_weight
return pose_loss, weighted_pose_loss
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Loss, Loss]:
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * self.config_expression_weight
return expression_loss, weighted_expression_loss
def detect_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
input_tensor = (input_tensor + 1) * 0.5
with torch.no_grad():
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
pose = translation, scale, rotation, motion_points
return pose, expression
class GazeLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
super().__init__()
+2 -8
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -31,7 +31,6 @@ class FaceSwapperTrainer(LightningModule):
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
@@ -39,7 +38,6 @@ class FaceSwapperTrainer(LightningModule):
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval()
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
@@ -49,7 +47,6 @@ class FaceSwapperTrainer(LightningModule):
self.feature_loss = FeatureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
self.gaze_loss = GazeLoss(config_parser, self.gazer)
self.mask_loss = MaskLoss(config_parser, self.face_masker)
self.automatic_optimization = False
@@ -105,10 +102,9 @@ class FaceSwapperTrainer(LightningModule):
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
@@ -140,8 +136,6 @@ class FaceSwapperTrainer(LightningModule):
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)
self.log('pose_loss', pose_loss)
self.log('expression_loss', expression_loss)
self.log('gaze_loss', gaze_loss)
self.log('mask_loss', mask_loss)
return generator_loss
-1
View File
@@ -16,7 +16,6 @@ Padding : TypeAlias = Tuple[int, int, int, int]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
GazerModule : TypeAlias = Module
MotionExtractorModule : TypeAlias = Module
FaceMaskerModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any