mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
clean
This commit is contained in:
@@ -8,7 +8,6 @@ import tqdm
|
||||
from PIL import Image
|
||||
from torch.utils.data import TensorDataset
|
||||
|
||||
from .augmentations import apply_random_motion_blur
|
||||
from .typing import Batch
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -52,7 +51,6 @@ class DataLoaderVGG(TensorDataset):
|
||||
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.RandomHorizontalFlip(p = 0.5),
|
||||
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
@@ -81,11 +79,3 @@ class DataLoaderVGG(TensorDataset):
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.dataset_total
|
||||
|
||||
|
||||
def state_dict(self):
|
||||
return {'current_index': self._current_index}
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._current_index = state_dict['current_index']
|
||||
|
||||
+3
-101
@@ -1,114 +1,16 @@
|
||||
import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from .typing import Tensor
|
||||
import numpy
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .typing import Tensor, Tuple
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
if CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
from LivePortrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
||||
|
||||
L2_loss = torch.nn.MSELoss()
|
||||
EXPRESSION_MIN = numpy.array(
|
||||
[
|
||||
[
|
||||
[-2.88067125e-02, -8.12731311e-02, -1.70541159e-03],
|
||||
[-4.88598682e-02, -3.32196616e-02, -1.67431499e-04],
|
||||
[-6.75425082e-02, -4.28681746e-02, -1.98950816e-04],
|
||||
[-7.23103955e-02, -3.28503326e-02, -7.31324719e-04],
|
||||
[-3.87073644e-02, -6.01546466e-02, -5.50269964e-04],
|
||||
[-6.38048723e-02, -2.23840728e-01, -7.13261834e-04],
|
||||
[-3.02710701e-02, -3.93195450e-02, -8.24086510e-06],
|
||||
[-2.95799859e-02, -5.39318882e-02, -1.74219604e-04],
|
||||
[-2.92359516e-02, -1.53050944e-02, -6.30460854e-05],
|
||||
[-5.56493877e-03, -2.34344602e-02, -1.26858242e-04],
|
||||
[-4.37593013e-02, -2.77768299e-02, -2.70503685e-02],
|
||||
[-1.76926646e-02, -1.91676542e-02, -1.15090821e-04],
|
||||
[-8.34268332e-03, -3.99775570e-03, -3.27481248e-05],
|
||||
[-3.40162888e-02, -2.81868968e-02, -1.96679524e-04],
|
||||
[-2.91855410e-02, -3.97511162e-02, -2.81230678e-05],
|
||||
[-1.50395725e-02, -2.49494594e-02, -9.42573533e-05],
|
||||
[-1.67938769e-02, -2.00953931e-02, -4.00750607e-04],
|
||||
[-1.86435618e-02, -2.48535164e-02, -2.74416432e-02],
|
||||
[-4.61211195e-03, -1.21660791e-02, -2.93173041e-04],
|
||||
[-4.10017073e-02, -7.43824020e-02, -4.42762971e-02],
|
||||
[-1.90370996e-02, -3.74363363e-02, -1.34740388e-02]
|
||||
]
|
||||
]).astype(numpy.float32)
|
||||
EXPRESSION_MAX = numpy.array(
|
||||
[
|
||||
[
|
||||
[4.46682945e-02, 7.08772913e-02, 4.08344204e-04],
|
||||
[2.14308221e-02, 6.15894832e-02, 4.85319615e-05],
|
||||
[3.02363783e-02, 4.45043296e-02, 1.28298725e-05],
|
||||
[3.05869691e-02, 3.79812494e-02, 6.57040102e-04],
|
||||
[4.45670523e-02, 3.97259220e-02, 7.10966764e-04],
|
||||
[9.43699256e-02, 9.85926315e-02, 2.02551950e-04],
|
||||
[1.61131397e-02, 2.92906128e-02, 3.44733417e-06],
|
||||
[5.23825921e-02, 1.07065082e-01, 6.61510974e-04],
|
||||
[2.85718683e-03, 8.32320191e-03, 2.39314613e-04],
|
||||
[2.57947259e-02, 1.60935968e-02, 2.41853559e-05],
|
||||
[4.90833223e-02, 3.43903080e-02, 3.22353356e-02],
|
||||
[1.44766076e-02, 3.39248963e-02, 1.42291479e-04],
|
||||
[8.75749043e-04, 6.82212645e-03, 2.76097053e-05],
|
||||
[1.86958015e-02, 3.84016186e-02, 7.33085908e-05],
|
||||
[2.01714113e-02, 4.90544215e-02, 2.34028921e-05],
|
||||
[2.46518422e-02, 3.29151377e-02, 3.48571630e-05],
|
||||
[2.22457591e-02, 1.21796541e-02, 1.56396593e-04],
|
||||
[1.72109623e-02, 3.01626958e-02, 1.36556877e-02],
|
||||
[1.83460284e-02, 1.61141958e-02, 2.87440169e-04],
|
||||
[3.57594155e-02, 1.80554688e-01, 2.75554154e-02],
|
||||
[2.17450950e-02, 8.66811201e-02, 3.34241726e-02]
|
||||
]
|
||||
]).astype(numpy.float32)
|
||||
|
||||
|
||||
def randomize_expression(face_tensor, feature_extractor, motion_extractor, warping_network, spade_generator):
|
||||
with torch.no_grad():
|
||||
face_tensor_norm = (face_tensor + 1) * 0.5
|
||||
input_device = face_tensor.device
|
||||
feature_volume = feature_extractor(face_tensor_norm)
|
||||
motion_extractor_dict = motion_extractor(face_tensor_norm)
|
||||
|
||||
translation = motion_extractor_dict.get('t')
|
||||
expression = motion_extractor_dict.get('exp')
|
||||
scale = motion_extractor_dict.get('scale')
|
||||
points = motion_extractor_dict.get('kp')
|
||||
|
||||
pitch = headpose_pred_to_degree(motion_extractor_dict.get('pitch'))[:, None]
|
||||
yaw = headpose_pred_to_degree(motion_extractor_dict.get('yaw'))[:, None]
|
||||
roll = headpose_pred_to_degree(motion_extractor_dict.get('roll'))[:, None]
|
||||
rotation_matrix = get_rotation_matrix(pitch, yaw, roll)
|
||||
random_expression = get_random_expression_blend(expression)
|
||||
|
||||
points_transformed = transform_points(points, rotation_matrix, expression, scale, translation)
|
||||
points_driv = transform_points(points, rotation_matrix, random_expression, scale, translation)
|
||||
|
||||
data = warping_network(feature_volume, points_driv, points_transformed).get('out')
|
||||
output = spade_generator(data)
|
||||
output = output.to(input_device)
|
||||
output = F.interpolate(output.clamp(0, 1), [256, 256], mode='bilinear', align_corners=False)
|
||||
output = (output - 0.5) * 2
|
||||
return output
|
||||
|
||||
|
||||
def get_random_expression_blend(expression : Tensor) -> Tensor:
|
||||
blend = 0.35
|
||||
expression = expression.view(-1, 21, 3)
|
||||
min_array = torch.from_numpy(EXPRESSION_MIN).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
|
||||
max_array = torch.from_numpy(EXPRESSION_MAX).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
|
||||
random_batch = torch.rand_like(min_array).to(expression.device) * (max_array - min_array) + min_array
|
||||
random_batch[:, [0, 1, 8, 6, 9, 4, 5, 10]] = expression[:, [0, 1, 8, 6, 9, 4, 5, 10]]
|
||||
random_batch[:, [3, 7]] = random_batch[:, [13, 16]] * 0.1 + expression[:, [13, 16]] * 0.9
|
||||
random_batch[:, [3, 7]] = random_batch[:, [3, 7]] * 0.5 + expression[:, [3, 7]] * 0.5
|
||||
return random_batch * 0.8 * blend + expression * (1 - blend)
|
||||
|
||||
|
||||
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor):
|
||||
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor) -> Tensor:
|
||||
points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3)
|
||||
points_transformed *= scale[..., None]
|
||||
points_transformed[:, :, 0:2] += translation[:, None, 0:2]
|
||||
|
||||
+50
-186
@@ -1,96 +1,28 @@
|
||||
import configparser
|
||||
import random
|
||||
import numpy
|
||||
|
||||
from typing import Tuple
|
||||
import os
|
||||
import cv2
|
||||
import torchvision
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import pytorch_lightning
|
||||
import torch
|
||||
import torchvision
|
||||
from LivePortrait.src.modules.motion_extractor import MotionExtractor
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.types import Optimizer
|
||||
from pytorch_msssim import ssim
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.utilities.types import OptimizerLRScheduler
|
||||
import torch
|
||||
|
||||
from .data_loader import DataLoaderVGG, read_image
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .data_loader import DataLoaderVGG, read_image
|
||||
|
||||
from .typing import Tensor, LossDict, TargetAttributes, DiscriminatorOutputs, Batch
|
||||
from .helper import hinge_loss, calc_distance_ratio, L2_loss, randomize_expression
|
||||
from pytorch_msssim import ssim
|
||||
from .helper import L2_loss, hinge_loss
|
||||
from .typing import Batch, DiscriminatorOutputs, IDEmbedding, LossDict, TargetAttributes, Tensor, Tuple
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def load_models():
|
||||
id_channels = CONFIG.getint('training.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
|
||||
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
|
||||
|
||||
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
|
||||
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
|
||||
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
|
||||
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
|
||||
arcface.eval()
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
|
||||
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
|
||||
landmarker.eval()
|
||||
else:
|
||||
landmarker = None
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0 or CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
from LivePortrait.src.modules.motion_extractor import MotionExtractor
|
||||
|
||||
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
|
||||
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
|
||||
motion_extractor.eval()
|
||||
else:
|
||||
motion_extractor = None
|
||||
|
||||
if CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
from LivePortrait.src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
|
||||
from LivePortrait.src.modules.warping_network import WarpingNetwork
|
||||
from LivePortrait.src.modules.spade_generator import SPADEDecoder
|
||||
|
||||
feature_extractor_path = CONFIG.get('auxiliary_models.paths', 'feature_extractor_path')
|
||||
feature_extractor = AppearanceFeatureExtractor(3, 64, 2, 512, 32, 16, 6)
|
||||
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location = 'cpu', weights_only = True))
|
||||
feature_extractor.eval()
|
||||
|
||||
warping_network_path = CONFIG.get('auxiliary_models.paths', 'warping_network_path')
|
||||
dense_motion_params = {
|
||||
'block_expansion': 32,
|
||||
'max_features': 1024,
|
||||
'num_blocks': 5,
|
||||
'reshape_depth': 16,
|
||||
'compress': 4
|
||||
}
|
||||
warping_network = WarpingNetwork(num_kp = 21, block_expansion = 64, max_features = 512, num_down_blocks = 2, reshape_channel = 32, estimate_occlusion_map = True, dense_motion_params = dense_motion_params)
|
||||
warping_network.load_state_dict(torch.load(warping_network_path, map_location='cpu', weights_only=True))
|
||||
warping_network.eval()
|
||||
|
||||
spade_generator_path = CONFIG.get('auxiliary_models.paths', 'spade_generator_path')
|
||||
spade_generator = SPADEDecoder(upscale = 2, block_expansion = 64, max_features = 512, num_down_blocks = 2)
|
||||
spade_generator.load_state_dict(torch.load(spade_generator_path, map_location = 'cpu', weights_only = True))
|
||||
spade_generator.eval()
|
||||
else:
|
||||
feature_extractor = None
|
||||
warping_network = None
|
||||
spade_generator = None
|
||||
return generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
@@ -106,7 +38,6 @@ def create_trainer() -> Trainer:
|
||||
monitor = 'l_G',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
# every_n_epochs = 1,
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 5,
|
||||
mode = 'min',
|
||||
@@ -118,7 +49,7 @@ def create_trainer() -> Trainer:
|
||||
)
|
||||
|
||||
|
||||
def train():
|
||||
def train() -> None:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
|
||||
@@ -127,85 +58,50 @@ def train():
|
||||
if not (checkpoint_path and os.path.exists(checkpoint_path)):
|
||||
checkpoint_path = None
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapper(*load_models())
|
||||
face_swap_model = FaceSwapper()
|
||||
trainer = create_trainer()
|
||||
trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path)
|
||||
|
||||
|
||||
class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
def __init__(self, generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
self.arcface = arcface
|
||||
self.landmarker = landmarker
|
||||
self.motion_extractor = motion_extractor
|
||||
self.feature_extractor = feature_extractor
|
||||
self.warping_network = warping_network
|
||||
self.spade_generator = spade_generator
|
||||
|
||||
self.loss_adversarial_accumulated = 20
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork(CONFIG.getint('training.generator', 'id_channels'), CONFIG.getint('training.generator', 'num_blocks'))
|
||||
self.discriminator = MultiscaleDiscriminator(CONFIG.getint('training.discriminator', 'input_channels'), CONFIG.getint('training.discriminator', 'num_filters'), CONFIG.getint('training.discriminator', 'num_layers'), CONFIG.getint('training.discriminator', 'num_discriminators'))
|
||||
self.arcface = torch.load(CONFIG.get('auxiliary_models.paths', 'arcface_path'), map_location = 'cpu', weights_only = False)
|
||||
self.landmarker = torch.load(CONFIG.get('auxiliary_models.paths', 'landmarker_path'), map_location = 'cpu', weights_only = False)
|
||||
self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
|
||||
self.motion_extractor.load_state_dict(torch.load(CONFIG.get('auxiliary_models.paths', 'motion_extractor_path'), map_location = 'cpu', weights_only = True))
|
||||
self.arcface.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
self.automatic_optimization = False
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
|
||||
|
||||
def forward(self, target_tensor : Tensor, source_embedding : Tensor) -> Tensor:
|
||||
def forward(self, target_tensor : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]:
|
||||
output = self.generator(target_tensor, source_embedding)
|
||||
return output
|
||||
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return {
|
||||
"generator": self.generator.state_dict(),
|
||||
"discriminator": self.discriminator.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True):
|
||||
if "generator" in state_dict:
|
||||
self.generator.load_state_dict(state_dict["generator"], strict = strict)
|
||||
if "discriminator" in state_dict:
|
||||
self.discriminator.load_state_dict(state_dict["discriminator"], strict = strict)
|
||||
|
||||
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
if CONFIG.getboolean('training.schedulers', 'enable'):
|
||||
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
return (
|
||||
{
|
||||
"optimizer": generator_optimizer,
|
||||
"lr_scheduler": generator_scheduler
|
||||
},
|
||||
{
|
||||
"optimizer": discriminator_optimizer,
|
||||
"lr_scheduler": discriminator_scheduler
|
||||
})
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers()
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
|
||||
|
||||
if random.random() > 0.5 and CONFIG.getboolean('preparing.augmentation', 'expression'):
|
||||
target_tensor = randomize_expression(target_tensor, self.feature_extractor, self.motion_extractor, self.warping_network, self.spade_generator)
|
||||
|
||||
swap_tensor, target_attributes = self(target_tensor, source_embedding)
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
|
||||
discriminator_outputs = self.discriminator(swap_tensor)
|
||||
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
|
||||
if not CONFIG.getboolean('training.discriminator', 'disable') or self.loss_adversarial_accumulated < 0.4:
|
||||
if not CONFIG.getboolean('training.discriminator', 'disable'):
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
|
||||
@@ -215,36 +111,33 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
self.log_validation_preview()
|
||||
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('l_ADV_A', self.loss_adversarial_accumulated, prog_bar = True)
|
||||
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False)
|
||||
self.log('l_id', generator_losses.get('loss_identity'), prog_bar = True)
|
||||
self.log('l_attr', generator_losses.get('loss_attribute'), prog_bar = True)
|
||||
self.log('l_rec', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True)
|
||||
self.log('l_ID', generator_losses.get('loss_identity'), prog_bar=True)
|
||||
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
generator_losses = {}
|
||||
# adversarial loss
|
||||
loss_adversarial = 0
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ])
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
generator_losses['loss_adversarial'] = loss_adversarial
|
||||
generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
self.loss_adversarial_accumulated = self.loss_adversarial_accumulated * 0.98 + loss_adversarial.item() * 0.02
|
||||
|
||||
# identity loss
|
||||
swap_embedding = self.get_arcface_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (30, 0, 10, 10))
|
||||
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
|
||||
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
|
||||
generator_losses['loss_identity'] = loss_identity
|
||||
generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
|
||||
# attribute loss
|
||||
loss_attribute = 0
|
||||
loss_attribute = torch.Tensor(0)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
@@ -256,7 +149,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
# reconstruction loss
|
||||
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
generator_losses['loss_reconstruction'] = loss_reconstruction
|
||||
generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
|
||||
@@ -271,49 +164,32 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
generator_losses['loss_tsr'] = loss_tsr
|
||||
generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr')
|
||||
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
|
||||
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0:
|
||||
swap_landmark_features = self.get_landmark_features(swap_tensor)
|
||||
target_landmark_features = self.get_landmark_features(target_tensor)
|
||||
|
||||
# eye gaze loss
|
||||
loss_left_eye_gaze = L2_loss(swap_landmark_features[3], target_landmark_features[3])
|
||||
loss_right_eye_gaze = L2_loss(swap_landmark_features[4], target_landmark_features[4])
|
||||
loss_left_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1])
|
||||
loss_right_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1])
|
||||
loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze
|
||||
generator_losses['loss_eye_gaze'] = loss_eye_gaze
|
||||
generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze')
|
||||
|
||||
# eye open loss
|
||||
loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0])
|
||||
loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1])
|
||||
loss_eye_open = loss_left_eye_open + loss_right_eye_open
|
||||
generator_losses['loss_eye_open'] = loss_eye_open
|
||||
generator_losses['loss_generator'] += loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
|
||||
|
||||
# lip open loss
|
||||
loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2])
|
||||
generator_losses['loss_lip_open'] = loss_lip_open
|
||||
generator_losses['loss_generator'] += loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
|
||||
return generator_losses
|
||||
|
||||
|
||||
def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict:
|
||||
discriminator_losses = {}
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
loss_fake = 0
|
||||
loss_fake = torch.Tensor(0)
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3]))
|
||||
true_discriminator_outputs = self.discriminator(source_tensor)
|
||||
loss_true = 0
|
||||
loss_true = torch.Tensor(0)
|
||||
|
||||
for true_discriminator_output in true_discriminator_outputs:
|
||||
loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3]))
|
||||
discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_losses
|
||||
|
||||
|
||||
def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
|
||||
def get_id_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
|
||||
_, _, height, width = vision_tensor.shape
|
||||
crop_height = int(height * 0.0586)
|
||||
crop_width = int(width * 0.0586)
|
||||
@@ -327,19 +203,12 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2]
|
||||
landmarks = landmarks.view(-1, 203, 2) * 256
|
||||
left_eye_open_ratio = calc_distance_ratio(landmarks, (6, 18, 0, 12))
|
||||
right_eye_open_ratio = calc_distance_ratio(landmarks, (30, 42, 24, 36))
|
||||
lip_open_ratio = calc_distance_ratio(landmarks, (90, 102, 48, 66))
|
||||
left_eye_gaze = landmarks[:, 198]
|
||||
right_eye_gaze = landmarks[:, 197]
|
||||
return left_eye_open_ratio, right_eye_open_ratio, lip_open_ratio, left_eye_gaze, right_eye_gaze
|
||||
|
||||
return landmarks[:, 198], landmarks[:, 197]
|
||||
|
||||
def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
@@ -349,21 +218,17 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
def log_generator_preview(self, source_tensor, target_tensor, swap_tensor):
|
||||
def log_generator_preview(self, source_tensor : Tensor, target_tensor : Tensor, swap_tensor : Tensor) -> None:
|
||||
max_preview = 8
|
||||
source_tensor = source_tensor[:max_preview]
|
||||
target_tensor = target_tensor[:max_preview]
|
||||
swap_tensor = swap_tensor[:max_preview]
|
||||
rows = [torch.cat([src, tgt, swp], dim=2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
|
||||
grid = torchvision.utils.make_grid(torch.cat(rows, dim=1).unsqueeze(0), nrow=1, normalize=True, scale_each=True)
|
||||
os.makedirs("previews", exist_ok=True)
|
||||
torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg")
|
||||
rows = [torch.cat([src, tgt, swp], dim = 2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
|
||||
grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
|
||||
|
||||
|
||||
def log_validation_preview(self):
|
||||
read_images = lambda path: [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
def log_validation_preview(self) -> None:
|
||||
read_images = lambda path : [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5
|
||||
transforms = torchvision.transforms.Compose(
|
||||
[
|
||||
@@ -377,7 +242,6 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup'))
|
||||
targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion'))
|
||||
|
||||
|
||||
self.generator.eval()
|
||||
|
||||
results_source = []
|
||||
@@ -388,7 +252,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
|
||||
for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion):
|
||||
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
|
||||
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
|
||||
target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half()
|
||||
target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half()
|
||||
target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
@@ -10,6 +11,6 @@ TargetAttributes = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs = List[List[Tensor]]
|
||||
LossDict = Dict[str, Tensor]
|
||||
IDEmbedding = Tensor
|
||||
|
||||
StateDict = OrderedDict[str, Any]
|
||||
Embedding = NDArray[Any]
|
||||
VisionFrame = NDArray[Any]
|
||||
|
||||
Reference in New Issue
Block a user