This commit is contained in:
henryruhs
2025-02-12 15:47:04 +01:00
parent 71e0ae34c0
commit dd320ea5be
+2 -2
View File
@@ -44,11 +44,11 @@ class MultiscaleDiscriminator(nn.Module):
class NLayerDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
super(NLayerDiscriminator, self).__init__()
layers = self.create_layers(self, input_channels, num_filters, num_layers, kernel_size)
layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size)
self.discriminator = nn.Sequential(*layers)
@staticmethod
def create_layers(self, input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]:
def create_layers(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]:
padding = int(numpy.ceil((kernel_size - 1) / 2))
current_filters = num_filters
layers =\