From bd762c4c38af6ab8a39861a76d5fbf9b4ac80b2a Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 15:15:50 +0530 Subject: [PATCH] dilate --- hyperswap/README.md | 1 + hyperswap/config.ini | 1 + hyperswap/src/helper.py | 9 +++++++++ hyperswap/src/models/loss.py | 4 +++- 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index f8592fe..623a8c5 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -84,6 +84,7 @@ reconstruction_weight = 10.0 identity_weight = 20.0 gaze_weight = 0.05 mask_weight = 5.0 +mask_dilate = 0.01 ``` ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 46a5556..f06ae24 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -44,6 +44,7 @@ reconstruction_weight = identity_weight = gaze_weight = mask_weight = +mask_dilate = [training.trainer] accumulate_size = diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index 7d99e03..d34d005 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -64,3 +64,12 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor: @lru_cache(maxsize = None) def resolve_static_file_pattern(file_pattern : str) -> List[str]: return sorted(glob.glob(file_pattern)) + + +def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor: + padding = round(input_tensor.shape[2] * factor) + kernel_size = 1 + 2 * padding + kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype = input_tensor.dtype, device = input_tensor.device) + dilate_tensor = nn.functional.conv2d(input_tensor, kernel, padding = padding) + dilate_tensor = torch.sigmoid(2 * (dilate_tensor - 0.5)) + return dilate_tensor diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index 49e2c7b..7fbfbd6 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -6,7 +6,7 @@ from pytorch_msssim import ssim from torch import Tensor, nn from torchvision import transforms -from ..helper import calc_embedding +from ..helper import calc_embedding, dilate_mask from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask @@ -163,12 +163,14 @@ class MaskLoss(nn.Module): def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None: super().__init__() self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight') + self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.face_masker = face_masker self.mse_loss = nn.MSELoss() def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]: target_mask = self.calc_mask(target_tensor) + target_mask = dilate_mask(target_mask, self.config_mask_dilate) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size) mask_loss = self.mse_loss(target_mask, output_mask)