From bd762c4c38af6ab8a39861a76d5fbf9b4ac80b2a Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 17 Jun 2025 15:15:50 +0530 Subject: [PATCH 1/2] dilate --- hyperswap/README.md | 1 + hyperswap/config.ini | 1 + hyperswap/src/helper.py | 9 +++++++++ hyperswap/src/models/loss.py | 4 +++- 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index f8592fe..623a8c5 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -84,6 +84,7 @@ reconstruction_weight = 10.0 identity_weight = 20.0 gaze_weight = 0.05 mask_weight = 5.0 +mask_dilate = 0.01 ``` ``` diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 46a5556..f06ae24 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -44,6 +44,7 @@ reconstruction_weight = identity_weight = gaze_weight = mask_weight = +mask_dilate = [training.trainer] accumulate_size = diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index 7d99e03..d34d005 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -64,3 +64,12 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor: @lru_cache(maxsize = None) def resolve_static_file_pattern(file_pattern : str) -> List[str]: return sorted(glob.glob(file_pattern)) + + +def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor: + padding = round(input_tensor.shape[2] * factor) + kernel_size = 1 + 2 * padding + kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype = input_tensor.dtype, device = input_tensor.device) + dilate_tensor = nn.functional.conv2d(input_tensor, kernel, padding = padding) + dilate_tensor = torch.sigmoid(2 * (dilate_tensor - 0.5)) + return dilate_tensor diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index 49e2c7b..7fbfbd6 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -6,7 +6,7 @@ from pytorch_msssim import ssim from torch import Tensor, nn from torchvision import transforms -from ..helper import calc_embedding +from ..helper import calc_embedding, dilate_mask from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask @@ -163,12 +163,14 @@ class MaskLoss(nn.Module): def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None: super().__init__() self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight') + self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.face_masker = face_masker self.mse_loss = nn.MSELoss() def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]: target_mask = self.calc_mask(target_tensor) + target_mask = dilate_mask(target_mask, self.config_mask_dilate) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size) mask_loss = self.mse_loss(target_mask, output_mask) From 56b71048e3185136a1378c89fc166505a45c199f Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Thu, 19 Jun 2025 16:29:30 +0530 Subject: [PATCH 2/2] add erode for export and make it conditional --- hyperswap/src/helper.py | 13 ++++++++++--- hyperswap/src/models/loss.py | 5 ++++- hyperswap/src/training.py | 6 +++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index d34d005..3741b4d 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -69,7 +69,14 @@ def resolve_static_file_pattern(file_pattern : str) -> List[str]: def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor: padding = round(input_tensor.shape[2] * factor) kernel_size = 1 + 2 * padding - kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype = input_tensor.dtype, device = input_tensor.device) - dilate_tensor = nn.functional.conv2d(input_tensor, kernel, padding = padding) - dilate_tensor = torch.sigmoid(2 * (dilate_tensor - 0.5)) + pad_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate') + dilate_tensor = nn.functional.max_pool2d(pad_tensor, kernel_size = kernel_size, stride = 1, padding = 0) + return dilate_tensor + + +def erode_mask(input_tensor : Tensor, factor : float) -> Tensor: + padding = round(input_tensor.shape[2] * factor) + kernel_size = 1 + 2 * padding + pad_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate') + dilate_tensor = 1 - nn.functional.max_pool2d(pad_tensor, kernel_size = kernel_size, stride = 1, padding = 0) return dilate_tensor diff --git a/hyperswap/src/models/loss.py b/hyperswap/src/models/loss.py index 7fbfbd6..8aa4f09 100644 --- a/hyperswap/src/models/loss.py +++ b/hyperswap/src/models/loss.py @@ -170,7 +170,10 @@ class MaskLoss(nn.Module): def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]: target_mask = self.calc_mask(target_tensor) - target_mask = dilate_mask(target_mask, self.config_mask_dilate) + + if self.config_mask_dilate > 0: + target_mask = dilate_mask(target_mask, self.config_mask_dilate) + target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size) mask_loss = self.mse_loss(target_mask, output_mask) diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index 621faaa..6aac4fc 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset, Dataset, random_split from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import DynamicDataset -from .helper import apply_noise, calc_embedding, overlay_mask +from .helper import apply_noise, calc_embedding, erode_mask, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss @@ -45,6 +45,7 @@ class HyperSwapTrainer(LightningModule): self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum') self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor') self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience') + self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate') self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval() self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval() self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() @@ -66,6 +67,9 @@ class HyperSwapTrainer(LightningModule): generator_target_features = self.generator.encode_features(target_tensor) output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features) + if self.config_mask_dilate > 0: + output_mask = erode_mask(output_mask, self.config_mask_dilate) + return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: