From 24f45877f56107535f282d6f18d49eab5fdc4238 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Mon, 2 Jun 2025 11:04:53 +0200 Subject: [PATCH] feat/noise injection (#78) * changes * add to config.ini * changes * changes --------- Co-authored-by: harisreedhar --- hyperswap/README.md | 1 + hyperswap/config.ini | 1 + hyperswap/src/helper.py | 6 ++++++ hyperswap/src/training.py | 8 +++++++- 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/hyperswap/README.md b/hyperswap/README.md index 5cc2baf..babd7e9 100644 --- a/hyperswap/README.md +++ b/hyperswap/README.md @@ -89,6 +89,7 @@ mask_weight = 5.0 accumulate_size = 4 learning_rate = 0.0004 gradient_clip = 20.0 +noise_factor = 0.05 max_epochs = 50 strategy = auto precision = 16-mixed diff --git a/hyperswap/config.ini b/hyperswap/config.ini index 2c626a5..f969fe9 100644 --- a/hyperswap/config.ini +++ b/hyperswap/config.ini @@ -47,6 +47,7 @@ mask_weight = accumulate_size = learning_rate = gradient_clip = +noise_factor = max_epochs = strategy = precision = diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index eb740c4..a14ea72 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -49,3 +49,9 @@ def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor: input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8) output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask return output_tensor + + +def apply_noise(input_tensor : Tensor, factor : float) -> Tensor: + noise_tensor = torch.randn_like(input_tensor) * factor + output_tensor = input_tensor + noise_tensor + return output_tensor diff --git a/hyperswap/src/training.py b/hyperswap/src/training.py index c7c2c32..cab9977 100644 --- a/hyperswap/src/training.py +++ b/hyperswap/src/training.py @@ -13,7 +13,7 @@ from torch.utils.data import ConcatDataset, Dataset, random_split from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import DynamicDataset -from .helper import calc_embedding, overlay_mask +from .helper import apply_noise, calc_embedding, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss @@ -32,6 +32,7 @@ class HyperSwapTrainer(LightningModule): self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path') self.config_gazer_path = config_parser.get('training.model', 'gazer_path') self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path') + self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip') @@ -91,6 +92,11 @@ class HyperSwapTrainer(LightningModule): generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0)) + + if self.config_noise_factor > 0: + source_embedding = apply_noise(source_embedding, self.config_noise_factor) + source_embedding = nn.functional.normalize(source_embedding, p = 2) + generator_target_features = self.generator.encode_features(target_tensor) generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features) generator_output_features = self.generator.encode_features(generator_output_tensor)