Introduce new GazeLoss class (switched to mean)

This commit is contained in:
henryruhs
2025-02-22 23:48:50 +01:00
parent a797548329
commit 579d3ef51c
3 changed files with 34 additions and 10 deletions
-2
View File
@@ -28,8 +28,6 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.source_embedder.eval()
self.target_embedder.eval()
self.mse_loss = nn.MSELoss()
def forward(self, source_embedding : Embedding) -> Embedding:
+26 -4
View File
@@ -22,9 +22,6 @@ class FaceSwapperLoss:
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
@@ -197,7 +194,6 @@ class IdentityLoss(torch.nn.Module):
super(IdentityLoss, self).__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder.eval()
def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
@@ -234,3 +230,29 @@ class PoseLoss(torch.nn.Module):
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class GazeLoss(torch.nn.Module):
def __init__(self) -> None:
super(GazeLoss, self).__init__()
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
def calc(self, target_tensor : VisionTensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_face_landmark = self.detect_face_landmark(output_tensor)
target_face_landmark = self.detect_face_landmark(target_tensor)
left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198])
right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197])
gaze_loss = left_gaze_loss + right_gaze_loss
weighted_gaze_loss = gaze_loss * gaze_weight
return gaze_loss, weighted_gaze_loss
def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203:
input_tensor = (input_tensor + 1) * 0.5
input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear')
face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2)
return face_landmarks_203
+8 -4
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -36,6 +36,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.reconstruction_loss = ReconstructionLoss()
self.identity_loss = IdentityLoss()
self.pose_loss = PoseLoss()
self.gaze_loss = GazeLoss()
self.automatic_optimization = automatic_optimization
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
@@ -77,11 +78,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True)
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_identity', generator_loss_set.get('loss_identity'))
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
self.log('loss_pose', generator_loss_set.get('loss_pose'), prog_bar = True)
self.log('loss_pose', generator_loss_set.get('loss_pose'))
self.log('loss_gaze', generator_loss_set.get('loss_gaze'), prog_bar = True)
###############################################
@@ -90,6 +92,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
self.log('generator_loss_new', generator_loss, prog_bar = True)
@@ -97,7 +100,8 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.log('attribute_loss_new', attribute_loss)
self.log('reconstruction_loss_new', reconstruction_loss)
self.log('identity_loss_new', identity_loss)
self.log('pose_loss_new', pose_loss, prog_bar = True)
self.log('pose_loss_new', pose_loss)
self.log('gaze_loss_new', gaze_loss, prog_bar = True)
return generator_loss_set.get('loss_generator')
def validation_step(self, batch : Batch, batch_index : int) -> Tensor: