This commit is contained in:
harisreedhar
2025-02-10 22:43:15 +05:30
committed by henryruhs
parent b7e2d3ccd7
commit 2ed558a873
10 changed files with 310 additions and 286 deletions
View File
+1 -2
View File
@@ -24,8 +24,7 @@ class DataLoaderVGG(TensorDataset):
image_path_set = {}
for directory_path in self.directory_paths:
image_paths = glob.glob(dataset_image_pattern.format(directory_path))
image_paths.extend(image_paths)
image_paths.extend(glob.glob(dataset_image_pattern.format(directory_path)))
image_path_set[directory_path] = image_paths
return image_paths, image_path_set
+2 -2
View File
@@ -3,7 +3,7 @@ from os import makedirs
import torch
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -17,7 +17,7 @@ def export() -> None:
makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
model = AdaptiveEmbeddingIntegrationNetwork(512, 2)
model = AdaptiveEmbeddingIntegrationNetwork()
model.load_state_dict(state_dict)
model.eval()
source_tensor = torch.randn(1, 512)
+2 -2
View File
@@ -3,8 +3,8 @@ import configparser
import cv2
import torch
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
from .types import Generator, IdEmbedder, VisionFrame
CONFIG = configparser.ConfigParser()
@@ -28,7 +28,7 @@ def infer() -> None:
output_path = CONFIG.get('inferencing', 'output_path')
state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator')
generator = AdaptiveEmbeddingIntegrationNetwork(512, 2)
generator = AdaptiveEmbeddingIntegrationNetwork()
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
@@ -1,3 +1,4 @@
import configparser
from itertools import chain
from typing import List
@@ -6,7 +7,41 @@ import torch.nn
import torch.nn as nn
from torch import Tensor
from .types import DiscriminatorOutputs
from face_swapper.src.types import DiscriminatorOutputs
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class MultiscaleDiscriminator(nn.Module):
def __init__(self) -> None:
super(MultiscaleDiscriminator, self).__init__()
self.input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
self.num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
self.kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
self.num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
self.num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
self.prepare_discriminators()
def prepare_discriminators(self) -> None:
for discriminator_index in range(self.num_discriminators):
single_discriminator = NLayerDiscriminator(self.input_channels, self.num_filters, self.num_layers, self.kernel_size)
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
discriminator_outputs = []
temp_tensor = input_tensor
for discriminator_index in range(self.num_discriminators):
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
discriminator_outputs.append([ model_layers(temp_tensor) ])
if discriminator_index < (self.num_discriminators - 1):
temp_tensor = self.downsample(temp_tensor)
return discriminator_outputs
class NLayerDiscriminator(nn.Module):
@@ -58,29 +93,3 @@ class NLayerDiscriminator(nn.Module):
def forward(self, input_tensor : Tensor) -> Tensor:
return self.model(input_tensor)
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int, kernel_size : int):
super(MultiscaleDiscriminator, self).__init__()
self.num_discriminators = num_discriminators
self.num_layers = num_layers
for discriminator_index in range(num_discriminators):
single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers, kernel_size)
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
discriminator_outputs = []
temp_tensor = input_tensor
for discriminator_index in range(self.num_discriminators):
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
discriminator_outputs.append([ model_layers(temp_tensor) ])
if discriminator_index < (self.num_discriminators - 1):
temp_tensor = self.downsample(temp_tensor)
return discriminator_outputs
+43
View File
@@ -0,0 +1,43 @@
import configparser
from typing import Tuple
import torch.nn as nn
from face_swapper.src.networks.attribute_modulator import AADGenerator
from face_swapper.src.networks.encoder import UNet
from face_swapper.src.types import SourceEmbedding, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
def __init__(self) -> None:
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.encoder = UNet()
self.generator = AADGenerator(id_channels, num_blocks)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
target_attributes = self.get_attributes(target)
swap_tensor = self.generator(target_attributes, source_embedding)
return swap_tensor, target_attributes
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
return self.encoder(target)
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(std = 0.001)
module.bias.data.zero_()
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight.data)
if isinstance(module, nn.ConvTranspose2d):
nn.init.xavier_normal_(module.weight.data)
+137
View File
@@ -0,0 +1,137 @@
import configparser
from typing import Tuple
import torch
from pytorch_msssim import ssim
from torch import Tensor
from face_swapper.src.helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from face_swapper.src.types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class FaceSwapperLoss:
def __init__(self) -> None:
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = torch.nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor, is_same_person = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
generator_loss_set = {}
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes)
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
if weight_pose > 0:
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
if weight_gaze > 0:
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
return generator_loss_set
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fake = torch.Tensor(0)
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
loss_true = torch.Tensor(0)
for true_discriminator_output in real_discriminator_outputs:
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_loss_set
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0])
loss_adversarial = torch.mean(loss_adversarial)
return loss_adversarial
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attribute *= 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_id
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_pose
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
gaze_loss = left_gaze_loss + right_gaze_loss
return gaze_loss
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
return landmarks
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
@@ -1,25 +1,7 @@
from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch import Tensor, nn as nn
from .types import SourceEmbedding, TargetAttributes, VisionTensor
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
def __init__(self, id_channels : int, num_blocks : int) -> None:
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
self.encoder = UNet()
self.generator = AADGenerator(id_channels, num_blocks)
def forward(self, target : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
target_attributes = self.get_attributes(target)
swap_tensor = self.generator(target_attributes, source_embedding)
return swap_tensor, target_attributes
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
return self.encoder(target)
from face_swapper.src.types import SourceEmbedding, TargetAttributes
class AADGenerator(nn.Module):
@@ -34,7 +16,6 @@ class AADGenerator(nn.Module):
self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks)
self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks)
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
self.apply(init_weight)
def forward(self, target_attributes : TargetAttributes, source_embedding : SourceEmbedding) -> Tensor:
feature_map = self.upsample(source_embedding)
@@ -49,42 +30,6 @@ class AADGenerator(nn.Module):
return torch.tanh(output)
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
self.downsampler_1 = DownSample(3, 32)
self.downsampler_2 = DownSample(32, 64)
self.downsampler_3 = DownSample(64, 128)
self.downsampler_4 = DownSample(128, 256)
self.downsampler_5 = DownSample(256, 512)
self.downsampler_6 = DownSample(512, 1024)
self.bottleneck = DownSample(1024, 1024)
self.upsampler_1 = Upsample(1024, 1024)
self.upsampler_2 = Upsample(2048, 512)
self.upsampler_3 = Upsample(1024, 256)
self.upsampler_4 = Upsample(512, 128)
self.upsampler_5 = Upsample(256, 64)
self.upsampler_6 = Upsample(128, 32)
self.apply(init_weight)
def forward(self, target : VisionTensor) -> TargetAttributes:
downsample_feature_1 = self.downsampler_1(target)
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
bottleneck_output = self.bottleneck(downsample_feature_6)
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
class AADLayer(nn.Module):
def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None:
super(AADLayer, self).__init__()
@@ -109,22 +54,18 @@ class AADLayer(nn.Module):
return feature_blend
class AddBlocksSequential(nn.Sequential):
#todo: what are inputs? improve the name
def forward(self, *inputs : Tuple[Tensor, Tensor, SourceEmbedding]) -> Tuple[Tuple[Tensor, Tensor, SourceEmbedding], ...]:
_, attribute_embedding, id_embedding = inputs
modules = self._modules.values() #todo: what kind of modules?
class AADSequential(nn.Module):
def __init__(self, *args : nn.Module) -> None:
super(AADSequential, self).__init__()
self.layers = nn.ModuleList(args)
for module_index, module in enumerate(modules):
if module_index % 3 == 0 and module_index > 0:
inputs = (inputs, attribute_embedding, id_embedding) # type:ignore[assignment]
if isinstance(inputs, torch.Tensor):
inputs = module(inputs)
def forward(self, feature_map: Tensor, attribute_embedding: Tensor, id_embedding: SourceEmbedding) -> Tensor:
for layer in self.layers:
if isinstance(layer, AADLayer):
feature_map = layer(feature_map, attribute_embedding, id_embedding)
else:
inputs = module(*inputs)
return inputs #todo: would be easier to read when you just return xxx_inputs, attribute_embedding, id_embedding ?
feature_map = layer(feature_map)
return feature_map
class AADResBlock(nn.Module):
@@ -147,11 +88,11 @@ class AADResBlock(nn.Module):
nn.Conv2d(input_channels, intermediate_channels, kernel_size = 3, padding = 1, bias = False)
]
)
self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks)
self.primary_add_blocks = AADSequential(*primary_add_blocks)
def prepare_auxiliary_add_blocks(self, input_channels : int, attribute_channels : int, id_channels : int, output_channels : int) -> None:
if input_channels > output_channels:
auxiliary_add_blocks = AddBlocksSequential(
auxiliary_add_blocks = AADSequential(
AADLayer(input_channels, attribute_channels, id_channels),
nn.ReLU(inplace = True),
nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1, bias = False)
@@ -168,34 +109,6 @@ class AADResBlock(nn.Module):
return output_feature
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(DownSample, self).__init__()
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, temp : Tensor) -> Tensor:
temp = self.conv(temp)
temp = self.batch_norm(temp)
temp = self.leaky_relu(temp)
return temp
class Upsample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(Upsample, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
temp = self.deconv(temp)
temp = self.batch_norm(temp)
temp = self.leaky_relu(temp)
return torch.cat((temp, skip_tensor), dim = 1)
class PixelShuffleUpsample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(PixelShuffleUpsample, self).__init__()
@@ -206,15 +119,3 @@ class PixelShuffleUpsample(nn.Module):
temp = self.conv(temp.view(temp.shape[0], -1, 1, 1))
temp = self.pixel_shuffle(temp)
return temp
def init_weight(module : nn.Module) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(std = 0.001)
module.bias.data.zero_()
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight.data)
if isinstance(module, nn.ConvTranspose2d):
nn.init.xavier_normal_(module.weight.data)
+67
View File
@@ -0,0 +1,67 @@
import torch
from torch import Tensor, nn as nn
from face_swapper.src.types import TargetAttributes, VisionTensor
class Upsample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(Upsample, self).__init__()
self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
temp = self.deconv(temp)
temp = self.batch_norm(temp)
temp = self.leaky_relu(temp)
return torch.cat((temp, skip_tensor), dim = 1)
class DownSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super(DownSample, self).__init__()
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
self.batch_norm = nn.BatchNorm2d(output_channels)
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
def forward(self, temp : Tensor) -> Tensor:
temp = self.conv(temp)
temp = self.batch_norm(temp)
temp = self.leaky_relu(temp)
return temp
class UNet(nn.Module):
def __init__(self) -> None:
super(UNet, self).__init__()
self.downsampler_1 = DownSample(3, 32)
self.downsampler_2 = DownSample(32, 64)
self.downsampler_3 = DownSample(64, 128)
self.downsampler_4 = DownSample(128, 256)
self.downsampler_5 = DownSample(256, 512)
self.downsampler_6 = DownSample(512, 1024)
self.bottleneck = DownSample(1024, 1024)
self.upsampler_1 = Upsample(1024, 1024)
self.upsampler_2 = Upsample(2048, 512)
self.upsampler_3 = Upsample(1024, 256)
self.upsampler_4 = Upsample(512, 128)
self.upsampler_5 = Upsample(256, 64)
self.upsampler_6 = Upsample(128, 32)
def forward(self, target : VisionTensor) -> TargetAttributes:
downsample_feature_1 = self.downsampler_1(target)
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
bottleneck_output = self.bottleneck(downsample_feature_6)
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
+8 -140
View File
@@ -8,157 +8,25 @@ import torchvision
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import Optimizer
from pytorch_msssim import ssim
from torch import Tensor
from torch.utils.data import DataLoader
from .data_loader import DataLoaderVGG
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from .types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SourceEmbedding, SwapAttributes, TargetAttributes, VisionTensor
from .helper import calc_id_embedding
from .models.discriminator import MultiscaleDiscriminator
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
from .models.loss import FaceSwapperLoss
from .types import Batch, SourceEmbedding, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class FaceSwapperLoss:
def __init__(self) -> None:
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = torch.nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.id_embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor, is_same_person = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
generator_loss_set = {}
generator_loss_set['loss_adversarial'] = self.calc_adversarial_loss(discriminator_outputs)
generator_loss_set['loss_id'] = self.calc_id_loss(source_tensor, swap_tensor)
generator_loss_set['loss_attribute'] = self.calc_attribute_loss(target_attributes, swap_attributes)
generator_loss_set['loss_reconstruction'] = self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
if weight_pose > 0:
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_pose'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
if weight_gaze > 0:
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
else:
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
return generator_loss_set
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fake = torch.Tensor(0)
for fake_discriminator_output in fake_discriminator_outputs:
loss_fake += hinge_fake_loss(fake_discriminator_output[0]).mean()
loss_true = torch.Tensor(0)
for true_discriminator_output in real_discriminator_outputs:
loss_true += hinge_real_loss(true_discriminator_output[0]).mean()
discriminator_loss_set['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5
return discriminator_loss_set
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)
for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0])
loss_adversarial = torch.mean(loss_adversarial)
return loss_adversarial
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()
loss_attribute *= 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_id
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_pose
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
gaze_loss = left_gaze_loss + right_gaze_loss
return gaze_loss
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
return landmarks
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class FaceSwapperTrain(pytorch_lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
self.generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks)
self.discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators, kernel_size)
self.generator = AdaptiveEmbeddingIntegrationNetwork()
self.discriminator = MultiscaleDiscriminator()
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
def forward(self, target_tensor : VisionTensor, source_embedding : SourceEmbedding) -> Tuple[VisionTensor, TargetAttributes]:
@@ -244,8 +112,8 @@ def train() -> None:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
file_path = CONFIG.get('training.output', 'file_path')
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
dataset = DataLoaderVGG(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
face_swap_model = FaceSwapperTrain()
trainer = create_trainer()