Merge pull request #57 from facefusion/fix-aad-naming

Fix AAD naming, Attribute vs. Embedding
This commit is contained in:
Henry Ruhs
2025-03-12 18:07:46 +01:00
committed by GitHub
4 changed files with 46 additions and 44 deletions
+1 -1
View File
@@ -51,7 +51,7 @@ face_parser_path = .models/face_parser.pt
```
[training.model.generator]
identity_channels = 512
source_channels = 512
output_channels = 4096
output_size = 256
num_blocks = 2
+1 -1
View File
@@ -17,7 +17,7 @@ motion_extractor_path =
face_parser_path =
[training.model.generator]
identity_channels =
source_channels =
output_channels =
output_size =
num_blocks =
+43 -41
View File
@@ -10,11 +10,11 @@ from ..types import Attribute, Embedding
class AAD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config_identity_channels = config_parser.getint('training.model.generator', 'identity_channels')
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_identity_channels, self.config_output_channels)
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels)
self.layers = self.create_layers()
def create_layers(self) -> nn.ModuleList:
@@ -23,36 +23,36 @@ class AAD(nn.Module):
if self.config_output_size == 128:
layers.extend(
[
AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 1024, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks)
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 256:
layers.extend(
[
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks)
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
if self.config_output_size == 512:
layers.extend(
[
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks)
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
])
layers.extend(
[
AdaptiveFeatureModulation(512, 256, 256, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.config_identity_channels, self.config_num_blocks),
AdaptiveFeatureModulation(64, 3, 64, self.config_identity_channels, self.config_num_blocks)
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks),
AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks)
])
return layers
@@ -61,21 +61,23 @@ class AAD(nn.Module):
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
target_attribute = target_attributes[index]
temp_tensor = layer(temp_tensors, source_embedding, target_attribute)
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
target_attribute = target_attributes[-1]
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_attribute)
output_tensor = torch.tanh(temp_tensors)
return output_tensor
class AdaptiveFeatureModulation(nn.Module):
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None:
def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None:
super().__init__()
self.context_input_channels = input_channels
self.context_output_channels = output_channels
self.context_attribute_channels = attribute_channels
self.context_identity_channels = identity_channels
self.context_target_channels = target_channels
self.context_source_channels = source_channels
self.context_num_blocks = num_blocks
self.primary_layers = self.create_primary_layers()
self.shortcut_layers = self.create_shortcut_layers()
@@ -86,7 +88,7 @@ class AdaptiveFeatureModulation(nn.Module):
for index in range(self.context_num_blocks):
primary_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels),
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(inplace = True)
])
@@ -103,19 +105,19 @@ class AdaptiveFeatureModulation(nn.Module):
if self.context_input_channels > self.context_output_channels:
shortcut_layers.extend(
[
FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels),
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
nn.ReLU(inplace = True),
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
])
return shortcut_layers
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
primary_tensor = input_tensor
for primary_layer in self.primary_layers:
if isinstance(primary_layer, FeatureModulation):
primary_tensor = primary_layer(primary_tensor, attribute_embedding, identity_embedding)
primary_tensor = primary_layer(primary_tensor, source_embedding, target_attribute)
else:
primary_tensor = primary_layer(primary_tensor)
@@ -124,7 +126,7 @@ class AdaptiveFeatureModulation(nn.Module):
for shortcut_layer in self.shortcut_layers:
if isinstance(shortcut_layer, FeatureModulation):
shortcut_tensor = shortcut_layer(shortcut_tensor, attribute_embedding, identity_embedding)
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_attribute)
else:
shortcut_tensor = shortcut_layer(shortcut_tensor)
@@ -134,29 +136,29 @@ class AdaptiveFeatureModulation(nn.Module):
class FeatureModulation(nn.Module):
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None:
super().__init__()
self.context_input_channels = input_channels
self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
self.linear1 = nn.Linear(identity_channels, input_channels)
self.linear2 = nn.Linear(identity_channels, input_channels)
self.linear1 = nn.Linear(source_channels, input_channels)
self.linear2 = nn.Linear(source_channels, input_channels)
self.instance_norm = nn.InstanceNorm2d(input_channels)
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
attribute_scale = self.conv1(attribute_embedding)
attribute_shift = self.conv2(attribute_embedding)
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
source_modulation = source_scale * temp_tensor + source_shift
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
identity_modulation = identity_scale * temp_tensor + identity_shift
target_scale = self.conv1(target_attribute)
target_shift = self.conv2(target_attribute)
target_modulation = target_scale * temp_tensor + target_shift
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
output_tensor = (1 - temp_mask) * attribute_modulation + temp_mask * identity_modulation
output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation
return output_tensor
+1 -1
View File
@@ -15,7 +15,7 @@ def test_aad_with_unet(output_size : int) -> None:
{
'training.model.generator':
{
'identity_channels': '512',
'source_channels': '512',
'output_channels': str(output_size * 16),
'output_size': str(output_size),
'num_blocks': '2'