mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -72,6 +72,7 @@ reconstruction_weight = 20
|
||||
identity_weight = 20
|
||||
pose_weight = 0
|
||||
gaze_weight = 0
|
||||
expression_weight = 0
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -34,6 +34,7 @@ reconstruction_weight =
|
||||
identity_weight =
|
||||
pose_weight =
|
||||
gaze_weight =
|
||||
expression_weight =
|
||||
|
||||
[training.trainer]
|
||||
learning_rate =
|
||||
|
||||
@@ -36,9 +36,9 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
AugmentTransform(),
|
||||
transforms.ToTensor(),
|
||||
WarpTransform(self.warp_template),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
@@ -74,7 +74,7 @@ class AugmentTransform:
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
|
||||
return self.transforms(temp_tensor).get('image')
|
||||
return self.transforms(image = temp_tensor).get('image')
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> albumentations.Compose:
|
||||
|
||||
@@ -106,32 +106,56 @@ class IdentityLoss(nn.Module):
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class PoseLoss(nn.Module):
|
||||
def __init__(self, motion_extractor : MotionExtractorModule) -> None:
|
||||
class MotionLoss(nn.Module):
|
||||
def __init__(self, motion_extractor : MotionExtractorModule):
|
||||
super().__init__()
|
||||
self.motion_extractor = motion_extractor
|
||||
self.pose_loss = PoseLoss()
|
||||
self.expression_loss = ExpressionLoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
target_pose_features, target_expression = self.get_motion_features(target_tensor)
|
||||
output_pose_features, output_expression = self.get_motion_features(output_tensor)
|
||||
pose_loss, weighted_pose_loss = self.pose_loss(target_pose_features, output_pose_features)
|
||||
expression_loss, weighted_expression_loss = self.expression_loss(target_expression, output_expression)
|
||||
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
pose_features = translation, scale, rotation, motion_points
|
||||
return pose_features, expression
|
||||
|
||||
|
||||
class ExpressionLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, target_expression : Tensor, output_expression : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
|
||||
class PoseLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_pose_features : Tensor, output_pose_features : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
|
||||
output_motion_features = self.get_motion_features(output_tensor)
|
||||
target_motion_features = self.get_motion_features(target_tensor)
|
||||
temp_tensors = []
|
||||
|
||||
for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features):
|
||||
temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature)
|
||||
for target_pose_feature, output_pose_feature in zip(target_pose_features, output_pose_features):
|
||||
temp_tensor = self.mse_loss(target_pose_feature, output_pose_feature)
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, gazer : GazerModule) -> None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
|
||||
from .types import Batch, BatchMode, Embedding, OptimizerConfig, WarpTemplate
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
@@ -44,7 +44,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss(self.embedder)
|
||||
self.identity_loss = IdentityLoss(self.embedder)
|
||||
self.pose_loss = PoseLoss(self.motion_extractor)
|
||||
self.motion_loss = MotionLoss(self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(self.gazer)
|
||||
self.automatic_optimization = False
|
||||
|
||||
@@ -95,9 +95,9 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss = self.pose_loss(target_tensor, generator_output_tensor)
|
||||
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
||||
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_loss)
|
||||
|
||||
Reference in New Issue
Block a user