Variable AAD layer according to output size

This commit is contained in:
henryruhs
2025-03-04 22:13:37 +01:00
parent 430c71d031
commit 5056b8df75
6 changed files with 77 additions and 43 deletions
+2
View File
@@ -53,6 +53,7 @@ motion_extractor_path = .models/motion_extractor.pt
encoder_type = unet-pro
identity_channels = 512
output_channels = 4096
output_size = 256
num_blocks = 2
```
@@ -97,6 +98,7 @@ resume_path = .outputs/last.ckpt
directory_path = .exports
source_path = .outputs/last.ckpt
target_path = .exports/face_swapper.onnx
target_size = 256
ir_version = 10
opset_version = 15
```
+2
View File
@@ -19,6 +19,7 @@ motion_extractor_path =
encoder_type =
identity_channels =
output_channels =
output_size =
num_blocks =
[training.model.discriminator]
@@ -53,6 +54,7 @@ resume_path =
directory_path =
source_path =
target_path =
target_size =
ir_version =
opset_version =
+2 -1
View File
@@ -13,6 +13,7 @@ def export() -> None:
directory_path = CONFIG.get('exporting', 'directory_path')
source_path = CONFIG.get('exporting', 'source_path')
target_path = CONFIG.get('exporting', 'target_path')
target_size = CONFIG.getint('exporting', 'target_size')
ir_version = CONFIG.getint('exporting', 'ir_version')
opset_version = CONFIG.getint('exporting', 'opset_version')
@@ -21,5 +22,5 @@ def export() -> None:
model.eval()
model.ir_version = torch.tensor(ir_version)
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, 256, 256)
target_tensor = torch.randn(1, 3, target_size, target_size)
torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version)
+2 -1
View File
@@ -16,13 +16,14 @@ class Generator(nn.Module):
encoder_type = CONFIG.get('training.model.generator', 'encoder_type')
identity_channels = CONFIG.getint('training.model.generator', 'identity_channels')
output_channels = CONFIG.getint('training.model.generator', 'output_channels')
output_size = CONFIG.getint('training.model.generator', 'output_size')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
if encoder_type == 'unet':
self.encoder = UNet()
if encoder_type == 'unet-pro':
self.encoder = UNetPro()
self.generator = AAD(identity_channels, output_channels, num_blocks)
self.generator = AAD(identity_channels, output_channels, output_size, num_blocks)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
+56 -31
View File
@@ -5,32 +5,58 @@ from ..types import Attributes, Embedding
class AAD(nn.Module):
def __init__(self, identity_channels : int, output_channels : int, num_blocks : int) -> None:
def __init__(self, identity_channels : int, output_channels : int, output_size : int, num_blocks : int) -> None:
super().__init__()
self.identity_channels = identity_channels
self.output_channels = output_channels
self.output_size = output_size
self.num_blocks = num_blocks
self.pixel_shuffle_up_sample = PixelShuffleUpSample(identity_channels, output_channels)
self.layers = self.create_layers(identity_channels, num_blocks)
self.layers = self.create_layers()
@staticmethod
def create_layers(identity_channels : int, num_blocks : int) -> nn.ModuleList:
return nn.ModuleList(
def create_layers(self) -> nn.ModuleList:
layers = nn.ModuleList(
[
AdaptiveFeatureModulation(1024, 1024, 1024, identity_channels, num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, identity_channels, num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, identity_channels, num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, identity_channels, num_blocks),
AdaptiveFeatureModulation(512, 256, 256, identity_channels, num_blocks),
AdaptiveFeatureModulation(256, 128, 128, identity_channels, num_blocks),
AdaptiveFeatureModulation(128, 64, 64, identity_channels, num_blocks),
AdaptiveFeatureModulation(64, 3, 64, identity_channels, num_blocks)
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks),
])
if self.output_size in [ 384, 512, 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(64, 32, 32, self.identity_channels, self.num_blocks))
if self.output_size in [ 512, 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(32, 16, 16, self.identity_channels, self.num_blocks))
if self.output_size in [ 768, 1024 ]:
layers.append(AdaptiveFeatureModulation(16, 8, 8, self.identity_channels, self.num_blocks))
if self.output_size == 1024:
layers.append(AdaptiveFeatureModulation(8, 4, 4, self.identity_channels, self.num_blocks))
if self.output_size == 256:
layers.append(AdaptiveFeatureModulation(64, 3, 64, self.identity_channels, self.num_blocks))
if self.output_size == 384:
layers.append(AdaptiveFeatureModulation(32, 3, 32, self.identity_channels, self.num_blocks))
if self.output_size == 512:
layers.append(AdaptiveFeatureModulation(16, 3, 16, self.identity_channels, self.num_blocks))
if self.output_size == 768:
layers.append(AdaptiveFeatureModulation(8, 3, 8, self.identity_channels, self.num_blocks))
if self.output_size == 1024:
layers.append(AdaptiveFeatureModulation(4, 3, 4, self.identity_channels, self.num_blocks))
return layers
def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor:
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
temp_size = target_attributes[index + 1].shape[2:]
temp_tensors = nn.functional.interpolate(temp_tensor, temp_size, mode = 'bilinear', align_corners = False)
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
output_tensor = torch.tanh(temp_tensors)
@@ -42,37 +68,38 @@ class AdaptiveFeatureModulation(nn.Module):
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.primary_layers = self.create_primary_layers(input_channels, output_channels, attribute_channels, identity_channels, num_blocks)
self.shortcut_layers = self.create_shortcut_layers(input_channels, output_channels, attribute_channels, identity_channels)
self.attribute_channels = attribute_channels
self.identity_channels = identity_channels
self.num_blocks = num_blocks
self.primary_layers = self.create_primary_layers()
self.shortcut_layers = self.create_shortcut_layers()
@staticmethod
def create_primary_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> nn.ModuleList:
def create_primary_layers(self) -> nn.ModuleList:
primary_layers = nn.ModuleList()
for index in range(num_blocks):
for index in range(self.num_blocks):
primary_layers.extend(
[
FeatureModulation(input_channels, attribute_channels, identity_channels),
FeatureModulation(self.input_channels, self.attribute_channels, self.identity_channels),
nn.ReLU(inplace = True)
])
if index < num_blocks - 1:
primary_layers.append(nn.Conv2d(input_channels, input_channels, kernel_size = 3, padding = 1, bias = False))
if index < self.num_blocks - 1:
primary_layers.append(nn.Conv2d(self.input_channels, self.input_channels, kernel_size = 3, padding = 1, bias = False))
else:
primary_layers.append(nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False))
primary_layers.append(nn.Conv2d(self.input_channels, self.output_channels, kernel_size = 3, padding = 1, bias = False))
return primary_layers
@staticmethod
def create_shortcut_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int) -> nn.ModuleList:
def create_shortcut_layers(self) -> nn.ModuleList:
shortcut_layers = nn.ModuleList()
if input_channels > output_channels:
if self.input_channels > self.output_channels:
shortcut_layers.extend(
[
FeatureModulation(input_channels, attribute_channels, identity_channels),
FeatureModulation(self.input_channels, self.attribute_channels, self.identity_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
nn.Conv2d(self.input_channels, self.output_channels, kernel_size = 3, padding = 1, bias = False)
])
return shortcut_layers
@@ -113,9 +140,7 @@ class FeatureModulation(nn.Module):
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
temp_size = temp_tensor.shape[2:]
attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_size, mode = 'bilinear')
attribute_scale = self.conv1(attribute_embedding)
attribute_shift = self.conv2(attribute_embedding)
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
+13 -10
View File
@@ -6,25 +6,28 @@ from torch import Tensor, nn
class NLD(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
super().__init__()
self.layers = self.create_layers(input_channels, num_filters, num_layers, kernel_size)
self.input_channels = input_channels
self.num_filters = num_filters
self.num_layers = num_layers
self.kernel_size = kernel_size
self.layers = self.create_layers()
self.sequences = nn.Sequential(*self.layers)
@staticmethod
def create_layers(input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> nn.ModuleList:
padding = math.ceil((kernel_size - 1) / 2)
current_filters = num_filters
def create_layers(self) -> nn.ModuleList:
padding = math.ceil((self.kernel_size - 1) / 2)
current_filters = self.num_filters
layers = nn.ModuleList(
[
nn.Conv2d(input_channels, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
nn.Conv2d(self.input_channels, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding),
nn.LeakyReLU(0.2, True)
])
for _ in range(1, num_layers):
for _ in range(1, self.num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding),
nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]
@@ -33,10 +36,10 @@ class NLD(nn.Module):
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, padding = padding),
nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True),
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, padding = padding)
nn.Conv2d(current_filters, 1, kernel_size = self.kernel_size, padding = padding)
]
return layers