mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -4,9 +4,10 @@ from typing import List, Tuple
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule
|
||||
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -133,25 +134,32 @@ class PoseLoss(nn.Module):
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, landmarker : LandmarkerModule) -> None:
|
||||
def __init__(self, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.landmarker = landmarker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.gazer = gazer
|
||||
self.mae_loss = nn.L1Loss()
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(448),
|
||||
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_face_landmark = self.detect_face_landmark(output_tensor)
|
||||
target_face_landmark = self.detect_face_landmark(target_tensor)
|
||||
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
|
||||
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
|
||||
|
||||
left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197])
|
||||
pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor)
|
||||
yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor)
|
||||
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear')
|
||||
face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2)
|
||||
return face_landmarks_203
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
crop_tensor = input_tensor[:, :, 60: 224, 16: 205]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')
|
||||
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
|
||||
return pitch_tensor, yaw_tensor
|
||||
|
||||
Reference in New Issue
Block a user