Remove deprecated argument

This commit is contained in:
henryruhs
2025-02-24 00:49:38 +01:00
parent 8b53c76a0a
commit 84be7d1ffb
+2 -1
View File
@@ -3,6 +3,7 @@ from typing import Tuple
import torch
from torch import Tensor, nn
from torchvision import models
from torchvision.models import ResNet34_Weights
class UNet(nn.Module):
@@ -60,7 +61,7 @@ class UNet(nn.Module):
class UNetPro(UNet):
def __init__(self) -> None:
super(UNet, self).__init__()
self.resnet = models.resnet34()
self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT)
self.down_samples = self.create_down_samples(self)
self.up_samples = self.create_up_samples()