Introduce new DiscriminatorLoss class, Remove useless super call params

This commit is contained in:
henryruhs
2025-02-23 00:29:41 +01:00
parent 579d3ef51c
commit 14b9bccafe
10 changed files with 84 additions and 49 deletions
@@ -4,7 +4,7 @@ from torch import Tensor, nn
class EmbeddingConverter(nn.Module):
def __init__(self) -> None:
super(EmbeddingConverter, self).__init__()
super().__init__()
self.layers = self.create_layers()
self.leaky_relu = nn.LeakyReLU()
+6 -6
View File
@@ -25,12 +25,12 @@ num_discriminators =
kernel_size =
[training.losses]
weight_adversarial =
weight_identity =
weight_attribute =
weight_reconstruction =
weight_pose =
weight_gaze =
adversarial_weight =
attribute_weight =
reconstruction_weight =
identity_weight =
pose_weight =
gaze_weight =
[training.trainer]
learning_rate =
-12
View File
@@ -22,18 +22,6 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
return vision_frame
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
real_loss = torch.relu(1 - input_tensor)
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
return real_loss
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
fake_loss = torch.relu(input_tensor + 1)
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
return fake_loss
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = input_tensor[:, :, 15: 241, 15: 241]
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
+1 -1
View File
@@ -11,7 +11,7 @@ CONFIG.read('config.ini')
class Discriminator(nn.Module):
def __init__(self) -> None:
super(Discriminator, self).__init__()
super().__init__()
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
+1 -1
View File
@@ -12,7 +12,7 @@ CONFIG.read('config.ini')
class Generator(nn.Module):
def __init__(self) -> None:
super(Generator, self).__init__()
super().__init__()
encoder_type = CONFIG.get('training.model.generator', 'encoder_type')
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
+57 -13
View File
@@ -5,13 +5,25 @@ import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
from ..helper import calc_embedding
from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
real_loss = torch.relu(1 - input_tensor)
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
return real_loss
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
fake_loss = torch.relu(input_tensor + 1)
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
return fake_loss
class FaceSwapperLoss:
def __init__(self) -> None:
embedder_path = CONFIG.get('training.model', 'embedder_path')
@@ -51,6 +63,16 @@ class FaceSwapperLoss:
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
return generator_loss_set
def hinge_real_loss(input_tensor: Tensor) -> Tensor:
real_loss = torch.relu(1 - input_tensor)
real_loss = real_loss.mean(dim = [1, 2, 3])
return real_loss
def hinge_fake_loss(input_tensor: Tensor) -> Tensor:
fake_loss = torch.relu(input_tensor + 1)
fake_loss = fake_loss.mean(dim = [1, 2, 3])
return fake_loss
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fakes = []
@@ -131,9 +153,31 @@ class FaceSwapperLoss:
return translation, scale, rotation
class AdversarialLoss(torch.nn.Module):
class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super(AdversarialLoss, self).__init__()
super().__init__()
def calc(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor:
temp1_tensors = []
temp2_tensors = []
for discriminator_output_tensor in discriminator_output_tensors:
temp1_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ])
temp1_tensors.append(temp1_tensor)
for discriminator_source_tensor in discriminator_source_tensors:
temp2_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ])
temp2_tensors.append(temp2_tensor)
discriminator1_loss = torch.stack(temp1_tensors).mean()
discriminator2_loss = torch.stack(temp2_tensors).mean()
discriminator_loss = (discriminator1_loss + discriminator2_loss) * 0.5
return discriminator_loss
class AdversarialLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
def calc(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight')
@@ -148,9 +192,9 @@ class AdversarialLoss(torch.nn.Module):
return adversarial_loss, weighted_adversarial_loss
class AttributeLoss(torch.nn.Module):
class AttributeLoss(nn.Module):
def __init__(self) -> None:
super(AttributeLoss, self).__init__()
super().__init__()
def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
@@ -166,9 +210,9 @@ class AttributeLoss(torch.nn.Module):
return attribute_loss, weighted_attribute_loss
class ReconstructionLoss(torch.nn.Module):
class ReconstructionLoss(nn.Module):
def __init__(self) -> None:
super(ReconstructionLoss, self).__init__()
super().__init__()
def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
@@ -189,9 +233,9 @@ class ReconstructionLoss(torch.nn.Module):
return reconstruction_loss, weighted_reconstruction_loss
class IdentityLoss(torch.nn.Module):
class IdentityLoss(nn.Module):
def __init__(self) -> None:
super(IdentityLoss, self).__init__()
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
@@ -204,9 +248,9 @@ class IdentityLoss(torch.nn.Module):
return identity_loss, weighted_identity_loss
class PoseLoss(torch.nn.Module):
class PoseLoss(nn.Module):
def __init__(self) -> None:
super(PoseLoss, self).__init__()
super().__init__()
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
@@ -232,9 +276,9 @@ class PoseLoss(torch.nn.Module):
return translation, scale, rotation
class GazeLoss(torch.nn.Module):
class GazeLoss(nn.Module):
def __init__(self) -> None:
super(GazeLoss, self).__init__()
super().__init__()
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
@@ -6,7 +6,7 @@ from ..types import Embedding, TargetAttributes
class AADGenerator(nn.Module):
def __init__(self, id_channels : int, num_blocks : int) -> None:
super(AADGenerator, self).__init__()
super().__init__()
self.upsample = PixelShuffleUpsample(id_channels, 1024 * 4)
self.res_block_1 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks)
self.res_block_2 = AADResBlock(1024, 1024, 2048, id_channels, num_blocks)
@@ -56,7 +56,7 @@ class AADLayer(nn.Module):
class AADSequential(nn.Module):
def __init__(self, *args : nn.Module) -> None:
super(AADSequential, self).__init__()
super().__init__()
self.layers = nn.ModuleList(args)
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
@@ -70,7 +70,7 @@ class AADSequential(nn.Module):
class AADResBlock(nn.Module):
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, id_channels : int, num_blocks : int) -> None:
super(AADResBlock, self).__init__()
super().__init__()
self.input_channels = input_channels
self.output_channels = output_channels
self.prepare_primary_add_blocks(input_channels, attribute_channels, id_channels, output_channels, num_blocks)
@@ -111,7 +111,7 @@ class AADResBlock(nn.Module):
class PixelShuffleUpsample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(PixelShuffleUpsample, self).__init__()
super().__init__()
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
+1 -1
View File
@@ -5,7 +5,7 @@ from torch import Tensor, nn
class NLD(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
super(NLD, self).__init__()
super().__init__()
self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size)
@staticmethod
+3 -3
View File
@@ -7,7 +7,7 @@ from torchvision import models
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
super().__init__()
self.down_samples = self.create_down_samples(self)
self.up_samples = self.create_up_samples()
@@ -87,7 +87,7 @@ class UNetPro(UNet):
class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(UpSample, self).__init__()
super().__init__()
self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
@@ -102,7 +102,7 @@ class UpSample(nn.Module):
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(DownSample, self).__init__()
super().__init__()
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
+10 -7
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -31,6 +31,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.generator = Generator()
self.discriminator = Discriminator()
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss()
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss()
@@ -65,10 +66,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.manual_backward(generator_loss_set.get('loss_generator'))
generator_optimizer.step()
discriminator_source_tensor = self.discriminator(source_tensor)
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensor, discriminator_output_tensors)
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss_set.get('loss_discriminator'))
discriminator_optimizer.step()
@@ -77,16 +78,17 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'))
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_identity', generator_loss_set.get('loss_identity'))
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
self.log('loss_pose', generator_loss_set.get('loss_pose'))
self.log('loss_gaze', generator_loss_set.get('loss_gaze'), prog_bar = True)
self.log('loss_gaze', generator_loss_set.get('loss_gaze'))
###############################################
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
@@ -96,12 +98,13 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
self.log('generator_loss_new', generator_loss, prog_bar = True)
self.log('discriminator_loss_new', discriminator_loss, prog_bar = True)
self.log('adversarial_loss_new', adversarial_loss)
self.log('attribute_loss_new', attribute_loss)
self.log('reconstruction_loss_new', reconstruction_loss)
self.log('identity_loss_new', identity_loss)
self.log('pose_loss_new', pose_loss)
self.log('gaze_loss_new', gaze_loss, prog_bar = True)
self.log('gaze_loss_new', gaze_loss)
return generator_loss_set.get('loss_generator')
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
@@ -161,7 +164,7 @@ def create_trainer() -> Trainer:
callbacks =
[
ModelCheckpoint(
monitor = 'loss_generator',
monitor = 'generator_loss',
dirpath = output_directory_path,
filename = output_file_pattern,
every_n_train_steps = 1000,