mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Move nld to networks
This commit is contained in:
@@ -1,19 +1,19 @@
|
||||
import configparser
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.nld import NLD
|
||||
from face_swapper.src.types import VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
super(Discriminator, self).__init__()
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
|
||||
@staticmethod
|
||||
@@ -26,7 +26,7 @@ class MultiscaleDiscriminator(nn.Module):
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(num_discriminators):
|
||||
discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size).discriminator
|
||||
discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).nld
|
||||
discriminators.append(discriminator)
|
||||
|
||||
return discriminators
|
||||
@@ -37,47 +37,6 @@ class MultiscaleDiscriminator(nn.Module):
|
||||
|
||||
for discriminator in self.discriminators:
|
||||
output_tensors.append([ discriminator(temp_tensor) ])
|
||||
temp_tensor = self.downsample(temp_tensor)
|
||||
temp_tensor = self.avg_pool(temp_tensor)
|
||||
|
||||
return output_tensors
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size)
|
||||
self.discriminator = nn.Sequential(*layers)
|
||||
|
||||
@staticmethod
|
||||
def create_layers(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> List[nn.Module]:
|
||||
padding = int(numpy.ceil((kernel_size - 1) / 2))
|
||||
current_filters = num_filters
|
||||
layers =\
|
||||
[
|
||||
nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
def forward(self, input_tensor : VisionTensor) -> VisionTensor:
|
||||
return self.discriminator(input_tensor)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import math
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
|
||||
super(NLD, self).__init__()
|
||||
self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size)
|
||||
|
||||
@staticmethod
|
||||
def create_nld(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> nn.Sequential:
|
||||
padding = math.ceil((kernel_size - 1) / 2)
|
||||
current_filters = num_filters
|
||||
layers =\
|
||||
[
|
||||
nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
for _ in range(1, num_layers):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding)
|
||||
]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
return self.nld(input_tensor)
|
||||
@@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from .data_loader import DataLoaderVGG
|
||||
from .helper import calc_id_embedding
|
||||
from .models.discriminator import MultiscaleDiscriminator
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .models.loss import FaceSwapperLoss
|
||||
from .types import Batch, Embedding, TargetAttributes, VisionTensor
|
||||
@@ -26,7 +26,7 @@ class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork()
|
||||
self.discriminator = MultiscaleDiscriminator()
|
||||
self.discriminator = Discriminator()
|
||||
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
|
||||
Reference in New Issue
Block a user