Introduce feature flag for Unet

This commit is contained in:
henryruhs
2025-02-22 10:08:23 +01:00
parent 094d5cea9e
commit 10ce04ed58
3 changed files with 44 additions and 19 deletions
+1
View File
@@ -47,6 +47,7 @@ motion_extractor_path = .models/motion_extractor.pt
```
[training.model.generator]
encoder_type = unet-pro
num_blocks = 2
id_channels = 512
```
+11 -7
View File
@@ -3,7 +3,7 @@ import configparser
from torch import Tensor, nn
from ..networks.attribute_modulator import AADGenerator
from ..networks.unet import UNet
from ..networks.unet import UNet, UNetPro
from ..types import Attributes, Embedding
CONFIG = configparser.ConfigParser()
@@ -13,21 +13,25 @@ CONFIG.read('config.ini')
class Generator(nn.Module):
def __init__(self) -> None:
super(Generator, self).__init__()
encoder_type = CONFIG.getint('training.model.generator', 'encoder_type')
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.unet = UNet()
self.aad_generator = AADGenerator(id_channels, num_blocks)
self.unet.apply(init_weight)
self.aad_generator.apply(init_weight)
if encoder_type == 'unet':
self.encoder = UNet()
if encoder_type == 'unet-pro':
self.encoder = UNetPro()
self.generator = AADGenerator(id_channels, num_blocks)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
target_attributes = self.get_attributes(target_tensor)
output_tensor = self.aad_generator(target_attributes, source_embedding)
output_tensor = self.generator(target_attributes, source_embedding)
return output_tensor
def get_attributes(self, input_tensor : Tensor) -> Attributes:
return self.unet(input_tensor)
return self.encoder(input_tensor)
def init_weight(module : nn.Module) -> None:
+32 -12
View File
@@ -8,7 +8,6 @@ from torchvision import models
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
self.resnet = models.resnet34(pretrained = True)
self.down_samples = self.create_down_samples(self)
self.up_samples = self.create_up_samples()
@@ -16,18 +15,11 @@ class UNet(nn.Module):
def create_down_samples(self : nn.Module) -> nn.ModuleList:
return nn.ModuleList(
[
nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace = True)
),
DownSample(3, 32),
DownSample(32, 64),
self.resnet.layer2,
self.resnet.layer3,
self.resnet.layer4,
DownSample(64, 128),
DownSample(128, 256),
DownSample(256, 512),
DownSample(512, 1024),
DownSample(1024, 1024)
])
@@ -65,6 +57,34 @@ class UNet(nn.Module):
return bottleneck_tensor, *up_features, output_tensor
class UNetPro(UNet):
def __init__(self) -> None:
super(UNetPro, self).__init__()
self.resnet = models.resnet34(pretrained = True)
self.down_samples = self.create_down_samples(self)
self.up_samples = self.create_up_samples()
@staticmethod
def create_down_samples(self : nn.Module) -> nn.ModuleList:
return nn.ModuleList(
[
nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace = True)
),
DownSample(32, 64),
self.resnet.layer2,
self.resnet.layer3,
self.resnet.layer4,
DownSample(512, 1024),
DownSample(1024, 1024)
])
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(UpSample, self).__init__()