Improve naming

This commit is contained in:
henryruhs
2025-02-23 21:39:01 +01:00
parent 5bba2a1c69
commit 8b2b6892aa
2 changed files with 31 additions and 28 deletions
+22 -19
View File
@@ -95,24 +95,27 @@ class AADLayer(nn.Module):
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
super().__init__()
self.input_channels = input_channels
self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
self.fc_beta = nn.Linear(identity_channels, input_channels)
self.fc_gamma = nn.Linear(identity_channels, input_channels)
self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
self.linear1 = nn.Linear(identity_channels, input_channels)
self.linear2 = nn.Linear(identity_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels)
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
feature_map = self.instance_norm(feature_map)
gamma_attribute = self.conv_gamma(attribute_embedding)
beta_attribute = self.conv_beta(attribute_embedding)
attribute_modulation = gamma_attribute * feature_map + beta_attribute
identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
identity_modulation = identity_gamma * feature_map + identity_beta
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation
return feature_blend
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
attribute_scale = self.conv1(attribute_embedding)
attribute_shift = self.conv2(attribute_embedding)
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor)
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor)
identity_modulation = identity_scale * temp_tensor + identity_shift
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
output_tensor = (1 - temp_mask) * attribute_modulation + temp_mask * identity_modulation
return output_tensor
class PixelShuffleUpSample(nn.Module):
@@ -122,6 +125,6 @@ class PixelShuffleUpSample(nn.Module):
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
temp_tensor = self.pixel_shuffle(temp_tensor)
return temp_tensor
output_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
output_tensor = self.pixel_shuffle(output_tensor)
return output_tensor
+9 -9
View File
@@ -93,11 +93,11 @@ class UpSample(nn.Module):
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
temp_tensor = self.conv_transpose(input_tensor)
temp_tensor = self.batch_norm(temp_tensor)
temp_tensor = self.leaky_relu(temp_tensor)
temp_tensor = torch.cat((temp_tensor, skip_tensor), dim = 1)
return temp_tensor
output_tensor = self.conv_transpose(input_tensor)
output_tensor = self.batch_norm(output_tensor)
output_tensor = self.leaky_relu(output_tensor)
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
return output_tensor
class DownSample(nn.Module):
@@ -108,7 +108,7 @@ class DownSample(nn.Module):
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, input_tensor : Tensor) -> Tensor:
temp_tensor = self.conv(input_tensor)
temp_tensor = self.batch_norm(temp_tensor)
temp_tensor = self.leaky_relu(temp_tensor)
return temp_tensor
output_tensor = self.conv(input_tensor)
output_tensor = self.batch_norm(output_tensor)
output_tensor = self.leaky_relu(output_tensor)
return output_tensor