Migrate most to self.config and self.context

This commit is contained in:
henryruhs
2025-03-06 18:27:56 +01:00
parent ab3b699124
commit b829d5e42c
5 changed files with 105 additions and 78 deletions
+36 -21
View File
@@ -53,31 +53,37 @@ class AdversarialLoss(nn.Module):
class AttributeLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'batch_size': config_parser.getint('training.loader', 'batch_size'),
'attribute_weight': config_parser.getfloat('training.losses', 'attribute_weight')
}
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight')
temp_tensors = []
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(batch_size, -1), dim = 1).mean()
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config.get('batch_size'), -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
weighted_attribute_loss = attribute_loss * attribute_weight
weighted_attribute_loss = attribute_loss * self.config.get('attribute_weight')
return attribute_loss, weighted_attribute_loss
class ReconstructionLoss(nn.Module):
def __init__(self, embedder : EmbedderModule) -> None:
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config =\
{
'reconstruction_weight': config_parser.getfloat('training.losses', 'reconstruction_weight')
}
self.embedder = embedder
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
@@ -88,27 +94,35 @@ class ReconstructionLoss(nn.Module):
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
weighted_reconstruction_loss = reconstruction_loss * self.config.get('reconstruction_weight')
return reconstruction_loss, weighted_reconstruction_loss
class IdentityLoss(nn.Module):
def __init__(self, embedder : EmbedderModule) -> None:
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config =\
{
'identity_weight': config_parser.getfloat('training.losses', 'identity_weight')
}
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * identity_weight
weighted_identity_loss = identity_loss * self.config.get('identity_weight')
return identity_loss, weighted_identity_loss
class MotionLoss(nn.Module):
def __init__(self, motion_extractor : MotionExtractorModule):
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
super().__init__()
self.config =\
{
'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'),
'expression_weight': config_parser.getfloat('training.losses', 'expression_weight')
}
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()
@@ -120,7 +134,6 @@ class MotionLoss(nn.Module):
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]:
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
temp_tensors = []
for target_pose, output_pose in zip(target_poses, output_poses):
@@ -128,13 +141,12 @@ class MotionLoss(nn.Module):
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * pose_weight
weighted_pose_loss = pose_loss * self.config.get('pose_weight')
return pose_loss, weighted_pose_loss
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * expression_weight
weighted_expression_loss = expression_loss * self.config.get('expression_weight')
return expression_loss, weighted_expression_loss
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
@@ -148,13 +160,17 @@ class MotionLoss(nn.Module):
class GazeLoss(nn.Module):
def __init__(self, gazer : GazerModule) -> None:
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
super().__init__()
self.config =\
{
'gaze_weight': config_parser.getfloat('training.losses', 'gaze_weight'),
'output_size': config_parser.getint('training.model.generator', 'output_size')
}
self.gazer = gazer
self.l1_loss = nn.L1Loss()
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_pitch, output_yaw = self.detect_gaze(output_tensor)
target_pitch, target_yaw = self.detect_gaze(target_tensor)
@@ -162,12 +178,11 @@ class GazeLoss(nn.Module):
yaw_loss = self.l1_loss(output_yaw, target_yaw)
gaze_loss = (pitch_loss + yaw_loss) * 0.5
weighted_gaze_loss = gaze_loss * gaze_weight
weighted_gaze_loss = gaze_loss * self.config.get('gaze_weight')
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
output_size = CONFIG.getint('training.model.generator', 'output_size')
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * output_size).int()
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config.get('output_size')).int()
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)