mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Refacto UNet
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Tuple
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.attribute_modulator import AADGenerator
|
||||
from face_swapper.src.networks.encoder import UNet
|
||||
from face_swapper.src.networks.unet import UNet
|
||||
from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -17,9 +17,9 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
self.encoder = UNet()
|
||||
self.unet = UNet()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.unet.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
@@ -28,7 +28,7 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
return swap_tensor, target_attributes
|
||||
|
||||
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
|
||||
return self.encoder(target)
|
||||
return self.unet(target)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import TargetAttributes, VisionTensor
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(Upsample, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.deconv(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
return torch.cat((temp_tensor, skip_tensor), dim = 1)
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(DownSample, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
return temp_tensor
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.downsampler_1 = DownSample(3, 32)
|
||||
self.downsampler_2 = DownSample(32, 64)
|
||||
self.downsampler_3 = DownSample(64, 128)
|
||||
self.downsampler_4 = DownSample(128, 256)
|
||||
self.downsampler_5 = DownSample(256, 512)
|
||||
self.downsampler_6 = DownSample(512, 1024)
|
||||
self.bottleneck = DownSample(1024, 1024)
|
||||
self.upsampler_1 = Upsample(1024, 1024)
|
||||
self.upsampler_2 = Upsample(2048, 512)
|
||||
self.upsampler_3 = Upsample(1024, 256)
|
||||
self.upsampler_4 = Upsample(512, 128)
|
||||
self.upsampler_5 = Upsample(256, 64)
|
||||
self.upsampler_6 = Upsample(128, 32)
|
||||
|
||||
def forward(self, target : VisionTensor) -> TargetAttributes:
|
||||
downsample_feature_1 = self.downsampler_1(target)
|
||||
downsample_feature_2 = self.downsampler_2(downsample_feature_1)
|
||||
downsample_feature_3 = self.downsampler_3(downsample_feature_2)
|
||||
downsample_feature_4 = self.downsampler_4(downsample_feature_3)
|
||||
downsample_feature_5 = self.downsampler_5(downsample_feature_4)
|
||||
downsample_feature_6 = self.downsampler_6(downsample_feature_5)
|
||||
bottleneck_output = self.bottleneck(downsample_feature_6)
|
||||
upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6)
|
||||
upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5)
|
||||
upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4)
|
||||
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
|
||||
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
|
||||
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
|
||||
output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
|
||||
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import TargetAttributes
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.down = self.create_down()
|
||||
self.up = self.create_up()
|
||||
|
||||
@staticmethod
|
||||
def create_down():
|
||||
return nn.ModuleList(
|
||||
[
|
||||
Down(3, 32),
|
||||
Down(32, 64),
|
||||
Down(64, 128),
|
||||
Down(128, 256),
|
||||
Down(256, 512),
|
||||
Down(512, 1024),
|
||||
Down(1024, 1024)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up():
|
||||
return nn.ModuleList(
|
||||
[
|
||||
Up(1024, 1024),
|
||||
Up(2048, 512),
|
||||
Up(1024, 256),
|
||||
Up(512, 128),
|
||||
Up(256, 64),
|
||||
Up(128, 32)
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor) -> TargetAttributes:
|
||||
down_features = []
|
||||
up_features = []
|
||||
temp_tensor = target_tensor
|
||||
|
||||
for down in self.down:
|
||||
temp_tensor = down(temp_tensor)
|
||||
down_features.append(temp_tensor)
|
||||
|
||||
bottleneck_tensor = down_features[-1]
|
||||
temp_tensor = bottleneck_tensor
|
||||
|
||||
for index, up in enumerate(self.up):
|
||||
down_index = -(index + 2)
|
||||
up_feature = up(temp_tensor, down_features[down_index])
|
||||
up_features.append(up_feature)
|
||||
|
||||
output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_tensor, *up_features, output_tensor
|
||||
|
||||
|
||||
class Up(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(Up, self).__init__()
|
||||
self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv_transpose(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
temp_tensor = torch.cat((temp_tensor, skip_tensor), dim = 1)
|
||||
return temp_tensor
|
||||
|
||||
|
||||
class Down(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(Down, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
return temp_tensor
|
||||
Reference in New Issue
Block a user