This commit is contained in:
harisreedhar
2025-03-04 14:23:20 +05:30
committed by henryruhs
parent ceb3c0cfdf
commit f2d3f8a19f
5 changed files with 45 additions and 19 deletions
+1
View File
@@ -72,6 +72,7 @@ reconstruction_weight = 20
identity_weight = 20
pose_weight = 0
gaze_weight = 0
expression_weight = 0
```
```
+1
View File
@@ -34,6 +34,7 @@ reconstruction_weight =
identity_weight =
pose_weight =
gaze_weight =
expression_weight =
[training.trainer]
learning_rate =
+2 -2
View File
@@ -36,9 +36,9 @@ class DynamicDataset(Dataset[Tensor]):
def compose_transforms(self) -> transforms:
return transforms.Compose(
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
AugmentTransform(),
transforms.ToTensor(),
WarpTransform(self.warp_template),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
@@ -74,7 +74,7 @@ class AugmentTransform:
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
return self.transforms(temp_tensor).get('image')
return self.transforms(image = temp_tensor).get('image')
@staticmethod
def compose_transforms() -> albumentations.Compose:
+37 -13
View File
@@ -106,32 +106,56 @@ class IdentityLoss(nn.Module):
return identity_loss, weighted_identity_loss
class PoseLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule) -> None:
class MotionLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule):
super().__init__()
self.motion_extractor = motion_extractor
self.pose_loss = PoseLoss()
self.expression_loss = ExpressionLoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
target_pose_features, target_expression = self.get_motion_features(target_tensor)
output_pose_features, output_expression = self.get_motion_features(output_tensor)
pose_loss, weighted_pose_loss = self.pose_loss(target_pose_features, output_pose_features)
expression_loss, weighted_expression_loss = self.expression_loss(target_expression, output_expression)
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]:
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
pose_features = translation, scale, rotation, motion_points
return pose_features, expression
class ExpressionLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, target_expression : Tensor, output_expression : Tensor, ) -> Tuple[Tensor, Tensor]:
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * expression_weight
return expression_loss, weighted_expression_loss
class PoseLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def forward(self, target_pose_features : Tensor, output_pose_features : Tensor, ) -> Tuple[Tensor, Tensor]:
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
output_motion_features = self.get_motion_features(output_tensor)
target_motion_features = self.get_motion_features(target_tensor)
temp_tensors = []
for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features):
temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature)
for target_pose_feature, output_pose_feature in zip(target_pose_features, output_pose_features):
temp_tensor = self.mse_loss(target_pose_feature, output_pose_feature)
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * pose_weight
return pose_loss, weighted_pose_loss
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class GazeLoss(nn.Module):
def __init__(self, gazer : GazerModule) -> None:
+4 -4
View File
@@ -17,7 +17,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -44,7 +44,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.pose_loss = PoseLoss(self.motion_extractor)
self.motion_loss = MotionLoss(self.motion_extractor)
self.gaze_loss = GazeLoss(self.gazer)
self.automatic_optimization = False
@@ -95,9 +95,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
generator_optimizer.zero_grad()
self.manual_backward(generator_loss)