mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
add infer and some cleaning
This commit is contained in:
+17
-25
@@ -2,34 +2,26 @@
|
||||
dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan
|
||||
folder_pattern = {}/*
|
||||
image_pattern = {}/*.*g
|
||||
|
||||
[preparing.dataloader]
|
||||
same_person_probability = 0.2
|
||||
|
||||
[preparing.augmentation]
|
||||
expression = false
|
||||
|
||||
[training.loader]
|
||||
batch_size = 4
|
||||
num_workers = 8
|
||||
batch_size = 24
|
||||
num_workers = 12
|
||||
|
||||
[training.generator]
|
||||
[training.model]
|
||||
id_embedder_path =
|
||||
landmarker_path =
|
||||
motion_extractor_path =
|
||||
|
||||
[training.model.generator]
|
||||
num_blocks = 2
|
||||
id_channels = 512
|
||||
learning_rate = 0.0004
|
||||
|
||||
[training.discriminator]
|
||||
[training.model.discriminator]
|
||||
input_channels = 3
|
||||
num_filters = 64
|
||||
num_layers = 5
|
||||
num_discriminators = 3
|
||||
learning_rate = 0.0004
|
||||
disable = false
|
||||
|
||||
[auxiliary_models.paths] #just model.trainer ... try to match the config of the arcface converter
|
||||
arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt
|
||||
landmarker_path = /assets/pretrained_models/landmark_203.pt
|
||||
motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth
|
||||
|
||||
[training.losses]
|
||||
weight_adversarial = 1
|
||||
@@ -38,12 +30,9 @@ weight_attribute = 10
|
||||
weight_reconstruction = 10
|
||||
weight_tsr = 100
|
||||
|
||||
[training.schedulers]
|
||||
step = 5000
|
||||
gamma = 0.2
|
||||
|
||||
[training.trainer]
|
||||
max_epochs = 50
|
||||
learning_rate = 0.0004
|
||||
|
||||
[training.output]
|
||||
checkpoint_path = checkpoints/last.ckpt
|
||||
@@ -52,10 +41,6 @@ file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
|
||||
preview_frequency = 250
|
||||
validation_frequency = 1000
|
||||
|
||||
[training.validation]
|
||||
sources = assets/test/front/sources
|
||||
targets = assets/test/front/targets
|
||||
|
||||
[exporting]
|
||||
directory_path =
|
||||
source_path =
|
||||
@@ -64,3 +49,10 @@ opset_version =
|
||||
|
||||
[execution]
|
||||
providers =
|
||||
|
||||
[inference]
|
||||
generator_path =
|
||||
id_embedder_path =
|
||||
source_path =
|
||||
target_path =
|
||||
output_path =
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import configparser
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from src.generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from src.helper import infer, read_image
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generator_path = CONFIG.get('inference', 'generator_path')
|
||||
id_embedder_path = CONFIG.get('inference', 'id_embedder_path')
|
||||
source_path = CONFIG.get('inference', 'source_path')
|
||||
target_path = CONFIG.get('inference', 'target_path')
|
||||
output_path = CONFIG.get('inference', 'output_path')
|
||||
|
||||
state_dict = torch.load(generator_path, map_location = 'cpu')['state_dict']['generator']
|
||||
generator = AdaptiveEmbeddingIntegrationNetwork(512, 2)
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') #type:ignore[no-untyped-call]
|
||||
id_embedder.eval()
|
||||
|
||||
source_vision_frame = read_image(source_path)
|
||||
target_vision_frame = read_image(target_path)
|
||||
output_vision_frame = infer(generator, id_embedder, source_vision_frame, target_vision_frame)
|
||||
cv2.imwrite(output_path, output_vision_frame)
|
||||
@@ -1,35 +1,24 @@
|
||||
import configparser
|
||||
import glob
|
||||
import os.path
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
from .typing import Batch, VisionFrame
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def read_image(image_path: str) -> VisionFrame:
|
||||
image = cv2.imread(image_path)[:, :, ::-1]
|
||||
return image
|
||||
from .helper import read_image
|
||||
from .typing import Batch
|
||||
|
||||
|
||||
class DataLoaderVGG(TensorDataset):
|
||||
def __init__(self, dataset_path : str) -> None:
|
||||
self.same_person_probability = CONFIG.getfloat('preparing.dataloader', 'same_person_probability')
|
||||
image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
|
||||
folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern')
|
||||
self.folder_paths = glob.glob(folder_pattern.format(dataset_path))
|
||||
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None:
|
||||
self.same_person_probability = same_person_probability
|
||||
self.folder_paths = glob.glob(dataset_folder_pattern.format(dataset_path))
|
||||
self.image_paths = []
|
||||
self.image_path_set = {}
|
||||
|
||||
for folder_path in self.folder_paths:
|
||||
image_paths = glob.glob(image_pattern.format(folder_path))
|
||||
image_paths = glob.glob(dataset_image_pattern.format(folder_path))
|
||||
self.image_paths.extend(image_paths)
|
||||
self.image_path_set[folder_path] = image_paths
|
||||
self.dataset_total = len(self.image_paths)
|
||||
@@ -40,6 +29,7 @@ class DataLoaderVGG(TensorDataset):
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda img: img[[2, 1, 0], :, :]),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
|
||||
@@ -1,6 +1,30 @@
|
||||
import cv2
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from .typing import Tensor
|
||||
from .typing import IdEmbedding, Padding, Tensor, VisionFrame, VisionTensor
|
||||
|
||||
|
||||
def read_image(image_path : str) -> VisionFrame:
|
||||
image = cv2.imread(image_path)
|
||||
return image
|
||||
|
||||
|
||||
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
|
||||
vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
|
||||
vision_tensor = vision_tensor / 255
|
||||
vision_tensor = (vision_tensor - 0.5) * 2
|
||||
vision_tensor = vision_tensor.unsqueeze(0)
|
||||
return vision_tensor
|
||||
|
||||
|
||||
def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
|
||||
vision_frame = vision_tensor.detach().cpu().numpy()[0]
|
||||
vision_frame = vision_frame.transpose(1, 2, 0)
|
||||
vision_frame = (vision_frame + 1) * 127.5
|
||||
vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8)
|
||||
vision_frame = vision_frame[:, :, ::-1]
|
||||
return vision_frame
|
||||
|
||||
|
||||
def hinge_real_loss(tensor : Tensor) -> Tensor:
|
||||
@@ -9,3 +33,24 @@ def hinge_real_loss(tensor : Tensor) -> Tensor:
|
||||
|
||||
def hinge_fake_loss(tensor : Tensor) -> Tensor:
|
||||
return torch.relu(tensor + 1)
|
||||
|
||||
|
||||
def calc_id_embedding(id_embedder : torch.nn.Module, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
|
||||
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
|
||||
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = torch.nn.functional.normalize(source_embedding, p = 2, dim = 1)
|
||||
return source_embedding
|
||||
|
||||
|
||||
def infer(generator : torch.nn.Module, id_embedder : torch.nn.Module, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
|
||||
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
|
||||
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
|
||||
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
|
||||
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
|
||||
return output_vision_frame
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Tuple
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
import torchvision
|
||||
from LivePortrait.src.modules.motion_extractor import MotionExtractor
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.types import Optimizer
|
||||
@@ -16,8 +15,8 @@ from torch.utils.data import DataLoader
|
||||
from .data_loader import DataLoaderVGG
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .helper import hinge_fake_loss, hinge_real_loss
|
||||
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor
|
||||
from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -26,23 +25,22 @@ CONFIG.read('config.ini')
|
||||
class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
id_channels = CONFIG.getint('training.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
|
||||
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
|
||||
arcface_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
|
||||
landmarker_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
|
||||
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
|
||||
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
|
||||
self.arcface = torch.load(arcface_path, map_location = 'cpu', weights_only = False)
|
||||
self.landmarker = torch.load(landmarker_path, map_location = 'cpu', weights_only = False)
|
||||
self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
self.motion_extractor.load_state_dict(torch.load(motion_extractor_path, map_location = 'cpu', weights_only = True))
|
||||
self.arcface.eval()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location ='cpu') #type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') #type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') #type:ignore[no-untyped-call]
|
||||
self.id_embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
self.automatic_optimization = False
|
||||
@@ -54,16 +52,15 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
return output
|
||||
|
||||
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
|
||||
generator_learning_rate = CONFIG.getfloat('training.generator', 'learning_rate')
|
||||
discriminator_learning_rate = CONFIG.getfloat('training.discriminator', 'learning_rate')
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = generator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = discriminator_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
discriminator_outputs = self.discriminator(swap_tensor)
|
||||
|
||||
@@ -112,8 +109,8 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
|
||||
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
|
||||
return loss_id
|
||||
|
||||
@@ -181,17 +178,6 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def get_id_embedding(self, vision_tensor : VisionTensor, padding : Padding) -> IdEmbedding:
|
||||
crop_vision_tensor = torch.nn.functional.interpolate(vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor = crop_vision_tensor[:, :, 0:112, 8:128]
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
embedding = self.arcface(crop_vision_tensor)
|
||||
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
|
||||
return embedding
|
||||
|
||||
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
@@ -200,10 +186,8 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
|
||||
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
motion_dict = self.motion_extractor(vision_tensor_norm)
|
||||
translation = motion_dict.get('t')
|
||||
scale = motion_dict.get('scale')
|
||||
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
def log_generator_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> None:
|
||||
@@ -246,7 +230,11 @@ def train() -> None:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
|
||||
dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path'))
|
||||
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
|
||||
dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
|
||||
dataset_folder_pattern = CONFIG.get('preparing.dataset', 'folder_pattern')
|
||||
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
|
||||
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_folder_pattern, same_person_probability)
|
||||
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapper()
|
||||
|
||||
Reference in New Issue
Block a user