mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Rename attribute to feature
This commit is contained in:
@@ -76,7 +76,7 @@ num_filters = 16
|
||||
```
|
||||
[training.losses]
|
||||
adversarial_weight = 1.0
|
||||
attribute_weight = 10.0
|
||||
feature_weight = 10.0
|
||||
reconstruction_weight = 10.0
|
||||
identity_weight = 20.0
|
||||
gaze_weight = 0.05
|
||||
|
||||
@@ -36,7 +36,7 @@ num_filters =
|
||||
|
||||
[training.losses]
|
||||
adversarial_weight =
|
||||
attribute_weight =
|
||||
feature_weight =
|
||||
reconstruction_weight =
|
||||
identity_weight =
|
||||
gaze_weight =
|
||||
|
||||
@@ -5,7 +5,7 @@ from torch import Tensor, nn
|
||||
|
||||
from ..networks.aad import AAD
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Attribute, Embedding
|
||||
from ..types import Feature, Embedding
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@@ -16,12 +16,12 @@ class Generator(nn.Module):
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]:
|
||||
target_attributes = self.encode_attributes(target_tensor)
|
||||
output_tensor = self.generator(source_embedding, target_attributes)
|
||||
return output_tensor, target_attributes
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Feature, ...]]:
|
||||
target_features = self.encode_features(target_tensor)
|
||||
output_tensor = self.generator(source_embedding, target_features)
|
||||
return output_tensor, target_features
|
||||
|
||||
def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
|
||||
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule
|
||||
from ..types import Feature, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
@@ -49,22 +49,22 @@ class AdversarialLoss(nn.Module):
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
class AttributeLoss(nn.Module):
|
||||
class FeautureLoss(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
|
||||
self.config_feature_weight = config_parser.getfloat('training.losses', 'feature_weight')
|
||||
|
||||
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Loss, Loss]:
|
||||
def forward(self, target_features : Tuple[Feature, ...], output_features : Tuple[Feature, ...]) -> Tuple[Loss, Loss]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
|
||||
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
for target_feature, output_feature in zip(target_features, output_features):
|
||||
temp_tensor = torch.mean(torch.pow(output_feature - target_feature, 2).reshape(self.config_batch_size, -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_attribute_loss = attribute_loss * self.config_attribute_weight
|
||||
return attribute_loss, weighted_attribute_loss
|
||||
feature_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_feature_loss = feature_loss * self.config_feature_weight
|
||||
return feature_loss, weighted_feature_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Tuple
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Attribute, Embedding
|
||||
from ..types import Feature, Embedding
|
||||
|
||||
|
||||
class AAD(nn.Module):
|
||||
@@ -57,16 +57,16 @@ class AAD(nn.Module):
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor:
|
||||
def forward(self, source_embedding : Embedding, target_features : Tuple[Feature, ...]) -> Tensor:
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
target_attribute = target_attributes[index]
|
||||
temp_tensor = layer(temp_tensors, source_embedding, target_attribute)
|
||||
target_feature = target_features[index]
|
||||
temp_tensor = layer(temp_tensors, source_embedding, target_feature)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
|
||||
target_attribute = target_attributes[-1]
|
||||
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_attribute)
|
||||
target_feature = target_features[-1]
|
||||
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_feature)
|
||||
output_tensor = torch.tanh(temp_tensors)
|
||||
return output_tensor
|
||||
|
||||
@@ -112,12 +112,12 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
|
||||
return shortcut_layers
|
||||
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
|
||||
primary_tensor = input_tensor
|
||||
|
||||
for primary_layer in self.primary_layers:
|
||||
if isinstance(primary_layer, FeatureModulation):
|
||||
primary_tensor = primary_layer(primary_tensor, source_embedding, target_attribute)
|
||||
primary_tensor = primary_layer(primary_tensor, source_embedding, target_feature)
|
||||
else:
|
||||
primary_tensor = primary_layer(primary_tensor)
|
||||
|
||||
@@ -126,7 +126,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
|
||||
for shortcut_layer in self.shortcut_layers:
|
||||
if isinstance(shortcut_layer, FeatureModulation):
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_attribute)
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_feature)
|
||||
else:
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor)
|
||||
|
||||
@@ -146,15 +146,15 @@ class FeatureModulation(nn.Module):
|
||||
self.linear2 = nn.Linear(source_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_feature : Feature) -> Tensor:
|
||||
temp_tensor = self.instance_norm(input_tensor)
|
||||
|
||||
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_modulation = source_scale * temp_tensor + source_shift
|
||||
|
||||
target_scale = self.conv1(target_attribute)
|
||||
target_shift = self.conv2(target_attribute)
|
||||
target_scale = self.conv1(target_feature)
|
||||
target_shift = self.conv2(target_feature)
|
||||
target_modulation = target_scale * temp_tensor + target_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
|
||||
@@ -3,7 +3,7 @@ from configparser import ConfigParser
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Attribute, Mask
|
||||
from ..types import Feature, Mask
|
||||
|
||||
|
||||
class MaskNet(nn.Module):
|
||||
@@ -34,8 +34,8 @@ class MaskNet(nn.Module):
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Mask:
|
||||
output_mask = torch.cat([ input_tensor, input_attribute ], dim = 1)
|
||||
def forward(self, input_tensor : Tensor, input_feature : Feature) -> Mask:
|
||||
output_mask = torch.cat([ input_tensor, input_feature ], dim = 1)
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
output_mask = down_sample(output_mask)
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Tuple
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import Feature
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
@@ -79,25 +81,25 @@ class UNet(nn.Module):
|
||||
|
||||
return up_samples
|
||||
|
||||
def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]:
|
||||
def forward(self, target_tensor : Tensor) -> Tuple[Feature, ...]:
|
||||
down_features = []
|
||||
up_features = []
|
||||
temp_tensor = target_tensor
|
||||
temp_feature = target_tensor
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
temp_tensor = down_sample(temp_tensor)
|
||||
down_features.append(temp_tensor)
|
||||
temp_feature = down_sample(temp_feature)
|
||||
down_features.append(temp_feature)
|
||||
|
||||
bottleneck_tensor = down_features[-1]
|
||||
temp_tensor = bottleneck_tensor
|
||||
bottleneck_feature = down_features[-1]
|
||||
temp_feature = bottleneck_feature
|
||||
|
||||
for index, up_sample in enumerate(self.up_samples):
|
||||
skip_tensor = down_features[-(index + 2)]
|
||||
temp_tensor = up_sample(temp_tensor, skip_tensor)
|
||||
up_features.append(temp_tensor)
|
||||
temp_feature = up_sample(temp_feature, skip_tensor)
|
||||
up_features.append(temp_feature)
|
||||
|
||||
output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_tensor, *up_features, output_tensor
|
||||
final_feature = nn.functional.interpolate(temp_feature, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_feature, *up_features, final_feature
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
|
||||
@@ -16,7 +16,7 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding, overlay_mask
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .models.loss import AdversarialLoss, DiscriminatorLoss, FeautureLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
|
||||
from .networks.masknet import MaskNet
|
||||
from .types import Batch, Embedding, Mask, OptimizerSet
|
||||
|
||||
@@ -45,7 +45,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.masker = MaskNet(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.attribute_loss = AttributeLoss(config_parser)
|
||||
self.feature_loss = FeautureLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder)
|
||||
self.identity_loss = IdentityLoss(config_parser, self.embedder)
|
||||
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
|
||||
@@ -55,9 +55,9 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
with torch.no_grad():
|
||||
output_tensor, target_attributes = self.generator(source_embedding, target_tensor)
|
||||
target_attribute = target_attributes[-1]
|
||||
output_mask = self.masker(target_tensor, target_attribute)
|
||||
output_tensor, target_features = self.generator(source_embedding, target_tensor)
|
||||
target_feature = target_features[-1]
|
||||
output_mask = self.masker(target_tensor, target_feature)
|
||||
|
||||
return output_tensor, output_mask
|
||||
|
||||
@@ -104,23 +104,23 @@ class FaceSwapperTrainer(LightningModule):
|
||||
generator_optimizer, discriminator_optimizer, masker_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
generator_output_tensor, generator_target_attributes = self.generator(source_embedding, target_tensor)
|
||||
generator_output_attributes = self.generator.encode_attributes(generator_output_tensor)
|
||||
generator_output_tensor, generator_target_features = self.generator(source_embedding, target_tensor)
|
||||
generator_output_features = self.generator.encode_features(generator_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss(generator_target_attributes, generator_output_attributes)
|
||||
feature_loss, weighted_feature_loss = self.feature_loss(generator_target_features, generator_output_features)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss(source_tensor, target_tensor, generator_output_tensor)
|
||||
identity_loss, weighted_identity_loss = self.identity_loss(generator_output_tensor, source_tensor)
|
||||
pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss = self.motion_loss(target_tensor, generator_output_tensor)
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
||||
generator_loss = weighted_adversarial_loss + weighted_feature_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss + weighted_gaze_loss + weighted_expression_loss
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
generator_output_attribute = generator_output_attributes[-1]
|
||||
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
|
||||
generator_output_feature = generator_output_features[-1]
|
||||
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_feature.detach())
|
||||
mask_loss = self.mask_loss(target_tensor, generator_output_mask)
|
||||
|
||||
self.toggle_optimizer(generator_optimizer)
|
||||
@@ -150,7 +150,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss', adversarial_loss)
|
||||
self.log('attribute_loss', attribute_loss)
|
||||
self.log('feature_loss', feature_loss)
|
||||
self.log('reconstruction_loss', reconstruction_loss)
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('pose_loss', pose_loss)
|
||||
|
||||
@@ -6,7 +6,7 @@ from torch.nn import Module
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same', 'different']
|
||||
|
||||
Attribute : TypeAlias = Tensor
|
||||
Feature : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
Mask : TypeAlias = Tensor
|
||||
Loss : TypeAlias = Tensor
|
||||
|
||||
@@ -22,14 +22,14 @@ def test_aad_with_unet(output_size : int) -> None:
|
||||
}
|
||||
})
|
||||
|
||||
generator = AAD(config_parser).eval()
|
||||
encoder = UNet(config_parser).eval()
|
||||
generator = AAD(config_parser).eval()
|
||||
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
|
||||
target_attributes = encoder(target_tensor)
|
||||
output_tensor = generator(source_tensor, target_attributes)
|
||||
target_features = encoder(target_tensor)
|
||||
output_tensor = generator(source_tensor, target_features)
|
||||
|
||||
assert output_tensor.shape == (1, 3, output_size, output_size)
|
||||
|
||||
@@ -50,8 +50,8 @@ def test_mask_net(output_size : int) -> None:
|
||||
masker = MaskNet(config_parser).eval()
|
||||
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
target_attribute = torch.randn(1, 64, output_size, output_size)
|
||||
target_feature = torch.randn(1, 64, output_size, output_size)
|
||||
|
||||
output_tensor = masker(target_tensor, target_attribute)
|
||||
output_tensor = masker(target_tensor, target_feature)
|
||||
|
||||
assert output_tensor.shape == (1, 1, output_size, output_size)
|
||||
|
||||
Reference in New Issue
Block a user