mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Introduce new PoseLoss class (switched to mean)
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from sqlalchemy.dialects.mssql.information_schema import identity_columns
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
|
||||
@@ -44,15 +43,8 @@ class FaceSwapperLoss:
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
}
|
||||
|
||||
if weight_pose > 0:
|
||||
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
if weight_gaze > 0:
|
||||
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
|
||||
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
|
||||
|
||||
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity
|
||||
@@ -214,3 +206,31 @@ class IdentityLoss(torch.nn.Module):
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class PoseLoss(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(PoseLoss, self).__init__()
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
|
||||
output_motion_features = self.get_motion_features(output_tensor)
|
||||
target_motion_features = self.get_motion_features(target_tensor)
|
||||
temp_tensors = []
|
||||
|
||||
for target_motion_feature, output_motion_feature in zip(target_motion_features, output_motion_features):
|
||||
temp_tensor = self.mse_loss(target_motion_feature, output_motion_feature)
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -35,6 +35,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.pose_loss = PoseLoss()
|
||||
self.automatic_optimization = automatic_optimization
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
|
||||
@@ -77,9 +78,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True)
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'), prog_bar = True)
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
self.log('loss_pose', generator_loss_set.get('loss_pose'), prog_bar = True)
|
||||
|
||||
###############################################
|
||||
|
||||
@@ -87,13 +89,15 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss
|
||||
pose_loss, weighted_pose_loss = self.pose_loss.calc(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
|
||||
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss_new', adversarial_loss)
|
||||
self.log('attribute_loss_new', attribute_loss, prog_bar = True)
|
||||
self.log('attribute_loss_new', attribute_loss)
|
||||
self.log('reconstruction_loss_new', reconstruction_loss)
|
||||
self.log('identity_loss_new', identity_loss)
|
||||
self.log('pose_loss_new', pose_loss, prog_bar = True)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
@@ -107,12 +111,14 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
|
||||
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_items = []
|
||||
preview_cells = []
|
||||
|
||||
for source_tensor, target_tensor, output_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit]):
|
||||
preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2))
|
||||
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2)
|
||||
preview_cells.append(preview_cell)
|
||||
|
||||
preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True)
|
||||
preview_cells = torch.cat(preview_cells, dim = 1).unsqueeze(0)
|
||||
preview_grid = torchvision.utils.make_grid(preview_cells, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user