From 5234874bc70a36bb38ec66bdc9800643393f8afd Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 14 Mar 2025 08:09:47 +0100 Subject: [PATCH] Mask typing and naming related updates --- face_swapper/src/helper.py | 8 ++++---- face_swapper/src/models/loss.py | 8 ++++---- face_swapper/src/networks/masknet.py | 24 +++++++++++++----------- face_swapper/src/training.py | 16 ++++++++-------- face_swapper/src/types.py | 1 + 5 files changed, 30 insertions(+), 27 deletions(-) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 98eb971..b1377d5 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,7 +1,7 @@ import torch from torch import Tensor, nn -from .types import EmbedderModule, Embedding, Padding, WarpTemplate, WarpTemplateSet +from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet WARP_TEMPLATE_SET : WarpTemplateSet =\ { @@ -38,9 +38,9 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P return embedding -def overlay_mask(input_tensor : Tensor, mask_tensor : Tensor) -> Tensor: +def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor: overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device) overlay_tensor[:, 2, :, :] = 1 - mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8) - output_tensor = input_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor + input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8) + output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask return output_tensor diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index e340924..08835f6 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, MotionExtractorModule +from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule class DiscriminatorLoss(nn.Module): @@ -186,11 +186,11 @@ class MaskLoss(nn.Module): self.face_parser = face_parser self.mse_loss = nn.MSELoss() - def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Loss: + def forward(self, target_tensor : Tensor, output_mask : Mask) -> Loss: target_mask = self.calc_mask(target_tensor) target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size) - mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size) - mask_loss = self.mse_loss(target_mask, mask_tensor) + output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size) + mask_loss = self.mse_loss(target_mask, output_mask) return mask_loss def calc_mask(self, target_tensor : Tensor) -> Tensor: diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index a26d56b..0449fca 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -3,7 +3,7 @@ from configparser import ConfigParser import torch from torch import Tensor, nn -from ..types import Attribute +from ..types import Attribute, Mask class MaskNet(nn.Module): @@ -14,7 +14,7 @@ class MaskNet(nn.Module): self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters') self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters) self.up_samples = self.create_up_samples(self.config_num_filters) - self.bottleneck = BottleNeck(self.config_num_filters * 2) + self.bottleneck = BottleNeck(self.config_num_filters * 4) self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1) self.sigmoid = nn.Sigmoid() @@ -23,31 +23,33 @@ class MaskNet(nn.Module): return nn.ModuleList( [ DownSample(input_channels, num_filters), - DownSample(num_filters, num_filters * 2) + DownSample(num_filters, num_filters * 2), + DownSample(num_filters, num_filters * 4) ]) @staticmethod def create_up_samples(num_filters : int) -> nn.ModuleList: return nn.ModuleList( [ + UpSample(num_filters * 4, num_filters), UpSample(num_filters * 2, num_filters), UpSample(num_filters, num_filters) ]) - def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Tensor: - output_tensor = torch.cat([ input_tensor, input_attribute ], dim = 1) + def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Mask: + output_mask = torch.cat([ input_tensor, input_attribute ], dim = 1) for down_sample in self.down_samples: - output_tensor = down_sample(output_tensor) + output_mask = down_sample(output_mask) - output_tensor = self.bottleneck(output_tensor) + output_mask = self.bottleneck(output_mask) for up_sample in self.up_samples: - output_tensor = up_sample(output_tensor) + output_mask = up_sample(output_mask) - output_tensor = self.conv(output_tensor) - output_tensor = self.sigmoid(output_tensor) - return output_tensor + output_mask = self.conv(output_mask) + output_mask = self.sigmoid(output_mask) + return output_mask class BottleNeck(nn.Module): diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index ae7addd..16f1790 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -18,7 +18,7 @@ from .models.discriminator import Discriminator from .models.generator import Generator from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss from .networks.masknet import MaskNet -from .types import Batch, Embedding, OptimizerSet +from .types import Batch, Embedding, Mask, OptimizerSet warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') @@ -57,9 +57,9 @@ class FaceSwapperTrainer(LightningModule): with torch.no_grad(): output_tensor, target_attributes = self.generator(source_embedding, target_tensor) target_attribute = target_attributes[-1] - mask_tensor = self.masker(target_tensor, target_attribute) + output_mask = self.masker(target_tensor, target_attribute) - return output_tensor, mask_tensor + return output_tensor, output_mask def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]: generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) @@ -120,8 +120,8 @@ class FaceSwapperTrainer(LightningModule): discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) generator_output_attribute = generator_output_attributes[-1] - mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach()) - mask_loss = self.mask_loss(target_tensor, mask_tensor) + generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach()) + mask_loss = self.mask_loss(target_tensor, generator_output_mask) self.toggle_optimizer(generator_optimizer) self.manual_backward(generator_loss) @@ -145,7 +145,7 @@ class FaceSwapperTrainer(LightningModule): self.untoggle_optimizer(masker_optimizer) if self.global_step % self.config_preview_frequency == 0: - self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor) + self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask) self.log('generator_loss', generator_loss, prog_bar = True) self.log('discriminator_loss', discriminator_loss, prog_bar = True) @@ -167,10 +167,10 @@ class FaceSwapperTrainer(LightningModule): self.log('validation_score', validation_score, prog_bar = True) return validation_score - def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None: + def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None: preview_limit = 8 preview_cells = [] - overlay_tensor = overlay_mask(output_tensor, mask_tensor) + overlay_tensor = overlay_mask(output_tensor, output_mask) for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]): preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index a794ee1..3789995 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -8,6 +8,7 @@ BatchMode = Literal['equal', 'same', 'different'] Attribute : TypeAlias = Tensor Embedding : TypeAlias = Tensor +Mask : TypeAlias = Tensor Loss : TypeAlias = Tensor Padding : TypeAlias = Tuple[int, int, int, int]