mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Merge pull request #85 from facefusion/feat/dilate-mask
Feat/dilate mask
This commit is contained in:
@@ -84,6 +84,7 @@ reconstruction_weight = 10.0
|
||||
identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
mask_weight = 5.0
|
||||
mask_dilate = 0.01
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -44,6 +44,7 @@ reconstruction_weight =
|
||||
identity_weight =
|
||||
gaze_weight =
|
||||
mask_weight =
|
||||
mask_dilate =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
|
||||
@@ -64,3 +64,19 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
@lru_cache(maxsize = None)
|
||||
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
|
||||
return sorted(glob.glob(file_pattern))
|
||||
|
||||
|
||||
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = round(input_tensor.shape[2] * factor)
|
||||
kernel_size = 1 + 2 * padding
|
||||
pad_tensor = nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
|
||||
dilate_tensor = nn.functional.max_pool2d(pad_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
|
||||
return dilate_tensor
|
||||
|
||||
|
||||
def erode_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = round(input_tensor.shape[2] * factor)
|
||||
kernel_size = 1 + 2 * padding
|
||||
pad_tensor = 1 - nn.functional.pad(input_tensor, (padding, padding, padding, padding), mode = 'replicate')
|
||||
dilate_tensor = 1 - nn.functional.max_pool2d(pad_tensor, kernel_size = kernel_size, stride = 1, padding = 0)
|
||||
return dilate_tensor
|
||||
|
||||
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..helper import calc_embedding, dilate_mask
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
|
||||
|
||||
|
||||
@@ -163,12 +163,17 @@ class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_masker = face_masker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
|
||||
if self.config_mask_dilate > 0:
|
||||
target_mask = dilate_mask(target_mask, self.config_mask_dilate)
|
||||
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, output_mask)
|
||||
|
||||
@@ -14,7 +14,7 @@ from torch.utils.data import ConcatDataset, Dataset, random_split
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import apply_noise, calc_embedding, overlay_mask
|
||||
from .helper import apply_noise, calc_embedding, erode_mask, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, CycleLoss, DiscriminatorLoss, FeatureLoss, GazeLoss, IdentityLoss, MaskLoss, ReconstructionLoss
|
||||
@@ -45,6 +45,7 @@ class HyperSwapTrainer(LightningModule):
|
||||
self.config_discriminator_momentum = config_parser.getfloat('training.optimizer.discriminator', 'momentum')
|
||||
self.config_discriminator_scheduler_factor = config_parser.getfloat('training.optimizer.discriminator', 'scheduler_factor')
|
||||
self.config_discriminator_scheduler_patience = config_parser.getint('training.optimizer.discriminator', 'scheduler_patience')
|
||||
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
|
||||
self.generator_embedder = torch.jit.load(self.config_generator_embedder_path, map_location = 'cpu').eval()
|
||||
self.loss_embedder = torch.jit.load(self.config_loss_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
@@ -66,6 +67,9 @@ class HyperSwapTrainer(LightningModule):
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
|
||||
if self.config_mask_dilate > 0:
|
||||
output_mask = erode_mask(output_mask, self.config_mask_dilate)
|
||||
|
||||
return output_tensor, output_mask
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
|
||||
Reference in New Issue
Block a user