mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Revert the config dicts
This commit is contained in:
@@ -9,10 +9,7 @@ from ..networks.nld import NLD
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
}
|
||||
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
self.config_parser = config_parser
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
@@ -20,7 +17,7 @@ class Discriminator(nn.Module):
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(self.config.get('num_discriminators')):
|
||||
for _ in range(self.config_num_discriminators):
|
||||
discriminator = NLD(self.config_parser).sequences
|
||||
discriminators.append(discriminator)
|
||||
|
||||
|
||||
@@ -101,28 +101,22 @@ class ReconstructionLoss(nn.Module):
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'identity_weight': config_parser.getfloat('training.losses', 'identity_weight')
|
||||
}
|
||||
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * self.config.get('identity_weight')
|
||||
weighted_identity_loss = identity_loss * self.config_identity_weight
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class MotionLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'),
|
||||
'expression_weight': config_parser.getfloat('training.losses', 'expression_weight')
|
||||
}
|
||||
self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight')
|
||||
self.expression_weight = config_parser.getfloat('training.losses', 'expression_weight')
|
||||
self.motion_extractor = motion_extractor
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
@@ -141,12 +135,12 @@ class MotionLoss(nn.Module):
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * self.config.get('pose_weight')
|
||||
weighted_pose_loss = pose_loss * self.config_pose_weight
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * self.config.get('expression_weight')
|
||||
weighted_expression_loss = expression_loss * self.config_expression_weight
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user