diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 632ba1c..09ebe54 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -40,9 +40,9 @@ class EmbeddingConverterTrainer(lightning.LightningModule): source_embedding = self.source_embedder(batch) target_embedding = self.target_embedder(batch) output_embedding = self(source_embedding) - loss_training = self.mse_loss(output_embedding, target_embedding) - self.log('loss_training', loss_training, prog_bar = True) - return loss_training + training_loss = self.mse_loss(output_embedding, target_embedding) + self.log('training_loss', training_loss, prog_bar = True) + return training_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: with torch.no_grad(): @@ -63,7 +63,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule): 'lr_scheduler': { 'scheduler': scheduler, - 'monitor': 'loss_training', + 'monitor': 'training_loss', 'interval': 'epoch', 'frequency': 1 } @@ -102,7 +102,7 @@ def create_trainer() -> Trainer: callbacks = [ ModelCheckpoint( - monitor = 'loss_training', + monitor = 'training_loss', dirpath = output_directory_path, filename = output_file_pattern, every_n_epochs = 10, diff --git a/face_swapper/README.md b/face_swapper/README.md index fc19b7a..e3a5c37 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -65,8 +65,8 @@ kernel_size = 4 [training.losses] adversarial_weight = 1.5 attribute_weight = 10 -reconstruction_weight = 15 -identity_weight = 15 +reconstruction_weight = 20 +identity_weight = 20 pose_weight = 0 gaze_weight = 0 ``` diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 5029efb..4f5c621 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -3,10 +3,11 @@ from typing import List, Tuple import torch from pytorch_msssim import ssim +from sqlalchemy.dialects.mssql.information_schema import identity_columns from torch import Tensor, nn from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss -from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor +from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -153,9 +154,27 @@ class AdversarialLoss(torch.nn.Module): temp_tensor = torch.relu(1 - discriminator_output_tensor[0]).mean(dim = [ 1, 2, 3 ]).mean() temp_tensors.append(temp_tensor) - loss = torch.stack(temp_tensors).mean() - weighted_loss = loss * adversarial_weight - return loss, weighted_loss + adversarial_loss = torch.stack(temp_tensors).mean() + weighted_adversarial_loss = adversarial_loss * adversarial_weight + return adversarial_loss, weighted_adversarial_loss + + +class AttributeLoss(torch.nn.Module): + def __init__(self) -> None: + super(AttributeLoss, self).__init__() + + def calc(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: + batch_size = CONFIG.getint('training.loader', 'batch_size') + attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight') + temp_tensors = [] + + for target_attribute, output_attribute in zip(target_attributes, output_attributes): + temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(batch_size, -1), dim = 1).mean() + temp_tensors.append(temp_tensor) + + attribute_loss = torch.stack(temp_tensors).mean() * 0.5 + weighted_attribute_loss = attribute_loss * attribute_weight + return attribute_loss, weighted_attribute_loss class ReconstructionLoss(torch.nn.Module): @@ -165,20 +184,20 @@ class ReconstructionLoss(torch.nn.Module): def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: batch_size = CONFIG.getint('training.loader', 'batch_size') reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') - loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1) - loss = torch.mean(loss, dim = 1) * 0.5 + reconstruction_loss = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1) + reconstruction_loss = torch.mean(reconstruction_loss, dim = 1) * 0.5 if torch.equal(source_tensor, target_tensor): - loss = torch.sum(loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4) + reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4) else: - loss = torch.sum(loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4) + reconstruction_loss = torch.sum(reconstruction_loss * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4) data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() - loss = (loss + similarity) * 0.5 - weighted_loss = loss * reconstruction_weight - return loss, weighted_loss + reconstruction_loss = (reconstruction_loss + similarity) * 0.5 + weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight + return reconstruction_loss, weighted_reconstruction_loss class IdentityLoss(torch.nn.Module): @@ -192,6 +211,6 @@ class IdentityLoss(torch.nn.Module): identity_weight = CONFIG.getfloat('training.losses', 'identity_weight') output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) - loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() - weighted_loss = loss * identity_weight - return loss, weighted_loss + identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() + weighted_identity_loss = identity_loss * identity_weight + return identity_loss, weighted_identity_loss diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 48e1199..e52c6cd 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, AttributeLoss, FaceSwapperLoss, IdentityLoss, ReconstructionLoss from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() @@ -32,6 +32,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.generator = Generator() self.discriminator = Discriminator() self.adversarial_loss = AdversarialLoss() + self.attribute_loss = AttributeLoss() self.reconstruction_loss = ReconstructionLoss() self.identity_loss = IdentityLoss() self.automatic_optimization = automatic_optimization @@ -74,23 +75,25 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.generate_preview(source_tensor, target_tensor, generator_output_tensor) self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True) - self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True) + self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator')) self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'), prog_bar = True) - self.log('loss_attribute', generator_loss_set.get('loss_attribute')) + self.log('loss_attribute', generator_loss_set.get('loss_attribute'), prog_bar = True) self.log('loss_identity', generator_loss_set.get('loss_identity')) self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) ############################################### adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors) + attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor) - generator_loss = weighted_adversarial_loss + weighted_reconstruction_loss + weighted_identity_loss + generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss self.log('generator_loss_new', generator_loss, prog_bar = True) - self.log('loss_adversarial_new', adversarial_loss, prog_bar = True) - self.log('loss_reconstruction_new', reconstruction_loss) - self.log('loss_identity_new', identity_loss) + self.log('adversarial_loss_new', adversarial_loss) + self.log('attribute_loss_new', attribute_loss, prog_bar = True) + self.log('reconstruction_loss_new', reconstruction_loss) + self.log('identity_loss_new', identity_loss) return generator_loss_set.get('loss_generator') def validation_step(self, batch : Batch, batch_index : int) -> Tensor: