Refactor discriminator to use ModuleList, Reduce complexity of layer creation

This commit is contained in:
henryruhs
2025-02-12 15:43:35 +01:00
parent 860771e482
commit 494b84aecb
2 changed files with 48 additions and 57 deletions
+47 -56
View File
@@ -3,11 +3,9 @@ from itertools import chain
from typing import List
import numpy
import torch.nn
import torch.nn as nn
from torch import Tensor
from face_swapper.src.types import DiscriminatorOutputs
from face_swapper.src.types import VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -16,77 +14,70 @@ CONFIG.read('config.ini')
class MultiscaleDiscriminator(nn.Module):
def __init__(self) -> None:
super(MultiscaleDiscriminator, self).__init__()
self.input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
self.num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
self.kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
self.num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
self.num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
self.prepare_discriminators()
def create_discriminators(self) -> nn.ModuleList:
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
discriminators = nn.ModuleList()
def prepare_discriminators(self) -> None:
for discriminator_index in range(self.num_discriminators):
single_discriminator = NLayerDiscriminator(self.input_channels, self.num_filters, self.num_layers, self.kernel_size)
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
for _ in range(num_discriminators):
discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size).discriminator
self.discriminators.append(discriminator)
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
discriminator_outputs = []
return discriminators
def forward(self, input_tensor : VisionTensor) -> List[List[VisionTensor]]:
temp_tensor = input_tensor
output_tensors = []
for discriminator_index in range(self.num_discriminators):
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
discriminator_outputs.append([ model_layers(temp_tensor) ])
for discriminator in self.discriminators:
output_tensors.append([ discriminator(temp_tensor) ])
temp_tensor = self.downsample(temp_tensor)
if discriminator_index < (self.num_discriminators - 1):
temp_tensor = self.downsample(temp_tensor)
return discriminator_outputs
return output_tensors
class NLayerDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
super(NLayerDiscriminator, self).__init__()
self.num_layers = num_layers
model_layers = self.prepare_model_layers(input_channels, num_filters, num_layers, kernel_size)
self.model = nn.Sequential(*list(chain(*model_layers)))
layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size)
self.discriminator = nn.Sequential(*layers)
def prepare_model_layers(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> List[List[torch.nn.Module]]:
padding_size = int(numpy.ceil((kernel_size - 1.0) / 2))
model_layers =\
[
[
nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
nn.LeakyReLU(0.2, True)
]
]
@staticmethod
def create_layers(self, input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]:
padding = int(numpy.ceil((kernel_size - 1) / 2))
current_filters = num_filters
layers =\
[
nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
nn.LeakyReLU(0.2, True)
]
for layer_index in range(1, num_layers):
for _ in range(1, num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
model_layers +=\
layers +=\
[
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True)
]
]
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
model_layers +=\
[
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding_size),
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
],
[
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding_size)
]
]
return model_layers
def forward(self, input_tensor : Tensor) -> Tensor:
return self.model(input_tensor)
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True),
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding)
]
return layers
def forward(self, input_tensor : VisionTensor) -> VisionTensor:
return self.discriminator(input_tensor)
+1 -1
View File
@@ -85,8 +85,8 @@ def create_trainer() -> Trainer:
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
trainer_precision = CONFIG.get('training.trainer', 'precision')
os.makedirs(output_directory_path, exist_ok = True)
os.makedirs(output_directory_path, exist_ok = True)
return Trainer(
max_epochs = trainer_max_epochs,
precision = trainer_precision,