add infer and some cleaning

This commit is contained in:
harisreedhar
2025-01-26 22:03:34 +05:30
committed by henryruhs
parent cfcf0ee2bd
commit 3b8b6442fc
5 changed files with 127 additions and 83 deletions
+17 -25
View File
@@ -2,34 +2,26 @@
dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan
folder_pattern = {}/*
image_pattern = {}/*.*g
[preparing.dataloader]
same_person_probability = 0.2
[preparing.augmentation]
expression = false
[training.loader]
batch_size = 4
num_workers = 8
batch_size = 24
num_workers = 12
[training.generator]
[training.model]
id_embedder_path =
landmarker_path =
motion_extractor_path =
[training.model.generator]
num_blocks = 2
id_channels = 512
learning_rate = 0.0004
[training.discriminator]
[training.model.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
learning_rate = 0.0004
disable = false
[auxiliary_models.paths] #just model.trainer ... try to match the config of the arcface converter
arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt
landmarker_path = /assets/pretrained_models/landmark_203.pt
motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth
[training.losses]
weight_adversarial = 1
@@ -38,12 +30,9 @@ weight_attribute = 10
weight_reconstruction = 10
weight_tsr = 100
[training.schedulers]
step = 5000
gamma = 0.2
[training.trainer]
max_epochs = 50
learning_rate = 0.0004
[training.output]
checkpoint_path = checkpoints/last.ckpt
@@ -52,10 +41,6 @@ file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
preview_frequency = 250
validation_frequency = 1000
[training.validation]
sources = assets/test/front/sources
targets = assets/test/front/targets
[exporting]
directory_path =
source_path =
@@ -64,3 +49,10 @@ opset_version =
[execution]
providers =
[inference]
generator_path =
id_embedder_path =
source_path =
target_path =
output_path =
+29
View File
@@ -0,0 +1,29 @@
import configparser
import cv2
import torch
from src.generator import AdaptiveEmbeddingIntegrationNetwork
from src.helper import infer, read_image
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
if __name__ == '__main__':
generator_path = CONFIG.get('inference', 'generator_path')
id_embedder_path = CONFIG.get('inference', 'id_embedder_path')
source_path = CONFIG.get('inference', 'source_path')
target_path = CONFIG.get('inference', 'target_path')
output_path = CONFIG.get('inference', 'output_path')
state_dict = torch.load(generator_path, map_location = 'cpu')['state_dict']['generator']
generator = AdaptiveEmbeddingIntegrationNetwork(512, 2)
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') #type:ignore[no-untyped-call]
id_embedder.eval()
source_vision_frame = read_image(source_path)
target_vision_frame = read_image(target_path)
output_vision_frame = infer(generator, id_embedder, source_vision_frame, target_vision_frame)
cv2.imwrite(output_path, output_vision_frame)
+7 -17
View File
@@ -1,35 +1,24 @@
import configparser
import glob
import os.path
import random
import cv2
import torch
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset
from .typing import Batch, VisionFrame
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def read_image(image_path: str) -> VisionFrame:
image = cv2.imread(image_path)[:, :, ::-1]
return image
from .helper import read_image
from .typing import Batch
class DataLoaderVGG(TensorDataset):
def __init__(self, dataset_path : str) -> None:
self.same_person_probability = CONFIG.getfloat('preparing.dataloader', 'same_person_probability')
image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern')
self.folder_paths = glob.glob(folder_pattern.format(dataset_path))
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None:
self.same_person_probability = same_person_probability
self.folder_paths = glob.glob(dataset_folder_pattern.format(dataset_path))
self.image_paths = []
self.image_path_set = {}
for folder_path in self.folder_paths:
image_paths = glob.glob(image_pattern.format(folder_path))
image_paths = glob.glob(dataset_image_pattern.format(folder_path))
self.image_paths.extend(image_paths)
self.image_path_set[folder_path] = image_paths
self.dataset_total = len(self.image_paths)
@@ -40,6 +29,7 @@ class DataLoaderVGG(TensorDataset):
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.ToTensor(),
transforms.Lambda(lambda img: img[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
+46 -1
View File
@@ -1,6 +1,30 @@
import cv2
import numpy
import torch
from .typing import Tensor
from .typing import IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor
def read_image(image_path : str) -> VisionFrame:
image = cv2.imread(image_path)
return image
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
vision_tensor = vision_tensor / 255
vision_tensor = (vision_tensor - 0.5) * 2
vision_tensor = vision_tensor.unsqueeze(0)
return vision_tensor
def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
vision_frame = vision_tensor.detach().cpu().numpy()[0]
vision_frame = vision_frame.transpose(1, 2, 0)
vision_frame = (vision_frame + 1) * 127.5
vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8)
vision_frame = vision_frame[:, :, ::-1]
return vision_frame
def hinge_real_loss(tensor : Tensor) -> Tensor:
@@ -9,3 +33,24 @@ def hinge_real_loss(tensor : Tensor) -> Tensor:
def hinge_fake_loss(tensor : Tensor) -> Tensor:
return torch.relu(tensor + 1)
def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor[:, :, :padding[0], :] = 0
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = torch.nn.functional.normalize(source_embedding, p = 2, dim = 1)
return source_embedding
def infer(generator : torch.nn.Module, id_embedder : torch.nn.Module, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
return output_vision_frame
+28 -40
View File
@@ -5,7 +5,6 @@ from typing import Tuple
import pytorch_lightning
import torch
import torchvision
from LivePortrait.src.modules.motion_extractor import MotionExtractor
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import Optimizer
@@ -16,8 +15,8 @@ from torch.utils.data import DataLoader
from .data_loader import DataLoaderVGG
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import hinge_fake_loss, hinge_real_loss
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor
from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -26,23 +25,22 @@ CONFIG.read('config.ini')
class FaceSwapper(pytorch_lightning.LightningModule):
def __init__(self) -> None:
super().__init__()
id_channels = CONFIG.getint('training.generator', 'id_channels')
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
arcface_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
landmarker_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
motion_extractor_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
self.arcface = torch.load(arcface_path, map_location = 'cpu', weights_only = False)
self.landmarker = torch.load(landmarker_path, map_location = 'cpu', weights_only = False)
self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
self.motion_extractor.load_state_dict(torch.load(motion_extractor_path, map_location = 'cpu', weights_only = True))
self.arcface.eval()
self.id_embedder = torch.jit.load(id_embedder_path, map_location ='cpu') #type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') #type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') #type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
self.automatic_optimization = False
@@ -54,16 +52,15 @@ class FaceSwapper(pytorch_lightning.LightningModule):
return output
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
generator_learning_rate = CONFIG.getfloat('training.generator', 'learning_rate')
discriminator_learning_rate = CONFIG.getfloat('training.discriminator', 'learning_rate')
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = generator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = discriminator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
return generator_optimizer, discriminator_optimizer
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, is_same_person = batch
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
discriminator_outputs = self.discriminator(swap_tensor)
@@ -112,8 +109,8 @@ class FaceSwapper(pytorch_lightning.LightningModule):
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
return loss_id
@@ -181,17 +178,6 @@ class FaceSwapper(pytorch_lightning.LightningModule):
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_loss_set
def get_id_embedding(self, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
crop_vision_tensor = torch.nn.functional.interpolate(vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor = crop_vision_tensor[:, :, 0:112, 8:128]
crop_vision_tensor[:, :, :padding[0], :] = 0
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
embedding = self.arcface(crop_vision_tensor)
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
return embedding
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
@@ -200,10 +186,8 @@ class FaceSwapper(pytorch_lightning.LightningModule):
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
motion_dict = self.motion_extractor(vision_tensor_norm)
translation = motion_dict.get('t')
scale = motion_dict.get('scale')
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
@@ -246,7 +230,11 @@ def train() -> None:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path'))
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
dataset_folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern')
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_folder_pattern, same_person_probability)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapper()