From dcc5ccccd7467ce0fc0705be216a9595a631ca33 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 5 Mar 2025 10:25:37 +0100 Subject: [PATCH] Extend Unet with more layers --- face_swapper/src/models/generator.py | 4 +- face_swapper/src/networks/aad.py | 5 +-- face_swapper/src/networks/unet.py | 55 ++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 4c057cf..369edc2 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -20,9 +20,9 @@ class Generator(nn.Module): num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') if encoder_type == 'unet': - self.encoder = UNet() + self.encoder = UNet(output_size) if encoder_type == 'unet-pro': - self.encoder = UNetPro() + self.encoder = UNetPro(output_size) self.generator = AAD(identity_channels, output_channels, output_size, num_blocks) self.encoder.apply(init_weight) self.generator.apply(init_weight) diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index 1f12986..53c9c8b 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -23,18 +23,15 @@ class AAD(nn.Module): AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks), AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks), AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks), + AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks) ]) if self.output_size in [ 384, 512, 768, 1024 ]: layers.append(AdaptiveFeatureModulation(64, 32, 32, self.identity_channels, self.num_blocks)) - if self.output_size in [ 512, 768, 1024 ]: layers.append(AdaptiveFeatureModulation(32, 16, 16, self.identity_channels, self.num_blocks)) - if self.output_size in [ 768, 1024 ]: layers.append(AdaptiveFeatureModulation(16, 8, 8, self.identity_channels, self.num_blocks)) - if self.output_size == 1024: layers.append(AdaptiveFeatureModulation(8, 4, 4, self.identity_channels, self.num_blocks)) diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 2d6147f..2f185bd 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -7,14 +7,14 @@ from torchvision.models import ResNet34_Weights class UNet(nn.Module): - def __init__(self) -> None: + def __init__(self, output_size : int) -> None: super().__init__() - self.down_samples = self.create_down_samples(self) + self.output_size = output_size + self.down_samples = self.create_down_samples() self.up_samples = self.create_up_samples() - @staticmethod - def create_down_samples(self : nn.Module) -> nn.ModuleList: - return nn.ModuleList( + def create_down_samples(self) -> nn.ModuleList: + down_samples = nn.ModuleList( [ DownSample(3, 32), DownSample(32, 64), @@ -25,9 +25,19 @@ class UNet(nn.Module): DownSample(1024, 1024) ]) - @staticmethod - def create_up_samples() -> nn.ModuleList: - return nn.ModuleList( + if self.output_size in [ 384, 512, 768, 1024 ]: + down_samples.append(DownSample(1024, 2048)) + if self.output_size in [ 512, 768, 1024 ]: + down_samples.append(DownSample(2048, 4096)) + if self.output_size in [ 768, 1024 ]: + down_samples.append(DownSample(4096, 8192)) + if self.output_size == 1024: + down_samples.append(DownSample(8192, 16384)) + + return down_samples + + def create_up_samples(self) -> nn.ModuleList: + up_samples = nn.ModuleList( [ UpSample(1024, 1024), UpSample(2048, 512), @@ -37,6 +47,17 @@ class UNet(nn.Module): UpSample(128, 32) ]) + if self.output_size in [ 384, 512, 768, 1024 ]: + up_samples.append(UpSample(32, 16)) + if self.output_size in [ 512, 768, 1024 ]: + up_samples.append(UpSample(16, 8)) + if self.output_size in [ 768, 1024 ]: + up_samples.append(UpSample(8, 4)) + if self.output_size == 1024: + up_samples.append(UpSample(4, 2)) + + return up_samples + def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]: down_features = [] up_features = [] @@ -62,12 +83,11 @@ class UNetPro(UNet): def __init__(self) -> None: super(UNet, self).__init__() self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT) - self.down_samples = self.create_down_samples(self) + self.down_samples = self.create_down_samples() self.up_samples = self.create_up_samples() - @staticmethod - def create_down_samples(self : nn.Module) -> nn.ModuleList: - return nn.ModuleList( + def create_down_samples(self) -> nn.ModuleList: + down_samples = nn.ModuleList( [ nn.Sequential( self.resnet.conv1, @@ -85,6 +105,17 @@ class UNetPro(UNet): DownSample(1024, 1024) ]) + if self.output_size in [ 384, 512, 768, 1024 ]: + down_samples.append(DownSample(1024, 2048)) + if self.output_size in [ 512, 768, 1024 ]: + down_samples.append(DownSample(2048, 4096)) + if self.output_size in [ 768, 1024 ]: + down_samples.append(DownSample(4096, 8192)) + if self.output_size == 1024: + down_samples.append(DownSample(8192, 16384)) + + return down_samples + class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: