From 494b84aecbdaa3f145a26ea746ac423f28740720 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Feb 2025 15:43:35 +0100 Subject: [PATCH] Refactor discriminator to use ModuleList, Reduce complexity of layer creation --- face_swapper/src/models/discriminator.py | 103 +++++++++++------------ face_swapper/src/training.py | 2 +- 2 files changed, 48 insertions(+), 57 deletions(-) diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index aa2df6b..707985a 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -3,11 +3,9 @@ from itertools import chain from typing import List import numpy -import torch.nn import torch.nn as nn -from torch import Tensor -from face_swapper.src.types import DiscriminatorOutputs +from face_swapper.src.types import VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -16,77 +14,70 @@ CONFIG.read('config.ini') class MultiscaleDiscriminator(nn.Module): def __init__(self) -> None: super(MultiscaleDiscriminator, self).__init__() - self.input_channels = CONFIG.getint('training.model.discriminator', 'input_channels') - self.num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') - self.kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size') - self.num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') - self.num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') + self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) + self.discriminators = self.create_discriminators() - self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] - self.prepare_discriminators() + def create_discriminators(self) -> nn.ModuleList: + num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') + input_channels = CONFIG.getint('training.model.discriminator', 'input_channels') + num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') + kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size') + num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') + discriminators = nn.ModuleList() - def prepare_discriminators(self) -> None: - for discriminator_index in range(self.num_discriminators): - single_discriminator = NLayerDiscriminator(self.input_channels, self.num_filters, self.num_layers, self.kernel_size) - setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) + for _ in range(num_discriminators): + discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size).discriminator + self.discriminators.append(discriminator) - def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: - discriminator_outputs = [] + return discriminators + + def forward(self, input_tensor : VisionTensor) -> List[List[VisionTensor]]: temp_tensor = input_tensor + output_tensors = [] - for discriminator_index in range(self.num_discriminators): - model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index)) - discriminator_outputs.append([ model_layers(temp_tensor) ]) + for discriminator in self.discriminators: + output_tensors.append([ discriminator(temp_tensor) ]) + temp_tensor = self.downsample(temp_tensor) - if discriminator_index < (self.num_discriminators - 1): - temp_tensor = self.downsample(temp_tensor) - - return discriminator_outputs + return output_tensors class NLayerDiscriminator(nn.Module): def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None: super(NLayerDiscriminator, self).__init__() - self.num_layers = num_layers - model_layers = self.prepare_model_layers(input_channels, num_filters, num_layers, kernel_size) - self.model = nn.Sequential(*list(chain(*model_layers))) + layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size) + self.discriminator = nn.Sequential(*layers) - def prepare_model_layers(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> List[List[torch.nn.Module]]: - padding_size = int(numpy.ceil((kernel_size - 1.0) / 2)) - - model_layers =\ - [ - [ - nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size), - nn.LeakyReLU(0.2, True) - ] - ] + @staticmethod + def create_layers(self, input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]: + padding = int(numpy.ceil((kernel_size - 1) / 2)) current_filters = num_filters + layers =\ + [ + nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding), + nn.LeakyReLU(0.2, True) + ] - for layer_index in range(1, num_layers): + for _ in range(1, num_layers): previous_filters = current_filters current_filters = min(current_filters * 2, 512) - model_layers +=\ + layers +=\ [ - [ - nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size), - nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) - ] - ] - previous_filters = current_filters - current_filters = min(current_filters * 2, 512) - model_layers +=\ - [ - [ - nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding_size), + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding), nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) - ], - [ - nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding_size) ] - ] - return model_layers - def forward(self, input_tensor : Tensor) -> Tensor: - return self.model(input_tensor) + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + layers +=\ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding), + nn.InstanceNorm2d(current_filters), + nn.LeakyReLU(0.2, True), + nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding) + ] + return layers + + def forward(self, input_tensor : VisionTensor) -> VisionTensor: + return self.discriminator(input_tensor) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index c936772..6233e9a 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -85,8 +85,8 @@ def create_trainer() -> Trainer: output_directory_path = CONFIG.get('training.output', 'directory_path') output_file_pattern = CONFIG.get('training.output', 'file_pattern') trainer_precision = CONFIG.get('training.trainer', 'precision') - os.makedirs(output_directory_path, exist_ok = True) + os.makedirs(output_directory_path, exist_ok = True) return Trainer( max_epochs = trainer_max_epochs, precision = trainer_precision,