diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index c36b82b..b44305d 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -11,8 +11,8 @@ class Discriminator(nn.Module): super().__init__() self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators') self.config_parser = config_parser - self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) self.discriminators = self.create_discriminators() + self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) def create_discriminators(self) -> nn.ModuleList: discriminators = nn.ModuleList() diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 7e8f871..ac53975 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -75,10 +75,10 @@ class ReconstructionLoss(nn.Module): self.mse_loss = nn.MSELoss() def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: - with torch.no_grad(): source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) + has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 reconstruction_loss = torch.mean((source_tensor - target_tensor) ** 2, dim = (1, 2, 3)) @@ -141,6 +141,7 @@ class MotionLoss(nn.Module): with torch.no_grad(): pitch, yaw, roll, translation, expression, scale, motion_points = self.motion_extractor(input_tensor) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) pose = translation, scale, rotation, motion_points return pose, expression @@ -174,6 +175,7 @@ class GazeLoss(nn.Module): with torch.no_grad(): pitch, yaw = self.gazer(crop_tensor) + return pitch, yaw @@ -201,4 +203,5 @@ class MaskLoss(nn.Module): output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype) output_tensor = output_tensor.view(-1, 1, 512, 512) output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear') + return output_tensor diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index a912319..aeb1969 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -162,10 +162,16 @@ class FeatureModulation(nn.Module): class PixelShuffleUpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1) - self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1), + nn.PixelShuffle(upscale_factor = 2) + ) def forward(self, input_tensor : Tensor) -> Tensor: - output_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1)) - output_tensor = self.pixel_shuffle(output_tensor) + temp_tensor = input_tensor.view(input_tensor.shape[0], -1, 1, 1) + output_tensor = self.sequences(temp_tensor) return output_tensor diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index bbe4e55..0a207aa 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -21,7 +21,7 @@ class MaskNet(nn.Module): return nn.ModuleList( [ DownSample(input_channels, num_filters), - DownSample(num_filters, num_filters * 2), + DownSample(num_filters, num_filters * 2) ]) @staticmethod @@ -74,26 +74,34 @@ class BottleNeck(nn.Module): class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2) - self.relu = nn.ReLU(inplace = True) + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2), + nn.ReLU(inplace = True) + ) def forward(self, input_tensor : Tensor) -> Tensor: - output_tensor = self.conv_transpose(input_tensor) - output_tensor = self.relu(output_tensor) + output_tensor = self.sequences(input_tensor) return output_tensor class DownSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.relu = nn.ReLU(inplace = True) - self.max_pool = nn.MaxPool2d(2) + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False), + nn.BatchNorm2d(output_channels), + nn.ReLU(inplace = True), + nn.MaxPool2d(2) + ) def forward(self, input_tensor : Tensor) -> Tensor: - output_tensor = self.conv(input_tensor) - output_tensor = self.batch_norm(output_tensor) - output_tensor = self.relu(output_tensor) - output_tensor = self.max_pool(output_tensor) + output_tensor = self.sequences(input_tensor) return output_tensor diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 4aaa653..1a27fe4 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -103,14 +103,18 @@ class UNet(nn.Module): class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False), + nn.BatchNorm2d(output_channels), + nn.LeakyReLU(0.1, inplace = True) + ) def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: - output_tensor = self.conv_transpose(input_tensor) - output_tensor = self.batch_norm(output_tensor) - output_tensor = self.leaky_relu(output_tensor) + output_tensor = self.sequences(input_tensor) output_tensor = torch.cat((output_tensor, skip_tensor), dim = 1) return output_tensor @@ -118,12 +122,16 @@ class UpSample(nn.Module): class DownSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + self.sequences = self.create_sequences(input_channels, output_channels) + + @staticmethod + def create_sequences(input_channels : int, output_channels : int) -> nn.Sequential: + return nn.Sequential( + nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False), + nn.BatchNorm2d(output_channels), + nn.LeakyReLU(0.1, inplace = True) + ) def forward(self, input_tensor : Tensor) -> Tensor: - output_tensor = self.conv(input_tensor) - output_tensor = self.batch_norm(output_tensor) - output_tensor = self.leaky_relu(output_tensor) + output_tensor = self.sequences(input_tensor) return output_tensor