diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 7fb479e..ec3eadf 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -2,7 +2,7 @@ import configparser from torch import Tensor, nn -from ..networks.aienet import AADGenerator +from ..networks.aienet import AIENet from ..networks.unet import UNet, UNetPro from ..types import Attributes, Embedding @@ -22,13 +22,13 @@ class Generator(nn.Module): self.encoder = UNet() if encoder_type == 'unet-pro': self.encoder = UNetPro() - self.generator = AADGenerator(identity_channels, output_channels, num_blocks) + self.generator = AIENet(identity_channels, output_channels, num_blocks) self.encoder.apply(init_weight) self.generator.apply(init_weight) def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: target_attributes = self.get_attributes(target_tensor) - output_tensor = self.generator(target_attributes, source_embedding) + output_tensor = self.generator(source_embedding, target_attributes) return output_tensor def get_attributes(self, input_tensor : Tensor) -> Attributes: diff --git a/face_swapper/src/networks/aienet.py b/face_swapper/src/networks/aienet.py index 52930be..1b2bee3 100644 --- a/face_swapper/src/networks/aienet.py +++ b/face_swapper/src/networks/aienet.py @@ -4,68 +4,36 @@ from torch import Tensor, nn from ..types import Attributes, Embedding -class AADGenerator(nn.Module): +class AIENet(nn.Module): def __init__(self, identity_channels : int, output_channels : int, num_blocks : int) -> None: super().__init__() self.pixel_shuffle_up_sample = PixelShuffleUpSample(identity_channels, output_channels) - self.res_block_1 = AADResBlock(1024, 1024, 1024, identity_channels, num_blocks) - self.res_block_2 = AADResBlock(1024, 1024, 2048, identity_channels, num_blocks) - self.res_block_3 = AADResBlock(1024, 1024, 1024, identity_channels, num_blocks) - self.res_block_4 = AADResBlock(1024, 512, 512, identity_channels, num_blocks) - self.res_block_5 = AADResBlock(512, 256, 256, identity_channels, num_blocks) - self.res_block_6 = AADResBlock(256, 128, 128, identity_channels, num_blocks) - self.res_block_7 = AADResBlock(128, 64, 64, identity_channels, num_blocks) - self.res_block_8 = AADResBlock(64, 3, 64, identity_channels, num_blocks) + self.layers = self.create_layers(identity_channels, num_blocks) - def forward(self, target_attributes : Attributes, source_embedding : Embedding) -> Tensor: - feature_map = self.pixel_shuffle_up_sample(source_embedding) - feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_3 = nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_4 = nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_5 = nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_6 = nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - feature_map_7 = nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) - output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding) - return torch.tanh(output) + @staticmethod + def create_layers(identity_channels : int, num_blocks : int) -> nn.ModuleList: + return nn.ModuleList( + [ + AADResBlock(1024, 1024, 1024, identity_channels, num_blocks), + AADResBlock(1024, 1024, 2048, identity_channels, num_blocks), + AADResBlock(1024, 1024, 1024, identity_channels, num_blocks), + AADResBlock(1024, 512, 512, identity_channels, num_blocks), + AADResBlock(512, 256, 256, identity_channels, num_blocks), + AADResBlock(256, 128, 128, identity_channels, num_blocks), + AADResBlock(128, 64, 64, identity_channels, num_blocks), + AADResBlock(64, 3, 64, identity_channels, num_blocks) + ]) + def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor: + temp_tensors = self.pixel_shuffle_up_sample(source_embedding) -class AADLayer(nn.Module): - def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None: - super().__init__() - self.input_channels = input_channels - self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) - self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) - self.fc_beta = nn.Linear(identity_channels, input_channels) - self.fc_gamma = nn.Linear(identity_channels, input_channels) - self.instance_norm = nn.InstanceNorm2d(input_channels) - self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1) + for index, layer in enumerate(self.layers[:-1]): + temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding) + temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) - def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: - feature_map = self.instance_norm(feature_map) - gamma_attribute = self.conv_gamma(attribute_embedding) - beta_attribute = self.conv_beta(attribute_embedding) - attribute_modulation = gamma_attribute * feature_map + beta_attribute - identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) - identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) - identity_modulation = identity_gamma * feature_map + identity_beta - feature_mask = torch.sigmoid(self.conv_mask(feature_map)) - feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation - return feature_blend - - -class AADSequential(nn.Module): - def __init__(self, *args : nn.Module) -> None: - super().__init__() - self.layers = nn.ModuleList(args) - - def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: - for layer in self.layers: - if isinstance(layer, AADLayer): - feature_map = layer(feature_map, attribute_embedding, identity_embedding) - else: - feature_map = layer(feature_map) - return feature_map + temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding) + output_tensor = torch.tanh(temp_tensors) + return output_tensor class AADResBlock(nn.Module): @@ -109,6 +77,44 @@ class AADResBlock(nn.Module): return output_feature +class AADSequential(nn.Module): + def __init__(self, *args : nn.Module) -> None: + super().__init__() + self.layers = nn.ModuleList(args) + + def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: + for layer in self.layers: + if isinstance(layer, AADLayer): + feature_map = layer(feature_map, attribute_embedding, identity_embedding) + else: + feature_map = layer(feature_map) + return feature_map + + +class AADLayer(nn.Module): + def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None: + super().__init__() + self.input_channels = input_channels + self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) + self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) + self.fc_beta = nn.Linear(identity_channels, input_channels) + self.fc_gamma = nn.Linear(identity_channels, input_channels) + self.instance_norm = nn.InstanceNorm2d(input_channels) + self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1) + + def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: + feature_map = self.instance_norm(feature_map) + gamma_attribute = self.conv_gamma(attribute_embedding) + beta_attribute = self.conv_beta(attribute_embedding) + attribute_modulation = gamma_attribute * feature_map + beta_attribute + identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) + identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) + identity_modulation = identity_gamma * feature_map + identity_beta + feature_mask = torch.sigmoid(self.conv_mask(feature_map)) + feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation + return feature_blend + + class PixelShuffleUpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__()