From 29e82f909a635525b07370a1bb71ebf0550345ad Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 13 Feb 2025 21:29:56 +0100 Subject: [PATCH] Move nld to networks --- face_swapper/src/models/discriminator.py | 53 +++--------------------- face_swapper/src/networks/nld.py | 43 +++++++++++++++++++ face_swapper/src/training.py | 4 +- 3 files changed, 51 insertions(+), 49 deletions(-) create mode 100644 face_swapper/src/networks/nld.py diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index 60eb8e1..ea7b814 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -1,19 +1,19 @@ import configparser from typing import List -import numpy from torch import nn +from face_swapper.src.networks.nld import NLD from face_swapper.src.types import VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class MultiscaleDiscriminator(nn.Module): +class Discriminator(nn.Module): def __init__(self) -> None: - super(MultiscaleDiscriminator, self).__init__() - self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) + super(Discriminator, self).__init__() + self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) self.discriminators = self.create_discriminators() @staticmethod @@ -26,7 +26,7 @@ class MultiscaleDiscriminator(nn.Module): discriminators = nn.ModuleList() for _ in range(num_discriminators): - discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size).discriminator + discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).nld discriminators.append(discriminator) return discriminators @@ -37,47 +37,6 @@ class MultiscaleDiscriminator(nn.Module): for discriminator in self.discriminators: output_tensors.append([ discriminator(temp_tensor) ]) - temp_tensor = self.downsample(temp_tensor) + temp_tensor = self.avg_pool(temp_tensor) return output_tensors - - -class NLayerDiscriminator(nn.Module): - def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None: - super(NLayerDiscriminator, self).__init__() - layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size) - self.discriminator = nn.Sequential(*layers) - - @staticmethod - def create_layers(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]: - padding = int(numpy.ceil((kernel_size - 1) / 2)) - current_filters = num_filters - layers =\ - [ - nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding), - nn.LeakyReLU(0.2, True) - ] - - for _ in range(1, num_layers): - previous_filters = current_filters - current_filters = min(current_filters * 2, 512) - layers +=\ - [ - nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding), - nn.InstanceNorm2d(current_filters), - nn.LeakyReLU(0.2, True) - ] - - previous_filters = current_filters - current_filters = min(current_filters * 2, 512) - layers +=\ - [ - nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding), - nn.InstanceNorm2d(current_filters), - nn.LeakyReLU(0.2, True), - nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding) - ] - return layers - - def forward(self, input_tensor : VisionTensor) -> VisionTensor: - return self.discriminator(input_tensor) diff --git a/face_swapper/src/networks/nld.py b/face_swapper/src/networks/nld.py new file mode 100644 index 0000000..73fc5a3 --- /dev/null +++ b/face_swapper/src/networks/nld.py @@ -0,0 +1,43 @@ +import math + +from torch import Tensor, nn + + +class NLD(nn.Module): + def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None: + super(NLD, self).__init__() + self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size) + + @staticmethod + def create_nld(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> nn.Sequential: + padding = math.ceil((kernel_size - 1) / 2) + current_filters = num_filters + layers =\ + [ + nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding), + nn.LeakyReLU(0.2, True) + ] + + for _ in range(1, num_layers): + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + layers +=\ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding), + nn.InstanceNorm2d(current_filters), + nn.LeakyReLU(0.2, True) + ] + + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + layers +=\ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding), + nn.InstanceNorm2d(current_filters), + nn.LeakyReLU(0.2, True), + nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding) + ] + return nn.Sequential(*layers) + + def forward(self, input_tensor : Tensor) -> Tensor: + return self.nld(input_tensor) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 6233e9a..a68dc7b 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from .data_loader import DataLoaderVGG from .helper import calc_id_embedding -from .models.discriminator import MultiscaleDiscriminator +from .models.discriminator import Discriminator from .models.generator import AdaptiveEmbeddingIntegrationNetwork from .models.loss import FaceSwapperLoss from .types import Batch, Embedding, TargetAttributes, VisionTensor @@ -26,7 +26,7 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() self.generator = AdaptiveEmbeddingIntegrationNetwork() - self.discriminator = MultiscaleDiscriminator() + self.discriminator = Discriminator() self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]: