mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce new GazeLoss class (switched to mean)
This commit is contained in:
@@ -28,8 +28,6 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.source_embedder.eval()
|
||||
self.target_embedder.eval()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
|
||||
@@ -22,9 +22,6 @@ class FaceSwapperLoss:
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
@@ -197,7 +194,6 @@ class IdentityLoss(torch.nn.Module):
|
||||
super(IdentityLoss, self).__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.embedder.eval()
|
||||
|
||||
def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
@@ -234,3 +230,29 @@ class PoseLoss(torch.nn.Module):
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class GazeLoss(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(GazeLoss, self).__init__()
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, target_tensor : VisionTensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_face_landmark = self.detect_face_landmark(output_tensor)
|
||||
target_face_landmark = self.detect_face_landmark(target_tensor)
|
||||
|
||||
left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197])
|
||||
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
weighted_gaze_loss = gaze_loss * gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear')
|
||||
face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2)
|
||||
return face_landmarks_203
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -36,6 +36,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.pose_loss = PoseLoss()
|
||||
self.gaze_loss = GazeLoss()
|
||||
self.automatic_optimization = automatic_optimization
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
|
||||
@@ -77,11 +78,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
|
||||
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
self.log('loss_pose', generator_loss_set.get('loss_pose'), prog_bar = True)
|
||||
self.log('loss_pose', generator_loss_set.get('loss_pose'))
|
||||
self.log('loss_gaze', generator_loss_set.get('loss_gaze'), prog_bar = True)
|
||||
|
||||
###############################################
|
||||
|
||||
@@ -90,6 +92,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
|
||||
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
@@ -97,7 +100,8 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.log('attribute_loss_new', attribute_loss)
|
||||
self.log('reconstruction_loss_new', reconstruction_loss)
|
||||
self.log('identity_loss_new', identity_loss)
|
||||
self.log('pose_loss_new', pose_loss, prog_bar = True)
|
||||
self.log('pose_loss_new', pose_loss)
|
||||
self.log('gaze_loss_new', gaze_loss, prog_bar = True)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user