This commit is contained in:
harisreedhar
2025-06-17 15:15:50 +05:30
parent f5c49a02cb
commit bd762c4c38
4 changed files with 14 additions and 1 deletions
+1
View File
@@ -84,6 +84,7 @@ reconstruction_weight = 10.0
identity_weight = 20.0
gaze_weight = 0.05
mask_weight = 5.0
mask_dilate = 0.01
```
```
+1
View File
@@ -44,6 +44,7 @@ reconstruction_weight =
identity_weight =
gaze_weight =
mask_weight =
mask_dilate =
[training.trainer]
accumulate_size =
+9
View File
@@ -64,3 +64,12 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
@lru_cache(maxsize = None)
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
return sorted(glob.glob(file_pattern))
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
kernel_size = 1 + 2 * padding
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype = input_tensor.dtype, device = input_tensor.device)
dilate_tensor = nn.functional.conv2d(input_tensor, kernel, padding = padding)
dilate_tensor = torch.sigmoid(2 * (dilate_tensor - 0.5))
return dilate_tensor
+3 -1
View File
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..helper import calc_embedding, dilate_mask
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
@@ -163,12 +163,14 @@ class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.face_masker = face_masker
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
target_mask = self.calc_mask(target_tensor)
target_mask = dilate_mask(target_mask, self.config_mask_dilate)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, output_mask)