From 1659805b0874c50d232cac02e88441c0895d8d0c Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 10 Mar 2025 13:24:39 +0530 Subject: [PATCH] changes --- face_swapper/README.md | 7 ++++ face_swapper/config.ini | 5 +++ face_swapper/src/models/generator.py | 2 +- face_swapper/src/models/loss.py | 4 +- face_swapper/src/networks/masknet.py | 57 ++++++++++++++++------------ 5 files changed, 47 insertions(+), 28 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index b7bf47f..3abe315 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -65,6 +65,13 @@ num_discriminators = 3 kernel_size = 4 ``` +``` +[training.model.masker] +input_channels = 67 +output_channels = 1 +base_channels = 16 +``` + ``` [training.losses] adversarial_weight = 1.0 diff --git a/face_swapper/config.ini b/face_swapper/config.ini index dc20073..088b6be 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -28,6 +28,11 @@ num_layers = num_discriminators = kernel_size = +[training.model.masker] +input_channels = +output_channels = +base_channels = + [training.losses] adversarial_weight = attribute_weight = diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 84a9303..615dcd1 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -14,7 +14,7 @@ class Generator(nn.Module): super().__init__() self.encoder = UNet(config_parser) self.generator = AAD(config_parser) - self.masker = MaskNet(67, 1, 16) + self.masker = MaskNet(config_parser) self.encoder.apply(init_weight) self.generator.apply(init_weight) self.masker.apply(init_weight) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 3341d81..aec44e3 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -195,12 +195,12 @@ class MaskLoss(nn.Module): def calc_mask(self, target_tensor : Tensor) -> Tensor: target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear') - face_indices = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device) + face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device) with torch.no_grad(): output_tensor = self.parser(target_tensor)[0] output_tensor = output_tensor.argmax(1) - output_tensor = torch.isin(output_tensor, face_indices).to(target_tensor.dtype) + output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype) output_tensor = output_tensor.view(-1, 1, 512, 512) output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear') return output_tensor diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index a7d83d4..859597f 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -1,62 +1,69 @@ +from configparser import ConfigParser + import torch from torch import Tensor, nn class MaskNet(nn.Module): - def __init__(self, input_channels : int, output_channels : int, base_channels : int): + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.down_samples = self.create_down_samples(input_channels, base_channels) - self.up_samples = self.create_up_samples(base_channels) - self.bottleneck = ResBlock(base_channels * 4) - self.conv = nn.Conv2d(base_channels, output_channels, kernel_size = 1) + self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels') + self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels') + self.config_base_channels = config_parser.getint('training.model.masker', 'base_channels') + self.down_samples = self.create_down_samples(self.config_input_channels, self.config_base_channels) + self.up_samples = self.create_up_samples(self.config_base_channels) + self.bottleneck = BottleNeck(self.config_base_channels * 4) + self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1) self.sigmoid = nn.Sigmoid() def create_down_samples(self, input_channels : int, base_channels: int) -> nn.ModuleList: down_samples = nn.ModuleList( - [ - DownSample(input_channels, base_channels), - DownSample(base_channels, base_channels * 2), - DownSample(base_channels * 2, base_channels * 4) - ]) + [ + DownSample(input_channels, base_channels), + DownSample(base_channels, base_channels * 2), + DownSample(base_channels * 2, base_channels * 4) + ]) return down_samples def create_up_samples(self, base_channels : int) -> nn.ModuleList: down_samples = nn.ModuleList( - [ - UpSample(base_channels * 4, base_channels * 2), - UpSample(base_channels * 2, base_channels), - UpSample(base_channels, base_channels) - ]) + [ + UpSample(base_channels * 4, base_channels * 2), + UpSample(base_channels * 2, base_channels), + UpSample(base_channels, base_channels) + ]) return down_samples def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor: - output_tensor = torch.cat([ target_tensor, target_attribute ], dim=1) + output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1) for down_sample in self.down_samples: output_tensor = down_sample(output_tensor) + output_tensor = self.bottleneck(output_tensor) for up_sample in self.up_samples: output_tensor = up_sample(output_tensor) + output_tensor = self.conv(output_tensor) output_tensor = self.activation(output_tensor) return output_tensor -class ResBlock(nn.Module): - def __init__(self, channels: int): +class BottleNeck(nn.Module): + def __init__(self, channels : int): super().__init__() self.conv = nn.Sequential( - nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), + nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(channels), - nn.ReLU(inplace=True), - nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), + nn.ReLU(inplace = True), + nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False), nn.BatchNorm2d(channels), - nn.ReLU(inplace=True) + nn.ReLU(inplace = True) ) - self.relu = nn.ReLU(inplace=True) + self.relu = nn.ReLU(inplace = True) - def forward(self, input_tensor: Tensor) -> Tensor: + def forward(self, input_tensor : Tensor) -> Tensor: output_tensor = self.conv(input_tensor) + input_tensor output_tensor = self.relu(output_tensor) return output_tensor @@ -66,7 +73,7 @@ class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2) - self.relu = nn.ReLU(inplace=True) + self.relu = nn.ReLU(inplace = True) def forward(self, input_tensor : Tensor) -> Tensor: output_tensor = self.conv_transpose(input_tensor)