This commit is contained in:
harisreedhar
2025-01-14 16:06:09 +05:30
committed by henryruhs
parent ef313042c6
commit 9e1c71b498
4 changed files with 55 additions and 298 deletions
-10
View File
@@ -8,7 +8,6 @@ import tqdm
from PIL import Image
from torch.utils.data import TensorDataset
from .augmentations import apply_random_motion_blur
from .typing import Batch
CONFIG = configparser.ConfigParser()
@@ -52,7 +51,6 @@ class DataLoaderVGG(TensorDataset):
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
@@ -81,11 +79,3 @@ class DataLoaderVGG(TensorDataset):
def __len__(self) -> int:
return self.dataset_total
def state_dict(self):
return {'current_index': self._current_index}
def load_state_dict(self, state_dict):
self._current_index = state_dict['current_index']
+3 -101
View File
@@ -1,114 +1,16 @@
import configparser
from typing import Tuple
import torch
from .typing import Tensor
import numpy
import torch.nn.functional as F
from .typing import Tensor, Tuple
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
if CONFIG.getboolean('preparing.augmentation', 'expression'):
from LivePortrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix
L2_loss = torch.nn.MSELoss()
EXPRESSION_MIN = numpy.array(
[
[
[-2.88067125e-02, -8.12731311e-02, -1.70541159e-03],
[-4.88598682e-02, -3.32196616e-02, -1.67431499e-04],
[-6.75425082e-02, -4.28681746e-02, -1.98950816e-04],
[-7.23103955e-02, -3.28503326e-02, -7.31324719e-04],
[-3.87073644e-02, -6.01546466e-02, -5.50269964e-04],
[-6.38048723e-02, -2.23840728e-01, -7.13261834e-04],
[-3.02710701e-02, -3.93195450e-02, -8.24086510e-06],
[-2.95799859e-02, -5.39318882e-02, -1.74219604e-04],
[-2.92359516e-02, -1.53050944e-02, -6.30460854e-05],
[-5.56493877e-03, -2.34344602e-02, -1.26858242e-04],
[-4.37593013e-02, -2.77768299e-02, -2.70503685e-02],
[-1.76926646e-02, -1.91676542e-02, -1.15090821e-04],
[-8.34268332e-03, -3.99775570e-03, -3.27481248e-05],
[-3.40162888e-02, -2.81868968e-02, -1.96679524e-04],
[-2.91855410e-02, -3.97511162e-02, -2.81230678e-05],
[-1.50395725e-02, -2.49494594e-02, -9.42573533e-05],
[-1.67938769e-02, -2.00953931e-02, -4.00750607e-04],
[-1.86435618e-02, -2.48535164e-02, -2.74416432e-02],
[-4.61211195e-03, -1.21660791e-02, -2.93173041e-04],
[-4.10017073e-02, -7.43824020e-02, -4.42762971e-02],
[-1.90370996e-02, -3.74363363e-02, -1.34740388e-02]
]
]).astype(numpy.float32)
EXPRESSION_MAX = numpy.array(
[
[
[4.46682945e-02, 7.08772913e-02, 4.08344204e-04],
[2.14308221e-02, 6.15894832e-02, 4.85319615e-05],
[3.02363783e-02, 4.45043296e-02, 1.28298725e-05],
[3.05869691e-02, 3.79812494e-02, 6.57040102e-04],
[4.45670523e-02, 3.97259220e-02, 7.10966764e-04],
[9.43699256e-02, 9.85926315e-02, 2.02551950e-04],
[1.61131397e-02, 2.92906128e-02, 3.44733417e-06],
[5.23825921e-02, 1.07065082e-01, 6.61510974e-04],
[2.85718683e-03, 8.32320191e-03, 2.39314613e-04],
[2.57947259e-02, 1.60935968e-02, 2.41853559e-05],
[4.90833223e-02, 3.43903080e-02, 3.22353356e-02],
[1.44766076e-02, 3.39248963e-02, 1.42291479e-04],
[8.75749043e-04, 6.82212645e-03, 2.76097053e-05],
[1.86958015e-02, 3.84016186e-02, 7.33085908e-05],
[2.01714113e-02, 4.90544215e-02, 2.34028921e-05],
[2.46518422e-02, 3.29151377e-02, 3.48571630e-05],
[2.22457591e-02, 1.21796541e-02, 1.56396593e-04],
[1.72109623e-02, 3.01626958e-02, 1.36556877e-02],
[1.83460284e-02, 1.61141958e-02, 2.87440169e-04],
[3.57594155e-02, 1.80554688e-01, 2.75554154e-02],
[2.17450950e-02, 8.66811201e-02, 3.34241726e-02]
]
]).astype(numpy.float32)
def randomize_expression(face_tensor, feature_extractor, motion_extractor, warping_network, spade_generator):
with torch.no_grad():
face_tensor_norm = (face_tensor + 1) * 0.5
input_device = face_tensor.device
feature_volume = feature_extractor(face_tensor_norm)
motion_extractor_dict = motion_extractor(face_tensor_norm)
translation = motion_extractor_dict.get('t')
expression = motion_extractor_dict.get('exp')
scale = motion_extractor_dict.get('scale')
points = motion_extractor_dict.get('kp')
pitch = headpose_pred_to_degree(motion_extractor_dict.get('pitch'))[:, None]
yaw = headpose_pred_to_degree(motion_extractor_dict.get('yaw'))[:, None]
roll = headpose_pred_to_degree(motion_extractor_dict.get('roll'))[:, None]
rotation_matrix = get_rotation_matrix(pitch, yaw, roll)
random_expression = get_random_expression_blend(expression)
points_transformed = transform_points(points, rotation_matrix, expression, scale, translation)
points_driv = transform_points(points, rotation_matrix, random_expression, scale, translation)
data = warping_network(feature_volume, points_driv, points_transformed).get('out')
output = spade_generator(data)
output = output.to(input_device)
output = F.interpolate(output.clamp(0, 1), [256, 256], mode='bilinear', align_corners=False)
output = (output - 0.5) * 2
return output
def get_random_expression_blend(expression : Tensor) -> Tensor:
blend = 0.35
expression = expression.view(-1, 21, 3)
min_array = torch.from_numpy(EXPRESSION_MIN).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
max_array = torch.from_numpy(EXPRESSION_MAX).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1)
random_batch = torch.rand_like(min_array).to(expression.device) * (max_array - min_array) + min_array
random_batch[:, [0, 1, 8, 6, 9, 4, 5, 10]] = expression[:, [0, 1, 8, 6, 9, 4, 5, 10]]
random_batch[:, [3, 7]] = random_batch[:, [13, 16]] * 0.1 + expression[:, [13, 16]] * 0.9
random_batch[:, [3, 7]] = random_batch[:, [3, 7]] * 0.5 + expression[:, [3, 7]] * 0.5
return random_batch * 0.8 * blend + expression * (1 - blend)
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor):
def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor) -> Tensor:
points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3)
points_transformed *= scale[..., None]
points_transformed[:, :, 0:2] += translation[:, None, 0:2]
+50 -186
View File
@@ -1,96 +1,28 @@
import configparser
import random
import numpy
from typing import Tuple
import os
import cv2
import torchvision
import cv2
import numpy
import pytorch_lightning
import torch
import torchvision
from LivePortrait.src.modules.motion_extractor import MotionExtractor
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import Optimizer
from pytorch_msssim import ssim
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.types import OptimizerLRScheduler
import torch
from .data_loader import DataLoaderVGG, read_image
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .data_loader import DataLoaderVGG, read_image
from .typing import Tensor, LossDict, TargetAttributes, DiscriminatorOutputs, Batch
from .helper import hinge_loss, calc_distance_ratio, L2_loss, randomize_expression
from pytorch_msssim import ssim
from .helper import L2_loss, hinge_loss
from .typing import Batch, DiscriminatorOutputs, IDEmbedding, LossDict, TargetAttributes, Tensor, Tuple
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def load_models():
id_channels = CONFIG.getint('training.generator', 'id_channels')
num_blocks = CONFIG.getint('training.generator', 'num_blocks')
generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
input_channels = CONFIG.getint('training.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators')
discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators)
model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path')
arcface = torch.load(model_path, map_location = 'cpu', weights_only = False)
arcface.eval()
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path')
landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False)
landmarker.eval()
else:
landmarker = None
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0 or CONFIG.getboolean('preparing.augmentation', 'expression'):
from LivePortrait.src.modules.motion_extractor import MotionExtractor
model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path')
motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True))
motion_extractor.eval()
else:
motion_extractor = None
if CONFIG.getboolean('preparing.augmentation', 'expression'):
from LivePortrait.src.modules.appearance_feature_extractor import AppearanceFeatureExtractor
from LivePortrait.src.modules.warping_network import WarpingNetwork
from LivePortrait.src.modules.spade_generator import SPADEDecoder
feature_extractor_path = CONFIG.get('auxiliary_models.paths', 'feature_extractor_path')
feature_extractor = AppearanceFeatureExtractor(3, 64, 2, 512, 32, 16, 6)
feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location = 'cpu', weights_only = True))
feature_extractor.eval()
warping_network_path = CONFIG.get('auxiliary_models.paths', 'warping_network_path')
dense_motion_params = {
'block_expansion': 32,
'max_features': 1024,
'num_blocks': 5,
'reshape_depth': 16,
'compress': 4
}
warping_network = WarpingNetwork(num_kp = 21, block_expansion = 64, max_features = 512, num_down_blocks = 2, reshape_channel = 32, estimate_occlusion_map = True, dense_motion_params = dense_motion_params)
warping_network.load_state_dict(torch.load(warping_network_path, map_location='cpu', weights_only=True))
warping_network.eval()
spade_generator_path = CONFIG.get('auxiliary_models.paths', 'spade_generator_path')
spade_generator = SPADEDecoder(upscale = 2, block_expansion = 64, max_features = 512, num_down_blocks = 2)
spade_generator.load_state_dict(torch.load(spade_generator_path, map_location = 'cpu', weights_only = True))
spade_generator.eval()
else:
feature_extractor = None
warping_network = None
spade_generator = None
return generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator
def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
@@ -106,7 +38,6 @@ def create_trainer() -> Trainer:
monitor = 'l_G',
dirpath = output_directory_path,
filename = output_file_pattern,
# every_n_epochs = 1,
every_n_train_steps = 1000,
save_top_k = 5,
mode = 'min',
@@ -118,7 +49,7 @@ def create_trainer() -> Trainer:
)
def train():
def train() -> None:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
checkpoint_path = CONFIG.get('training.output', 'checkpoint_path')
@@ -127,85 +58,50 @@ def train():
if not (checkpoint_path and os.path.exists(checkpoint_path)):
checkpoint_path = None
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapper(*load_models())
face_swap_model = FaceSwapper()
trainer = create_trainer()
trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path)
class FaceSwapper(pytorch_lightning.LightningModule):
def __init__(self, generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator) -> None:
def __init__(self) -> None:
super().__init__()
self.generator = generator
self.discriminator = discriminator
self.arcface = arcface
self.landmarker = landmarker
self.motion_extractor = motion_extractor
self.feature_extractor = feature_extractor
self.warping_network = warping_network
self.spade_generator = spade_generator
self.loss_adversarial_accumulated = 20
self.generator = AdaptiveEmbeddingIntegrationNetwork(CONFIG.getint('training.generator', 'id_channels'), CONFIG.getint('training.generator', 'num_blocks'))
self.discriminator = MultiscaleDiscriminator(CONFIG.getint('training.discriminator', 'input_channels'), CONFIG.getint('training.discriminator', 'num_filters'), CONFIG.getint('training.discriminator', 'num_layers'), CONFIG.getint('training.discriminator', 'num_discriminators'))
self.arcface = torch.load(CONFIG.get('auxiliary_models.paths', 'arcface_path'), map_location = 'cpu', weights_only = False)
self.landmarker = torch.load(CONFIG.get('auxiliary_models.paths', 'landmarker_path'), map_location = 'cpu', weights_only = False)
self.motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny')
self.motion_extractor.load_state_dict(torch.load(CONFIG.get('auxiliary_models.paths', 'motion_extractor_path'), map_location = 'cpu', weights_only = True))
self.arcface.eval()
self.landmarker.eval()
self.motion_extractor.eval()
self.automatic_optimization = False
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
def forward(self, target_tensor : Tensor, source_embedding : Tensor) -> Tensor:
def forward(self, target_tensor : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]:
output = self.generator(target_tensor, source_embedding)
return output
def state_dict(self, *args, **kwargs):
return {
"generator": self.generator.state_dict(),
"discriminator": self.discriminator.state_dict(),
}
def load_state_dict(self, state_dict, strict: bool = True):
if "generator" in state_dict:
self.generator.load_state_dict(state_dict["generator"], strict = strict)
if "discriminator" in state_dict:
self.discriminator.load_state_dict(state_dict["discriminator"], strict = strict)
def configure_optimizers(self) -> OptimizerLRScheduler:
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
if CONFIG.getboolean('training.schedulers', 'enable'):
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
return (
{
"optimizer": generator_optimizer,
"lr_scheduler": generator_scheduler
},
{
"optimizer": discriminator_optimizer,
"lr_scheduler": discriminator_scheduler
})
return generator_optimizer, discriminator_optimizer
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor, is_same_person = batch
generator_optimizer, discriminator_optimizer = self.optimizers()
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
if random.random() > 0.5 and CONFIG.getboolean('preparing.augmentation', 'expression'):
target_tensor = randomize_expression(target_tensor, self.feature_extractor, self.motion_extractor, self.warping_network, self.spade_generator)
swap_tensor, target_attributes = self(target_tensor, source_embedding)
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
discriminator_outputs = self.discriminator(swap_tensor)
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch)
generator_optimizer.zero_grad()
self.manual_backward(generator_losses.get('loss_generator'))
generator_optimizer.step()
discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_losses.get('loss_discriminator'))
if not CONFIG.getboolean('training.discriminator', 'disable') or self.loss_adversarial_accumulated < 0.4:
if not CONFIG.getboolean('training.discriminator', 'disable'):
discriminator_optimizer.step()
if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0:
@@ -215,36 +111,33 @@ class FaceSwapper(pytorch_lightning.LightningModule):
self.log_validation_preview()
self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True)
self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True)
self.log('l_ADV_A', self.loss_adversarial_accumulated, prog_bar = True)
self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False)
self.log('l_id', generator_losses.get('loss_identity'), prog_bar = True)
self.log('l_attr', generator_losses.get('loss_attribute'), prog_bar = True)
self.log('l_rec', generator_losses.get('loss_reconstruction'), prog_bar = True)
self.log('l_ATTR', generator_losses.get('loss_attribute'), prog_bar = True)
self.log('l_ID', generator_losses.get('loss_identity'), prog_bar=True)
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
return generator_losses.get('loss_generator')
def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict:
source_tensor, target_tensor, is_same_person = batch
generator_losses = {}
# adversarial loss
loss_adversarial = 0
loss_adversarial = torch.Tensor(0)
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ])
loss_adversarial = torch.mean(loss_adversarial)
generator_losses['loss_adversarial'] = loss_adversarial
generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial')
self.loss_adversarial_accumulated = self.loss_adversarial_accumulated * 0.98 + loss_adversarial.item() * 0.02
# identity loss
swap_embedding = self.get_arcface_embedding(swap_tensor, (30, 0, 10, 10))
source_embedding = self.get_arcface_embedding(source_tensor, (30, 0, 10, 10))
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
generator_losses['loss_identity'] = loss_identity
generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity')
# attribute loss
loss_attribute = 0
loss_attribute = torch.Tensor(0)
swap_attributes = self.generator.get_attributes(swap_tensor)
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
@@ -256,7 +149,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
# reconstruction loss
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
generator_losses['loss_reconstruction'] = loss_reconstruction
generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction')
@@ -271,49 +164,32 @@ class FaceSwapper(pytorch_lightning.LightningModule):
generator_losses['loss_tsr'] = loss_tsr
generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr')
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0:
if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0:
swap_landmark_features = self.get_landmark_features(swap_tensor)
target_landmark_features = self.get_landmark_features(target_tensor)
# eye gaze loss
loss_left_eye_gaze = L2_loss(swap_landmark_features[3], target_landmark_features[3])
loss_right_eye_gaze = L2_loss(swap_landmark_features[4], target_landmark_features[4])
loss_left_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1])
loss_right_eye_gaze = L2_loss(swap_landmark_features[0], target_landmark_features[1])
loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze
generator_losses['loss_eye_gaze'] = loss_eye_gaze
generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze')
# eye open loss
loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0])
loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1])
loss_eye_open = loss_left_eye_open + loss_right_eye_open
generator_losses['loss_eye_open'] = loss_eye_open
generator_losses['loss_generator'] += loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
# lip open loss
loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2])
generator_losses['loss_lip_open'] = loss_lip_open
generator_losses['loss_generator'] += loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
return generator_losses
def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict:
discriminator_losses = {}
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
loss_fake = 0
loss_fake = torch.Tensor(0)
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3]))
true_discriminator_outputs = self.discriminator(source_tensor)
loss_true = 0
loss_true = torch.Tensor(0)
for true_discriminator_output in true_discriminator_outputs:
loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3]))
discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_losses
def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
def get_id_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor:
_, _, height, width = vision_tensor.shape
crop_height = int(height * 0.0586)
crop_width = int(width * 0.0586)
@@ -327,19 +203,12 @@ class FaceSwapper(pytorch_lightning.LightningModule):
embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1)
return embedding
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2]
landmarks = landmarks.view(-1, 203, 2) * 256
left_eye_open_ratio = calc_distance_ratio(landmarks, (6, 18, 0, 12))
right_eye_open_ratio = calc_distance_ratio(landmarks, (30, 42, 24, 36))
lip_open_ratio = calc_distance_ratio(landmarks, (90, 102, 48, 66))
left_eye_gaze = landmarks[:, 198]
right_eye_gaze = landmarks[:, 197]
return left_eye_open_ratio, right_eye_open_ratio, lip_open_ratio, left_eye_gaze, right_eye_gaze
return landmarks[:, 198], landmarks[:, 197]
def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
@@ -349,21 +218,17 @@ class FaceSwapper(pytorch_lightning.LightningModule):
rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1)
return translation, scale, rotation
def log_generator_preview(self, source_tensor, target_tensor, swap_tensor):
def log_generator_preview(self, source_tensor : Tensor, target_tensor : Tensor, swap_tensor : Tensor) -> None:
max_preview = 8
source_tensor = source_tensor[:max_preview]
target_tensor = target_tensor[:max_preview]
swap_tensor = swap_tensor[:max_preview]
rows = [torch.cat([src, tgt, swp], dim=2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
grid = torchvision.utils.make_grid(torch.cat(rows, dim=1).unsqueeze(0), nrow=1, normalize=True, scale_each=True)
os.makedirs("previews", exist_ok=True)
torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg")
rows = [torch.cat([src, tgt, swp], dim = 2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)]
grid = torchvision.utils.make_grid(torch.cat(rows, dim = 1).unsqueeze(0), nrow = 1, normalize = True, scale_each = True)
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
def log_validation_preview(self):
read_images = lambda path: [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
def log_validation_preview(self) -> None:
read_images = lambda path : [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5
transforms = torchvision.transforms.Compose(
[
@@ -377,7 +242,6 @@ class FaceSwapper(pytorch_lightning.LightningModule):
targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup'))
targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion'))
self.generator.eval()
results_source = []
@@ -388,7 +252,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion):
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
source_embedding = self.get_id_embedding(source_tensor, (0, 0, 0, 0))
target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half()
target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half()
target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half()
+2 -1
View File
@@ -1,3 +1,4 @@
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
from numpy.typing import NDArray
@@ -10,6 +11,6 @@ TargetAttributes = Tuple[Tensor, ...]
DiscriminatorOutputs = List[List[Tensor]]
LossDict = Dict[str, Tensor]
IDEmbedding = Tensor
StateDict = OrderedDict[str, Any]
Embedding = NDArray[Any]
VisionFrame = NDArray[Any]