mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Move output channels to config
This commit is contained in:
@@ -48,8 +48,9 @@ motion_extractor_path = .models/motion_extractor.pt
|
||||
```
|
||||
[training.model.generator]
|
||||
encoder_type = unet-pro
|
||||
num_blocks = 2
|
||||
identity_channels = 512
|
||||
output_channels = 4096
|
||||
num_blocks = 2
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -14,8 +14,9 @@ motion_extractor_path =
|
||||
|
||||
[training.model.generator]
|
||||
encoder_type =
|
||||
num_blocks =
|
||||
identity_channels =
|
||||
output_channels =
|
||||
num_blocks =
|
||||
|
||||
[training.model.discriminator]
|
||||
input_channels =
|
||||
|
||||
@@ -2,7 +2,7 @@ import configparser
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.attribute_modulator import AADGenerator
|
||||
from ..networks.aienet import AADGenerator
|
||||
from ..networks.unet import UNet, UNetPro
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
@@ -14,14 +14,15 @@ class Generator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
encoder_type = CONFIG.get('training.model.generator', 'encoder_type')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
identity_channels = CONFIG.getint('training.model.generator', 'identity_channels')
|
||||
output_channels = CONFIG.getint('training.model.generator', 'output_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
if encoder_type == 'unet':
|
||||
self.encoder = UNet()
|
||||
if encoder_type == 'unet-pro':
|
||||
self.encoder = UNetPro()
|
||||
self.generator = AADGenerator(identity_channels, num_blocks)
|
||||
self.generator = AADGenerator(identity_channels, output_channels, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from ..types import Attributes, Embedding
|
||||
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
def __init__(self, identity_channels : int, num_blocks : int) -> None:
|
||||
def __init__(self, identity_channels : int, output_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.upsample = PixelShuffleUpsample(identity_channels, 1024 * 4)
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(identity_channels, output_channels)
|
||||
self.res_block_1 = AADResBlock(1024, 1024, 1024, identity_channels, num_blocks)
|
||||
self.res_block_2 = AADResBlock(1024, 1024, 2048, identity_channels, num_blocks)
|
||||
self.res_block_3 = AADResBlock(1024, 1024, 1024, identity_channels, num_blocks)
|
||||
@@ -18,7 +18,7 @@ class AADGenerator(nn.Module):
|
||||
self.res_block_8 = AADResBlock(64, 3, 64, identity_channels, num_blocks)
|
||||
|
||||
def forward(self, target_attributes : Attributes, source_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.upsample(source_embedding)
|
||||
feature_map = self.pixel_shuffle_up_sample(source_embedding)
|
||||
feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_3 = nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
@@ -32,7 +32,7 @@ class AADGenerator(nn.Module):
|
||||
|
||||
class AADLayer(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
super(AADLayer, self).__init__()
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.conv_beta = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv_gamma = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
@@ -109,10 +109,10 @@ class AADResBlock(nn.Module):
|
||||
return output_feature
|
||||
|
||||
|
||||
class PixelShuffleUpsample(nn.Module):
|
||||
class PixelShuffleUpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1)
|
||||
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 3, padding = 1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user