Use sequential whenever possible

This commit is contained in:
henryruhs
2025-03-11 08:04:14 +01:00
parent f90fd73b54
commit e758eb3e19
5 changed files with 56 additions and 31 deletions
+1 -1
View File
@@ -11,8 +11,8 @@ class Discriminator(nn.Module):
super().__init__()
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
self.config_parser = config_parser
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
+4 -1
View File
@@ -75,10 +75,10 @@ class ReconstructionLoss(nn.Module):
self.mse_loss = nn.MSELoss()
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
with torch.no_grad():
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3))
@@ -141,6 +141,7 @@ class MotionLoss(nn.Module):
with torch.no_grad():
pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
pose = translation, scale, rotation, motion_points
return pose, expression
@@ -174,6 +175,7 @@ class GazeLoss(nn.Module):
with torch.no_grad():
pitch, yaw = self.gazer(crop_tensor)
return pitch, yaw
@@ -201,4 +203,5 @@ class MaskLoss(nn.Module):
output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype)
output_tensor = output_tensor.view(-1, 1, 512, 512)
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
return output_tensor
+10 -4
View File
@@ -162,10 +162,16 @@ class FeatureModulation(nn.Module):
class PixelShuffleUpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1),
nn.PixelShuffle(upscale_factor = 2)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
output_tensor = self.pixel_shuffle(output_tensor)
temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1)
output_tensor = self.sequences(temp_tensor)
return output_tensor
+21 -13
View File
@@ -21,7 +21,7 @@ class MaskNet(nn.Module):
return nn.ModuleList(
[
DownSample(input_channels, num_filters),
DownSample(num_filters, num_filters * 2),
DownSample(num_filters, num_filters * 2)
])
@staticmethod
@@ -74,26 +74,34 @@ class BottleNeck(nn.Module):
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2)
self.relu = nn.ReLU(inplace = True)
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2),
nn.ReLU(inplace = True)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv_transpose(input_tensor)
output_tensor = self.relu(output_tensor)
output_tensor = self.sequences(input_tensor)
return output_tensor
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU(inplace = True)
self.max_pool = nn.MaxPool2d(2)
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace = True),
nn.MaxPool2d(2)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv(input_tensor)
output_tensor = self.batch_norm(output_tensor)
output_tensor = self.relu(output_tensor)
output_tensor = self.max_pool(output_tensor)
output_tensor = self.sequences(input_tensor)
return output_tensor
+20 -12
View File
@@ -103,14 +103,18 @@ class UNet(nn.Module):
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1, inplace = True)
)
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
output_tensor = self.conv_transpose(input_tensor)
output_tensor = self.batch_norm(output_tensor)
output_tensor = self.leaky_relu(output_tensor)
output_tensor = self.sequences(input_tensor)
output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1)
return output_tensor
@@ -118,12 +122,16 @@ class UpSample(nn.Module):
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
self.sequences = self.create_sequences(input_channels, output_channels)
@staticmethod
def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False),
nn.BatchNorm2d(output_channels),
nn.LeakyReLU(0.1, inplace = True)
)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv(input_tensor)
output_tensor = self.batch_norm(output_tensor)
output_tensor = self.leaky_relu(output_tensor)
output_tensor = self.sequences(input_tensor)
return output_tensor