diff --git a/face_swapper/README.md b/face_swapper/README.md index 33f3e6d..5776b31 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -46,7 +46,7 @@ split_ratio = 0.9995 embedder_path = .models/arcface.pt gazer_path = .models/gazer.pt motion_extractor_path = .models/motion_extractor.pt -parser_path = .models/parser.pt +face_parser_path = .models/face_parser.pt ``` ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 1566739..1ff4100 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -14,7 +14,7 @@ split_ratio = embedder_path = gazer_path = motion_extractor_path = -parser_path = +face_parser_path = [training.model.generator] identity_channels = diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index ac53975..ec84563 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule, ParserModule +from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule, FaceParserModule class DiscriminatorLoss(nn.Module): @@ -180,10 +180,10 @@ class GazeLoss(nn.Module): class MaskLoss(nn.Module): - def __init__(self, config_parser : ConfigParser, parser : ParserModule) -> None: + def __init__(self, config_parser : ConfigParser, face_parser : FaceParserModule) -> None: super().__init__() self.config_output_size = config_parser.getint('training.model.generator', 'output_size') - self.parser = parser + self.face_parser = face_parser self.mse_loss = nn.MSELoss() def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor: @@ -198,7 +198,7 @@ class MaskLoss(nn.Module): face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device) with torch.no_grad(): - output_tensor = self.parser(target_tensor)[0] + output_tensor = self.face_parser(target_tensor)[0] output_tensor = output_tensor.argmax(1) output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype) output_tensor = output_tensor.view(-1, 1, 512, 512) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a79d142..0f8ff94 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -32,14 +32,14 @@ class FaceSwapperTrainer(LightningModule): self.config_embedder_path = config_parser.get('training.model', 'embedder_path') self.config_gazer_path = config_parser.get('training.model', 'gazer_path') self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path') - self.config_parser_path = config_parser.get('training.model', 'parser_path') + self.config_face_parser_path = config_parser.get('training.model', 'parser_path') self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size') self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate') self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency') self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval() self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval() self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval() - self.parser = torch.jit.load(self.config_parser_path, map_location = 'cpu').eval() + self.face_parser = torch.jit.load(self.config_face_parser_path, map_location ='cpu').eval() self.generator = Generator(config_parser) self.discriminator = Discriminator(config_parser) self.masker = MaskNet(config_parser) @@ -50,7 +50,7 @@ class FaceSwapperTrainer(LightningModule): self.identity_loss = IdentityLoss(config_parser, self.embedder) self.motion_loss = MotionLoss(config_parser, self.motion_extractor) self.gaze_loss = GazeLoss(config_parser, self.gazer) - self.mask_loss = MaskLoss(config_parser, self.parser) + self.mask_loss = MaskLoss(config_parser, self.face_parser) self.automatic_optimization = False def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]: diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index bce6a9e..27708b7 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -16,7 +16,7 @@ GeneratorModule : TypeAlias = Module EmbedderModule : TypeAlias = Module GazerModule : TypeAlias = Module MotionExtractorModule : TypeAlias = Module -ParserModule : TypeAlias = Module +FaceParserModule : TypeAlias = Module OptimizerSet : TypeAlias = Any