From 085c493e18deef8af38396e0603a03261bb6bda3 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sat, 22 Feb 2025 12:37:47 +0100 Subject: [PATCH] Rename id_embedder to embedder, Tons of naming in training step, Introduce new IdentityLoss class --- face_swapper/README.md | 4 +- face_swapper/src/helper.py | 20 +++++----- face_swapper/src/inferencing.py | 14 +++---- face_swapper/src/models/loss.py | 26 ++++++++++--- face_swapper/src/training.py | 65 +++++++++++++++++++++------------ 5 files changed, 80 insertions(+), 49 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index fda306a..633b832 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -40,7 +40,7 @@ split_ratio = 0.9995 ``` [training.model] -id_embedder_path = .models/id_embedder.pt +embedder_path = .models/arcface.pt landmarker_path = .models/landmarker.pt motion_extractor_path = .models/motion_extractor.pt ``` @@ -99,7 +99,7 @@ opset_version = 15 ``` [inferencing] generator_path = .outputs/last.ckpt -id_embedder_path = .models/id_embedder.pt +embedder_path = .models/arcface.pt source_path = .assets/source.jpg target_path = .assets/target.jpg output_path = .outputs/output.jpg diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index cd9a41a..01acf39 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -34,13 +34,13 @@ def hinge_fake_loss(input_tensor : Tensor) -> Tensor: return fake_loss -def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor, padding : Padding) -> Embedding: - crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241] - crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area') - crop_vision_tensor[:, :, :padding[0], :] = 0 - crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 - crop_vision_tensor[:, :, :, :padding[2]] = 0 - crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 - source_embedding = id_embedder(crop_vision_tensor) - source_embedding = nn.functional.normalize(source_embedding, p = 2) - return source_embedding +def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: + crop_tensor = input_tensor[:, :, 15: 241, 15: 241] + crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area') + crop_tensor[:, :, :padding[0], :] = 0 + crop_tensor[:, :, 112 - padding[1]:, :] = 0 + crop_tensor[:, :, :, :padding[2]] = 0 + crop_tensor[:, :, :, 112 - padding[3]:] = 0 + embedding = embedder(crop_tensor) + embedding = nn.functional.normalize(embedding, p = 2) + return embedding diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 658ec69..d01510d 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -3,7 +3,7 @@ import configparser import cv2 import torch -from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor +from .helper import calc_embedding, convert_to_vision_frame, convert_to_vision_tensor from .models.generator import Generator from .types import EmbedderModule, GeneratorModule, VisionFrame @@ -11,10 +11,10 @@ CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: +def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: source_vision_tensor = convert_to_vision_tensor(source_vision_frame) target_vision_tensor = convert_to_vision_tensor(target_vision_frame) - source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0)) + source_embedding = calc_embedding(embedder, source_vision_tensor, (0, 0, 0, 0)) output_vision_tensor = generator(source_embedding, target_vision_tensor)[0] output_vision_frame = convert_to_vision_frame(output_vision_tensor) return output_vision_frame @@ -22,7 +22,7 @@ def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_v def infer() -> None: generator_path = CONFIG.get('inferencing', 'generator_path') - id_embedder_path = CONFIG.get('inferencing', 'id_embedder_path') + embedder_path = CONFIG.get('inferencing', 'embedder_path') source_path = CONFIG.get('inferencing', 'source_path') target_path = CONFIG.get('inferencing', 'target_path') output_path = CONFIG.get('inferencing', 'output_path') @@ -31,10 +31,10 @@ def infer() -> None: generator = Generator() generator.load_state_dict(state_dict) generator.eval() - id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] - id_embedder.eval() + embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + embedder.eval() source_vision_frame = cv2.imread(source_path) target_vision_frame = cv2.imread(target_path) - output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame) + output_vision_frame = run_swap(generator, embedder, source_vision_frame, target_vision_frame) cv2.imwrite(output_path, output_vision_frame) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 093cc43..b66a2e6 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -5,7 +5,7 @@ import torch from pytorch_msssim import ssim from torch import Tensor, nn -from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss +from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() @@ -14,15 +14,15 @@ CONFIG.read('config.ini') class FaceSwapperLoss: def __init__(self) -> None: - id_embedder_path = CONFIG.get('training.model', 'id_embedder_path') + embedder_path = CONFIG.get('training.model', 'embedder_path') landmarker_path = CONFIG.get('training.model', 'landmarker_path') motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') self.batch_size = CONFIG.getint('training.loader', 'batch_size') self.mse_loss = nn.MSELoss() - self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.id_embedder.eval() + self.embedder.eval() self.landmarker.eval() self.motion_extractor.eval() @@ -105,8 +105,8 @@ class FaceSwapperLoss: return loss_reconstruction def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) + swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10)) + source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() return loss_identity @@ -139,3 +139,17 @@ class FaceSwapperLoss: pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) rotation = torch.cat([ pitch, yaw, roll ], dim = 1) return translation, scale, rotation + + +class IdentityLoss(torch.nn.Module): + def __init__(self) -> None: + super(IdentityLoss, self).__init__() + embedder_path = CONFIG.get('training.model', 'embedder_path') + self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.embedder.eval() + + def calc_loss(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor: + output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) + source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) + loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() + return loss diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 9687048..169e6d8 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -13,10 +13,10 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset, random_split from .dataset import DynamicDataset -from .helper import calc_id_embedding +from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import FaceSwapperLoss +from .models.loss import FaceSwapperLoss, IdentityLoss from .types import Batch, Embedding, VisionTensor CONFIG = configparser.ConfigParser() @@ -27,9 +27,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): def __init__(self) -> None: super().__init__() FaceSwapperLoss.__init__(self) + automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') + self.generator = Generator() self.discriminator = Discriminator() - self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') + self.identity_loss = IdentityLoss() + self.automatic_optimization = automatic_optimization def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor: output_tensor = self.generator(source_embedding, target_tensor) @@ -42,43 +45,57 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): return generator_optimizer, discriminator_optimizer def training_step(self, batch : Batch, batch_index : int) -> Tensor: + preview_frequency = CONFIG.getfloat('training.trainer', 'preview_frequency') + source_tensor, target_tensor = batch generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) - swap_tensor = self.generator(source_embedding, target_tensor) + source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_attributes = self.generator.get_attributes(target_tensor) - swap_attributes = self.generator.get_attributes(swap_tensor) - fake_discriminator_outputs = self.discriminator(swap_tensor) + generator_output_tensor = self.generator(source_embedding, target_tensor) + generator_output_attributes = self.generator.get_attributes(generator_output_tensor) + discriminator_output_tensor = self.discriminator(generator_output_tensor) - generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch) + generator_loss_set = self.calc_generator_loss(generator_output_tensor, target_attributes, generator_output_attributes, discriminator_output_tensor, batch) generator_optimizer.zero_grad() - self.manual_backward(generator_losses.get('loss_generator')) + self.manual_backward(generator_loss_set.get('loss_generator')) generator_optimizer.step() - real_discriminator_outputs = self.discriminator(source_tensor) - fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) + discriminator_source_tensor = self.discriminator(source_tensor) + discriminator_output_tensor = self.discriminator(generator_output_tensor.detach()) - discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs) + discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensor, discriminator_output_tensor) discriminator_optimizer.zero_grad() - self.manual_backward(discriminator_losses.get('loss_discriminator')) + self.manual_backward(discriminator_loss_set.get('loss_discriminator')) discriminator_optimizer.step() - if self.global_step % CONFIG.getint('training.trainer', 'preview_frequency') == 0: - self.generate_preview(source_tensor, target_tensor, swap_tensor) + if self.global_step % preview_frequency == 0: + self.generate_preview(source_tensor, target_tensor, generator_output_tensor) - self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True) - self.log('loss_discriminator', discriminator_losses.get('loss_discriminator'), prog_bar = True) - self.log('loss_adversarial', generator_losses.get('loss_adversarial')) - self.log('loss_attribute', generator_losses.get('loss_attribute')) - self.log('loss_identity', generator_losses.get('loss_identity')) - self.log('loss_reconstruction', generator_losses.get('loss_reconstruction')) - return generator_losses.get('loss_generator') + self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True) + self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True) + self.log('loss_adversarial', generator_loss_set.get('loss_adversarial')) + self.log('loss_attribute', generator_loss_set.get('loss_attribute')) + self.log('loss_identity', generator_loss_set.get('loss_identity'), prog_bar = True) + self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) + + identity_loss = self.identity_loss.calc_loss(generator_output_tensor, source_tensor) + generator_loss = self.calc_generator_loss_new(identity_loss) + + self.log('loss_generator_new', generator_loss, prog_bar = True) + self.log('loss_identity_new', identity_loss, prog_bar = True) + return generator_loss_set.get('loss_generator') + + def calc_generator_loss_new(self, identity_loss : Tensor) -> Tensor: + weight_identity = CONFIG.getfloat('training.losses', 'weight_identity') + generator_loss = identity_loss * weight_identity + + return generator_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch - source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) + source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) output_tensor = self.generator(source_embedding, target_tensor) - output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0)) + output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0)) validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5 self.log('validation', validation) return validation