This commit is contained in:
harisreedhar
2025-03-10 22:54:23 +05:30
committed by henryruhs
parent f9d105ea2b
commit c2a639229f
2 changed files with 50 additions and 23 deletions
+26 -23
View File
@@ -9,29 +9,29 @@ class MaskNet(nn.Module):
super().__init__()
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
self.config_base_channels = config_parser.getint('training.model.masker', 'base_channels')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_base_channels)
self.up_samples = self.create_up_samples(self.config_base_channels)
self.bottleneck = BottleNeck(self.config_base_channels * 4)
self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1)
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
self.up_samples = self.create_up_samples(self.config_num_filters)
self.bottleneck = BottleNeck(self.config_num_filters * 4)
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
@staticmethod
def create_down_samples(input_channels : int, base_channels : int) -> nn.ModuleList:
def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
DownSample(input_channels, base_channels),
DownSample(base_channels, base_channels * 2),
DownSample(base_channels * 2, base_channels * 4)
DownSample(input_channels, num_filters),
DownSample(num_filters, num_filters * 2),
DownSample(num_filters * 2, num_filters * 4)
])
@staticmethod
def create_up_samples(base_channels : int) -> nn.ModuleList:
def create_up_samples(num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
UpSample(base_channels * 4, base_channels * 2),
UpSample(base_channels * 2, base_channels),
UpSample(base_channels, base_channels)
UpSample(num_filters * 4, num_filters * 2),
UpSample(num_filters * 2, num_filters),
UpSample(num_filters, num_filters)
])
def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor:
@@ -51,20 +51,23 @@ class MaskNet(nn.Module):
class BottleNeck(nn.Module):
def __init__(self, base_channels : int):
def __init__(self, num_filters : int):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace = True),
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace = True)
)
self.layers = self.create_layers(num_filters)
self.relu = nn.ReLU(inplace = True)
def create_layers(self, num_filters : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace = True),
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(num_filters),
nn.ReLU(inplace = True)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv(input_tensor) + input_tensor
output_tensor = self.layers(input_tensor) + input_tensor
output_tensor = self.relu(output_tensor)
return output_tensor
+24
View File
@@ -4,6 +4,7 @@ import pytest
import torch
from face_swapper.src.networks.aad import AAD
from face_swapper.src.networks.masknet import MaskNet
from face_swapper.src.networks.unet import UNet
@@ -31,3 +32,26 @@ def test_aad_with_unet(output_size : int) -> None:
output_tensor = generator(source_tensor, target_attributes)
assert output_tensor.shape == (1, 3, output_size, output_size)
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
def test_mask_net(output_size : int) -> None:
config_parser = ConfigParser()
config_parser.read_dict(
{
'training.model.masker':
{
'input_channels': '67',
'output_channels': '1',
'num_filters': '16',
}
})
masker = MaskNet(config_parser).eval()
target_tensor = torch.randn(1, 3, output_size, output_size)
target_attribute = torch.randn(1, 64, output_size, output_size)
output_tensor = masker(target_tensor, target_attribute)
assert output_tensor.shape == (1, 1, output_size, output_size)