Prepare test for 512

This commit is contained in:
henryruhs
2025-03-05 13:50:23 +01:00
parent 64ebfa7b84
commit 2148e9b701
+8 -1
View File
@@ -7,7 +7,14 @@ from face_swapper.src.networks.unet import UNet
@pytest.mark.parametrize('output_size', [ 256 ])
def test_aad_with_unet(output_size : int) -> None:
generator = AAD(512, 4096, output_size, 2).eval()
identity_channels = 512
if output_size == 256:
output_channels = 4096
if output_size == 512:
output_channels = 8192
num_blocks = 2
generator = AAD(identity_channels, output_channels, output_size, num_blocks).eval()
encoder = UNet(output_size).eval()
source_tensor = torch.randn(1, 512)