diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py index b00e2b0..37ccc7a 100644 --- a/face_swapper/tests/test_networks.py +++ b/face_swapper/tests/test_networks.py @@ -7,7 +7,14 @@ from face_swapper.src.networks.unet import UNet @pytest.mark.parametrize('output_size', [ 256 ]) def test_aad_with_unet(output_size : int) -> None: - generator = AAD(512, 4096, output_size, 2).eval() + identity_channels = 512 + if output_size == 256: + output_channels = 4096 + if output_size == 512: + output_channels = 8192 + num_blocks = 2 + + generator = AAD(identity_channels, output_channels, output_size, num_blocks).eval() encoder = UNet(output_size).eval() source_tensor = torch.randn(1, 512)