diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index de8852c..b766979 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -95,24 +95,27 @@ class AADLayer(nn.Module): def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None: super().__init__() self.input_channels = input_channels - self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) - self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) - self.fc_beta = nn.Linear(identity_channels, input_channels) - self.fc_gamma = nn.Linear(identity_channels, input_channels) + self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) + self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) + self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1) + self.linear1 = nn.Linear(identity_channels, input_channels) + self.linear2 = nn.Linear(identity_channels, input_channels) self.instance_norm = nn.InstanceNorm2d(input_channels) - self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1) - def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: - feature_map = self.instance_norm(feature_map) - gamma_attribute = self.conv_gamma(attribute_embedding) - beta_attribute = self.conv_beta(attribute_embedding) - attribute_modulation = gamma_attribute * feature_map + beta_attribute - identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) - identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) - identity_modulation = identity_gamma * feature_map + identity_beta - feature_mask = torch.sigmoid(self.conv_mask(feature_map)) - feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation - return feature_blend + def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: + temp_tensor = self.instance_norm(input_tensor) + + attribute_scale = self.conv1(attribute_embedding) + attribute_shift = self.conv2(attribute_embedding) + attribute_modulation = attribute_scale * temp_tensor + attribute_shift + + identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor) + identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor) + identity_modulation = identity_scale * temp_tensor + identity_shift + + temp_mask = torch.sigmoid(self.conv3(temp_tensor)) + output_tensor = (1 - temp_mask) * attribute_modulation + temp_mask * identity_modulation + return output_tensor class PixelShuffleUpSample(nn.Module): @@ -122,6 +125,6 @@ class PixelShuffleUpSample(nn.Module): self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) def forward(self, input_tensor : Tensor) -> Tensor: - temp_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1)) - temp_tensor = self.pixel_shuffle(temp_tensor) - return temp_tensor + output_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1)) + output_tensor = self.pixel_shuffle(output_tensor) + return output_tensor diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 7ea9b2a..d1a4f2a 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -93,11 +93,11 @@ class UpSample(nn.Module): self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: - temp_tensor = self.conv_transpose(input_tensor) - temp_tensor = self.batch_norm(temp_tensor) - temp_tensor = self.leaky_relu(temp_tensor) - temp_tensor = torch.cat((temp_tensor, skip_tensor), dim = 1) - return temp_tensor + output_tensor = self.conv_transpose(input_tensor) + output_tensor = self.batch_norm(output_tensor) + output_tensor = self.leaky_relu(output_tensor) + output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1) + return output_tensor class DownSample(nn.Module): @@ -108,7 +108,7 @@ class DownSample(nn.Module): self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) def forward(self, input_tensor : Tensor) -> Tensor: - temp_tensor = self.conv(input_tensor) - temp_tensor = self.batch_norm(temp_tensor) - temp_tensor = self.leaky_relu(temp_tensor) - return temp_tensor + output_tensor = self.conv(input_tensor) + output_tensor = self.batch_norm(output_tensor) + output_tensor = self.leaky_relu(output_tensor) + return output_tensor