Extend Unet with more layers

This commit is contained in:
henryruhs
2025-03-05 10:25:37 +01:00
parent 5056b8df75
commit dcc5ccccd7
3 changed files with 46 additions and 18 deletions
+2 -2
View File
@@ -20,9 +20,9 @@ class Generator(nn.Module):
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
if encoder_type == 'unet':
self.encoder = UNet()
self.encoder = UNet(output_size)
if encoder_type == 'unet-pro':
self.encoder = UNetPro()
self.encoder = UNetPro(output_size)
self.generator = AAD(identity_channels, output_channels, output_size, num_blocks)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
+1 -4
View File
@@ -23,18 +23,15 @@ class AAD(nn.Module):
AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks)
])
if self.output_size in [ 384, 512, 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(64, 32, 32, self.identity_channels, self.num_blocks))
if self.output_size in [ 512, 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(32, 16, 16, self.identity_channels, self.num_blocks))
if self.output_size in [ 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(16, 8, 8, self.identity_channels, self.num_blocks))
if self.output_size == 1024:
layers.append(AdaptiveFeatureModulation(8, 4, 4, self.identity_channels, self.num_blocks))
+43 -12
View File
@@ -7,14 +7,14 @@ from torchvision.models import ResNet34_Weights
class UNet(nn.Module):
def __init__(self) -> None:
def __init__(self, output_size : int) -> None:
super().__init__()
self.down_samples = self.create_down_samples(self)
self.output_size = output_size
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()
@staticmethod
def create_down_samples(self : nn.Module) -> nn.ModuleList:
return nn.ModuleList(
def create_down_samples(self) -> nn.ModuleList:
down_samples = nn.ModuleList(
[
DownSample(3, 32),
DownSample(32, 64),
@@ -25,9 +25,19 @@ class UNet(nn.Module):
DownSample(1024, 1024)
])
@staticmethod
def create_up_samples() -> nn.ModuleList:
return nn.ModuleList(
if self.output_size in [ 384, 512, 768, 1024 ]:
down_samples.append(DownSample(1024, 2048))
if self.output_size in [ 512, 768, 1024 ]:
down_samples.append(DownSample(2048, 4096))
if self.output_size in [ 768, 1024 ]:
down_samples.append(DownSample(4096, 8192))
if self.output_size == 1024:
down_samples.append(DownSample(8192, 16384))
return down_samples
def create_up_samples(self) -> nn.ModuleList:
up_samples = nn.ModuleList(
[
UpSample(1024, 1024),
UpSample(2048, 512),
@@ -37,6 +47,17 @@ class UNet(nn.Module):
UpSample(128, 32)
])
if self.output_size in [ 384, 512, 768, 1024 ]:
up_samples.append(UpSample(32, 16))
if self.output_size in [ 512, 768, 1024 ]:
up_samples.append(UpSample(16, 8))
if self.output_size in [ 768, 1024 ]:
up_samples.append(UpSample(8, 4))
if self.output_size == 1024:
up_samples.append(UpSample(4, 2))
return up_samples
def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]:
down_features = []
up_features = []
@@ -62,12 +83,11 @@ class UNetPro(UNet):
def __init__(self) -> None:
super(UNet, self).__init__()
self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT)
self.down_samples = self.create_down_samples(self)
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()
@staticmethod
def create_down_samples(self : nn.Module) -> nn.ModuleList:
return nn.ModuleList(
def create_down_samples(self) -> nn.ModuleList:
down_samples = nn.ModuleList(
[
nn.Sequential(
self.resnet.conv1,
@@ -85,6 +105,17 @@ class UNetPro(UNet):
DownSample(1024, 1024)
])
if self.output_size in [ 384, 512, 768, 1024 ]:
down_samples.append(DownSample(1024, 2048))
if self.output_size in [ 512, 768, 1024 ]:
down_samples.append(DownSample(2048, 4096))
if self.output_size in [ 768, 1024 ]:
down_samples.append(DownSample(4096, 8192))
if self.output_size == 1024:
down_samples.append(DownSample(8192, 16384))
return down_samples
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None: