mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Modernize AIENet
This commit is contained in:
@@ -2,7 +2,7 @@ import configparser
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.aienet import AADGenerator
|
||||
from ..networks.aienet import AIENet
|
||||
from ..networks.unet import UNet, UNetPro
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
@@ -22,13 +22,13 @@ class Generator(nn.Module):
|
||||
self.encoder = UNet()
|
||||
if encoder_type == 'unet-pro':
|
||||
self.encoder = UNetPro()
|
||||
self.generator = AADGenerator(identity_channels, output_channels, num_blocks)
|
||||
self.generator = AIENet(identity_channels, output_channels, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
target_attributes = self.get_attributes(target_tensor)
|
||||
output_tensor = self.generator(target_attributes, source_embedding)
|
||||
output_tensor = self.generator(source_embedding, target_attributes)
|
||||
return output_tensor
|
||||
|
||||
def get_attributes(self, input_tensor : Tensor) -> Attributes:
|
||||
|
||||
@@ -4,68 +4,36 @@ from torch import Tensor, nn
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
class AIENet(nn.Module):
|
||||
def __init__(self, identity_channels : int, output_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(identity_channels, output_channels)
|
||||
self.res_block_1 = AADResBlock(1024, 1024, 1024, identity_channels, num_blocks)
|
||||
self.res_block_2 = AADResBlock(1024, 1024, 2048, identity_channels, num_blocks)
|
||||
self.res_block_3 = AADResBlock(1024, 1024, 1024, identity_channels, num_blocks)
|
||||
self.res_block_4 = AADResBlock(1024, 512, 512, identity_channels, num_blocks)
|
||||
self.res_block_5 = AADResBlock(512, 256, 256, identity_channels, num_blocks)
|
||||
self.res_block_6 = AADResBlock(256, 128, 128, identity_channels, num_blocks)
|
||||
self.res_block_7 = AADResBlock(128, 64, 64, identity_channels, num_blocks)
|
||||
self.res_block_8 = AADResBlock(64, 3, 64, identity_channels, num_blocks)
|
||||
self.layers = self.create_layers(identity_channels, num_blocks)
|
||||
|
||||
def forward(self, target_attributes : Attributes, source_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.pixel_shuffle_up_sample(source_embedding)
|
||||
feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_3 = nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_4 = nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_5 = nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_6 = nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_7 = nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding)
|
||||
return torch.tanh(output)
|
||||
@staticmethod
|
||||
def create_layers(identity_channels : int, num_blocks : int) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
AADResBlock(1024, 1024, 1024, identity_channels, num_blocks),
|
||||
AADResBlock(1024, 1024, 2048, identity_channels, num_blocks),
|
||||
AADResBlock(1024, 1024, 1024, identity_channels, num_blocks),
|
||||
AADResBlock(1024, 512, 512, identity_channels, num_blocks),
|
||||
AADResBlock(512, 256, 256, identity_channels, num_blocks),
|
||||
AADResBlock(256, 128, 128, identity_channels, num_blocks),
|
||||
AADResBlock(128, 64, 64, identity_channels, num_blocks),
|
||||
AADResBlock(64, 3, 64, identity_channels, num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor:
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.fc_beta = nn.Linear(identity_channels, input_channels)
|
||||
self.fc_gamma = nn.Linear(identity_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
gamma_attribute = self.conv_gamma(attribute_embedding)
|
||||
beta_attribute = self.conv_beta(attribute_embedding)
|
||||
attribute_modulation = gamma_attribute * feature_map + beta_attribute
|
||||
identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
identity_modulation = identity_gamma * feature_map + identity_beta
|
||||
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
|
||||
feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation
|
||||
return feature_blend
|
||||
|
||||
|
||||
class AADSequential(nn.Module):
|
||||
def __init__(self, *args : nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(args)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, AADLayer):
|
||||
feature_map = layer(feature_map, attribute_embedding, identity_embedding)
|
||||
else:
|
||||
feature_map = layer(feature_map)
|
||||
return feature_map
|
||||
temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
|
||||
output_tensor = torch.tanh(temp_tensors)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class AADResBlock(nn.Module):
|
||||
@@ -109,6 +77,44 @@ class AADResBlock(nn.Module):
|
||||
return output_feature
|
||||
|
||||
|
||||
class AADSequential(nn.Module):
|
||||
def __init__(self, *args : nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(args)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, AADLayer):
|
||||
feature_map = layer(feature_map, attribute_embedding, identity_embedding)
|
||||
else:
|
||||
feature_map = layer(feature_map)
|
||||
return feature_map
|
||||
|
||||
|
||||
class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.fc_beta = nn.Linear(identity_channels, input_channels)
|
||||
self.fc_gamma = nn.Linear(identity_channels, input_channels)
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
gamma_attribute = self.conv_gamma(attribute_embedding)
|
||||
beta_attribute = self.conv_beta(attribute_embedding)
|
||||
attribute_modulation = gamma_attribute * feature_map + beta_attribute
|
||||
identity_gamma = self.fc_gamma(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
identity_beta = self.fc_beta(identity_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map)
|
||||
identity_modulation = identity_gamma * feature_map + identity_beta
|
||||
feature_mask = torch.sigmoid(self.conv_mask(feature_map))
|
||||
feature_blend = (1 - feature_mask) * attribute_modulation + feature_mask * identity_modulation
|
||||
return feature_blend
|
||||
|
||||
|
||||
class PixelShuffleUpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user