diff --git a/face_swapper/README.md b/face_swapper/README.md index 633b832..fc19b7a 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -63,12 +63,12 @@ kernel_size = 4 ``` [training.losses] -weight_adversarial = 1.5 -weight_identity = 15 -weight_attribute = 10 -weight_reconstruction = 15 -weight_pose = 0 -weight_gaze = 0 +adversarial_weight = 1.5 +attribute_weight = 10 +reconstruction_weight = 15 +identity_weight = 15 +pose_weight = 0 +gaze_weight = 0 ``` ``` diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index b66a2e6..1f85614 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -141,6 +141,28 @@ class FaceSwapperLoss: return translation, scale, rotation +class ReconstructionLoss(torch.nn.Module): + def __init__(self) -> None: + super(ReconstructionLoss, self).__init__() + + def calc(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tensor: + batch_size = CONFIG.getint('training.loader', 'batch_size') + + loss_tensor = torch.pow(output_tensor - target_tensor, 2).reshape(batch_size, -1) + loss_tensor = torch.mean(loss_tensor, dim = 1) * 0.5 + + if torch.equal(source_tensor, target_tensor): + loss_tensor = torch.sum(loss_tensor * torch.tensor(0)) / (torch.tensor(0).sum() + 1e-4) + else: + loss_tensor = torch.sum(loss_tensor * torch.tensor(1)) / (torch.tensor(1).sum() + 1e-4) + + data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) + similarity = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() + + loss_tensor = (loss_tensor + similarity) * 0.5 + return loss_tensor + + class IdentityLoss(torch.nn.Module): def __init__(self) -> None: super(IdentityLoss, self).__init__() @@ -148,8 +170,8 @@ class IdentityLoss(torch.nn.Module): self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.embedder.eval() - def calc_loss(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor: + def calc(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor: output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) - loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() - return loss + loss_tensor = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() + return loss_tensor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 169e6d8..4e3e0c2 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import FaceSwapperLoss, IdentityLoss +from .models.loss import FaceSwapperLoss, IdentityLoss, ReconstructionLoss from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() @@ -31,6 +31,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.generator = Generator() self.discriminator = Discriminator() + self.reconstruction_loss = ReconstructionLoss() self.identity_loss = IdentityLoss() self.automatic_optimization = automatic_optimization @@ -76,18 +77,24 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.log('loss_adversarial', generator_loss_set.get('loss_adversarial')) self.log('loss_attribute', generator_loss_set.get('loss_attribute')) self.log('loss_identity', generator_loss_set.get('loss_identity'), prog_bar = True) - self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) + self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'), prog_bar = True) - identity_loss = self.identity_loss.calc_loss(generator_output_tensor, source_tensor) - generator_loss = self.calc_generator_loss_new(identity_loss) + reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) + identity_loss = self.identity_loss.calc(generator_output_tensor, source_tensor) + generator_loss = self.calc_generator_loss_new(reconstruction_loss, identity_loss) - self.log('loss_generator_new', generator_loss, prog_bar = True) + self.log('generator_loss_new', generator_loss, prog_bar = True) + self.log('loss_reconstruction_new', reconstruction_loss, prog_bar = True) self.log('loss_identity_new', identity_loss, prog_bar = True) return generator_loss_set.get('loss_generator') - def calc_generator_loss_new(self, identity_loss : Tensor) -> Tensor: - weight_identity = CONFIG.getfloat('training.losses', 'weight_identity') - generator_loss = identity_loss * weight_identity + @staticmethod + def calc_generator_loss_new(reconstruction_loss : Tensor, identity_loss : Tensor) -> Tensor: + reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') + identity_weight = CONFIG.getfloat('training.losses', 'identity_weight') + + generator_loss = reconstruction_loss * reconstruction_weight + generator_loss += identity_loss * identity_weight return generator_loss