diff --git a/face_swapper/README.md b/face_swapper/README.md index f55308d..a6c65a4 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -46,7 +46,6 @@ split_ratio = 0.9995 generator_embedder_path = .models/blendface.pt loss_embedder_path = .models/adaface.pt gazer_path = .models/gazer.pt -motion_extractor_path = .models/motion_extractor.pt face_masker_path = .models/face_masker.pt ``` @@ -82,8 +81,6 @@ feature_weight = 10.0 reconstruction_weight = 10.0 identity_weight = 20.0 gaze_weight = 0.05 -pose_weight = 0.05 -expression_weight = 0.05 mask_weight = 5.0 ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 37f4627..fd94c01 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -14,7 +14,6 @@ split_ratio = generator_embedder_path = loss_embedder_path = gazer_path = -motion_extractor_path = face_masker_path = [training.model.generator] @@ -42,8 +41,6 @@ feature_weight = reconstruction_weight = identity_weight = gaze_weight = -pose_weight = -expression_weight = mask_weight = [training.trainer] diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 82642bb..dbe8e10 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask, MotionExtractorModule +from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask class DiscriminatorLoss(nn.Module): @@ -126,48 +126,6 @@ class IdentityLoss(nn.Module): return identity_loss, weighted_identity_loss -class MotionLoss(nn.Module): - def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule): - super().__init__() - self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight') - self.config_expression_weight = config_parser.getfloat('training.losses', 'expression_weight') - self.motion_extractor = motion_extractor - self.mse_loss = nn.MSELoss() - - def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss, Loss, Loss]: - target_poses, target_expression = self.detect_motions(target_tensor) - output_poses, output_expression = self.detect_motions(output_tensor) - pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses) - expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression) - return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss - - def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Loss, Loss]: - temp_tensors = [] - - for target_pose, output_pose in zip(target_poses, output_poses): - temp_tensor = self.mse_loss(target_pose, output_pose) - temp_tensors.append(temp_tensor) - - pose_loss = torch.stack(temp_tensors).mean() - weighted_pose_loss = pose_loss * self.config_pose_weight - return pose_loss, weighted_pose_loss - - def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Loss, Loss]: - expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean() - weighted_expression_loss = expression_loss * self.config_expression_weight - return expression_loss, weighted_expression_loss - - def detect_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: - input_tensor = (input_tensor + 1) * 0.5 - - with torch.no_grad(): - pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) - - rotation = torch.cat([ pitch, yaw, roll ], dim = 1) - pose = translation, scale, rotation, motion_points - return pose, expression - - class GazeLoss(nn.Module): def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None: super().__init__() diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 8d1d749..67662b4 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss from .types import Batch, Embedding, Mask, OptimizerSet warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -31,7 +31,6 @@ class FaceSwapperTrainer(LightningModule): self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path') self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path') self.config_gazer_path = config_parser.get('training.model', 'gazer_path') - self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path') self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') @@ -39,7 +38,6 @@ class FaceSwapperTrainer(LightningModule): self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() - self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval() self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval() self.generator = Generator(config_parser) self.discriminator = Discriminator(config_parser) @@ -49,7 +47,6 @@ class FaceSwapperTrainer(LightningModule): self.feature_loss = FeatureLoss(config_parser) self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder) self.identity_loss = IdentityLoss(config_parser, self.loss_embedder) - self.motion_loss = MotionLoss(config_parser, self.motion_extractor) self.gaze_loss = GazeLoss(config_parser, self.gazer) self.mask_loss = MaskLoss(config_parser, self.face_masker) self.automatic_optimization = False @@ -105,10 +102,9 @@ class FaceSwapperTrainer(LightningModule): feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) - pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask) - generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss + generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) @@ -140,8 +136,6 @@ class FaceSwapperTrainer(LightningModule): self.log('feature_loss', feature_loss) self.log('reconstruction_loss', reconstruction_loss) self.log('identity_loss', identity_loss) - self.log('pose_loss', pose_loss) - self.log('expression_loss', expression_loss) self.log('gaze_loss', gaze_loss) self.log('mask_loss', mask_loss) return generator_loss diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 342dc6f..d840897 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -16,7 +16,6 @@ Padding : TypeAlias = Tuple[int, int, int, int] GeneratorModule : TypeAlias = Module EmbedderModule : TypeAlias = Module GazerModule : TypeAlias = Module -MotionExtractorModule : TypeAlias = Module FaceMaskerModule : TypeAlias = Module OptimizerSet : TypeAlias = Any