mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Fix config for mask factor
This commit is contained in:
@@ -163,7 +163,7 @@ class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_mask_dilate = config_parser.getfloat('training.losses', 'mask_dilate')
|
||||
self.config_mask_factor = config_parser.getfloat('training.modifier', 'mask_factor')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_masker = face_masker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
@@ -171,8 +171,8 @@ class MaskLoss(nn.Module):
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
target_mask = self.calc_mask(target_tensor)
|
||||
|
||||
if self.config_mask_dilate > 0:
|
||||
target_mask = dilate_mask(target_mask, self.config_mask_dilate)
|
||||
if self.config_mask_factor > 0:
|
||||
target_mask = dilate_mask(target_mask, self.config_mask_factor)
|
||||
|
||||
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
|
||||
|
||||
Reference in New Issue
Block a user