Fix config for mask factor

This commit is contained in:
henryruhs
2025-06-24 21:47:53 +02:00
parent 2809a59704
commit dbe79aa3b9
+3 -3
View File
@@ -163,7 +163,7 @@ class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.face_masker = face_masker
self.mse_loss = nn.MSELoss()
@@ -171,8 +171,8 @@ class MaskLoss(nn.Module):
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
target_mask = self.calc_mask(target_tensor)
if self.config_mask_dilate > 0:
target_mask = dilate_mask(target_mask, self.config_mask_dilate)
if self.config_mask_factor > 0:
target_mask = dilate_mask(target_mask, self.config_mask_factor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)