From 982a94b53584b26e9a41d1487daacc7ab10c075b Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 21 Apr 2025 15:37:13 +0530 Subject: [PATCH] add 1024 --- face_swapper/src/networks/aad.py | 11 +++++++++++ face_swapper/src/networks/unet.py | 18 ++++++++++++++++++ face_swapper/tests/test_networks.py | 2 +- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index 880065e..10c1231 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -47,6 +47,17 @@ class AAD(nn.Module): AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) ]) + if self.config_output_size == 1024: + layers.extend( + [ + AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) + ]) + layers.extend( [ AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks), diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index c12f3da..5d63561 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -45,6 +45,15 @@ class UNet(nn.Module): DownSample(2048, 2048) ]) + if self.config_output_size == 1024: + down_samples.extend( + [ + DownSample(512, 1024), + DownSample(1024, 2048), + DownSample(2048, 4096), + DownSample(4096, 4096) + ]) + return down_samples def create_up_samples(self) -> nn.ModuleList: @@ -71,6 +80,15 @@ class UNet(nn.Module): UpSample(2048, 512) ]) + if self.config_output_size == 1024: + up_samples.extend( + [ + UpSample(4096, 4096), + UpSample(8192, 2048), + UpSample(4096, 1024), + UpSample(2048, 512) + ]) + up_samples.extend( [ UpSample(1024, 256), diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py index 654a172..ddd1397 100644 --- a/face_swapper/tests/test_networks.py +++ b/face_swapper/tests/test_networks.py @@ -8,7 +8,7 @@ from face_swapper.src.networks.masknet import MaskNet from face_swapper.src.networks.unet import UNet -@pytest.mark.parametrize('output_size', [ 128, 256, 512 ]) +@pytest.mark.parametrize('output_size', [ 128, 256, 512, 1024 ]) def test_aad_with_unet(output_size : int) -> None: config_parser = ConfigParser() config_parser.read_dict(