mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Introduce feature flag for Unet
This commit is contained in:
@@ -47,6 +47,7 @@ motion_extractor_path = .models/motion_extractor.pt
|
||||
|
||||
```
|
||||
[training.model.generator]
|
||||
encoder_type = unet-pro
|
||||
num_blocks = 2
|
||||
id_channels = 512
|
||||
```
|
||||
|
||||
@@ -3,7 +3,7 @@ import configparser
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.attribute_modulator import AADGenerator
|
||||
from ..networks.unet import UNet
|
||||
from ..networks.unet import UNet, UNetPro
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -13,21 +13,25 @@ CONFIG.read('config.ini')
|
||||
class Generator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(Generator, self).__init__()
|
||||
encoder_type = CONFIG.getint('training.model.generator', 'encoder_type')
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
self.unet = UNet()
|
||||
self.aad_generator = AADGenerator(id_channels, num_blocks)
|
||||
self.unet.apply(init_weight)
|
||||
self.aad_generator.apply(init_weight)
|
||||
if encoder_type == 'unet':
|
||||
self.encoder = UNet()
|
||||
if encoder_type == 'unet-pro':
|
||||
self.encoder = UNetPro()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
target_attributes = self.get_attributes(target_tensor)
|
||||
output_tensor = self.aad_generator(target_attributes, source_embedding)
|
||||
output_tensor = self.generator(target_attributes, source_embedding)
|
||||
return output_tensor
|
||||
|
||||
def get_attributes(self, input_tensor : Tensor) -> Attributes:
|
||||
return self.unet(input_tensor)
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
|
||||
@@ -8,7 +8,6 @@ from torchvision import models
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.resnet = models.resnet34(pretrained = True)
|
||||
self.down_samples = self.create_down_samples(self)
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@@ -16,18 +15,11 @@ class UNet(nn.Module):
|
||||
def create_down_samples(self : nn.Module) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
self.resnet.conv1,
|
||||
self.resnet.bn1,
|
||||
self.resnet.relu,
|
||||
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
),
|
||||
DownSample(3, 32),
|
||||
DownSample(32, 64),
|
||||
self.resnet.layer2,
|
||||
self.resnet.layer3,
|
||||
self.resnet.layer4,
|
||||
DownSample(64, 128),
|
||||
DownSample(128, 256),
|
||||
DownSample(256, 512),
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
@@ -65,6 +57,34 @@ class UNet(nn.Module):
|
||||
return bottleneck_tensor, *up_features, output_tensor
|
||||
|
||||
|
||||
class UNetPro(UNet):
|
||||
def __init__(self) -> None:
|
||||
super(UNetPro, self).__init__()
|
||||
self.resnet = models.resnet34(pretrained = True)
|
||||
self.down_samples = self.create_down_samples(self)
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples(self : nn.Module) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
self.resnet.conv1,
|
||||
self.resnet.bn1,
|
||||
self.resnet.relu,
|
||||
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
),
|
||||
DownSample(32, 64),
|
||||
self.resnet.layer2,
|
||||
self.resnet.layer3,
|
||||
self.resnet.layer4,
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super(UpSample, self).__init__()
|
||||
|
||||
Reference in New Issue
Block a user