From 3b8b6442fc1f09189c38660a7e8d69fd858770ff Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 26 Jan 2025 22:03:34 +0530 Subject: [PATCH] add infer and some cleaning --- face_swapper/config.ini | 42 +++++++++----------- face_swapper/infer.py | 29 ++++++++++++++ face_swapper/src/data_loader.py | 24 ++++-------- face_swapper/src/helper.py | 47 ++++++++++++++++++++++- face_swapper/src/training.py | 68 ++++++++++++++------------------- 5 files changed, 127 insertions(+), 83 deletions(-) create mode 100644 face_swapper/infer.py diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 116c6ec..8422a1f 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -2,34 +2,26 @@ dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan folder_pattern = {}/* image_pattern = {}/*.*g - -[preparing.dataloader] same_person_probability = 0.2 -[preparing.augmentation] -expression = false - [training.loader] -batch_size = 4 -num_workers = 8 +batch_size = 24 +num_workers = 12 -[training.generator] +[training.model] +id_embedder_path = +landmarker_path = +motion_extractor_path = + +[training.model.generator] num_blocks = 2 id_channels = 512 -learning_rate = 0.0004 -[training.discriminator] +[training.model.discriminator] input_channels = 3 num_filters = 64 num_layers = 5 num_discriminators = 3 -learning_rate = 0.0004 -disable = false - -[auxiliary_models.paths] #just model.trainer ... try to match the config of the arcface converter -arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt -landmarker_path = /assets/pretrained_models/landmark_203.pt -motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth [training.losses] weight_adversarial = 1 @@ -38,12 +30,9 @@ weight_attribute = 10 weight_reconstruction = 10 weight_tsr = 100 -[training.schedulers] -step = 5000 -gamma = 0.2 - [training.trainer] max_epochs = 50 +learning_rate = 0.0004 [training.output] checkpoint_path = checkpoints/last.ckpt @@ -52,10 +41,6 @@ file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}' preview_frequency = 250 validation_frequency = 1000 -[training.validation] -sources = assets/test/front/sources -targets = assets/test/front/targets - [exporting] directory_path = source_path = @@ -64,3 +49,10 @@ opset_version = [execution] providers = + +[inference] +generator_path = +id_embedder_path = +source_path = +target_path = +output_path = diff --git a/face_swapper/infer.py b/face_swapper/infer.py new file mode 100644 index 0000000..1a2ac35 --- /dev/null +++ b/face_swapper/infer.py @@ -0,0 +1,29 @@ +import configparser + +import cv2 +import torch +from src.generator import AdaptiveEmbeddingIntegrationNetwork +from src.helper import infer, read_image + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +if __name__ == '__main__': + generator_path = CONFIG.get('inference', 'generator_path') + id_embedder_path = CONFIG.get('inference', 'id_embedder_path') + source_path = CONFIG.get('inference', 'source_path') + target_path = CONFIG.get('inference', 'target_path') + output_path = CONFIG.get('inference', 'output_path') + + state_dict = torch.load(generator_path, map_location = 'cpu')['state_dict']['generator'] + generator = AdaptiveEmbeddingIntegrationNetwork(512, 2) + generator.load_state_dict(state_dict) + generator.eval() + id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') #type:ignore[no-untyped-call] + id_embedder.eval() + + source_vision_frame = read_image(source_path) + target_vision_frame = read_image(target_path) + output_vision_frame = infer(generator, id_embedder, source_vision_frame, target_vision_frame) + cv2.imwrite(output_path, output_vision_frame) diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index c4f8a9f..c8b6784 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -1,35 +1,24 @@ -import configparser import glob import os.path import random -import cv2 import torch import torchvision.transforms as transforms from torch.utils.data import TensorDataset -from .typing import Batch, VisionFrame - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - - -def read_image(image_path: str) -> VisionFrame: - image = cv2.imread(image_path)[:, :, ::-1] - return image +from .helper import read_image +from .typing import Batch class DataLoaderVGG(TensorDataset): - def __init__(self, dataset_path : str) -> None: - self.same_person_probability = CONFIG.getfloat('preparing.dataloader', 'same_person_probability') - image_pattern = CONFIG.get('preparing.dataset', 'image_pattern') - folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern') - self.folder_paths = glob.glob(folder_pattern.format(dataset_path)) + def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None: + self.same_person_probability = same_person_probability + self.folder_paths = glob.glob(dataset_folder_pattern.format(dataset_path)) self.image_paths = [] self.image_path_set = {} for folder_path in self.folder_paths: - image_paths = glob.glob(image_pattern.format(folder_path)) + image_paths = glob.glob(dataset_image_pattern.format(folder_path)) self.image_paths.extend(image_paths) self.image_path_set[folder_path] = image_paths self.dataset_total = len(self.image_paths) @@ -40,6 +29,7 @@ class DataLoaderVGG(TensorDataset): transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0), transforms.ToTensor(), + transforms.Lambda(lambda img: img[[2, 1, 0], :, :]), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 5abdb41..438e273 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -1,6 +1,30 @@ +import cv2 +import numpy import torch -from .typing import Tensor +from .typing import IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor + + +def read_image(image_path : str) -> VisionFrame: + image = cv2.imread(image_path) + return image + + +def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor: + vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32)) + vision_tensor = vision_tensor / 255 + vision_tensor = (vision_tensor - 0.5) * 2 + vision_tensor = vision_tensor.unsqueeze(0) + return vision_tensor + + +def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame: + vision_frame = vision_tensor.detach().cpu().numpy()[0] + vision_frame = vision_frame.transpose(1, 2, 0) + vision_frame = (vision_frame + 1) * 127.5 + vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8) + vision_frame = vision_frame[:, :, ::-1] + return vision_frame def hinge_real_loss(tensor : Tensor) -> Tensor: @@ -9,3 +33,24 @@ def hinge_real_loss(tensor : Tensor) -> Tensor: def hinge_fake_loss(tensor : Tensor) -> Tensor: return torch.relu(tensor + 1) + + +def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding: + crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241] + crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area') + crop_vision_tensor[:, :, :padding[0], :] = 0 + crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 + crop_vision_tensor[:, :, :, :padding[2]] = 0 + crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 + source_embedding = id_embedder(crop_vision_tensor) + source_embedding = torch.nn.functional.normalize(source_embedding, p = 2, dim = 1) + return source_embedding + + +def infer(generator : torch.nn.Module, id_embedder : torch.nn.Module, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: + source_vision_tensor = convert_to_vision_tensor(source_vision_frame) + target_vision_tensor = convert_to_vision_tensor(target_vision_frame) + source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0)) + output_vision_tensor = generator(source_embedding, target_vision_tensor)[0] + output_vision_frame = convert_to_vision_frame(output_vision_tensor) + return output_vision_frame diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 9416543..e248ee3 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -5,7 +5,6 @@ from typing import Tuple import pytorch_lightning import torch import torchvision -from LivePortrait.src.modules.motion_extractor import MotionExtractor from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities.types import Optimizer @@ -16,8 +15,8 @@ from torch.utils.data import DataLoader from .data_loader import DataLoaderVGG from .discriminator import MultiscaleDiscriminator from .generator import AdaptiveEmbeddingIntegrationNetwork -from .helper import hinge_fake_loss, hinge_real_loss -from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor +from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss +from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -26,23 +25,22 @@ CONFIG.read('config.ini') class FaceSwapper(pytorch_lightning.LightningModule): def __init__(self) -> None: super().__init__() - id_channels = CONFIG.getint('training.generator', 'id_channels') - num_blocks = CONFIG.getint('training.generator', 'num_blocks') - input_channels = CONFIG.getint('training.discriminator', 'input_channels') - num_filters = CONFIG.getint('training.discriminator', 'num_filters') - num_layers = CONFIG.getint('training.discriminator', 'num_layers') - num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators') - arcface_path = CONFIG.get('auxiliary_models.paths', 'arcface_path') - landmarker_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path') - motion_extractor_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path') + id_channels = CONFIG.getint('training.model.generator', 'id_channels') + num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') + input_channels = CONFIG.getint('training.model.discriminator', 'input_channels') + num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') + num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') + num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') + id_embedder_path = CONFIG.get('training.model', 'id_embedder_path') + landmarker_path = CONFIG.get('training.model', 'landmarker_path') + motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) - self.arcface = torch.load(arcface_path, map_location = 'cpu', weights_only = False) - self.landmarker = torch.load(landmarker_path, map_location = 'cpu', weights_only = False) - self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') - self.motion_extractor.load_state_dict(torch.load(motion_extractor_path, map_location = 'cpu', weights_only = True)) - self.arcface.eval() + self.id_embedder = torch.jit.load(id_embedder_path, map_location ='cpu') #type:ignore[no-untyped-call] + self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') #type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') #type:ignore[no-untyped-call] + self.id_embedder.eval() self.landmarker.eval() self.motion_extractor.eval() self.automatic_optimization = False @@ -54,16 +52,15 @@ class FaceSwapper(pytorch_lightning.LightningModule): return output def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]: - generator_learning_rate = CONFIG.getfloat('training.generator', 'learning_rate') - discriminator_learning_rate = CONFIG.getfloat('training.discriminator', 'learning_rate') - generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = generator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = discriminator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') + generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) return generator_optimizer, discriminator_optimizer def training_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor, is_same_person = batch generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] - source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0)) + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0)) swap_tensor, target_attributes = self.generator(target_tensor, source_embedding) discriminator_outputs = self.discriminator(swap_tensor) @@ -112,8 +109,8 @@ class FaceSwapper(pytorch_lightning.LightningModule): return loss_reconstruction def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10)) - source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10)) + swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10)) + source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10)) loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() return loss_id @@ -181,17 +178,6 @@ class FaceSwapper(pytorch_lightning.LightningModule): discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 return discriminator_loss_set - def get_id_embedding(self, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding: - crop_vision_tensor = torch.nn.functional.interpolate(vision_tensor, size = (112, 112), mode = 'area') - crop_vision_tensor = crop_vision_tensor[:, :, 0:112, 8:128] - crop_vision_tensor[:, :, :padding[0], :] = 0 - crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 - crop_vision_tensor[:, :, :, :padding[2]] = 0 - crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 - embedding = self.arcface(crop_vision_tensor) - embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1) - return embedding - def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: vision_tensor_norm = (vision_tensor + 1) * 0.5 vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') @@ -200,10 +186,8 @@ class FaceSwapper(pytorch_lightning.LightningModule): def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: vision_tensor_norm = (vision_tensor + 1) * 0.5 - motion_dict = self.motion_extractor(vision_tensor_norm) - translation = motion_dict.get('t') - scale = motion_dict.get('scale') - rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1) + pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) + rotation = torch.cat([ pitch, yaw, roll ], dim = 1) return translation, scale, rotation def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None: @@ -246,7 +230,11 @@ def train() -> None: batch_size = CONFIG.getint('training.loader', 'batch_size') num_workers = CONFIG.getint('training.loader', 'num_workers') checkpoint_path = CONFIG.get('training.output', 'checkpoint_path') - dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path')) + dataset_path = CONFIG.get('preparing.dataset', 'dataset_path') + dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern') + dataset_folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern') + same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability') + dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_folder_pattern, same_person_probability) data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) face_swap_model = FaceSwapper()