Rename attribute to feature

This commit is contained in:
henryruhs
2025-03-16 08:39:35 +01:00
parent 904a447e06
commit 94571c5676
10 changed files with 62 additions and 60 deletions
+1 -1
View File
@@ -76,7 +76,7 @@ num_filters = 16
```
[training.losses]
adversarial_weight = 1.0
attribute_weight = 10.0
feature_weight = 10.0
reconstruction_weight = 10.0
identity_weight = 20.0
gaze_weight = 0.05
+1 -1
View File
@@ -36,7 +36,7 @@ num_filters =
[training.losses]
adversarial_weight =
attribute_weight =
feature_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
+6 -6
View File
@@ -5,7 +5,7 @@ from torch import Tensor, nn
from ..networks.aad import AAD
from ..networks.unet import UNet
from ..types import Attribute, Embedding
from ..types import Feature, Embedding
class Generator(nn.Module):
@@ -16,12 +16,12 @@ class Generator(nn.Module):
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]:
target_attributes = self.encode_attributes(target_tensor)
output_tensor = self.generator(source_embedding, target_attributes)
return output_tensor, target_attributes
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Feature, ...]]:
target_features = self.encode_features(target_tensor)
output_tensor = self.generator(source_embedding, target_features)
return output_tensor, target_features
def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
return self.encoder(input_tensor)
+9 -9
View File
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule
from ..types import Feature, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule
class DiscriminatorLoss(nn.Module):
@@ -49,22 +49,22 @@ class AdversarialLoss(nn.Module):
return adversarial_loss, weighted_adversarial_loss
class AttributeLoss(nn.Module):
class FeautureLoss(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight')
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Loss, Loss]:
def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
temp_tensors = []
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
for target_feature, output_feature in zip(target_features, output_features):
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
temp_tensors.append(temp_tensor)
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
weighted_attribute_loss = attribute_loss * self.config_attribute_weight
return attribute_loss, weighted_attribute_loss
feature_loss = torch.stack(temp_tensors).mean() * 0.5
weighted_feature_loss = feature_loss * self.config_feature_weight
return feature_loss, weighted_feature_loss
class ReconstructionLoss(nn.Module):
+12 -12
View File
@@ -4,7 +4,7 @@ from typing import Tuple
import torch
from torch import Tensor, nn
from ..types import Attribute, Embedding
from ..types import Feature, Embedding
class AAD(nn.Module):
@@ -57,16 +57,16 @@ class AAD(nn.Module):
return layers
def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor:
def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor:
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
target_attribute = target_attributes[index]
temp_tensor = layer(temp_tensors, source_embedding, target_attribute)
target_feature = target_features[index]
temp_tensor = layer(temp_tensors, source_embedding, target_feature)
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
target_attribute = target_attributes[-1]
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_attribute)
target_feature = target_features[-1]
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature)
output_tensor = torch.tanh(temp_tensors)
return output_tensor
@@ -112,12 +112,12 @@ class AdaptiveFeatureModulation(nn.Module):
return shortcut_layers
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
primary_tensor = input_tensor
for primary_layer in self.primary_layers:
if isinstance(primary_layer, FeatureModulation):
primary_tensor = primary_layer(primary_tensor, source_embedding, target_attribute)
primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature)
else:
primary_tensor = primary_layer(primary_tensor)
@@ -126,7 +126,7 @@ class AdaptiveFeatureModulation(nn.Module):
for shortcut_layer in self.shortcut_layers:
if isinstance(shortcut_layer, FeatureModulation):
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_attribute)
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature)
else:
shortcut_tensor = shortcut_layer(shortcut_tensor)
@@ -146,15 +146,15 @@ class FeatureModulation(nn.Module):
self.linear2 = nn.Linear(source_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels)
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_modulation = source_scale * temp_tensor + source_shift
target_scale = self.conv1(target_attribute)
target_shift = self.conv2(target_attribute)
target_scale = self.conv1(target_feature)
target_shift = self.conv2(target_feature)
target_modulation = target_scale * temp_tensor + target_shift
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
+3 -3
View File
@@ -3,7 +3,7 @@ from configparser import ConfigParser
import torch
from torch import Tensor, nn
from ..types import Attribute, Mask
from ..types import Feature, Mask
class MaskNet(nn.Module):
@@ -34,8 +34,8 @@ class MaskNet(nn.Module):
UpSample(num_filters, num_filters)
])
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Mask:
output_mask = torch.cat([ input_tensor, input_attribute ], dim = 1)
def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask:
output_mask = torch.cat([ input_tensor, input_feature ], dim = 1)
for down_sample in self.down_samples:
output_mask = down_sample(output_mask)
+12 -10
View File
@@ -4,6 +4,8 @@ from typing import Tuple
import torch
from torch import Tensor, nn
from face_swapper.src.types import Feature
class UNet(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
@@ -79,25 +81,25 @@ class UNet(nn.Module):
return up_samples
def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]:
def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]:
down_features = []
up_features = []
temp_tensor = target_tensor
temp_feature = target_tensor
for down_sample in self.down_samples:
temp_tensor = down_sample(temp_tensor)
down_features.append(temp_tensor)
temp_feature = down_sample(temp_feature)
down_features.append(temp_feature)
bottleneck_tensor = down_features[-1]
temp_tensor = bottleneck_tensor
bottleneck_feature = down_features[-1]
temp_feature = bottleneck_feature
for index, up_sample in enumerate(self.up_samples):
skip_tensor = down_features[-(index + 2)]
temp_tensor = up_sample(temp_tensor, skip_tensor)
up_features.append(temp_tensor)
temp_feature = up_sample(temp_feature, skip_tensor)
up_features.append(temp_feature)
output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_tensor, *up_features, output_tensor
final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_feature, *up_features, final_feature
class UpSample(nn.Module):
+12 -12
View File
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
from .helper import calc_embedding, overlay_mask
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .networks.masknet import MaskNet
from .types import Batch, Embedding, Mask, OptimizerSet
@@ -45,7 +45,7 @@ class FaceSwapperTrainer(LightningModule):
self.masker = MaskNet(config_parser)
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss(config_parser)
self.attribute_loss = AttributeLoss(config_parser)
self.feature_loss = FeautureLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder)
self.identity_loss = IdentityLoss(config_parser, self.embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
@@ -55,9 +55,9 @@ class FaceSwapperTrainer(LightningModule):
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
with torch.no_grad():
output_tensor, target_attributes = self.generator(source_embedding, target_tensor)
target_attribute = target_attributes[-1]
output_mask = self.masker(target_tensor, target_attribute)
output_tensor, target_features = self.generator(source_embedding, target_tensor)
target_feature = target_features[-1]
output_mask = self.masker(target_tensor, target_feature)
return output_tensor, output_mask
@@ -104,23 +104,23 @@ class FaceSwapperTrainer(LightningModule):
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
generator_output_tensor, generator_target_attributes = self.generator(source_embedding, target_tensor)
generator_output_attributes = self.generator.encode_attributes(generator_output_tensor)
generator_output_tensor, generator_target_features = self.generator(source_embedding, target_tensor)
generator_output_features = self.generator.encode_features(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes)
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
generator_output_attribute = generator_output_attributes[-1]
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
generator_output_feature = generator_output_features[-1]
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_feature.detach())
mask_loss = self.mask_loss(target_tensor, generator_output_mask)
self.toggle_optimizer(generator_optimizer)
@@ -150,7 +150,7 @@ class FaceSwapperTrainer(LightningModule):
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('attribute_loss', attribute_loss)
self.log('feature_loss', feature_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)
self.log('pose_loss', pose_loss)
+1 -1
View File
@@ -6,7 +6,7 @@ from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
BatchMode = Literal['equal', 'same', 'different']
Attribute : TypeAlias = Tensor
Feature : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Mask : TypeAlias = Tensor
Loss : TypeAlias = Tensor
+5 -5
View File
@@ -22,14 +22,14 @@ def test_aad_with_unet(output_size : int) -> None:
}
})
generator = AAD(config_parser).eval()
encoder = UNet(config_parser).eval()
generator = AAD(config_parser).eval()
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, output_size, output_size)
target_attributes = encoder(target_tensor)
output_tensor = generator(source_tensor, target_attributes)
target_features = encoder(target_tensor)
output_tensor = generator(source_tensor, target_features)
assert output_tensor.shape == (1, 3, output_size, output_size)
@@ -50,8 +50,8 @@ def test_mask_net(output_size : int) -> None:
masker = MaskNet(config_parser).eval()
target_tensor = torch.randn(1, 3, output_size, output_size)
target_attribute = torch.randn(1, 64, output_size, output_size)
target_feature = torch.randn(1, 64, output_size, output_size)
output_tensor = masker(target_tensor, target_attribute)
output_tensor = masker(target_tensor, target_feature)
assert output_tensor.shape == (1, 1, output_size, output_size)