This commit is contained in:
harisreedhar
2025-04-21 15:37:13 +05:30
parent 5b41d8e91f
commit 982a94b535
3 changed files with 30 additions and 1 deletions
+11
View File
@@ -47,6 +47,17 @@ class AAD(nn.Module):
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 1024:
layers.extend(
[
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
layers.extend(
[
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
+18
View File
@@ -45,6 +45,15 @@ class UNet(nn.Module):
DownSample(2048, 2048)
])
if self.config_output_size == 1024:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 2048),
DownSample(2048, 4096),
DownSample(4096, 4096)
])
return down_samples
def create_up_samples(self) -> nn.ModuleList:
@@ -71,6 +80,15 @@ class UNet(nn.Module):
UpSample(2048, 512)
])
if self.config_output_size == 1024:
up_samples.extend(
[
UpSample(4096, 4096),
UpSample(8192, 2048),
UpSample(4096, 1024),
UpSample(2048, 512)
])
up_samples.extend(
[
UpSample(1024, 256),
+1 -1
View File
@@ -8,7 +8,7 @@ from face_swapper.src.networks.masknet import MaskNet
from face_swapper.src.networks.unet import UNet
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
@pytest.mark.parametrize('output_size', [ 128, 256, 512, 1024 ])
def test_aad_with_unet(output_size : int) -> None:
config_parser = ConfigParser()
config_parser.read_dict(