diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 5b8305d..e2ddfab 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -9,29 +9,29 @@ class MaskNet(nn.Module): super().__init__() self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels') self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels') - self.config_base_channels = config_parser.getint('training.model.masker', 'base_channels') - self.down_samples = self.create_down_samples(self.config_input_channels, self.config_base_channels) - self.up_samples = self.create_up_samples(self.config_base_channels) - self.bottleneck = BottleNeck(self.config_base_channels * 4) - self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1) + self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters') + self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters) + self.up_samples = self.create_up_samples(self.config_num_filters) + self.bottleneck = BottleNeck(self.config_num_filters * 4) + self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1) self.sigmoid = nn.Sigmoid() @staticmethod - def create_down_samples(input_channels : int, base_channels : int) -> nn.ModuleList: + def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList: return nn.ModuleList( [ - DownSample(input_channels, base_channels), - DownSample(base_channels, base_channels * 2), - DownSample(base_channels * 2, base_channels * 4) + DownSample(input_channels, num_filters), + DownSample(num_filters, num_filters * 2), + DownSample(num_filters * 2, num_filters * 4) ]) @staticmethod - def create_up_samples(base_channels : int) -> nn.ModuleList: + def create_up_samples(num_filters : int) -> nn.ModuleList: return nn.ModuleList( [ - UpSample(base_channels * 4, base_channels * 2), - UpSample(base_channels * 2, base_channels), - UpSample(base_channels, base_channels) + UpSample(num_filters * 4, num_filters * 2), + UpSample(num_filters * 2, num_filters), + UpSample(num_filters, num_filters) ]) def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor: @@ -51,20 +51,23 @@ class MaskNet(nn.Module): class BottleNeck(nn.Module): - def __init__(self, base_channels : int): + def __init__(self, num_filters : int): super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False), - nn.BatchNorm2d(base_channels), - nn.ReLU(inplace = True), - nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False), - nn.BatchNorm2d(base_channels), - nn.ReLU(inplace = True) - ) + self.layers = self.create_layers(num_filters) self.relu = nn.ReLU(inplace = True) + def create_layers(self, num_filters : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(num_filters), + nn.ReLU(inplace = True), + nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(num_filters), + nn.ReLU(inplace = True) + ) + def forward(self, input_tensor : Tensor) -> Tensor: - output_tensor = self.conv(input_tensor) + input_tensor + output_tensor = self.layers(input_tensor) + input_tensor output_tensor = self.relu(output_tensor) return output_tensor diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py index 96f1c6a..7ad12dc 100644 --- a/face_swapper/tests/test_networks.py +++ b/face_swapper/tests/test_networks.py @@ -4,6 +4,7 @@ import pytest import torch from face_swapper.src.networks.aad import AAD +from face_swapper.src.networks.masknet import MaskNet from face_swapper.src.networks.unet import UNet @@ -31,3 +32,26 @@ def test_aad_with_unet(output_size : int) -> None: output_tensor = generator(source_tensor, target_attributes) assert output_tensor.shape == (1, 3, output_size, output_size) + + +@pytest.mark.parametrize('output_size', [ 128, 256, 512 ]) +def test_mask_net(output_size : int) -> None: + config_parser = ConfigParser() + config_parser.read_dict( + { + 'training.model.masker': + { + 'input_channels': '67', + 'output_channels': '1', + 'num_filters': '16', + } + }) + + masker = MaskNet(config_parser).eval() + + target_tensor = torch.randn(1, 3, output_size, output_size) + target_attribute = torch.randn(1, 64, output_size, output_size) + + output_tensor = masker(target_tensor, target_attribute) + + assert output_tensor.shape == (1, 1, output_size, output_size)