mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -46,7 +46,6 @@ split_ratio = 0.9995
|
||||
generator_embedder_path = .models/blendface.pt
|
||||
loss_embedder_path = .models/adaface.pt
|
||||
gazer_path = .models/gazer.pt
|
||||
motion_extractor_path = .models/motion_extractor.pt
|
||||
face_masker_path = .models/face_masker.pt
|
||||
```
|
||||
|
||||
@@ -82,8 +81,6 @@ feature_weight = 10.0
|
||||
reconstruction_weight = 10.0
|
||||
identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
pose_weight = 0.05
|
||||
expression_weight = 0.05
|
||||
mask_weight = 5.0
|
||||
```
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ split_ratio =
|
||||
generator_embedder_path =
|
||||
loss_embedder_path =
|
||||
gazer_path =
|
||||
motion_extractor_path =
|
||||
face_masker_path =
|
||||
|
||||
[training.model.generator]
|
||||
@@ -42,8 +41,6 @@ feature_weight =
|
||||
reconstruction_weight =
|
||||
identity_weight =
|
||||
gaze_weight =
|
||||
pose_weight =
|
||||
expression_weight =
|
||||
mask_weight =
|
||||
|
||||
[training.trainer]
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask, MotionExtractorModule
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
@@ -126,48 +126,6 @@ class IdentityLoss(nn.Module):
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class MotionLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
|
||||
super().__init__()
|
||||
self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight')
|
||||
self.config_expression_weight = config_parser.getfloat('training.losses', 'expression_weight')
|
||||
self.motion_extractor = motion_extractor
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss, Loss, Loss]:
|
||||
target_poses, target_expression = self.detect_motions(target_tensor)
|
||||
output_poses, output_expression = self.detect_motions(output_tensor)
|
||||
pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses)
|
||||
expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression)
|
||||
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
|
||||
|
||||
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_pose, output_pose in zip(target_poses, output_poses):
|
||||
temp_tensor = self.mse_loss(target_pose, output_pose)
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * self.config_pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Loss, Loss]:
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * self.config_expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
def detect_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
|
||||
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
pose = translation, scale, rotation, motion_points
|
||||
return pose, expression
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
@@ -31,7 +31,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.config_generator_embedder_path = config_parser.get('training.model', 'generator_embedder_path')
|
||||
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
@@ -39,7 +38,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval()
|
||||
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
@@ -49,7 +47,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.feature_loss = FeatureLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.loss_embedder)
|
||||
self.identity_loss = IdentityLoss(config_parser, self.loss_embedder)
|
||||
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(config_parser, self.gazer)
|
||||
self.mask_loss = MaskLoss(config_parser, self.face_masker)
|
||||
self.automatic_optimization = False
|
||||
@@ -105,10 +102,9 @@ class FaceSwapperTrainer(LightningModule):
|
||||
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
mask_loss, weighted_mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + weighted_mask_loss
|
||||
generator_loss = weighted_adversarial_loss + weighted_cycle_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_gaze_loss + weighted_mask_loss
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
@@ -140,8 +136,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.log('feature_loss', feature_loss)
|
||||
self.log('reconstruction_loss', reconstruction_loss)
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('pose_loss', pose_loss)
|
||||
self.log('expression_loss', expression_loss)
|
||||
self.log('gaze_loss', gaze_loss)
|
||||
self.log('mask_loss', mask_loss)
|
||||
return generator_loss
|
||||
|
||||
@@ -16,7 +16,6 @@ Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
GazerModule : TypeAlias = Module
|
||||
MotionExtractorModule : TypeAlias = Module
|
||||
FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
Reference in New Issue
Block a user