mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Follow the concept of layers and sequences
This commit is contained in:
@@ -25,7 +25,7 @@ class Discriminator(nn.Module):
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(num_discriminators):
|
||||
discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).nld
|
||||
discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).sequences
|
||||
discriminators.append(discriminator)
|
||||
|
||||
return discriminators
|
||||
|
||||
@@ -6,17 +6,18 @@ from torch import Tensor, nn
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
|
||||
super().__init__()
|
||||
self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size)
|
||||
self.layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size)
|
||||
self.sequences = nn.Sequential(*self.layers)
|
||||
|
||||
@staticmethod
|
||||
def create_nld(input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> nn.Sequential:
|
||||
def create_layers(input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> nn.ModuleList:
|
||||
padding = math.ceil((kernel_size - 1) / 2)
|
||||
current_filters = num_filters
|
||||
layers =\
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
])
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
previous_filters = current_filters
|
||||
@@ -37,7 +38,7 @@ class NLD(nn.Module):
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding)
|
||||
]
|
||||
return nn.Sequential(*layers)
|
||||
return layers
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
return self.nld(input_tensor)
|
||||
return self.sequences(input_tensor)
|
||||
|
||||
Reference in New Issue
Block a user