diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 2f185bd..7a6ebd6 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -80,8 +80,8 @@ class UNet(nn.Module): class UNetPro(UNet): - def __init__(self) -> None: - super(UNet, self).__init__() + def __init__(self, output_size : int) -> None: + super().__init__(output_size) self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT) self.down_samples = self.create_down_samples() self.up_samples = self.create_up_samples()