Partial use Resnet34 as a DownSample replacement

This commit is contained in:
henryruhs
2025-02-16 23:12:23 +01:00
parent 83ef075b1d
commit f6c59257d9
+15 -6
View File
@@ -2,23 +2,32 @@ from typing import Tuple
import torch
from torch import Tensor, nn
from torchvision import models
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
self.down_samples = self.create_down_samples()
self.resnet = models.resnet34(pretrained = True)
self.down_samples = self.create_down_samples(self)
self.up_samples = self.create_up_samples()
@staticmethod
def create_down_samples() -> nn.ModuleList:
def create_down_samples(self) -> nn.ModuleList:
return nn.ModuleList(
[
DownSample(3, 32),
nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace = True)
),
DownSample(32, 64),
DownSample(64, 128),
DownSample(128, 256),
DownSample(256, 512),
self.resnet.layer2,
self.resnet.layer3,
self.resnet.layer4,
DownSample(512, 1024),
DownSample(1024, 1024)
])