diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 85e85e1..9c0525d 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -53,31 +53,37 @@ class AdversarialLoss(nn.Module): class AttributeLoss(nn.Module): - def __init__(self) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() + self.config =\ + { + 'batch_size': config_parser.getint('training.loader', 'batch_size'), + 'attribute_weight': config_parser.getfloat('training.losses', 'attribute_weight') + } def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]: - batch_size = CONFIG.getint('training.loader', 'batch_size') - attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight') temp_tensors = [] for target_attribute, output_attribute in zip(target_attributes, output_attributes): - temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(batch_size, -1), dim = 1).mean() + temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config.get('batch_size'), -1), dim = 1).mean() temp_tensors.append(temp_tensor) attribute_loss = torch.stack(temp_tensors).mean() * 0.5 - weighted_attribute_loss = attribute_loss * attribute_weight + weighted_attribute_loss = attribute_loss * self.config.get('attribute_weight') return attribute_loss, weighted_attribute_loss class ReconstructionLoss(nn.Module): - def __init__(self, embedder : EmbedderModule) -> None: + def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None: super().__init__() + self.config =\ + { + 'reconstruction_weight': config_parser.getfloat('training.losses', 'reconstruction_weight') + } self.embedder = embedder self.mse_loss = nn.MSELoss() def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: - reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight') source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0)) has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8 @@ -88,27 +94,35 @@ class ReconstructionLoss(nn.Module): data_range = float(torch.max(output_tensor) - torch.min(output_tensor)) visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean() reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5 - weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight + weighted_reconstruction_loss = reconstruction_loss * self.config.get('reconstruction_weight') return reconstruction_loss, weighted_reconstruction_loss class IdentityLoss(nn.Module): - def __init__(self, embedder : EmbedderModule) -> None: + def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None: super().__init__() + self.config =\ + { + 'identity_weight': config_parser.getfloat('training.losses', 'identity_weight') + } self.embedder = embedder def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: - identity_weight = CONFIG.getfloat('training.losses', 'identity_weight') output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10)) source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean() - weighted_identity_loss = identity_loss * identity_weight + weighted_identity_loss = identity_loss * self.config.get('identity_weight') return identity_loss, weighted_identity_loss class MotionLoss(nn.Module): - def __init__(self, motion_extractor : MotionExtractorModule): + def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule): super().__init__() + self.config =\ + { + 'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'), + 'expression_weight': config_parser.getfloat('training.losses', 'expression_weight') + } self.motion_extractor = motion_extractor self.mse_loss = nn.MSELoss() @@ -120,7 +134,6 @@ class MotionLoss(nn.Module): return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]: - pose_weight = CONFIG.getfloat('training.losses', 'pose_weight') temp_tensors = [] for target_pose, output_pose in zip(target_poses, output_poses): @@ -128,13 +141,12 @@ class MotionLoss(nn.Module): temp_tensors.append(temp_tensor) pose_loss = torch.stack(temp_tensors).mean() - weighted_pose_loss = pose_loss * pose_weight + weighted_pose_loss = pose_loss * self.config.get('pose_weight') return pose_loss, weighted_pose_loss def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]: - expression_weight = CONFIG.getfloat('training.losses', 'expression_weight') expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean() - weighted_expression_loss = expression_loss * expression_weight + weighted_expression_loss = expression_loss * self.config.get('expression_weight') return expression_loss, weighted_expression_loss def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]: @@ -148,13 +160,17 @@ class MotionLoss(nn.Module): class GazeLoss(nn.Module): - def __init__(self, gazer : GazerModule) -> None: + def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None: super().__init__() + self.config =\ + { + 'gaze_weight': config_parser.getfloat('training.losses', 'gaze_weight'), + 'output_size': config_parser.getint('training.model.generator', 'output_size') + } self.gazer = gazer self.l1_loss = nn.L1Loss() def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]: - gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') output_pitch, output_yaw = self.detect_gaze(output_tensor) target_pitch, target_yaw = self.detect_gaze(target_tensor) @@ -162,12 +178,11 @@ class GazeLoss(nn.Module): yaw_loss = self.l1_loss(output_yaw, target_yaw) gaze_loss = (pitch_loss + yaw_loss) * 0.5 - weighted_gaze_loss = gaze_loss * gaze_weight + weighted_gaze_loss = gaze_loss * self.config.get('gaze_weight') return gaze_loss, weighted_gaze_loss def detect_gaze(self, input_tensor : Tensor) -> Gaze: - output_size = CONFIG.getint('training.model.generator', 'output_size') - crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * output_size).int() + crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config.get('output_size')).int() crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]] crop_tensor = (crop_tensor + 1) * 0.5 crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor) diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index 8810ee8..8bf261e 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -1,3 +1,5 @@ +from configparser import ConfigParser + import torch from torch import Tensor, nn @@ -5,51 +7,55 @@ from ..types import Attributes, Embedding class AAD(nn.Module): - def __init__(self, identity_channels : int, output_channels : int, output_size : int, num_blocks : int) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.identity_channels = identity_channels - self.output_channels = output_channels - self.output_size = output_size - self.num_blocks = num_blocks - self.pixel_shuffle_up_sample = PixelShuffleUpSample(identity_channels, output_channels) + self.config =\ + { + 'identity_channels': config_parser.getint('training.model.generator', 'identity_channels'), + 'output_channels': config_parser.getint('training.model.generator', 'output_channels'), + 'output_size': config_parser.getint('training.model.generator', 'output_size'), + 'num_blocks': config_parser.getint('training.model.generator', 'num_blocks') + } + self.config_parser = config_parser + self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config.get('identity_channels'), self.config.get('output_channels')) self.layers = self.create_layers() def create_layers(self) -> nn.ModuleList: layers = nn.ModuleList() - if self.output_size == 128: + if self.config.get('output_size') == 128: layers.extend( [ - AdaptiveFeatureModulation(512, 512, 512, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(512, 512, 1024, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(512, 512, 512, self.identity_channels, self.num_blocks), + AdaptiveFeatureModulation(512, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(512, 512, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(512, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks')) ]) - if self.output_size == 256: + if self.config.get('output_size') == 256: layers.extend( [ - AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(1024, 1024, 2048, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks) + AdaptiveFeatureModulation(1024, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(1024, 1024, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(1024, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(1024, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks')) ]) - if self.output_size == 512: + if self.config.get('output_size') == 512: layers.extend( [ - AdaptiveFeatureModulation(2048, 2048, 2048, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(2048, 2048, 4096, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(2048, 2048, 2048, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(2048, 1024, 1024, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks) + AdaptiveFeatureModulation(2048, 2048, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(2048, 2048, 4096, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(2048, 2048, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(2048, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(1024, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks')) ]) layers.extend( [ - AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks), - AdaptiveFeatureModulation(64, 3, 64, self.identity_channels, self.num_blocks) + AdaptiveFeatureModulation(512, 256, 256, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(256, 128, 128, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(128, 64, 64, self.config.get('identity_channels'), self.config.get('num_blocks')), + AdaptiveFeatureModulation(64, 3, 64, self.config.get('identity_channels'), self.config.get('num_blocks')) ]) return layers @@ -69,40 +75,43 @@ class AAD(nn.Module): class AdaptiveFeatureModulation(nn.Module): def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None: super().__init__() - self.input_channels = input_channels - self.output_channels = output_channels - self.attribute_channels = attribute_channels - self.identity_channels = identity_channels - self.num_blocks = num_blocks + self.context =\ + { + 'input_channels': input_channels, + 'output_channels': output_channels, + 'attribute_channels': attribute_channels, + 'identity_channels': identity_channels, + 'num_blocks': num_blocks + } self.primary_layers = self.create_primary_layers() self.shortcut_layers = self.create_shortcut_layers() def create_primary_layers(self) -> nn.ModuleList: primary_layers = nn.ModuleList() - for index in range(self.num_blocks): + for index in range(self.context.get('num_blocks')): primary_layers.extend( [ - FeatureModulation(self.input_channels, self.attribute_channels, self.identity_channels), + FeatureModulation(self.context.get('input_channels'), self.context.get('attribute_channels'), self.context.get('identity_channels')), nn.ReLU(inplace = True) ]) - if index < self.num_blocks - 1: - primary_layers.append(nn.Conv2d(self.input_channels, self.input_channels, kernel_size = 3, padding = 1, bias = False)) + if index < self.context.get('num_blocks') - 1: + primary_layers.append(nn.Conv2d(self.context.get('input_channels'), self.context.get('input_channels'), kernel_size = 3, padding = 1, bias = False)) else: - primary_layers.append(nn.Conv2d(self.input_channels, self.output_channels, kernel_size = 3, padding = 1, bias = False)) + primary_layers.append(nn.Conv2d(self.context.get('input_channels'), self.context.get('output_channels'), kernel_size = 3, padding = 1, bias = False)) return primary_layers def create_shortcut_layers(self) -> nn.ModuleList: shortcut_layers = nn.ModuleList() - if self.input_channels > self.output_channels: + if self.context.get('input_channels') > self.context.get('output_channels'): shortcut_layers.extend( [ - FeatureModulation(self.input_channels, self.attribute_channels, self.identity_channels), + FeatureModulation(self.context.get('input_channels'), self.context.get('attribute_channels'), self.context.get('identity_channels')), nn.ReLU(inplace = True), - nn.Conv2d(self.input_channels, self.output_channels, kernel_size = 3, padding = 1, bias = False) + nn.Conv2d(self.context.get('input_channels'), self.context.get('output_channels'), kernel_size = 3, padding = 1, bias = False) ]) return shortcut_layers @@ -116,7 +125,7 @@ class AdaptiveFeatureModulation(nn.Module): else: primary_tensor = primary_layer(primary_tensor) - if self.input_channels > self.output_channels: + if self.context.get('input_channels') > self.context.get('output_channels'): shortcut_tensor = input_tensor for shortcut_layer in self.shortcut_layers: @@ -133,7 +142,10 @@ class AdaptiveFeatureModulation(nn.Module): class FeatureModulation(nn.Module): def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None: super().__init__() - self.input_channels = input_channels + self.context =\ + { + 'input_channels': input_channels + } self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1) self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1) @@ -148,8 +160,8 @@ class FeatureModulation(nn.Module): attribute_shift = self.conv2(attribute_embedding) attribute_modulation = attribute_scale * temp_tensor + attribute_shift - identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor) - identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor) + identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context.get('input_channels'), 1, 1).expand_as(temp_tensor) + identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context.get('input_channels'), 1, 1).expand_as(temp_tensor) identity_modulation = identity_scale * temp_tensor + identity_shift temp_mask = torch.sigmoid(self.conv3(temp_tensor)) diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 250c7b0..892a593 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -106,7 +106,7 @@ class UNet(nn.Module): class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) @@ -121,7 +121,7 @@ class UpSample(nn.Module): class DownSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() - self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) self.batch_norm = nn.BatchNorm2d(output_channels) self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a375dad..c9ffde4 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -39,10 +39,9 @@ class FaceSwapperTrainer(LightningModule): self.embedder = torch.jit.load(self.config.get('embedder_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call] self.gazer = torch.jit.load(self.config.get('gazer_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call] self.motion_extractor = torch.jit.load(self.config.get('motion_extractor_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call] - self.generator = Generator(config_parser) self.discriminator = Discriminator(config_parser) - self.discriminator_loss = DiscriminatorLoss(config_parser) + self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss(config_parser) self.attribute_loss = AttributeLoss(config_parser) self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder) diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py index e33f5ca..16305e8 100644 --- a/face_swapper/tests/test_networks.py +++ b/face_swapper/tests/test_networks.py @@ -1,3 +1,5 @@ +from configparser import ConfigParser + import pytest import torch @@ -7,18 +9,17 @@ from face_swapper.src.networks.unet import UNet @pytest.mark.parametrize('output_size', [ 128, 256, 512 ]) def test_aad_with_unet(output_size : int) -> None: - identity_channels = 512 - output_channels = 1024 - if output_size == 128: - output_channels = 2048 - if output_size == 256: - output_channels = 4096 - if output_size == 512: - output_channels = 8192 - num_blocks = 2 + config_parser = ConfigParser() + config_parser['training.model.generator'] =\ + { + 'identity_channels': '512', + 'output_channels': str(output_size * 16), + 'output_size': str(output_size), + 'num_blocks': '2' + } - generator = AAD(identity_channels, output_channels, output_size, num_blocks).eval() - encoder = UNet(output_size).eval() + generator = AAD(config_parser).eval() + encoder = UNet(config_parser).eval() source_tensor = torch.randn(1, 512) target_tensor = torch.randn(1, 3, output_size, output_size)