mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Extend Unet with more layers
This commit is contained in:
@@ -20,9 +20,9 @@ class Generator(nn.Module):
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
if encoder_type == 'unet':
|
||||
self.encoder = UNet()
|
||||
self.encoder = UNet(output_size)
|
||||
if encoder_type == 'unet-pro':
|
||||
self.encoder = UNetPro()
|
||||
self.encoder = UNetPro(output_size)
|
||||
self.generator = AAD(identity_channels, output_channels, output_size, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
@@ -23,18 +23,15 @@ class AAD(nn.Module):
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks)
|
||||
])
|
||||
|
||||
if self.output_size in [ 384, 512, 768, 1024 ]:
|
||||
layers.append(AdaptiveFeatureModulation(64, 32, 32, self.identity_channels, self.num_blocks))
|
||||
|
||||
if self.output_size in [ 512, 768, 1024 ]:
|
||||
layers.append(AdaptiveFeatureModulation(32, 16, 16, self.identity_channels, self.num_blocks))
|
||||
|
||||
if self.output_size in [ 768, 1024 ]:
|
||||
layers.append(AdaptiveFeatureModulation(16, 8, 8, self.identity_channels, self.num_blocks))
|
||||
|
||||
if self.output_size == 1024:
|
||||
layers.append(AdaptiveFeatureModulation(8, 4, 4, self.identity_channels, self.num_blocks))
|
||||
|
||||
|
||||
@@ -7,14 +7,14 @@ from torchvision.models import ResNet34_Weights
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, output_size : int) -> None:
|
||||
super().__init__()
|
||||
self.down_samples = self.create_down_samples(self)
|
||||
self.output_size = output_size
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples(self : nn.Module) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
def create_down_samples(self) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
[
|
||||
DownSample(3, 32),
|
||||
DownSample(32, 64),
|
||||
@@ -25,9 +25,19 @@ class UNet(nn.Module):
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def create_up_samples() -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
if self.output_size in [ 384, 512, 768, 1024 ]:
|
||||
down_samples.append(DownSample(1024, 2048))
|
||||
if self.output_size in [ 512, 768, 1024 ]:
|
||||
down_samples.append(DownSample(2048, 4096))
|
||||
if self.output_size in [ 768, 1024 ]:
|
||||
down_samples.append(DownSample(4096, 8192))
|
||||
if self.output_size == 1024:
|
||||
down_samples.append(DownSample(8192, 16384))
|
||||
|
||||
return down_samples
|
||||
|
||||
def create_up_samples(self) -> nn.ModuleList:
|
||||
up_samples = nn.ModuleList(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512),
|
||||
@@ -37,6 +47,17 @@ class UNet(nn.Module):
|
||||
UpSample(128, 32)
|
||||
])
|
||||
|
||||
if self.output_size in [ 384, 512, 768, 1024 ]:
|
||||
up_samples.append(UpSample(32, 16))
|
||||
if self.output_size in [ 512, 768, 1024 ]:
|
||||
up_samples.append(UpSample(16, 8))
|
||||
if self.output_size in [ 768, 1024 ]:
|
||||
up_samples.append(UpSample(8, 4))
|
||||
if self.output_size == 1024:
|
||||
up_samples.append(UpSample(4, 2))
|
||||
|
||||
return up_samples
|
||||
|
||||
def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]:
|
||||
down_features = []
|
||||
up_features = []
|
||||
@@ -62,12 +83,11 @@ class UNetPro(UNet):
|
||||
def __init__(self) -> None:
|
||||
super(UNet, self).__init__()
|
||||
self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT)
|
||||
self.down_samples = self.create_down_samples(self)
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@staticmethod
|
||||
def create_down_samples(self : nn.Module) -> nn.ModuleList:
|
||||
return nn.ModuleList(
|
||||
def create_down_samples(self) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
self.resnet.conv1,
|
||||
@@ -85,6 +105,17 @@ class UNetPro(UNet):
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.output_size in [ 384, 512, 768, 1024 ]:
|
||||
down_samples.append(DownSample(1024, 2048))
|
||||
if self.output_size in [ 512, 768, 1024 ]:
|
||||
down_samples.append(DownSample(2048, 4096))
|
||||
if self.output_size in [ 768, 1024 ]:
|
||||
down_samples.append(DownSample(4096, 8192))
|
||||
if self.output_size == 1024:
|
||||
down_samples.append(DownSample(8192, 16384))
|
||||
|
||||
return down_samples
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user