diff --git a/face_swapper/README.md b/face_swapper/README.md index fa85087..f18a7ae 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -51,7 +51,7 @@ face_parser_path = .models/face_parser.pt ``` [training.model.generator] -identity_channels = 512 +source_channels = 512 output_channels = 4096 output_size = 256 num_blocks = 2 diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 79ac834..eec38bd 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -17,7 +17,7 @@ motion_extractor_path = face_parser_path = [training.model.generator] -identity_channels = +source_channels = output_channels = output_size = num_blocks = diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index 671ff14..7babfbe 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -10,11 +10,11 @@ from ..types import Attribute, Embedding class AAD(nn.Module): def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.config_identity_channels = config_parser.getint('training.model.generator', 'identity_channels') + self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels') self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels') self.config_output_size = config_parser.getint('training.model.generator', 'output_size') self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks') - self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_identity_channels, self.config_output_channels) + self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels) self.layers = self.create_layers() def create_layers(self) -> nn.ModuleList: @@ -23,36 +23,36 @@ class AAD(nn.Module): if self.config_output_size == 128: layers.extend( [ - AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(512, 512, 1024, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks) + AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks) ]) if self.config_output_size == 256: layers.extend( [ - AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(1024, 1024, 2048, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks) + AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) ]) if self.config_output_size == 512: layers.extend( [ - AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 2048, 4096, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(2048, 1024, 1024, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks) + AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks) ]) layers.extend( [ - AdaptiveFeatureModulation(512, 256, 256, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(256, 128, 128, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(128, 64, 64, self.config_identity_channels, self.config_num_blocks), - AdaptiveFeatureModulation(64, 3, 64, self.config_identity_channels, self.config_num_blocks) + AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks), + AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks) ]) return layers @@ -61,21 +61,23 @@ class AAD(nn.Module): temp_tensors = self.pixel_shuffle_up_sample(source_embedding) for index, layer in enumerate(self.layers[:-1]): - temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding) + target_attribute = target_attributes[index] + temp_tensor = layer(temp_tensors, source_embedding, target_attribute) temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) - temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding) + target_attribute = target_attributes[-1] + temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_attribute) output_tensor = torch.tanh(temp_tensors) return output_tensor class AdaptiveFeatureModulation(nn.Module): - def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None: + def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None: super().__init__() self.context_input_channels = input_channels self.context_output_channels = output_channels - self.context_attribute_channels = attribute_channels - self.context_identity_channels = identity_channels + self.context_target_channels = target_channels + self.context_source_channels = source_channels self.context_num_blocks = num_blocks self.primary_layers = self.create_primary_layers() self.shortcut_layers = self.create_shortcut_layers() @@ -86,7 +88,7 @@ class AdaptiveFeatureModulation(nn.Module): for index in range(self.context_num_blocks): primary_layers.extend( [ - FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels), + FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels), nn.ReLU(inplace = True) ]) @@ -103,19 +105,19 @@ class AdaptiveFeatureModulation(nn.Module): if self.context_input_channels > self.context_output_channels: shortcut_layers.extend( [ - FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels), + FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels), nn.ReLU(inplace = True), nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False) ]) return shortcut_layers - def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: + def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor: primary_tensor = input_tensor for primary_layer in self.primary_layers: if isinstance(primary_layer, FeatureModulation): - primary_tensor = primary_layer(primary_tensor, attribute_embedding, identity_embedding) + primary_tensor = primary_layer(primary_tensor, source_embedding, target_attribute) else: primary_tensor = primary_layer(primary_tensor) @@ -124,7 +126,7 @@ class AdaptiveFeatureModulation(nn.Module): for shortcut_layer in self.shortcut_layers: if isinstance(shortcut_layer, FeatureModulation): - shortcut_tensor = shortcut_layer(shortcut_tensor, attribute_embedding, identity_embedding) + shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_attribute) else: shortcut_tensor = shortcut_layer(shortcut_tensor) @@ -134,29 +136,29 @@ class AdaptiveFeatureModulation(nn.Module): class FeatureModulation(nn.Module): - def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None: + def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None: super().__init__() self.context_input_channels = input_channels - self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) - self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) + self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1) + self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1) self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1) - self.linear1 = nn.Linear(identity_channels, input_channels) - self.linear2 = nn.Linear(identity_channels, input_channels) + self.linear1 = nn.Linear(source_channels, input_channels) + self.linear2 = nn.Linear(source_channels, input_channels) self.instance_norm = nn.InstanceNorm2d(input_channels) - def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: + def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor: temp_tensor = self.instance_norm(input_tensor) - attribute_scale = self.conv1(attribute_embedding) - attribute_shift = self.conv2(attribute_embedding) - attribute_modulation = attribute_scale * temp_tensor + attribute_shift + source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) + source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) + source_modulation = source_scale * temp_tensor + source_shift - identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) - identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) - identity_modulation = identity_scale * temp_tensor + identity_shift + target_scale = self.conv1(target_attribute) + target_shift = self.conv2(target_attribute) + target_modulation = target_scale * temp_tensor + target_shift temp_mask = torch.sigmoid(self.conv3(temp_tensor)) - output_tensor = (1 - temp_mask) * attribute_modulation + temp_mask * identity_modulation + output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation return output_tensor diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py index 20c645e..37a9c1a 100644 --- a/face_swapper/tests/test_networks.py +++ b/face_swapper/tests/test_networks.py @@ -15,7 +15,7 @@ def test_aad_with_unet(output_size : int) -> None: { 'training.model.generator': { - 'identity_channels': '512', + 'source_channels': '512', 'output_channels': str(output_size * 16), 'output_size': str(output_size), 'num_blocks': '2'