mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
cleanup
This commit is contained in:
@@ -24,8 +24,7 @@ class DataLoaderVGG(TensorDataset):
|
||||
image_path_set = {}
|
||||
|
||||
for directory_path in self.directory_paths:
|
||||
image_paths = glob.glob(dataset_image_pattern.format(directory_path))
|
||||
image_paths.extend(image_paths)
|
||||
image_paths.extend(glob.glob(dataset_image_pattern.format(directory_path)))
|
||||
image_path_set[directory_path] = image_paths
|
||||
return image_paths, image_path_set
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from os import makedirs
|
||||
|
||||
import torch
|
||||
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -17,7 +17,7 @@ def export() -> None:
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
|
||||
model = AdaptiveEmbeddingIntegrationNetwork(512, 2)
|
||||
model = AdaptiveEmbeddingIntegrationNetwork()
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
source_tensor = torch.randn(1, 512)
|
||||
|
||||
@@ -3,8 +3,8 @@ import configparser
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
|
||||
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .types import Generator, IdEmbedder, VisionFrame
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -28,7 +28,7 @@ def infer() -> None:
|
||||
output_path = CONFIG.get('inferencing', 'output_path')
|
||||
|
||||
state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator')
|
||||
generator = AdaptiveEmbeddingIntegrationNetwork(512, 2)
|
||||
generator = AdaptiveEmbeddingIntegrationNetwork()
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import configparser
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
@@ -6,7 +7,41 @@ import torch.nn
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from .types import DiscriminatorOutputs
|
||||
from face_swapper.src.types import DiscriminatorOutputs
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
self.input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
|
||||
self.num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
|
||||
self.kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
|
||||
self.num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
|
||||
self.num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
|
||||
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
|
||||
self.prepare_discriminators()
|
||||
|
||||
def prepare_discriminators(self) -> None:
|
||||
for discriminator_index in range(self.num_discriminators):
|
||||
single_discriminator = NLayerDiscriminator(self.input_channels, self.num_filters, self.num_layers, self.kernel_size)
|
||||
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
|
||||
discriminator_outputs = []
|
||||
temp_tensor = input_tensor
|
||||
|
||||
for discriminator_index in range(self.num_discriminators):
|
||||
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
|
||||
discriminator_outputs.append([ model_layers(temp_tensor) ])
|
||||
|
||||
if discriminator_index < (self.num_discriminators - 1):
|
||||
temp_tensor = self.downsample(temp_tensor)
|
||||
|
||||
return discriminator_outputs
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
@@ -58,29 +93,3 @@ class NLayerDiscriminator(nn.Module):
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
return self.model(input_tensor)
|
||||
|
||||
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int, kernel_size : int):
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
self.num_discriminators = num_discriminators
|
||||
self.num_layers = num_layers
|
||||
|
||||
for discriminator_index in range(num_discriminators):
|
||||
single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size)
|
||||
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
|
||||
|
||||
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
|
||||
discriminator_outputs = []
|
||||
temp_tensor = input_tensor
|
||||
|
||||
for discriminator_index in range(self.num_discriminators):
|
||||
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
|
||||
discriminator_outputs.append([ model_layers(temp_tensor) ])
|
||||
|
||||
if discriminator_index < (self.num_discriminators - 1):
|
||||
temp_tensor = self.downsample(temp_tensor)
|
||||
|
||||
return discriminator_outputs
|
||||
@@ -0,0 +1,43 @@
|
||||
import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from face_swapper.src.networks.attribute_modulator import AADGenerator
|
||||
from face_swapper.src.networks.encoder import UNet
|
||||
from face_swapper.src.types import SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
self.encoder = UNet()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
target_attributes = self.get_attributes(target)
|
||||
swap_tensor = self.generator(target_attributes, source_embedding)
|
||||
return swap_tensor, target_attributes
|
||||
|
||||
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
|
||||
return self.encoder(target)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(std = 0.001)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
|
||||
if isinstance(module, nn.ConvTranspose2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
@@ -0,0 +1,137 @@
|
||||
import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor
|
||||
|
||||
from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperLoss:
|
||||
def __init__(self) -> None:
|
||||
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = torch.nn.MSELoss()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.id_embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
|
||||
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
|
||||
generator_loss_set = {}
|
||||
|
||||
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
|
||||
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
|
||||
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes)
|
||||
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
|
||||
if weight_pose > 0:
|
||||
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
if weight_gaze > 0:
|
||||
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
|
||||
return generator_loss_set
|
||||
|
||||
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
loss_fake = torch.Tensor(0)
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
|
||||
|
||||
loss_true = torch.Tensor(0)
|
||||
|
||||
for true_discriminator_output in real_discriminator_outputs:
|
||||
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
|
||||
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_real_loss(discriminator_output[0])
|
||||
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
|
||||
loss_attribute = torch.Tensor(0)
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
|
||||
loss_attribute *= 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
|
||||
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
|
||||
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
|
||||
return loss_id
|
||||
|
||||
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_motion_features = self.get_pose_features(swap_tensor)
|
||||
target_motion_features = self.get_pose_features(target_tensor)
|
||||
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
|
||||
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
|
||||
|
||||
return loss_pose
|
||||
|
||||
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_landmark = self.get_face_landmarks(swap_tensor)
|
||||
target_landmark = self.get_face_landmarks(target_tensor)
|
||||
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
return gaze_loss
|
||||
|
||||
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
|
||||
return landmarks
|
||||
|
||||
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
@@ -1,25 +1,7 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn as nn
|
||||
|
||||
from .types import SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
|
||||
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
def __init__(self, id_channels : int, num_blocks : int) -> None:
|
||||
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
|
||||
self.encoder = UNet()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
|
||||
def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
target_attributes = self.get_attributes(target)
|
||||
swap_tensor = self.generator(target_attributes, source_embedding)
|
||||
return swap_tensor, target_attributes
|
||||
|
||||
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
|
||||
return self.encoder(target)
|
||||
from face_swapper.src.types import SourceEmbedding, TargetAttributes
|
||||
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
@@ -34,7 +16,6 @@ class AADGenerator(nn.Module):
|
||||
self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks)
|
||||
self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks)
|
||||
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
|
||||
self.apply(init_weight)
|
||||
|
||||
def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor:
|
||||
feature_map = self.upsample(source_embedding)
|
||||
@@ -49,42 +30,6 @@ class AADGenerator(nn.Module):
|
||||
return torch.tanh(output)
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.downsampler_1 = DownSample(3, 32)
|
||||
self.downsampler_2 = DownSample(32, 64)
|
||||
self.downsampler_3 = DownSample(64, 128)
|
||||
self.downsampler_4 = DownSample(128, 256)
|
||||
self.downsampler_5 = DownSample(256, 512)
|
||||
self.downsampler_6 = DownSample(512, 1024)
|
||||
self.bottleneck = DownSample(1024, 1024)
|
||||
self.upsampler_1 = Upsample(1024, 1024)
|
||||
self.upsampler_2 = Upsample(2048, 512)
|
||||
self.upsampler_3 = Upsample(1024, 256)
|
||||
self.upsampler_4 = Upsample(512, 128)
|
||||
self.upsampler_5 = Upsample(256, 64)
|
||||
self.upsampler_6 = Upsample(128, 32)
|
||||
self.apply(init_weight)
|
||||
|
||||
def forward(self, target : VisionTensor) -> TargetAttributes:
|
||||
downsample_feature_1 = self.downsampler_1(target)
|
||||
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
|
||||
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
|
||||
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
|
||||
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
|
||||
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
|
||||
bottleneck_output = self.bottleneck(downsample_feature_6)
|
||||
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
|
||||
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
|
||||
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
|
||||
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
|
||||
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
|
||||
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
|
||||
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
|
||||
|
||||
|
||||
class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
|
||||
super(AADLayer, self).__init__()
|
||||
@@ -109,22 +54,18 @@ class AADLayer(nn.Module):
|
||||
return feature_blend
|
||||
|
||||
|
||||
class AddBlocksSequential(nn.Sequential):
|
||||
#todo: what are inputs? improve the name
|
||||
def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]:
|
||||
_, attribute_embedding, id_embedding = inputs
|
||||
modules = self._modules.values() #todo: what kind of modules?
|
||||
class AADSequential(nn.Module):
|
||||
def __init__(self, *args : nn.Module) -> None:
|
||||
super(AADSequential, self).__init__()
|
||||
self.layers = nn.ModuleList(args)
|
||||
|
||||
for module_index, module in enumerate(modules):
|
||||
if module_index % 3 == 0 and module_index > 0:
|
||||
inputs = (inputs, attribute_embedding, id_embedding) # type:ignore[assignment]
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = module(inputs)
|
||||
def forward(self, feature_map: Tensor, attribute_embedding: Tensor, id_embedding: SourceEmbedding) -> Tensor:
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, AADLayer):
|
||||
feature_map = layer(feature_map, attribute_embedding, id_embedding)
|
||||
else:
|
||||
inputs = module(*inputs)
|
||||
|
||||
return inputs #todo: would be easier to read when you just return xxx_inputs, attribute_embedding, id_embedding ?
|
||||
feature_map = layer(feature_map)
|
||||
return feature_map
|
||||
|
||||
|
||||
class AADResBlock(nn.Module):
|
||||
@@ -147,11 +88,11 @@ class AADResBlock(nn.Module):
|
||||
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
]
|
||||
)
|
||||
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
|
||||
self.primary_add_blocks = AADSequential(*primary_add_blocks)
|
||||
|
||||
def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int) -> None:
|
||||
if input_channels > output_channels:
|
||||
auxiliary_add_blocks = AddBlocksSequential(
|
||||
auxiliary_add_blocks = AADSequential(
|
||||
AADLayer(input_channels, attribute_channels, id_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
@@ -168,34 +109,6 @@ class AADResBlock(nn.Module):
|
||||
return output_feature
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(DownSample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
temp = self.conv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return temp
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(Upsample, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp = self.deconv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return torch.cat((temp, skip_tensor), dim = 1)
|
||||
|
||||
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(PixelShuffleUpsample, self).__init__()
|
||||
@@ -206,15 +119,3 @@ class PixelShuffleUpsample(nn.Module):
|
||||
temp = self.conv(temp.view(temp.shape[0], -1, 1, 1))
|
||||
temp = self.pixel_shuffle(temp)
|
||||
return temp
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(std = 0.001)
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
|
||||
if isinstance(module, nn.ConvTranspose2d):
|
||||
nn.init.xavier_normal_(module.weight.data)
|
||||
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
from torch import Tensor, nn as nn
|
||||
|
||||
from face_swapper.src.types import TargetAttributes, VisionTensor
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(Upsample, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp = self.deconv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return torch.cat((temp, skip_tensor), dim = 1)
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(DownSample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
temp = self.conv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return temp
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.downsampler_1 = DownSample(3, 32)
|
||||
self.downsampler_2 = DownSample(32, 64)
|
||||
self.downsampler_3 = DownSample(64, 128)
|
||||
self.downsampler_4 = DownSample(128, 256)
|
||||
self.downsampler_5 = DownSample(256, 512)
|
||||
self.downsampler_6 = DownSample(512, 1024)
|
||||
self.bottleneck = DownSample(1024, 1024)
|
||||
self.upsampler_1 = Upsample(1024, 1024)
|
||||
self.upsampler_2 = Upsample(2048, 512)
|
||||
self.upsampler_3 = Upsample(1024, 256)
|
||||
self.upsampler_4 = Upsample(512, 128)
|
||||
self.upsampler_5 = Upsample(256, 64)
|
||||
self.upsampler_6 = Upsample(128, 32)
|
||||
|
||||
def forward(self, target : VisionTensor) -> TargetAttributes:
|
||||
downsample_feature_1 = self.downsampler_1(target)
|
||||
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
|
||||
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
|
||||
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
|
||||
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
|
||||
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
|
||||
bottleneck_output = self.bottleneck(downsample_feature_6)
|
||||
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
|
||||
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
|
||||
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
|
||||
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
|
||||
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
|
||||
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
|
||||
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
|
||||
@@ -8,157 +8,25 @@ import torchvision
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.types import Optimizer
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .data_loader import DataLoaderVGG
|
||||
from .discriminator import MultiscaleDiscriminator
|
||||
from .generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from .types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, SwapAttributes, TargetAttributes, VisionTensor
|
||||
from .helper import calc_id_embedding
|
||||
from .models.discriminator import MultiscaleDiscriminator
|
||||
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
|
||||
from .models.loss import FaceSwapperLoss
|
||||
from .types import Batch, SourceEmbedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperLoss:
|
||||
def __init__(self) -> None:
|
||||
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = torch.nn.MSELoss()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.id_embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
|
||||
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
|
||||
generator_loss_set = {}
|
||||
|
||||
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
|
||||
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
|
||||
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes)
|
||||
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
|
||||
if weight_pose > 0:
|
||||
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
if weight_gaze > 0:
|
||||
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
|
||||
else:
|
||||
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
|
||||
return generator_loss_set
|
||||
|
||||
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
loss_fake = torch.Tensor(0)
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
|
||||
|
||||
loss_true = torch.Tensor(0)
|
||||
|
||||
for true_discriminator_output in real_discriminator_outputs:
|
||||
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
|
||||
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarial = torch.Tensor(0)
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarial += hinge_real_loss(discriminator_output[0])
|
||||
|
||||
loss_adversarial = torch.mean(loss_adversarial)
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
|
||||
loss_attribute = torch.Tensor(0)
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
|
||||
|
||||
loss_attribute *= 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
|
||||
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
|
||||
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
|
||||
return loss_id
|
||||
|
||||
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_motion_features = self.get_pose_features(swap_tensor)
|
||||
target_motion_features = self.get_pose_features(target_tensor)
|
||||
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
|
||||
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
|
||||
|
||||
return loss_pose
|
||||
|
||||
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_landmark = self.get_face_landmarks(swap_tensor)
|
||||
target_landmark = self.get_face_landmarks(target_tensor)
|
||||
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
return gaze_loss
|
||||
|
||||
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
|
||||
return landmarks
|
||||
|
||||
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
|
||||
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
|
||||
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
|
||||
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
|
||||
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators, kernel_size)
|
||||
self.generator = AdaptiveEmbeddingIntegrationNetwork()
|
||||
self.discriminator = MultiscaleDiscriminator()
|
||||
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
@@ -244,8 +112,8 @@ def train() -> None:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
file_path = CONFIG.get('training.output', 'file_path')
|
||||
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
|
||||
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
face_swap_model = FaceSwapperTrain()
|
||||
trainer = create_trainer()
|
||||
|
||||
Reference in New Issue
Block a user