mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
dilate
This commit is contained in:
@@ -84,6 +84,7 @@ reconstruction_weight = 10.0
|
||||
identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
mask_weight = 5.0
|
||||
mask_dilate = 0.01
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -44,6 +44,7 @@ reconstruction_weight =
|
||||
identity_weight =
|
||||
gaze_weight =
|
||||
mask_weight =
|
||||
mask_dilate =
|
||||
|
||||
[training.trainer]
|
||||
accumulate_size =
|
||||
|
||||
@@ -64,3 +64,12 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
@lru_cache(maxsize = None)
|
||||
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
|
||||
return sorted(glob.glob(file_pattern))
|
||||
|
||||
|
||||
def dilate_mask(input_tensor : Tensor, factor : float) -> Tensor:
|
||||
padding = round(input_tensor.shape[2] * factor)
|
||||
kernel_size = 1 + 2 * padding
|
||||
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype = input_tensor.dtype, device = input_tensor.device)
|
||||
dilate_tensor = nn.functional.conv2d(input_tensor, kernel, padding = padding)
|
||||
dilate_tensor = torch.sigmoid(2 * (dilate_tensor - 0.5))
|
||||
return dilate_tensor
|
||||
|
||||
@@ -6,7 +6,7 @@ from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..helper import calc_embedding, dilate_mask
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask
|
||||
|
||||
|
||||
@@ -163,12 +163,14 @@ class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_masker = face_masker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
target_mask = dilate_mask(target_mask, self.config_mask_dilate)
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
mask_loss = self.mse_loss(target_mask, output_mask)
|
||||
|
||||
Reference in New Issue
Block a user