diff --git a/embedding_converter/src/models/embedding_converter.py b/embedding_converter/src/models/embedding_converter.py index de483f3..bc599c3 100644 --- a/embedding_converter/src/models/embedding_converter.py +++ b/embedding_converter/src/models/embedding_converter.py @@ -4,7 +4,7 @@ from torch import Tensor, nn class EmbeddingConverter(nn.Module): def __init__(self) -> None: - super(EmbeddingConverter, self).__init__() + super().__init__() self.layers = self.create_layers() self.leaky_relu = nn.LeakyReLU() diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 44f453d..d627d88 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -25,12 +25,12 @@ num_discriminators = kernel_size = [training.losses] -weight_adversarial = -weight_identity = -weight_attribute = -weight_reconstruction = -weight_pose = -weight_gaze = +adversarial_weight = +attribute_weight = +reconstruction_weight = +identity_weight = +pose_weight = +gaze_weight = [training.trainer] learning_rate = diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 01acf39..855a317 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -22,18 +22,6 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame: return vision_frame -def hinge_real_loss(input_tensor : Tensor) -> Tensor: - real_loss = torch.relu(1 - input_tensor) - real_loss = real_loss.mean(dim = [ 1, 2, 3 ]) - return real_loss - - -def hinge_fake_loss(input_tensor : Tensor) -> Tensor: - fake_loss = torch.relu(input_tensor + 1) - fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ]) - return fake_loss - - def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: crop_tensor = input_tensor[:, :, 15: 241, 15: 241] crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area') diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index bd78ae0..4f49cd2 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -11,7 +11,7 @@ CONFIG.read('config.ini') class Discriminator(nn.Module): def __init__(self) -> None: - super(Discriminator, self).__init__() + super().__init__() self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) self.discriminators = self.create_discriminators() diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 6f757ec..6e530a3 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -12,7 +12,7 @@ CONFIG.read('config.ini') class Generator(nn.Module): def __init__(self) -> None: - super(Generator, self).__init__() + super().__init__() encoder_type = CONFIG.get('training.model.generator', 'encoder_type') id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 801257e..8397652 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -5,13 +5,25 @@ import torch from pytorch_msssim import ssim from torch import Tensor, nn -from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss +from ..helper import calc_embedding from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') +def hinge_real_loss(input_tensor : Tensor) -> Tensor: + real_loss = torch.relu(1 - input_tensor) + real_loss = real_loss.mean(dim = [ 1, 2, 3 ]) + return real_loss + + +def hinge_fake_loss(input_tensor : Tensor) -> Tensor: + fake_loss = torch.relu(input_tensor + 1) + fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ]) + return fake_loss + + class FaceSwapperLoss: def __init__(self) -> None: embedder_path = CONFIG.get('training.model', 'embedder_path') @@ -51,6 +63,16 @@ class FaceSwapperLoss: generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze return generator_loss_set + def hinge_real_loss(input_tensor: Tensor) -> Tensor: + real_loss = torch.relu(1 - input_tensor) + real_loss = real_loss.mean(dim = [1, 2, 3]) + return real_loss + + def hinge_fake_loss(input_tensor: Tensor) -> Tensor: + fake_loss = torch.relu(input_tensor + 1) + fake_loss = fake_loss.mean(dim = [1, 2, 3]) + return fake_loss + def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet: discriminator_loss_set = {} loss_fakes = [] @@ -131,9 +153,31 @@ class FaceSwapperLoss: return translation, scale, rotation -class AdversarialLoss(torch.nn.Module): +class DiscriminatorLoss(nn.Module): def __init__(self) -> None: - super(AdversarialLoss, self).__init__() + super().__init__() + + def calc(self, discriminator_source_tensors : List[Tensor], discriminator_output_tensors : List[Tensor]) -> Tensor: + temp1_tensors = [] + temp2_tensors = [] + + for discriminator_output_tensor in discriminator_output_tensors: + temp1_tensor = torch.relu(discriminator_output_tensor[0] + 1).mean(dim = [ 1, 2, 3 ]) + temp1_tensors.append(temp1_tensor) + + for discriminator_source_tensor in discriminator_source_tensors: + temp2_tensor = torch.relu(1 - discriminator_source_tensor[0]).mean(dim = [ 1, 2, 3 ]) + temp2_tensors.append(temp2_tensor) + + discriminator1_loss = torch.stack(temp1_tensors).mean() + discriminator2_loss = torch.stack(temp2_tensors).mean() + discriminator_loss = (discriminator1_loss + discriminator2_loss) * 0.5 + return discriminator_loss + + +class AdversarialLoss(nn.Module): + def __init__(self) -> None: + super().__init__() def calc(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]: adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight') @@ -148,9 +192,9 @@ class AdversarialLoss(torch.nn.Module): return adversarial_loss, weighted_adversarial_loss -class AttributeLoss(torch.nn.Module): +class AttributeLoss(nn.Module): def __init__(self) -> None: - super(AttributeLoss, self).__init__() + super().__init__() def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: batch_size = CONFIG.getint('training.loader', 'batch_size') @@ -166,9 +210,9 @@ class AttributeLoss(torch.nn.Module): return attribute_loss, weighted_attribute_loss -class ReconstructionLoss(torch.nn.Module): +class ReconstructionLoss(nn.Module): def __init__(self) -> None: - super(ReconstructionLoss, self).__init__() + super().__init__() def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: batch_size = CONFIG.getint('training.loader', 'batch_size') @@ -189,9 +233,9 @@ class ReconstructionLoss(torch.nn.Module): return reconstruction_loss, weighted_reconstruction_loss -class IdentityLoss(torch.nn.Module): +class IdentityLoss(nn.Module): def __init__(self) -> None: - super(IdentityLoss, self).__init__() + super().__init__() embedder_path = CONFIG.get('training.model', 'embedder_path') self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] @@ -204,9 +248,9 @@ class IdentityLoss(torch.nn.Module): return identity_loss, weighted_identity_loss -class PoseLoss(torch.nn.Module): +class PoseLoss(nn.Module): def __init__(self) -> None: - super(PoseLoss, self).__init__() + super().__init__() motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() @@ -232,9 +276,9 @@ class PoseLoss(torch.nn.Module): return translation, scale, rotation -class GazeLoss(torch.nn.Module): +class GazeLoss(nn.Module): def __init__(self) -> None: - super(GazeLoss, self).__init__() + super().__init__() landmarker_path = CONFIG.get('training.model', 'landmarker_path') self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() diff --git a/face_swapper/src/networks/attribute_modulator.py b/face_swapper/src/networks/attribute_modulator.py index 4d3e973..cdf44e1 100644 --- a/face_swapper/src/networks/attribute_modulator.py +++ b/face_swapper/src/networks/attribute_modulator.py @@ -6,7 +6,7 @@ from ..types import Embedding, TargetAttributes class AADGenerator(nn.Module): def __init__(self, id_channels : int, num_blocks : int) -> None: - super(AADGenerator, self).__init__() + super().__init__() self.upsample = PixelShuffleUpsample(id_channels, 1024 * 4) self.res_block_1 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks) self.res_block_2 = AADResBlock(1024, 1024, 2048, id_channels, num_blocks) @@ -56,7 +56,7 @@ class AADLayer(nn.Module): class AADSequential(nn.Module): def __init__(self, *args : nn.Module) -> None: - super(AADSequential, self).__init__() + super().__init__() self.layers = nn.ModuleList(args) def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor: @@ -70,7 +70,7 @@ class AADSequential(nn.Module): class AADResBlock(nn.Module): def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, id_channels : int, num_blocks : int) -> None: - super(AADResBlock, self).__init__() + super().__init__() self.input_channels = input_channels self.output_channels = output_channels self.prepare_primary_add_blocks(input_channels, attribute_channels, id_channels, output_channels, num_blocks) @@ -111,7 +111,7 @@ class AADResBlock(nn.Module): class PixelShuffleUpsample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: - super(PixelShuffleUpsample, self).__init__() + super().__init__() self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1) self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) diff --git a/face_swapper/src/networks/nld.py b/face_swapper/src/networks/nld.py index 73fc5a3..5fee9e1 100644 --- a/face_swapper/src/networks/nld.py +++ b/face_swapper/src/networks/nld.py @@ -5,7 +5,7 @@ from torch import Tensor, nn class NLD(nn.Module): def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None: - super(NLD, self).__init__() + super().__init__() self.nld = self.create_nld(input_channels, num_filters, num_layers, kernel_size) @staticmethod diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 5f9cb05..7ea9b2a 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -7,7 +7,7 @@ from torchvision import models class UNet(nn.Module): def __init__(self) -> None: - super(UNet, self).__init__() + super().__init__() self.down_samples = self.create_down_samples(self) self.up_samples = self.create_up_samples() @@ -87,7 +87,7 @@ class UNetPro(UNet): class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: - super(UpSample, self).__init__() + super().__init__() self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) @@ -102,7 +102,7 @@ class UpSample(nn.Module): class DownSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: - super(DownSample, self).__init__() + super().__init__() self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 0c2ab5e..e5760cb 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() @@ -31,6 +31,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.generator = Generator() self.discriminator = Discriminator() + self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss() self.attribute_loss = AttributeLoss() self.reconstruction_loss = ReconstructionLoss() @@ -65,10 +66,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.manual_backward(generator_loss_set.get('loss_generator')) generator_optimizer.step() - discriminator_source_tensor = self.discriminator(source_tensor) + discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) - discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensor, discriminator_output_tensors) + discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) discriminator_optimizer.zero_grad() self.manual_backward(discriminator_loss_set.get('loss_discriminator')) discriminator_optimizer.step() @@ -77,16 +78,17 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.generate_preview(source_tensor, target_tensor, generator_output_tensor) self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True) - self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator')) + self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True) self.log('loss_adversarial', generator_loss_set.get('loss_adversarial')) self.log('loss_attribute', generator_loss_set.get('loss_attribute')) self.log('loss_identity', generator_loss_set.get('loss_identity')) self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) self.log('loss_pose', generator_loss_set.get('loss_pose')) - self.log('loss_gaze', generator_loss_set.get('loss_gaze'), prog_bar = True) + self.log('loss_gaze', generator_loss_set.get('loss_gaze')) ############################################### + discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors) attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) @@ -96,12 +98,13 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss self.log('generator_loss_new', generator_loss, prog_bar = True) + self.log('discriminator_loss_new', discriminator_loss, prog_bar = True) self.log('adversarial_loss_new', adversarial_loss) self.log('attribute_loss_new', attribute_loss) self.log('reconstruction_loss_new', reconstruction_loss) self.log('identity_loss_new', identity_loss) self.log('pose_loss_new', pose_loss) - self.log('gaze_loss_new', gaze_loss, prog_bar = True) + self.log('gaze_loss_new', gaze_loss) return generator_loss_set.get('loss_generator') def validation_step(self, batch : Batch, batch_index : int) -> Tensor: @@ -161,7 +164,7 @@ def create_trainer() -> Trainer: callbacks = [ ModelCheckpoint( - monitor = 'loss_generator', + monitor = 'generator_loss', dirpath = output_directory_path, filename = output_file_pattern, every_n_train_steps = 1000,