diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 97a361c..2d6147f 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -3,6 +3,7 @@ from typing import Tuple import torch from torch import Tensor, nn from torchvision import models +from torchvision.models import ResNet34_Weights class UNet(nn.Module): @@ -60,7 +61,7 @@ class UNet(nn.Module): class UNetPro(UNet): def __init__(self) -> None: super(UNet, self).__init__() - self.resnet = models.resnet34() + self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT) self.down_samples = self.create_down_samples(self) self.up_samples = self.create_up_samples()