Revert the config dicts

This commit is contained in:
henryruhs
2025-03-06 23:28:11 +01:00
parent e5f983b2bf
commit 1dfd230fc5
7 changed files with 60 additions and 96 deletions
+2 -5
View File
@@ -9,10 +9,7 @@ from ..networks.nld import NLD
class Discriminator(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators')
}
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
self.config_parser = config_parser
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
@@ -20,7 +17,7 @@ class Discriminator(nn.Module):
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
for _ in range(self.config.get('num_discriminators')):
for _ in range(self.config_num_discriminators):
discriminator = NLD(self.config_parser).sequences
discriminators.append(discriminator)
+6 -12
View File
@@ -101,28 +101,22 @@ class ReconstructionLoss(nn.Module):
class IdentityLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config =\
{
'identity_weight': config_parser.getfloat('training.losses', 'identity_weight')
}
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * self.config.get('identity_weight')
weighted_identity_loss = identity_loss * self.config_identity_weight
return identity_loss, weighted_identity_loss
class MotionLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
super().__init__()
self.config =\
{
'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'),
'expression_weight': config_parser.getfloat('training.losses', 'expression_weight')
}
self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight')
self.expression_weight = config_parser.getfloat('training.losses', 'expression_weight')
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()
@@ -141,12 +135,12 @@ class MotionLoss(nn.Module):
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * self.config.get('pose_weight')
weighted_pose_loss = pose_loss * self.config_pose_weight
return pose_loss, weighted_pose_loss
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * self.config.get('expression_weight')
weighted_expression_loss = expression_loss * self.config_expression_weight
return expression_loss, weighted_expression_loss
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: