feat/noise injection (#78)

* changes

* add to config.ini

* changes

* changes

---------

Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
This commit is contained in:
Henry Ruhs
2025-06-02 11:04:53 +02:00
committed by GitHub
parent 0722db91f1
commit 24f45877f5
4 changed files with 15 additions and 1 deletions
+1
View File
@@ -89,6 +89,7 @@ mask_weight = 5.0
accumulate_size = 4
learning_rate = 0.0004
gradient_clip = 20.0
noise_factor = 0.05
max_epochs = 50
strategy = auto
precision = 16-mixed
+1
View File
@@ -47,6 +47,7 @@ mask_weight =
accumulate_size =
learning_rate =
gradient_clip =
noise_factor =
max_epochs =
strategy =
precision =
+6
View File
@@ -49,3 +49,9 @@ def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
return output_tensor
def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
noise_tensor = torch.randn_like(input_tensor) * factor
output_tensor = input_tensor + noise_tensor
return output_tensor
+7 -1
View File
@@ -13,7 +13,7 @@ from torch.utils.data import ConcatDataset, Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .helper import apply_noise, calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
@@ -32,6 +32,7 @@ class HyperSwapTrainer(LightningModule):
self.config_loss_embedder_path = config_parser.get('training.model', 'loss_embedder_path')
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
self.config_noise_factor = config_parser.getfloat('training.trainer', 'noise_factor')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.config_gradient_clip = config_parser.getfloat('training.trainer', 'gradient_clip')
@@ -91,6 +92,11 @@ class HyperSwapTrainer(LightningModule):
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
if self.config_noise_factor > 0:
source_embedding = apply_noise(source_embedding, self.config_noise_factor)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
generator_target_features = self.generator.encode_features(target_tensor)
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
generator_output_features = self.generator.encode_features(generator_output_tensor)