mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -10,7 +10,7 @@ split_ratio =
|
||||
|
||||
[training.model]
|
||||
embedder_path =
|
||||
landmarker_path =
|
||||
gazer_path =
|
||||
motion_extractor_path =
|
||||
|
||||
[training.model.generator]
|
||||
|
||||
@@ -4,9 +4,10 @@ from typing import List, Tuple
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, EmbedderModule, FaceLandmark203, LandmarkerModule, MotionExtractorModule
|
||||
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -133,25 +134,32 @@ class PoseLoss(nn.Module):
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, landmarker : LandmarkerModule) -> None:
|
||||
def __init__(self, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.landmarker = landmarker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.gazer = gazer
|
||||
self.mae_loss = nn.L1Loss()
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(448),
|
||||
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_face_landmark = self.detect_face_landmark(output_tensor)
|
||||
target_face_landmark = self.detect_face_landmark(target_tensor)
|
||||
output_pitch_tensor, output_yaw_tensor = self.detect_gaze(output_tensor)
|
||||
target_pitch_tensor, target_yaw_tensor = self.detect_gaze(target_tensor)
|
||||
|
||||
left_gaze_loss = self.mse_loss(output_face_landmark[:, 198], target_face_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(output_face_landmark[:, 197], target_face_landmark[:, 197])
|
||||
pitch_gaze_loss = self.mae_loss(output_pitch_tensor, target_pitch_tensor)
|
||||
yaw_gaze_loss = self.mae_loss(output_yaw_tensor, target_yaw_tensor)
|
||||
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
gaze_loss = (pitch_gaze_loss + yaw_gaze_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_face_landmark(self, input_tensor : Tensor) -> FaceLandmark203:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
input_tensor = nn.functional.interpolate(input_tensor, size = (224, 224), mode = 'bilinear')
|
||||
face_landmarks_203 = self.landmarker(input_tensor)[2].view(-1, 203, 2)
|
||||
return face_landmarks_203
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
crop_tensor = input_tensor[:, :, 60: 224, 16: 205]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')
|
||||
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
|
||||
return pitch_tensor, yaw_tensor
|
||||
|
||||
@@ -30,11 +30,11 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
gazer_path = CONFIG.get('training.model', 'gazer_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.gazer = torch.jit.load(gazer_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
|
||||
self.generator = Generator()
|
||||
@@ -45,7 +45,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
self.reconstruction_loss = ReconstructionLoss(self.embedder)
|
||||
self.identity_loss = IdentityLoss(self.embedder)
|
||||
self.pose_loss = PoseLoss(self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(self.landmarker)
|
||||
self.gaze_loss = GazeLoss(self.gazer)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
|
||||
@@ -7,13 +7,13 @@ Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
|
||||
Attributes : TypeAlias = Tuple[Tensor, ...]
|
||||
Embedding : TypeAlias = Tensor
|
||||
FaceLandmark203 : TypeAlias = Tensor
|
||||
Gaze : TypeAlias = Tuple[Tensor, Tensor]
|
||||
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
LandmarkerModule : TypeAlias = Module
|
||||
GazerModule : TypeAlias = Module
|
||||
MotionExtractorModule : TypeAlias = Module
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
Reference in New Issue
Block a user