mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce Loss type, Remove Gaze type
This commit is contained in:
@@ -7,7 +7,7 @@ from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule
|
||||
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, MotionExtractorModule
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
@@ -55,7 +55,7 @@ class AttributeLoss(nn.Module):
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
|
||||
|
||||
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
|
||||
@@ -74,7 +74,7 @@ class ReconstructionLoss(nn.Module):
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
with torch.no_grad():
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
@@ -97,7 +97,7 @@ class IdentityLoss(nn.Module):
|
||||
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
@@ -113,14 +113,14 @@ class MotionLoss(nn.Module):
|
||||
self.motion_extractor = motion_extractor
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, ...]:
|
||||
target_poses, target_expression = self.get_motions(target_tensor)
|
||||
output_poses, output_expression = self.get_motions(output_tensor)
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss, Loss, Loss]:
|
||||
target_poses, target_expression = self.detect_motions(target_tensor)
|
||||
output_poses, output_expression = self.detect_motions(output_tensor)
|
||||
pose_loss, weighted_pose_loss = self.calc_pose_loss(target_poses, output_poses)
|
||||
expression_loss, weighted_expression_loss = self.calc_expression_loss(target_expression, output_expression)
|
||||
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
|
||||
|
||||
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]:
|
||||
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_pose, output_pose in zip(target_poses, output_poses):
|
||||
@@ -131,12 +131,12 @@ class MotionLoss(nn.Module):
|
||||
weighted_pose_loss = pose_loss * self.config_pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Loss, Loss]:
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * self.config_expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
def detect_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -155,7 +155,7 @@ class GazeLoss(nn.Module):
|
||||
self.gazer = gazer
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Loss, Loss]:
|
||||
output_pitch, output_yaw = self.detect_gaze(output_tensor)
|
||||
target_pitch, target_yaw = self.detect_gaze(target_tensor)
|
||||
|
||||
@@ -166,7 +166,7 @@ class GazeLoss(nn.Module):
|
||||
weighted_gaze_loss = gaze_loss * self.config_gaze_weight
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config_output_size).int()
|
||||
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
@@ -186,7 +186,7 @@ class MaskLoss(nn.Module):
|
||||
self.face_parser = face_parser
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
|
||||
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Loss:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size)
|
||||
|
||||
@@ -4,11 +4,11 @@ from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same']
|
||||
BatchMode = Literal['equal', 'same', 'different']
|
||||
|
||||
Attribute : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
Gaze : TypeAlias = Tuple[Tensor, Tensor]
|
||||
Loss : TypeAlias = Tensor
|
||||
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user