mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Refactor discriminator to use ModuleList, Reduce complexity of layer creation
This commit is contained in:
@@ -3,11 +3,9 @@ from itertools import chain
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch.nn
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from face_swapper.src.types import DiscriminatorOutputs
|
||||
from face_swapper.src.types import VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -16,77 +14,70 @@ CONFIG.read('config.ini')
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
self.input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
|
||||
self.num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
|
||||
self.kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
|
||||
self.num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
|
||||
self.num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
|
||||
self.prepare_discriminators()
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
|
||||
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
|
||||
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
|
||||
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
def prepare_discriminators(self) -> None:
|
||||
for discriminator_index in range(self.num_discriminators):
|
||||
single_discriminator = NLayerDiscriminator(self.input_channels, self.num_filters, self.num_layers, self.kernel_size)
|
||||
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
|
||||
for _ in range(num_discriminators):
|
||||
discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size).discriminator
|
||||
self.discriminators.append(discriminator)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
|
||||
discriminator_outputs = []
|
||||
return discriminators
|
||||
|
||||
def forward(self, input_tensor : VisionTensor) -> List[List[VisionTensor]]:
|
||||
temp_tensor = input_tensor
|
||||
output_tensors = []
|
||||
|
||||
for discriminator_index in range(self.num_discriminators):
|
||||
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
|
||||
discriminator_outputs.append([ model_layers(temp_tensor) ])
|
||||
for discriminator in self.discriminators:
|
||||
output_tensors.append([ discriminator(temp_tensor) ])
|
||||
temp_tensor = self.downsample(temp_tensor)
|
||||
|
||||
if discriminator_index < (self.num_discriminators - 1):
|
||||
temp_tensor = self.downsample(temp_tensor)
|
||||
|
||||
return discriminator_outputs
|
||||
return output_tensors
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
self.num_layers = num_layers
|
||||
model_layers = self.prepare_model_layers(input_channels, num_filters, num_layers, kernel_size)
|
||||
self.model = nn.Sequential(*list(chain(*model_layers)))
|
||||
layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size)
|
||||
self.discriminator = nn.Sequential(*layers)
|
||||
|
||||
def prepare_model_layers(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> List[List[torch.nn.Module]]:
|
||||
padding_size = int(numpy.ceil((kernel_size - 1.0) / 2))
|
||||
|
||||
model_layers =\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
]
|
||||
@staticmethod
|
||||
def create_layers(self, input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]:
|
||||
padding = int(numpy.ceil((kernel_size - 1) / 2))
|
||||
current_filters = num_filters
|
||||
layers =\
|
||||
[
|
||||
nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
for layer_index in range(1, num_layers):
|
||||
for _ in range(1, num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
model_layers +=\
|
||||
layers +=\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
|
||||
nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
]
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
model_layers +=\
|
||||
[
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding_size),
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
],
|
||||
[
|
||||
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding_size)
|
||||
]
|
||||
]
|
||||
return model_layers
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
return self.model(input_tensor)
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
def forward(self, input_tensor : VisionTensor) -> VisionTensor:
|
||||
return self.discriminator(input_tensor)
|
||||
|
||||
@@ -85,8 +85,8 @@ def create_trainer() -> Trainer:
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
trainer_precision = CONFIG.get('training.trainer', 'precision')
|
||||
os.makedirs(output_directory_path, exist_ok = True)
|
||||
|
||||
os.makedirs(output_directory_path, exist_ok = True)
|
||||
return Trainer(
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = trainer_precision,
|
||||
|
||||
Reference in New Issue
Block a user