diff --git a/face_swapper/README.md b/face_swapper/README.md index c52cf60..147245d 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -72,6 +72,7 @@ reconstruction_weight = 20 identity_weight = 20 pose_weight = 0 gaze_weight = 0 +expression_weight = 0 ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 0b585f4..db2437a 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -34,6 +34,7 @@ reconstruction_weight = identity_weight = pose_weight = gaze_weight = +expression_weight = [training.trainer] learning_rate = diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index a0f00e7..852648b 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -36,9 +36,9 @@ class DynamicDataset(Dataset[Tensor]): def compose_transforms(self) -> transforms: return transforms.Compose( [ + AugmentTransform(), transforms.ToPILImage(), transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC), - AugmentTransform(), transforms.ToTensor(), WarpTransform(self.warp_template), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) @@ -74,7 +74,7 @@ class AugmentTransform: def __call__(self, input_tensor : Tensor) -> Tensor: temp_tensor = input_tensor.numpy().transpose(1, 2, 0) - return self.transforms(temp_tensor).get('image') + return self.transforms(image = temp_tensor).get('image') @staticmethod def compose_transforms() -> albumentations.Compose: diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 8ab6533..093e260 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -106,32 +106,56 @@ class IdentityLoss(nn.Module): return identity_loss, weighted_identity_loss -class PoseLoss(nn.Module): - def __init__(self, motion_extractor : MotionExtractorModule) -> None: +class MotionLoss(nn.Module): + def __init__(self, motion_extractor : MotionExtractorModule): super().__init__() self.motion_extractor = motion_extractor + self.pose_loss = PoseLoss() + self.expression_loss = ExpressionLoss() + + def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + target_pose_features, target_expression = self.get_motion_features(target_tensor) + output_pose_features, output_expression = self.get_motion_features(output_tensor) + pose_loss, weighted_pose_loss = self.pose_loss(target_pose_features, output_pose_features) + expression_loss, weighted_expression_loss = self.expression_loss(target_expression, output_expression) + return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss + + def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]: + input_tensor = (input_tensor + 1) * 0.5 + pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) + pose_features = translation, scale, rotation, motion_points + return pose_features, expression + + +class ExpressionLoss(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, target_expression : Tensor, output_expression : Tensor, ) -> Tuple[Tensor, Tensor]: + expression_weight = CONFIG.getfloat('training.losses', 'expression_weight') + expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean() + weighted_expression_loss = expression_loss * expression_weight + return expression_loss, weighted_expression_loss + + +class PoseLoss(nn.Module): + def __init__(self) -> None: + super().__init__() self.mse_loss = nn.MSELoss() - def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + def forward(self, target_pose_features : Tensor, output_pose_features : Tensor, ) -> Tuple[Tensor, Tensor]: pose_weight = CONFIG.getfloat('training.losses', 'pose_weight') - output_motion_features = self.get_motion_features(output_tensor) - target_motion_features = self.get_motion_features(target_tensor) temp_tensors = [] - for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features): - temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature) + for target_pose_feature, output_pose_feature in zip(target_pose_features, output_pose_features): + temp_tensor = self.mse_loss(target_pose_feature, output_pose_feature) temp_tensors.append(temp_tensor) pose_loss = torch.stack(temp_tensors).mean() weighted_pose_loss = pose_loss * pose_weight return pose_loss, weighted_pose_loss - def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: - input_tensor = (input_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor) - rotation = torch.cat([ pitch, yaw, roll ], dim = 1) - return translation, scale, rotation - class GazeLoss(nn.Module): def __init__(self, gazer : GazerModule) -> None: diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 90dab65..49b7612 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -17,7 +17,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -44,7 +44,7 @@ class FaceSwapperTrainer(lightning.LightningModule): self.attribute_loss = AttributeLoss() self.reconstruction_loss = ReconstructionLoss(self.embedder) self.identity_loss = IdentityLoss(self.embedder) - self.pose_loss = PoseLoss(self.motion_extractor) + self.motion_loss = MotionLoss(self.motion_extractor) self.gaze_loss = GazeLoss(self.gazer) self.automatic_optimization = False @@ -95,9 +95,9 @@ class FaceSwapperTrainer(lightning.LightningModule): attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) - pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor) + pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) - generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss generator_optimizer.zero_grad() self.manual_backward(generator_loss)