Rename parser to face parser

This commit is contained in:
henryruhs
2025-03-11 12:51:32 +01:00
parent e758eb3e19
commit afab997ffc
5 changed files with 10 additions and 10 deletions
+1 -1
View File
@@ -46,7 +46,7 @@ split_ratio = 0.9995
embedder_path = .models/arcface.pt
gazer_path = .models/gazer.pt
motion_extractor_path = .models/motion_extractor.pt
parser_path = .models/parser.pt
face_parser_path = .models/face_parser.pt
```
```
+1 -1
View File
@@ -14,7 +14,7 @@ split_ratio =
embedder_path =
gazer_path =
motion_extractor_path =
parser_path =
face_parser_path =
[training.model.generator]
identity_channels =
+4 -4
View File
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule, ParserModule
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule, FaceParserModule
class DiscriminatorLoss(nn.Module):
@@ -180,10 +180,10 @@ class GazeLoss(nn.Module):
class MaskLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, parser : ParserModule) -> None:
def __init__(self, config_parser : ConfigParser, face_parser : FaceParserModule) -> None:
super().__init__()
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.parser = parser
self.face_parser = face_parser
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
@@ -198,7 +198,7 @@ class MaskLoss(nn.Module):
face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device)
with torch.no_grad():
output_tensor = self.parser(target_tensor)[0]
output_tensor = self.face_parser(target_tensor)[0]
output_tensor = output_tensor.argmax(1)
output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype)
output_tensor = output_tensor.view(-1, 1, 512, 512)
+3 -3
View File
@@ -32,14 +32,14 @@ class FaceSwapperTrainer(LightningModule):
self.config_embedder_path = config_parser.get('training.model', 'embedder_path')
self.config_gazer_path = config_parser.get('training.model', 'gazer_path')
self.config_motion_extractor_path = config_parser.get('training.model', 'motion_extractor_path')
self.config_parser_path = config_parser.get('training.model', 'parser_path')
self.config_face_parser_path = config_parser.get('training.model', 'parser_path')
self.config_accumulate_size = config_parser.getfloat('training.trainer', 'accumulate_size')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.config_preview_frequency = config_parser.getint('training.trainer', 'preview_frequency')
self.embedder = torch.jit.load(self.config_embedder_path, map_location = 'cpu').eval()
self.gazer = torch.jit.load(self.config_gazer_path, map_location = 'cpu').eval()
self.motion_extractor = torch.jit.load(self.config_motion_extractor_path, map_location = 'cpu').eval()
self.parser = torch.jit.load(self.config_parser_path, map_location = 'cpu').eval()
self.face_parser = torch.jit.load(self.config_face_parser_path, map_location ='cpu').eval()
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
self.masker = MaskNet(config_parser)
@@ -50,7 +50,7 @@ class FaceSwapperTrainer(LightningModule):
self.identity_loss = IdentityLoss(config_parser, self.embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
self.gaze_loss = GazeLoss(config_parser, self.gazer)
self.mask_loss = MaskLoss(config_parser, self.parser)
self.mask_loss = MaskLoss(config_parser, self.face_parser)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]:
+1 -1
View File
@@ -16,7 +16,7 @@ GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module
GazerModule : TypeAlias = Module
MotionExtractorModule : TypeAlias = Module
ParserModule : TypeAlias = Module
FaceParserModule : TypeAlias = Module
OptimizerSet : TypeAlias = Any