From 9e1c71b498ca51c57d4d1ce22540d8c600d81af7 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Tue, 14 Jan 2025 16:06:09 +0530 Subject: [PATCH] clean --- face_swapper/src/data_loader.py | 10 -- face_swapper/src/helper.py | 104 +------------- face_swapper/src/training.py | 236 +++++++------------------------- face_swapper/src/typing.py | 3 +- 4 files changed, 55 insertions(+), 298 deletions(-) diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index 412a970..34c7dca 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -8,7 +8,6 @@ import tqdm from PIL import Image from torch.utils.data import TensorDataset -from .augmentations import apply_random_motion_blur from .typing import Batch CONFIG = configparser.ConfigParser() @@ -52,7 +51,6 @@ class DataLoaderVGG(TensorDataset): transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.RandomHorizontalFlip(p = 0.5), - transforms.RandomApply([ apply_random_motion_blur ], p = 0.3), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1), transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), @@ -81,11 +79,3 @@ class DataLoaderVGG(TensorDataset): def __len__(self) -> int: return self.dataset_total - - - def state_dict(self): - return {'current_index': self._current_index} - - - def load_state_dict(self, state_dict): - self._current_index = state_dict['current_index'] diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 4a6b2bd..6f42cd2 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,114 +1,16 @@ import configparser -from typing import Tuple import torch -from .typing import Tensor -import numpy -import torch.nn.functional as F + +from .typing import Tensor, Tuple CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -if CONFIG.getboolean('preparing.augmentation', 'expression'): - from LivePortrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix - L2_loss = torch.nn.MSELoss() -EXPRESSION_MIN = numpy.array( -[ - [ - [-2.88067125e-02, -8.12731311e-02, -1.70541159e-03], - [-4.88598682e-02, -3.32196616e-02, -1.67431499e-04], - [-6.75425082e-02, -4.28681746e-02, -1.98950816e-04], - [-7.23103955e-02, -3.28503326e-02, -7.31324719e-04], - [-3.87073644e-02, -6.01546466e-02, -5.50269964e-04], - [-6.38048723e-02, -2.23840728e-01, -7.13261834e-04], - [-3.02710701e-02, -3.93195450e-02, -8.24086510e-06], - [-2.95799859e-02, -5.39318882e-02, -1.74219604e-04], - [-2.92359516e-02, -1.53050944e-02, -6.30460854e-05], - [-5.56493877e-03, -2.34344602e-02, -1.26858242e-04], - [-4.37593013e-02, -2.77768299e-02, -2.70503685e-02], - [-1.76926646e-02, -1.91676542e-02, -1.15090821e-04], - [-8.34268332e-03, -3.99775570e-03, -3.27481248e-05], - [-3.40162888e-02, -2.81868968e-02, -1.96679524e-04], - [-2.91855410e-02, -3.97511162e-02, -2.81230678e-05], - [-1.50395725e-02, -2.49494594e-02, -9.42573533e-05], - [-1.67938769e-02, -2.00953931e-02, -4.00750607e-04], - [-1.86435618e-02, -2.48535164e-02, -2.74416432e-02], - [-4.61211195e-03, -1.21660791e-02, -2.93173041e-04], - [-4.10017073e-02, -7.43824020e-02, -4.42762971e-02], - [-1.90370996e-02, -3.74363363e-02, -1.34740388e-02] - ] -]).astype(numpy.float32) -EXPRESSION_MAX = numpy.array( -[ - [ - [4.46682945e-02, 7.08772913e-02, 4.08344204e-04], - [2.14308221e-02, 6.15894832e-02, 4.85319615e-05], - [3.02363783e-02, 4.45043296e-02, 1.28298725e-05], - [3.05869691e-02, 3.79812494e-02, 6.57040102e-04], - [4.45670523e-02, 3.97259220e-02, 7.10966764e-04], - [9.43699256e-02, 9.85926315e-02, 2.02551950e-04], - [1.61131397e-02, 2.92906128e-02, 3.44733417e-06], - [5.23825921e-02, 1.07065082e-01, 6.61510974e-04], - [2.85718683e-03, 8.32320191e-03, 2.39314613e-04], - [2.57947259e-02, 1.60935968e-02, 2.41853559e-05], - [4.90833223e-02, 3.43903080e-02, 3.22353356e-02], - [1.44766076e-02, 3.39248963e-02, 1.42291479e-04], - [8.75749043e-04, 6.82212645e-03, 2.76097053e-05], - [1.86958015e-02, 3.84016186e-02, 7.33085908e-05], - [2.01714113e-02, 4.90544215e-02, 2.34028921e-05], - [2.46518422e-02, 3.29151377e-02, 3.48571630e-05], - [2.22457591e-02, 1.21796541e-02, 1.56396593e-04], - [1.72109623e-02, 3.01626958e-02, 1.36556877e-02], - [1.83460284e-02, 1.61141958e-02, 2.87440169e-04], - [3.57594155e-02, 1.80554688e-01, 2.75554154e-02], - [2.17450950e-02, 8.66811201e-02, 3.34241726e-02] - ] -]).astype(numpy.float32) -def randomize_expression(face_tensor, feature_extractor, motion_extractor, warping_network, spade_generator): - with torch.no_grad(): - face_tensor_norm = (face_tensor + 1) * 0.5 - input_device = face_tensor.device - feature_volume = feature_extractor(face_tensor_norm) - motion_extractor_dict = motion_extractor(face_tensor_norm) - - translation = motion_extractor_dict.get('t') - expression = motion_extractor_dict.get('exp') - scale = motion_extractor_dict.get('scale') - points = motion_extractor_dict.get('kp') - - pitch = headpose_pred_to_degree(motion_extractor_dict.get('pitch'))[:, None] - yaw = headpose_pred_to_degree(motion_extractor_dict.get('yaw'))[:, None] - roll = headpose_pred_to_degree(motion_extractor_dict.get('roll'))[:, None] - rotation_matrix = get_rotation_matrix(pitch, yaw, roll) - random_expression = get_random_expression_blend(expression) - - points_transformed = transform_points(points, rotation_matrix, expression, scale, translation) - points_driv = transform_points(points, rotation_matrix, random_expression, scale, translation) - - data = warping_network(feature_volume, points_driv, points_transformed).get('out') - output = spade_generator(data) - output = output.to(input_device) - output = F.interpolate(output.clamp(0, 1), [256, 256], mode='bilinear', align_corners=False) - output = (output - 0.5) * 2 - return output - - -def get_random_expression_blend(expression : Tensor) -> Tensor: - blend = 0.35 - expression = expression.view(-1, 21, 3) - min_array = torch.from_numpy(EXPRESSION_MIN).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1) - max_array = torch.from_numpy(EXPRESSION_MAX).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1) - random_batch = torch.rand_like(min_array).to(expression.device) * (max_array - min_array) + min_array - random_batch[:, [0, 1, 8, 6, 9, 4, 5, 10]] = expression[:, [0, 1, 8, 6, 9, 4, 5, 10]] - random_batch[:, [3, 7]] = random_batch[:, [13, 16]] * 0.1 + expression[:, [13, 16]] * 0.9 - random_batch[:, [3, 7]] = random_batch[:, [3, 7]] * 0.5 + expression[:, [3, 7]] * 0.5 - return random_batch * 0.8 * blend + expression * (1 - blend) - - -def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor): +def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor) -> Tensor: points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3) points_transformed *= scale[..., None] points_transformed[:, :, 0:2] += translation[:, None, 0:2] diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 342db36..12d740b 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,96 +1,28 @@ import configparser -import random -import numpy - -from typing import Tuple import os -import cv2 -import torchvision +import cv2 +import numpy import pytorch_lightning +import torch +import torchvision +from LivePortrait.src.modules.motion_extractor import MotionExtractor from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.types import Optimizer +from pytorch_msssim import ssim from torch.utils.data import DataLoader -from pytorch_lightning.utilities.types import OptimizerLRScheduler -import torch +from .data_loader import DataLoaderVGG, read_image from .discriminator import MultiscaleDiscriminator from .generator import AdaptiveEmbeddingIntegrationNetwork -from .data_loader import DataLoaderVGG, read_image - -from .typing import Tensor, LossDict, TargetAttributes, DiscriminatorOutputs, Batch -from .helper import hinge_loss, calc_distance_ratio, L2_loss, randomize_expression -from pytorch_msssim import ssim +from .helper import L2_loss, hinge_loss +from .typing import Batch, DiscriminatorOutputs, IDEmbedding, LossDict, TargetAttributes, Tensor, Tuple CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -def load_models(): - id_channels = CONFIG.getint('training.generator', 'id_channels') - num_blocks = CONFIG.getint('training.generator', 'num_blocks') - generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) - - input_channels = CONFIG.getint('training.discriminator', 'input_channels') - num_filters = CONFIG.getint('training.discriminator', 'num_filters') - num_layers = CONFIG.getint('training.discriminator', 'num_layers') - num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators') - discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) - - model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path') - arcface = torch.load(model_path, map_location = 'cpu', weights_only = False) - arcface.eval() - - if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0: - model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path') - landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False) - landmarker.eval() - else: - landmarker = None - - if CONFIG.getfloat('training.losses', 'weight_tsr') > 0 or CONFIG.getboolean('preparing.augmentation', 'expression'): - from LivePortrait.src.modules.motion_extractor import MotionExtractor - - model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path') - motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') - motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True)) - motion_extractor.eval() - else: - motion_extractor = None - - if CONFIG.getboolean('preparing.augmentation', 'expression'): - from LivePortrait.src.modules.appearance_feature_extractor import AppearanceFeatureExtractor - from LivePortrait.src.modules.warping_network import WarpingNetwork - from LivePortrait.src.modules.spade_generator import SPADEDecoder - - feature_extractor_path = CONFIG.get('auxiliary_models.paths', 'feature_extractor_path') - feature_extractor = AppearanceFeatureExtractor(3, 64, 2, 512, 32, 16, 6) - feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location = 'cpu', weights_only = True)) - feature_extractor.eval() - - warping_network_path = CONFIG.get('auxiliary_models.paths', 'warping_network_path') - dense_motion_params = { - 'block_expansion': 32, - 'max_features': 1024, - 'num_blocks': 5, - 'reshape_depth': 16, - 'compress': 4 - } - warping_network = WarpingNetwork(num_kp = 21, block_expansion = 64, max_features = 512, num_down_blocks = 2, reshape_channel = 32, estimate_occlusion_map = True, dense_motion_params = dense_motion_params) - warping_network.load_state_dict(torch.load(warping_network_path, map_location='cpu', weights_only=True)) - warping_network.eval() - - spade_generator_path = CONFIG.get('auxiliary_models.paths', 'spade_generator_path') - spade_generator = SPADEDecoder(upscale = 2, block_expansion = 64, max_features = 512, num_down_blocks = 2) - spade_generator.load_state_dict(torch.load(spade_generator_path, map_location = 'cpu', weights_only = True)) - spade_generator.eval() - else: - feature_extractor = None - warping_network = None - spade_generator = None - return generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator - - def create_trainer() -> Trainer: trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') output_directory_path = CONFIG.get('training.output', 'directory_path') @@ -106,7 +38,6 @@ def create_trainer() -> Trainer: monitor = 'l_G', dirpath = output_directory_path, filename = output_file_pattern, - # every_n_epochs = 1, every_n_train_steps = 1000, save_top_k = 5, mode = 'min', @@ -118,7 +49,7 @@ def create_trainer() -> Trainer: ) -def train(): +def train() -> None: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') checkpoint_path = CONFIG.get('training.output', 'checkpoint_path') @@ -127,85 +58,50 @@ def train(): if not (checkpoint_path and os.path.exists(checkpoint_path)): checkpoint_path = None data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - face_swap_model = FaceSwapper(*load_models()) + face_swap_model = FaceSwapper() trainer = create_trainer() trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path) class FaceSwapper(pytorch_lightning.LightningModule): - def __init__(self, generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator) -> None: + def __init__(self) -> None: super().__init__() - self.generator = generator - self.discriminator = discriminator - self.arcface = arcface - self.landmarker = landmarker - self.motion_extractor = motion_extractor - self.feature_extractor = feature_extractor - self.warping_network = warping_network - self.spade_generator = spade_generator - - self.loss_adversarial_accumulated = 20 + self.generator = AdaptiveEmbeddingIntegrationNetwork(CONFIG.getint('training.generator', 'id_channels'), CONFIG.getint('training.generator', 'num_blocks')) + self.discriminator = MultiscaleDiscriminator(CONFIG.getint('training.discriminator', 'input_channels'), CONFIG.getint('training.discriminator', 'num_filters'), CONFIG.getint('training.discriminator', 'num_layers'), CONFIG.getint('training.discriminator', 'num_discriminators')) + self.arcface = torch.load(CONFIG.get('auxiliary_models.paths', 'arcface_path'), map_location = 'cpu', weights_only = False) + self.landmarker = torch.load(CONFIG.get('auxiliary_models.paths', 'landmarker_path'), map_location = 'cpu', weights_only = False) + self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') + self.motion_extractor.load_state_dict(torch.load(CONFIG.get('auxiliary_models.paths', 'motion_extractor_path'), map_location = 'cpu', weights_only = True)) + self.arcface.eval() + self.landmarker.eval() + self.motion_extractor.eval() self.automatic_optimization = False self.batch_size = CONFIG.getint('training.loader', 'batch_size') - - def forward(self, target_tensor : Tensor, source_embedding : Tensor) -> Tensor: + def forward(self, target_tensor : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]: output = self.generator(target_tensor, source_embedding) return output - - def state_dict(self, *args, **kwargs): - return { - "generator": self.generator.state_dict(), - "discriminator": self.discriminator.state_dict(), - } - - def load_state_dict(self, state_dict, strict: bool = True): - if "generator" in state_dict: - self.generator.load_state_dict(state_dict["generator"], strict = strict) - if "discriminator" in state_dict: - self.discriminator.load_state_dict(state_dict["discriminator"], strict = strict) - - - def configure_optimizers(self) -> OptimizerLRScheduler: + def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]: generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) - if CONFIG.getboolean('training.schedulers', 'enable'): - generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) - discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) - return ( - { - "optimizer": generator_optimizer, - "lr_scheduler": generator_scheduler - }, - { - "optimizer": discriminator_optimizer, - "lr_scheduler": discriminator_scheduler - }) return generator_optimizer, discriminator_optimizer - def training_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor, is_same_person = batch - generator_optimizer, discriminator_optimizer = self.optimizers() - source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0)) - - if random.random() > 0.5 and CONFIG.getboolean('preparing.augmentation', 'expression'): - target_tensor = randomize_expression(target_tensor, self.feature_extractor, self.motion_extractor, self.warping_network, self.spade_generator) - - swap_tensor, target_attributes = self(target_tensor, source_embedding) + generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] + source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0)) + swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) discriminator_outputs = self.discriminator(swap_tensor) - generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch) generator_optimizer.zero_grad() self.manual_backward(generator_losses.get('loss_generator')) generator_optimizer.step() - discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor) discriminator_optimizer.zero_grad() self.manual_backward(discriminator_losses.get('loss_discriminator')) - if not CONFIG.getboolean('training.discriminator', 'disable') or self.loss_adversarial_accumulated < 0.4: + if not CONFIG.getboolean('training.discriminator', 'disable'): discriminator_optimizer.step() if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0: @@ -215,36 +111,33 @@ class FaceSwapper(pytorch_lightning.LightningModule): self.log_validation_preview() self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True) self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True) - self.log('l_ADV_A', self.loss_adversarial_accumulated, prog_bar = True) self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False) - self.log('l_id', generator_losses.get('loss_identity'), prog_bar = True) - self.log('l_attr', generator_losses.get('loss_attribute'), prog_bar = True) - self.log('l_rec', generator_losses.get('loss_reconstruction'), prog_bar = True) + self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True) + self.log('l_ID', generator_losses.get('loss_identity'), prog_bar=True) + self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True) return generator_losses.get('loss_generator') - def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict: source_tensor, target_tensor, is_same_person = batch generator_losses = {} # adversarial loss - loss_adversarial = 0 + loss_adversarial = torch.Tensor(0) for discriminator_output in discriminator_outputs: loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ]) loss_adversarial = torch.mean(loss_adversarial) generator_losses['loss_adversarial'] = loss_adversarial generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial') - self.loss_adversarial_accumulated = self.loss_adversarial_accumulated * 0.98 + loss_adversarial.item() * 0.02 # identity loss - swap_embedding = self.get_arcface_embedding(swap_tensor, (30, 0, 10, 10)) - source_embedding = self.get_arcface_embedding(source_tensor, (30, 0, 10, 10)) + swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10)) + source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10)) loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() generator_losses['loss_identity'] = loss_identity generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity') # attribute loss - loss_attribute = 0 + loss_attribute = torch.Tensor(0) swap_attributes = self.generator.get_attributes(swap_tensor) for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): @@ -256,7 +149,7 @@ class FaceSwapper(pytorch_lightning.LightningModule): # reconstruction loss loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4) loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() - loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7 + loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 generator_losses['loss_reconstruction'] = loss_reconstruction generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction') @@ -271,49 +164,32 @@ class FaceSwapper(pytorch_lightning.LightningModule): generator_losses['loss_tsr'] = loss_tsr generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr') - - if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0: + if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0: swap_landmark_features = self.get_landmark_features(swap_tensor) target_landmark_features = self.get_landmark_features(target_tensor) - - # eye gaze loss - loss_left_eye_gaze = L2_loss(swap_landmark_features[3], target_landmark_features[3]) - loss_right_eye_gaze = L2_loss(swap_landmark_features[4], target_landmark_features[4]) + loss_left_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1]) + loss_right_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1]) loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze generator_losses['loss_eye_gaze'] = loss_eye_gaze generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze') - - # eye open loss - loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0]) - loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1]) - loss_eye_open = loss_left_eye_open + loss_right_eye_open - generator_losses['loss_eye_open'] = loss_eye_open - generator_losses['loss_generator'] += loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open') - - # lip open loss - loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2]) - generator_losses['loss_lip_open'] = loss_lip_open - generator_losses['loss_generator'] += loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open') return generator_losses - def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict: discriminator_losses = {} fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) - loss_fake = 0 + loss_fake = torch.Tensor(0) for fake_discriminator_output in fake_discriminator_outputs: loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3])) true_discriminator_outputs = self.discriminator(source_tensor) - loss_true = 0 + loss_true = torch.Tensor(0) for true_discriminator_output in true_discriminator_outputs: loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3])) discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 return discriminator_losses - - def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor: + def get_id_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor: _, _, height, width = vision_tensor.shape crop_height = int(height * 0.0586) crop_width = int(width * 0.0586) @@ -327,19 +203,12 @@ class FaceSwapper(pytorch_lightning.LightningModule): embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1) return embedding - - def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor]: vision_tensor_norm = (vision_tensor + 1) * 0.5 vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') landmarks = self.landmarker(vision_tensor_norm)[2] landmarks = landmarks.view(-1, 203, 2) * 256 - left_eye_open_ratio = calc_distance_ratio(landmarks, (6, 18, 0, 12)) - right_eye_open_ratio = calc_distance_ratio(landmarks, (30, 42, 24, 36)) - lip_open_ratio = calc_distance_ratio(landmarks, (90, 102, 48, 66)) - left_eye_gaze = landmarks[:, 198] - right_eye_gaze = landmarks[:, 197] - return left_eye_open_ratio, right_eye_open_ratio, lip_open_ratio, left_eye_gaze, right_eye_gaze - + return landmarks[:, 198], landmarks[:, 197] def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: vision_tensor_norm = (vision_tensor + 1) * 0.5 @@ -349,21 +218,17 @@ class FaceSwapper(pytorch_lightning.LightningModule): rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1) return translation, scale, rotation - - def log_generator_preview(self, source_tensor, target_tensor, swap_tensor): + def log_generator_preview(self, source_tensor : Tensor, target_tensor : Tensor, swap_tensor : Tensor) -> None: max_preview = 8 source_tensor = source_tensor[:max_preview] target_tensor = target_tensor[:max_preview] swap_tensor = swap_tensor[:max_preview] - rows = [torch.cat([src, tgt, swp], dim=2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)] - grid = torchvision.utils.make_grid(torch.cat(rows, dim=1).unsqueeze(0), nrow=1, normalize=True, scale_each=True) - os.makedirs("previews", exist_ok=True) - torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg") + rows = [torch.cat([src, tgt, swp], dim = 2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)] + grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True) self.logger.experiment.add_image("Generator Preview", grid, self.global_step) - - def log_validation_preview(self): - read_images = lambda path: [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] + def log_validation_preview(self) -> None: + read_images = lambda path : [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5 transforms = torchvision.transforms.Compose( [ @@ -377,7 +242,6 @@ class FaceSwapper(pytorch_lightning.LightningModule): targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup')) targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion')) - self.generator.eval() results_source = [] @@ -388,7 +252,7 @@ class FaceSwapper(pytorch_lightning.LightningModule): for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion): source_tensor = transforms(source).unsqueeze(0).to(self.device).half() - source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0)) + source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0)) target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half() target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half() target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half() diff --git a/face_swapper/src/typing.py b/face_swapper/src/typing.py index b99e58b..29ab130 100644 --- a/face_swapper/src/typing.py +++ b/face_swapper/src/typing.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from typing import Any, Dict, List, Tuple from numpy.typing import NDArray @@ -10,6 +11,6 @@ TargetAttributes = Tuple[Tensor, ...] DiscriminatorOutputs = List[List[Tensor]] LossDict = Dict[str, Tensor] IDEmbedding = Tensor - +StateDict = OrderedDict[str, Any] Embedding = NDArray[Any] VisionFrame = NDArray[Any]