mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
add 1024
This commit is contained in:
@@ -47,6 +47,17 @@ class AAD(nn.Module):
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 4096, 8192, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 4096, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(4096, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
|
||||
|
||||
@@ -45,6 +45,15 @@ class UNet(nn.Module):
|
||||
DownSample(2048, 2048)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 2048),
|
||||
DownSample(2048, 4096),
|
||||
DownSample(4096, 4096)
|
||||
])
|
||||
|
||||
return down_samples
|
||||
|
||||
def create_up_samples(self) -> nn.ModuleList:
|
||||
@@ -71,6 +80,15 @@ class UNet(nn.Module):
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
if self.config_output_size == 1024:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(4096, 4096),
|
||||
UpSample(8192, 2048),
|
||||
UpSample(4096, 1024),
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 256),
|
||||
|
||||
@@ -8,7 +8,7 @@ from face_swapper.src.networks.masknet import MaskNet
|
||||
from face_swapper.src.networks.unet import UNet
|
||||
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512, 1024 ])
|
||||
def test_aad_with_unet(output_size : int) -> None:
|
||||
config_parser = ConfigParser()
|
||||
config_parser.read_dict(
|
||||
|
||||
Reference in New Issue
Block a user