Introduce new PoseLoss class (switched to mean)

This commit is contained in:
henryruhs
2025-02-22 23:15:39 +01:00
parent 6eabcad1d0
commit a797548329
2 changed files with 43 additions and 17 deletions
+30 -10
View File
@@ -3,7 +3,6 @@ from typing import List, Tuple
import torch
from pytorch_msssim import ssim
from sqlalchemy.dialects.mssql.information_schema import identity_columns
from torch import Tensor, nn
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
@@ -44,15 +43,8 @@ class FaceSwapperLoss:
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
}
if weight_pose > 0:
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
if weight_gaze > 0:
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity
@@ -214,3 +206,31 @@ class IdentityLoss(torch.nn.Module):
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * identity_weight
return identity_loss, weighted_identity_loss
class PoseLoss(torch.nn.Module):
def __init__(self) -> None:
super(PoseLoss, self).__init__()
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
output_motion_features = self.get_motion_features(output_tensor)
target_motion_features = self.get_motion_features(target_tensor)
temp_tensors = []
for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features):
temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature)
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * pose_weight
return pose_loss, weighted_pose_loss
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
+13 -7
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -35,6 +35,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss()
self.identity_loss = IdentityLoss()
self.pose_loss = PoseLoss()
self.automatic_optimization = automatic_optimization
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
@@ -77,9 +78,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True)
self.log('loss_attribute', generator_loss_set.get('loss_attribute'), prog_bar = True)
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_identity', generator_loss_set.get('loss_identity'))
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
self.log('loss_pose', generator_loss_set.get('loss_pose'), prog_bar = True)
###############################################
@@ -87,13 +89,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss
pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
self.log('generator_loss_new', generator_loss, prog_bar = True)
self.log('adversarial_loss_new', adversarial_loss)
self.log('attribute_loss_new', attribute_loss, prog_bar = True)
self.log('attribute_loss_new', attribute_loss)
self.log('reconstruction_loss_new', reconstruction_loss)
self.log('identity_loss_new', identity_loss)
self.log('pose_loss_new', pose_loss, prog_bar = True)
return generator_loss_set.get('loss_generator')
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
@@ -107,12 +111,14 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
preview_limit = 8
preview_items = []
preview_cells = []
for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]):
preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2))
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)
preview_cells.append(preview_cell)
preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True)
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True)
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]