Merge pull request #85 from facefusion/feat/dilate-mask

Feat/dilate mask
This commit is contained in:
Harisreedhar
2025-06-20 18:28:08 +05:30
committed by GitHub
5 changed files with 29 additions and 2 deletions
+1
View File
@@ -84,6 +84,7 @@ reconstruction_weight = 10.0
identity_weight = 20.0
gaze_weight = 0.05
mask_weight = 5.0
mask_dilate = 0.01
```
```
+1
View File
@@ -44,6 +44,7 @@ reconstruction_weight =
identity_weight =
gaze_weight =
mask_weight =
mask_dilate =
[training.trainer]
accumulate_size =
+16
View File
@@ -64,3 +64,19 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
@lru_cache(maxsize = None)
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
return sorted(glob.glob(file_pattern))
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
kernel_size = 1 + 2 * padding
pad_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
dilate_tensor = nn.functional.max_pool2d(pad_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
return dilate_tensor
def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
padding = round(input_tensor.shape[2] * factor)
kernel_size = 1 + 2 * padding
pad_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
dilate_tensor = 1 - nn.functional.max_pool2d(pad_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
return dilate_tensor
+6 -1
View File
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..helper import calc_embedding, dilate_mask
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
@@ -163,12 +163,17 @@ class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
super().__init__()
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.face_masker = face_masker
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
target_mask = self.calc_mask(target_tensor)
if self.config_mask_dilate > 0:
target_mask = dilate_mask(target_mask, self.config_mask_dilate)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, output_mask)
+5 -1
View File
@@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset, Dataset, random_split
from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import DynamicDataset
from .helper import apply_noise, calc_embedding, overlay_mask
from .helper import apply_noise, calc_embedding, erode_mask, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
@@ -45,6 +45,7 @@ class HyperSwapTrainer(LightningModule):
self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum')
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor')
self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience')
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
@@ -66,6 +67,9 @@ class HyperSwapTrainer(LightningModule):
generator_target_features = self.generator.encode_features(target_tensor)
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
if self.config_mask_dilate > 0:
output_mask = erode_mask(output_mask, self.config_mask_dilate)
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: