Fix UnetPro

This commit is contained in:
henryruhs
2025-03-05 10:27:55 +01:00
parent dcc5ccccd7
commit 786adf73a2
+2 -2
View File
@@ -80,8 +80,8 @@ class UNet(nn.Module):
class UNetPro(UNet):
def __init__(self) -> None:
super(UNet, self).__init__()
def __init__(self, output_size : int) -> None:
super().__init__(output_size)
self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT)
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()