diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index c018e96..6d2e17d 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -14,14 +14,14 @@ class AAD(nn.Module): def create_layers(identity_channels : int, num_blocks : int) -> nn.ModuleList: return nn.ModuleList( [ - AADResBlock(1024, 1024, 1024, identity_channels, num_blocks), - AADResBlock(1024, 1024, 2048, identity_channels, num_blocks), - AADResBlock(1024, 1024, 1024, identity_channels, num_blocks), - AADResBlock(1024, 512, 512, identity_channels, num_blocks), - AADResBlock(512, 256, 256, identity_channels, num_blocks), - AADResBlock(256, 128, 128, identity_channels, num_blocks), - AADResBlock(128, 64, 64, identity_channels, num_blocks), - AADResBlock(64, 3, 64, identity_channels, num_blocks) + AdaptiveFeatureModulation(1024, 1024, 1024, identity_channels, num_blocks), + AdaptiveFeatureModulation(1024, 1024, 2048, identity_channels, num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1024, identity_channels, num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, identity_channels, num_blocks), + AdaptiveFeatureModulation(512, 256, 256, identity_channels, num_blocks), + AdaptiveFeatureModulation(256, 128, 128, identity_channels, num_blocks), + AdaptiveFeatureModulation(128, 64, 64, identity_channels, num_blocks), + AdaptiveFeatureModulation(64, 3, 64, identity_channels, num_blocks) ]) def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor: @@ -36,59 +36,85 @@ class AAD(nn.Module): return output_tensor -class AADResBlock(nn.Module): +class AdaptiveFeatureModulation(nn.Module): def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None: super().__init__() self.input_channels = input_channels self.output_channels = output_channels - self.prepare_primary_add_blocks(input_channels, attribute_channels, identity_channels, output_channels, num_blocks) - self.prepare_auxiliary_add_blocks(input_channels, attribute_channels, identity_channels, output_channels) + self.primary_layers = self.create_primary_layers(input_channels, output_channels, attribute_channels, identity_channels, num_blocks) + self.shortcut_layers = self.create_shortcut_layers(input_channels, output_channels, attribute_channels, identity_channels) - def prepare_primary_add_blocks(self, input_channels : int, attribute_channels : int, identity_channels : int, output_channels : int, num_blocks : int) -> None: - primary_add_blocks = [] + @staticmethod + def create_primary_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> nn.ModuleList: + primary_layers = nn.ModuleList() for index in range(num_blocks): - intermediate_channels = input_channels if index < (num_blocks - 1) else output_channels - primary_add_blocks.extend( - [ - FeatureModulation(input_channels, attribute_channels, identity_channels), - nn.ReLU(inplace = True), - nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False) - ] - ) - self.primary_add_blocks = AADSequential(*primary_add_blocks) + primary_layers.extend( + [ + FeatureModulation(input_channels, attribute_channels, identity_channels), + nn.ReLU(inplace = True) + ]) + + if index < num_blocks - 1: + primary_layers.append(nn.Conv2d(input_channels, input_channels, kernel_size = 3, padding = 1, bias = False)) + else: + primary_layers.append(nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)) + + return primary_layers + + @staticmethod + def _create_primary_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> nn.ModuleList: + primary_layers = nn.ModuleList() + + for index in range(num_blocks): + primary_layers.extend( + [ + FeatureModulation(input_channels, attribute_channels, identity_channels), + nn.ReLU(inplace = True) + ]) + + if index < num_blocks - 1: + primary_layers.append(nn.Conv2d(input_channels, input_channels, kernel_size = 3, padding = 1, bias = False)) + else: + primary_layers.append(nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)) + + return primary_layers + + @staticmethod + def create_shortcut_layers(input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int) -> nn.ModuleList: + shortcut_layers = nn.ModuleList() - def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, identity_channels : int, output_channels : int) -> None: if input_channels > output_channels: - auxiliary_add_blocks = AADSequential( + shortcut_layers.extend( + [ FeatureModulation(input_channels, attribute_channels, identity_channels), nn.ReLU(inplace = True), nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False) - ) - self.auxiliary_add_blocks = auxiliary_add_blocks + ]) - def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: - primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, identity_embedding) + return shortcut_layers + + def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: + primary_tensor = input_tensor + + for primary_layer in self.primary_layers: + if isinstance(primary_layer, FeatureModulation): + primary_tensor = primary_layer(primary_tensor, attribute_embedding, identity_embedding) + else: + primary_tensor = primary_layer(primary_tensor) if self.input_channels > self.output_channels: - feature_map = self.auxiliary_add_blocks(feature_map, attribute_embedding, identity_embedding) + shortcut_tensor = input_tensor - output_feature = primary_feature + feature_map - return output_feature + for shortcut_layer in self.shortcut_layers: + if isinstance(shortcut_layer, FeatureModulation): + shortcut_tensor = shortcut_layer(shortcut_tensor, attribute_embedding, identity_embedding) + else: + shortcut_tensor = shortcut_layer(shortcut_tensor) + input_tensor = shortcut_tensor -class AADSequential(nn.Module): - def __init__(self, *args : nn.Module) -> None: - super().__init__() - self.layers = nn.ModuleList(args) - - def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: - for layer in self.layers: - if isinstance(layer, FeatureModulation): - feature_map = layer(feature_map, attribute_embedding, identity_embedding) - else: - feature_map = layer(feature_map) - return feature_map + return primary_tensor + input_tensor class FeatureModulation(nn.Module): diff --git a/face_swapper/src/networks/nld.py b/face_swapper/src/networks/nld.py index 5fee9e1..97e3ba4 100644 --- a/face_swapper/src/networks/nld.py +++ b/face_swapper/src/networks/nld.py @@ -9,7 +9,7 @@ class NLD(nn.Module): self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size) @staticmethod - def create_nld(input_channels : int, num_filters : int, num_layers: int, kernel_size : int) -> nn.Sequential: + def create_nld(input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> nn.Sequential: padding = math.ceil((kernel_size - 1) / 2) current_filters = num_filters layers =\