diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 8a97df3..a5c3d9a 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -2,23 +2,32 @@ from typing import Tuple import torch from torch import Tensor, nn +from torchvision import models class UNet(nn.Module): def __init__(self) -> None: super(UNet, self).__init__() - self.down_samples = self.create_down_samples() + self.resnet = models.resnet34(pretrained = True) + self.down_samples = self.create_down_samples(self) self.up_samples = self.create_up_samples() @staticmethod - def create_down_samples() -> nn.ModuleList: + def create_down_samples(self) -> nn.ModuleList: return nn.ModuleList( [ - DownSample(3, 32), + nn.Sequential( + self.resnet.conv1, + self.resnet.bn1, + self.resnet.relu, + nn.Conv2d(64, 32, kernel_size = 1, bias = False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.1, inplace = True) + ), DownSample(32, 64), - DownSample(64, 128), - DownSample(128, 256), - DownSample(256, 512), + self.resnet.layer2, + self.resnet.layer3, + self.resnet.layer4, DownSample(512, 1024), DownSample(1024, 1024) ])