diff --git a/face_swapper/README.md b/face_swapper/README.md index 399b69e..e95d853 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -76,7 +76,7 @@ num_filters = 16 ``` [training.losses] adversarial_weight = 1.0 -attribute_weight = 10.0 +feature_weight = 10.0 reconstruction_weight = 10.0 identity_weight = 20.0 gaze_weight = 0.05 diff --git a/face_swapper/config.ini b/face_swapper/config.ini index eec38bd..113a0e1 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -36,7 +36,7 @@ num_filters = [training.losses] adversarial_weight = -attribute_weight = +feature_weight = reconstruction_weight = identity_weight = gaze_weight = diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index d083614..0084d1e 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -5,7 +5,7 @@ from torch import Tensor, nn from ..networks.aad import AAD from ..networks.unet import UNet -from ..types import Attribute, Embedding +from ..types import Feature, Embedding class Generator(nn.Module): @@ -16,12 +16,12 @@ class Generator(nn.Module): self.encoder.apply(init_weight) self.generator.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]: - target_attributes = self.encode_attributes(target_tensor) - output_tensor = self.generator(source_embedding, target_attributes) - return output_tensor, target_attributes + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Feature, ...]]: + target_features = self.encode_features(target_tensor) + output_tensor = self.generator(source_embedding, target_features) + return output_tensor, target_features - def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]: + def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]: return self.encoder(input_tensor) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 08835f6..37034c5 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -7,7 +7,7 @@ from torch import Tensor, nn from torchvision import transforms from ..helper import calc_embedding -from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule +from ..types import Feature, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule class DiscriminatorLoss(nn.Module): @@ -49,22 +49,22 @@ class AdversarialLoss(nn.Module): return adversarial_loss, weighted_adversarial_loss -class AttributeLoss(nn.Module): +class FeautureLoss(nn.Module): def __init__(self, config_parser : ConfigParser) -> None: super().__init__() self.config_batch_size = config_parser.getint('training.loader', 'batch_size') - self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight') + self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight') - def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Loss, Loss]: + def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]: temp_tensors = [] - for target_attribute, output_attribute in zip(target_attributes, output_attributes): - temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config_batch_size, -1), dim = 1).mean() + for target_feature, output_feature in zip(target_features, output_features): + temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean() temp_tensors.append(temp_tensor) - attribute_loss = torch.stack(temp_tensors).mean() * 0.5 - weighted_attribute_loss = attribute_loss * self.config_attribute_weight - return attribute_loss, weighted_attribute_loss + feature_loss = torch.stack(temp_tensors).mean() * 0.5 + weighted_feature_loss = feature_loss * self.config_feature_weight + return feature_loss, weighted_feature_loss class ReconstructionLoss(nn.Module): diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index 7babfbe..b2211e2 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -4,7 +4,7 @@ from typing import Tuple import torch from torch import Tensor, nn -from ..types import Attribute, Embedding +from ..types import Feature, Embedding class AAD(nn.Module): @@ -57,16 +57,16 @@ class AAD(nn.Module): return layers - def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor: + def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor: temp_tensors = self.pixel_shuffle_up_sample(source_embedding) for index, layer in enumerate(self.layers[:-1]): - target_attribute = target_attributes[index] - temp_tensor = layer(temp_tensors, source_embedding, target_attribute) + target_feature = target_features[index] + temp_tensor = layer(temp_tensors, source_embedding, target_feature) temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) - target_attribute = target_attributes[-1] - temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_attribute) + target_feature = target_features[-1] + temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature) output_tensor = torch.tanh(temp_tensors) return output_tensor @@ -112,12 +112,12 @@ class AdaptiveFeatureModulation(nn.Module): return shortcut_layers - def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor: + def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor: primary_tensor = input_tensor for primary_layer in self.primary_layers: if isinstance(primary_layer, FeatureModulation): - primary_tensor = primary_layer(primary_tensor, source_embedding, target_attribute) + primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature) else: primary_tensor = primary_layer(primary_tensor) @@ -126,7 +126,7 @@ class AdaptiveFeatureModulation(nn.Module): for shortcut_layer in self.shortcut_layers: if isinstance(shortcut_layer, FeatureModulation): - shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_attribute) + shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature) else: shortcut_tensor = shortcut_layer(shortcut_tensor) @@ -146,15 +146,15 @@ class FeatureModulation(nn.Module): self.linear2 = nn.Linear(source_channels, input_channels) self.instance_norm = nn.InstanceNorm2d(input_channels) - def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor: + def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor: temp_tensor = self.instance_norm(input_tensor) source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor) source_modulation = source_scale * temp_tensor + source_shift - target_scale = self.conv1(target_attribute) - target_shift = self.conv2(target_attribute) + target_scale = self.conv1(target_feature) + target_shift = self.conv2(target_feature) target_modulation = target_scale * temp_tensor + target_shift temp_mask = torch.sigmoid(self.conv3(temp_tensor)) diff --git a/face_swapper/src/networks/masknet.py b/face_swapper/src/networks/masknet.py index 0a465e4..2b66467 100644 --- a/face_swapper/src/networks/masknet.py +++ b/face_swapper/src/networks/masknet.py @@ -3,7 +3,7 @@ from configparser import ConfigParser import torch from torch import Tensor, nn -from ..types import Attribute, Mask +from ..types import Feature, Mask class MaskNet(nn.Module): @@ -34,8 +34,8 @@ class MaskNet(nn.Module): UpSample(num_filters, num_filters) ]) - def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Mask: - output_mask = torch.cat([ input_tensor, input_attribute ], dim = 1) + def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask: + output_mask = torch.cat([ input_tensor, input_feature ], dim = 1) for down_sample in self.down_samples: output_mask = down_sample(output_mask) diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 1a27fe4..c102b18 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -4,6 +4,8 @@ from typing import Tuple import torch from torch import Tensor, nn +from face_swapper.src.types import Feature + class UNet(nn.Module): def __init__(self, config_parser : ConfigParser) -> None: @@ -79,25 +81,25 @@ class UNet(nn.Module): return up_samples - def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]: + def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]: down_features = [] up_features = [] - temp_tensor = target_tensor + temp_feature = target_tensor for down_sample in self.down_samples: - temp_tensor = down_sample(temp_tensor) - down_features.append(temp_tensor) + temp_feature = down_sample(temp_feature) + down_features.append(temp_feature) - bottleneck_tensor = down_features[-1] - temp_tensor = bottleneck_tensor + bottleneck_feature = down_features[-1] + temp_feature = bottleneck_feature for index, up_sample in enumerate(self.up_samples): skip_tensor = down_features[-(index + 2)] - temp_tensor = up_sample(temp_tensor, skip_tensor) - up_features.append(temp_tensor) + temp_feature = up_sample(temp_feature, skip_tensor) + up_features.append(temp_feature) - output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) - return bottleneck_tensor, *up_features, output_tensor + final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False) + return bottleneck_feature, *up_features, final_feature class UpSample(nn.Module): diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 502c907..02df9aa 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,7 +16,7 @@ from .dataset import DynamicDataset from .helper import calc_embedding, overlay_mask from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss +from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss from .networks.masknet import MaskNet from .types import Batch, Embedding, Mask, OptimizerSet @@ -45,7 +45,7 @@ class FaceSwapperTrainer(LightningModule): self.masker = MaskNet(config_parser) self.discriminator_loss = DiscriminatorLoss() self.adversarial_loss = AdversarialLoss(config_parser) - self.attribute_loss = AttributeLoss(config_parser) + self.feature_loss = FeautureLoss(config_parser) self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder) self.identity_loss = IdentityLoss(config_parser, self.embedder) self.motion_loss = MotionLoss(config_parser, self.motion_extractor) @@ -55,9 +55,9 @@ class FaceSwapperTrainer(LightningModule): def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]: with torch.no_grad(): - output_tensor, target_attributes = self.generator(source_embedding, target_tensor) - target_attribute = target_attributes[-1] - output_mask = self.masker(target_tensor, target_attribute) + output_tensor, target_features = self.generator(source_embedding, target_tensor) + target_feature = target_features[-1] + output_mask = self.masker(target_tensor, target_feature) return output_tensor, output_mask @@ -104,23 +104,23 @@ class FaceSwapperTrainer(LightningModule): generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) - generator_output_tensor, generator_target_attributes = self.generator(source_embedding, target_tensor) - generator_output_attributes = self.generator.encode_attributes(generator_output_tensor) + generator_output_tensor, generator_target_features = self.generator(source_embedding, target_tensor) + generator_output_features = self.generator.encode_features(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors) - attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes) + feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor) identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor) pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor) gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor) - generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss + generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss discriminator_source_tensors = self.discriminator(source_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) - generator_output_attribute = generator_output_attributes[-1] - generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach()) + generator_output_feature = generator_output_features[-1] + generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_feature.detach()) mask_loss = self.mask_loss(target_tensor, generator_output_mask) self.toggle_optimizer(generator_optimizer) @@ -150,7 +150,7 @@ class FaceSwapperTrainer(LightningModule): self.log('generator_loss', generator_loss, prog_bar = True) self.log('discriminator_loss', discriminator_loss, prog_bar = True) self.log('adversarial_loss', adversarial_loss) - self.log('attribute_loss', attribute_loss) + self.log('feature_loss', feature_loss) self.log('reconstruction_loss', reconstruction_loss) self.log('identity_loss', identity_loss) self.log('pose_loss', pose_loss) diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 3789995..915741d 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -6,7 +6,7 @@ from torch.nn import Module Batch : TypeAlias = Tuple[Tensor, Tensor] BatchMode = Literal['equal', 'same', 'different'] -Attribute : TypeAlias = Tensor +Feature : TypeAlias = Tensor Embedding : TypeAlias = Tensor Mask : TypeAlias = Tensor Loss : TypeAlias = Tensor diff --git a/face_swapper/tests/test_networks.py b/face_swapper/tests/test_networks.py index 37a9c1a..c0551be 100644 --- a/face_swapper/tests/test_networks.py +++ b/face_swapper/tests/test_networks.py @@ -22,14 +22,14 @@ def test_aad_with_unet(output_size : int) -> None: } }) - generator = AAD(config_parser).eval() encoder = UNet(config_parser).eval() + generator = AAD(config_parser).eval() source_tensor = torch.randn(1, 512) target_tensor = torch.randn(1, 3, output_size, output_size) - target_attributes = encoder(target_tensor) - output_tensor = generator(source_tensor, target_attributes) + target_features = encoder(target_tensor) + output_tensor = generator(source_tensor, target_features) assert output_tensor.shape == (1, 3, output_size, output_size) @@ -50,8 +50,8 @@ def test_mask_net(output_size : int) -> None: masker = MaskNet(config_parser).eval() target_tensor = torch.randn(1, 3, output_size, output_size) - target_attribute = torch.randn(1, 64, output_size, output_size) + target_feature = torch.randn(1, 64, output_size, output_size) - output_tensor = masker(target_tensor, target_attribute) + output_tensor = masker(target_tensor, target_feature) assert output_tensor.shape == (1, 1, output_size, output_size)