mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Merge pull request #57 from facefusion/fix-aad-naming
Fix AAD naming, Attribute vs. Embedding
This commit is contained in:
@@ -51,7 +51,7 @@ face_parser_path = .models/face_parser.pt
|
||||
|
||||
```
|
||||
[training.model.generator]
|
||||
identity_channels = 512
|
||||
source_channels = 512
|
||||
output_channels = 4096
|
||||
output_size = 256
|
||||
num_blocks = 2
|
||||
|
||||
@@ -17,7 +17,7 @@ motion_extractor_path =
|
||||
face_parser_path =
|
||||
|
||||
[training.model.generator]
|
||||
identity_channels =
|
||||
source_channels =
|
||||
output_channels =
|
||||
output_size =
|
||||
num_blocks =
|
||||
|
||||
@@ -10,11 +10,11 @@ from ..types import Attribute, Embedding
|
||||
class AAD(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config_identity_channels = config_parser.getint('training.model.generator', 'identity_channels')
|
||||
self.config_source_channels = config_parser.getint('training.model.generator', 'source_channels')
|
||||
self.config_output_channels = config_parser.getint('training.model.generator', 'output_channels')
|
||||
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
|
||||
self.config_num_blocks = config_parser.getint('training.model.generator', 'num_blocks')
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_identity_channels, self.config_output_channels)
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config_source_channels, self.config_output_channels)
|
||||
self.layers = self.create_layers()
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
@@ -23,36 +23,36 @@ class AAD(nn.Module):
|
||||
if self.config_output_size == 128:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_identity_channels, self.config_num_blocks)
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 256:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks)
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
if self.config_output_size == 512:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_identity_channels, self.config_num_blocks)
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config_identity_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config_identity_channels, self.config_num_blocks)
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config_source_channels, self.config_num_blocks),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config_source_channels, self.config_num_blocks)
|
||||
])
|
||||
|
||||
return layers
|
||||
@@ -61,21 +61,23 @@ class AAD(nn.Module):
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
|
||||
target_attribute = target_attributes[index]
|
||||
temp_tensor = layer(temp_tensors, source_embedding, target_attribute)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
|
||||
temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
|
||||
target_attribute = target_attributes[-1]
|
||||
temp_tensors = self.layers[-1](temp_tensors, source_embedding, target_attribute)
|
||||
output_tensor = torch.tanh(temp_tensors)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class AdaptiveFeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None:
|
||||
def __init__(self, input_channels : int, output_channels : int, target_channels : int, source_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.context_input_channels = input_channels
|
||||
self.context_output_channels = output_channels
|
||||
self.context_attribute_channels = attribute_channels
|
||||
self.context_identity_channels = identity_channels
|
||||
self.context_target_channels = target_channels
|
||||
self.context_source_channels = source_channels
|
||||
self.context_num_blocks = num_blocks
|
||||
self.primary_layers = self.create_primary_layers()
|
||||
self.shortcut_layers = self.create_shortcut_layers()
|
||||
@@ -86,7 +88,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
for index in range(self.context_num_blocks):
|
||||
primary_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels),
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(inplace = True)
|
||||
])
|
||||
|
||||
@@ -103,19 +105,19 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
if self.context_input_channels > self.context_output_channels:
|
||||
shortcut_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.context_input_channels, self.context_attribute_channels, self.context_identity_channels),
|
||||
FeatureModulation(self.context_input_channels, self.context_target_channels, self.context_source_channels),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(self.context_input_channels, self.context_output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
])
|
||||
|
||||
return shortcut_layers
|
||||
|
||||
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
|
||||
primary_tensor = input_tensor
|
||||
|
||||
for primary_layer in self.primary_layers:
|
||||
if isinstance(primary_layer, FeatureModulation):
|
||||
primary_tensor = primary_layer(primary_tensor, attribute_embedding, identity_embedding)
|
||||
primary_tensor = primary_layer(primary_tensor, source_embedding, target_attribute)
|
||||
else:
|
||||
primary_tensor = primary_layer(primary_tensor)
|
||||
|
||||
@@ -124,7 +126,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
|
||||
for shortcut_layer in self.shortcut_layers:
|
||||
if isinstance(shortcut_layer, FeatureModulation):
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor, attribute_embedding, identity_embedding)
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor, source_embedding, target_attribute)
|
||||
else:
|
||||
shortcut_tensor = shortcut_layer(shortcut_tensor)
|
||||
|
||||
@@ -134,29 +136,29 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
|
||||
|
||||
class FeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
def __init__(self, input_channels : int, target_channels : int, source_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.context_input_channels = input_channels
|
||||
self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv1 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(target_channels, input_channels, kernel_size = 1)
|
||||
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
self.linear1 = nn.Linear(identity_channels, input_channels)
|
||||
self.linear2 = nn.Linear(identity_channels, input_channels)
|
||||
self.linear1 = nn.Linear(source_channels, input_channels)
|
||||
self.linear2 = nn.Linear(source_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
|
||||
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
def forward(self, input_tensor : Tensor, source_embedding : Embedding, target_attribute : Attribute) -> Tensor:
|
||||
temp_tensor = self.instance_norm(input_tensor)
|
||||
|
||||
attribute_scale = self.conv1(attribute_embedding)
|
||||
attribute_shift = self.conv2(attribute_embedding)
|
||||
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
|
||||
source_scale = self.linear2(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_shift = self.linear1(source_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
source_modulation = source_scale * temp_tensor + source_shift
|
||||
|
||||
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context_input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_modulation = identity_scale * temp_tensor + identity_shift
|
||||
target_scale = self.conv1(target_attribute)
|
||||
target_shift = self.conv2(target_attribute)
|
||||
target_modulation = target_scale * temp_tensor + target_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
output_tensor = (1 - temp_mask) * attribute_modulation + temp_mask * identity_modulation
|
||||
output_tensor = (1 - temp_mask) * target_modulation + temp_mask * source_modulation
|
||||
return output_tensor
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_aad_with_unet(output_size : int) -> None:
|
||||
{
|
||||
'training.model.generator':
|
||||
{
|
||||
'identity_channels': '512',
|
||||
'source_channels': '512',
|
||||
'output_channels': str(output_size * 16),
|
||||
'output_size': str(output_size),
|
||||
'num_blocks': '2'
|
||||
|
||||
Reference in New Issue
Block a user