From 10ce04ed58b87c492a7578fb190f2f66f6aa37af Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 22 Feb 2025 10:08:23 +0100 Subject: [PATCH] Introduce feature flag for Unet --- face_swapper/README.md | 1 + face_swapper/src/models/generator.py | 18 +++++++----- face_swapper/src/networks/unet.py | 44 ++++++++++++++++++++-------- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index ccd0d6d..5f6370b 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -47,6 +47,7 @@ motion_extractor_path = .models/motion_extractor.pt ``` [training.model.generator] +encoder_type = unet-pro num_blocks = 2 id_channels = 512 ``` diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index d6fa9b6..be9c236 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -3,7 +3,7 @@ import configparser from torch import Tensor, nn from ..networks.attribute_modulator import AADGenerator -from ..networks.unet import UNet +from ..networks.unet import UNet, UNetPro from ..types import Attributes, Embedding CONFIG = configparser.ConfigParser() @@ -13,21 +13,25 @@ CONFIG.read('config.ini') class Generator(nn.Module): def __init__(self) -> None: super(Generator, self).__init__() + encoder_type = CONFIG.getint('training.model.generator', 'encoder_type') id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - self.unet = UNet() - self.aad_generator = AADGenerator(id_channels, num_blocks) - self.unet.apply(init_weight) - self.aad_generator.apply(init_weight) + if encoder_type == 'unet': + self.encoder = UNet() + if encoder_type == 'unet-pro': + self.encoder = UNetPro() + self.generator = AADGenerator(id_channels, num_blocks) + self.encoder.apply(init_weight) + self.generator.apply(init_weight) def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: target_attributes = self.get_attributes(target_tensor) - output_tensor = self.aad_generator(target_attributes, source_embedding) + output_tensor = self.generator(target_attributes, source_embedding) return output_tensor def get_attributes(self, input_tensor : Tensor) -> Attributes: - return self.unet(input_tensor) + return self.encoder(input_tensor) def init_weight(module : nn.Module) -> None: diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index b70b44d..127fdef 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -8,7 +8,6 @@ from torchvision import models class UNet(nn.Module): def __init__(self) -> None: super(UNet, self).__init__() - self.resnet = models.resnet34(pretrained = True) self.down_samples = self.create_down_samples(self) self.up_samples = self.create_up_samples() @@ -16,18 +15,11 @@ class UNet(nn.Module): def create_down_samples(self : nn.Module) -> nn.ModuleList: return nn.ModuleList( [ - nn.Sequential( - self.resnet.conv1, - self.resnet.bn1, - self.resnet.relu, - nn.Conv2d(64, 32, kernel_size = 1, bias = False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.1, inplace = True) - ), + DownSample(3, 32), DownSample(32, 64), - self.resnet.layer2, - self.resnet.layer3, - self.resnet.layer4, + DownSample(64, 128), + DownSample(128, 256), + DownSample(256, 512), DownSample(512, 1024), DownSample(1024, 1024) ]) @@ -65,6 +57,34 @@ class UNet(nn.Module): return bottleneck_tensor, *up_features, output_tensor +class UNetPro(UNet): + def __init__(self) -> None: + super(UNetPro, self).__init__() + self.resnet = models.resnet34(pretrained = True) + self.down_samples = self.create_down_samples(self) + self.up_samples = self.create_up_samples() + + @staticmethod + def create_down_samples(self : nn.Module) -> nn.ModuleList: + return nn.ModuleList( + [ + nn.Sequential( + self.resnet.conv1, + self.resnet.bn1, + self.resnet.relu, + nn.Conv2d(64, 32, kernel_size = 1, bias = False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.1, inplace = True) + ), + DownSample(32, 64), + self.resnet.layer2, + self.resnet.layer3, + self.resnet.layer4, + DownSample(512, 1024), + DownSample(1024, 1024) + ]) + + class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super(UpSample, self).__init__()