mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -9,29 +9,29 @@ class MaskNet(nn.Module):
|
||||
super().__init__()
|
||||
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
|
||||
self.config_base_channels = config_parser.getint('training.model.masker', 'base_channels')
|
||||
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_base_channels)
|
||||
self.up_samples = self.create_up_samples(self.config_base_channels)
|
||||
self.bottleneck = BottleNeck(self.config_base_channels * 4)
|
||||
self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1)
|
||||
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
|
||||
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
|
||||
self.up_samples = self.create_up_samples(self.config_num_filters)
|
||||
self.bottleneck = BottleNeck(self.config_num_filters * 4)
|
||||
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples(input_channels : int, base_channels : int) -> nn.ModuleList:
|
||||
def create_down_samples(input_channels : int, num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
DownSample(input_channels, base_channels),
|
||||
DownSample(base_channels, base_channels * 2),
|
||||
DownSample(base_channels * 2, base_channels * 4)
|
||||
DownSample(input_channels, num_filters),
|
||||
DownSample(num_filters, num_filters * 2),
|
||||
DownSample(num_filters * 2, num_filters * 4)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up_samples(base_channels : int) -> nn.ModuleList:
|
||||
def create_up_samples(num_filters : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
UpSample(base_channels * 4, base_channels * 2),
|
||||
UpSample(base_channels * 2, base_channels),
|
||||
UpSample(base_channels, base_channels)
|
||||
UpSample(num_filters * 4, num_filters * 2),
|
||||
UpSample(num_filters * 2, num_filters),
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor:
|
||||
@@ -51,20 +51,23 @@ class MaskNet(nn.Module):
|
||||
|
||||
|
||||
class BottleNeck(nn.Module):
|
||||
def __init__(self, base_channels : int):
|
||||
def __init__(self, num_filters : int):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(base_channels, base_channels, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(base_channels),
|
||||
nn.ReLU(inplace = True)
|
||||
)
|
||||
self.layers = self.create_layers(num_filters)
|
||||
self.relu = nn.ReLU(inplace = True)
|
||||
|
||||
def create_layers(self, num_filters : int) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(num_filters, num_filters, kernel_size = 3, padding = 1, bias = False),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(inplace = True)
|
||||
)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
output_tensor = self.conv(input_tensor) + input_tensor
|
||||
output_tensor = self.layers(input_tensor) + input_tensor
|
||||
output_tensor = self.relu(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from face_swapper.src.networks.aad import AAD
|
||||
from face_swapper.src.networks.masknet import MaskNet
|
||||
from face_swapper.src.networks.unet import UNet
|
||||
|
||||
|
||||
@@ -31,3 +32,26 @@ def test_aad_with_unet(output_size : int) -> None:
|
||||
output_tensor = generator(source_tensor, target_attributes)
|
||||
|
||||
assert output_tensor.shape == (1, 3, output_size, output_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
def test_mask_net(output_size : int) -> None:
|
||||
config_parser = ConfigParser()
|
||||
config_parser.read_dict(
|
||||
{
|
||||
'training.model.masker':
|
||||
{
|
||||
'input_channels': '67',
|
||||
'output_channels': '1',
|
||||
'num_filters': '16',
|
||||
}
|
||||
})
|
||||
|
||||
masker = MaskNet(config_parser).eval()
|
||||
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
target_attribute = torch.randn(1, 64, output_size, output_size)
|
||||
|
||||
output_tensor = masker(target_tensor, target_attribute)
|
||||
|
||||
assert output_tensor.shape == (1, 1, output_size, output_size)
|
||||
|
||||
Reference in New Issue
Block a user