From dbe79aa3b9f04bead569a4ff989e69f441a7bf2c Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 24 Jun 2025 21:47:53 +0200 Subject: [PATCH] Fix config for mask factor --- hyperswap/src/models/loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index 8aa4f09..494e7f9 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -163,7 +163,7 @@ class MaskLoss(nn.Module): def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None: super().__init__() self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight') - self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate') + self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.face_masker = face_masker self.mse_loss = nn.MSELoss() @@ -171,8 +171,8 @@ class MaskLoss(nn.Module): def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]: target_mask = self.calc_mask(target_tensor) - if self.config_mask_dilate > 0: - target_mask = dilate_mask(target_mask, self.config_mask_dilate) + if self.config_mask_factor > 0: + target_mask = dilate_mask(target_mask, self.config_mask_factor) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)