diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 6454455..ebf88ba 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -11,7 +11,7 @@ class UNet(nn.Module): self.up = self.create_up() @staticmethod - def create_down(): + def create_down() -> nn.ModuleList: return nn.ModuleList( [ Down(3, 32), @@ -24,7 +24,7 @@ class UNet(nn.Module): ]) @staticmethod - def create_up(): + def create_up() -> nn.ModuleList: return nn.ModuleList( [ Up(1024, 1024),