Debug path

This commit is contained in:
henryruhs
2025-03-05 14:59:59 +01:00
parent 94ad33cb1e
commit 866019d44f
+7 -7
View File
@@ -7,8 +7,8 @@ import torch
package_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(package_directory)
from face_swapper.src.networks.aad import AAD
from face_swapper.src.networks.unet import UNet
#from face_swapper.src.networks.aad import AAD
#from face_swapper.src.networks.unet import UNet
@pytest.mark.parametrize('output_size', [ 256 ])
@@ -20,14 +20,14 @@ def test_aad_with_unet(output_size : int) -> None:
output_channels = 8192
num_blocks = 2
generator = AAD(identity_channels, output_channels, output_size, num_blocks).eval()
encoder = UNet(output_size).eval()
#generator = AAD(identity_channels, output_channels, output_size, num_blocks).eval()
#encoder = UNet(output_size).eval()
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, output_size, output_size)
target_attributes = encoder(target_tensor)
output_tensor = generator(source_tensor, target_attributes)
#target_attributes = encoder(target_tensor)
#output_tensor = generator(source_tensor, target_attributes)
assert package_directory == None
assert output_tensor.shape == (1, 3, output_size, output_size)
#assert output_tensor.shape == (1, 3, output_size, output_size)