mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Improve naming
This commit is contained in:
@@ -95,24 +95,27 @@ class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.fc_beta = nn.Linear(identity_channels, input_channels)
|
||||
self.fc_gamma = nn.Linear(identity_channels, input_channels)
|
||||
self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
self.linear1 = nn.Linear(identity_channels, input_channels)
|
||||
self.linear2 = nn.Linear(identity_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
gamma_attribute = self.conv_gamma(attribute_embedding)
|
||||
beta_attribute = self.conv_beta(attribute_embedding)
|
||||
attribute_modulation = gamma_attribute * feature_map + beta_attribute
|
||||
identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
identity_modulation = identity_gamma * feature_map + identity_beta
|
||||
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
|
||||
feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation
|
||||
return feature_blend
|
||||
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
temp_tensor = self.instance_norm(input_tensor)
|
||||
|
||||
attribute_scale = self.conv1(attribute_embedding)
|
||||
attribute_shift = self.conv2(attribute_embedding)
|
||||
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
|
||||
|
||||
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_modulation = identity_scale * temp_tensor + identity_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
output_tensor = (1 - temp_mask) * attribute_modulation + temp_mask * identity_modulation
|
||||
return output_tensor
|
||||
|
||||
|
||||
class PixelShuffleUpSample(nn.Module):
|
||||
@@ -122,6 +125,6 @@ class PixelShuffleUpSample(nn.Module):
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
|
||||
temp_tensor = self.pixel_shuffle(temp_tensor)
|
||||
return temp_tensor
|
||||
output_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
|
||||
output_tensor = self.pixel_shuffle(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
@@ -93,11 +93,11 @@ class UpSample(nn.Module):
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv_transpose(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
temp_tensor = torch.cat((temp_tensor, skip_tensor), dim = 1)
|
||||
return temp_tensor
|
||||
output_tensor = self.conv_transpose(input_tensor)
|
||||
output_tensor = self.batch_norm(output_tensor)
|
||||
output_tensor = self.leaky_relu(output_tensor)
|
||||
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
@@ -108,7 +108,7 @@ class DownSample(nn.Module):
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
return temp_tensor
|
||||
output_tensor = self.conv(input_tensor)
|
||||
output_tensor = self.batch_norm(output_tensor)
|
||||
output_tensor = self.leaky_relu(output_tensor)
|
||||
return output_tensor
|
||||
|
||||
Reference in New Issue
Block a user