mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
change face-parser to face-masker
This commit is contained in:
@@ -46,7 +46,7 @@ split_ratio = 0.9995
|
||||
embedder_path = .models/arcface.pt
|
||||
gazer_path = .models/gazer.pt
|
||||
motion_extractor_path = .models/motion_extractor.pt
|
||||
face_parser_path = .models/face_parser.pt
|
||||
face_masker_path = .models/face_masker.pt
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -14,7 +14,7 @@ split_ratio =
|
||||
embedder_path =
|
||||
gazer_path =
|
||||
motion_extractor_path =
|
||||
face_parser_path =
|
||||
face_masker_path =
|
||||
|
||||
[training.model.generator]
|
||||
source_channels =
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import EmbedderModule, FaceParserModule, Feature, GazerModule, Loss, Mask, MotionExtractorModule
|
||||
from ..types import EmbedderModule, FaceMaskerModule, Feature, GazerModule, Loss, Mask, MotionExtractorModule
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
@@ -180,11 +180,11 @@ class GazeLoss(nn.Module):
|
||||
|
||||
|
||||
class MaskLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser, face_parser : FaceParserModule) -> None:
|
||||
def __init__(self, config_parser : ConfigParser, face_masker : FaceMaskerModule) -> None:
|
||||
super().__init__()
|
||||
self.config_mask_weight = config_parser.getfloat('training.losses', 'mask_weight')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.face_parser = face_parser
|
||||
self.face_masker = face_masker
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Tuple[Loss, Loss]:
|
||||
@@ -200,7 +200,7 @@ class MaskLoss(nn.Module):
|
||||
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensor = self.face_parser(target_tensor)
|
||||
output_tensor = self.face_masker(target_tensor)
|
||||
output_tensor = output_tensor.clamp(0, 1)
|
||||
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
|
||||
|
||||
|
||||
@@ -31,14 +31,14 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.config_embedder_path = config_parser.get('training.model', 'embedder_path')
|
||||
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
|
||||
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
|
||||
self.config_face_parser_path = config_parser.get('training.model', 'face_parser_path')
|
||||
self.config_face_masker_path = config_parser.get('training.model', 'face_masker_path')
|
||||
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
|
||||
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
|
||||
self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval()
|
||||
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
|
||||
self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval()
|
||||
self.face_parser = torch.jit.load(self.config_face_parser_path, map_location ='cpu').eval()
|
||||
self.face_masker = torch.jit.load(self.config_face_masker_path, map_location ='cpu').eval()
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
@@ -48,7 +48,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.identity_loss = IdentityLoss(config_parser, self.embedder)
|
||||
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(config_parser, self.gazer)
|
||||
self.mask_loss = MaskLoss(config_parser, self.face_parser)
|
||||
self.mask_loss = MaskLoss(config_parser, self.face_masker)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
|
||||
@@ -17,7 +17,7 @@ GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
GazerModule : TypeAlias = Module
|
||||
MotionExtractorModule : TypeAlias = Module
|
||||
FaceParserModule : TypeAlias = Module
|
||||
FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
|
||||
Reference in New Issue
Block a user