mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Remove UnetPro, make float values visible in README
This commit is contained in:
@@ -16,7 +16,6 @@ gazer_path =
|
||||
motion_extractor_path =
|
||||
|
||||
[training.model.generator]
|
||||
encoder_type =
|
||||
identity_channels =
|
||||
output_channels =
|
||||
output_size =
|
||||
|
||||
@@ -3,7 +3,7 @@ import configparser
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.aad import AAD
|
||||
from ..networks.unet import UNet, UNetPro
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -13,16 +13,12 @@ CONFIG.read('config.ini')
|
||||
class Generator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
encoder_type = CONFIG.get('training.model.generator', 'encoder_type')
|
||||
identity_channels = CONFIG.getint('training.model.generator', 'identity_channels')
|
||||
output_channels = CONFIG.getint('training.model.generator', 'output_channels')
|
||||
output_size = CONFIG.getint('training.model.generator', 'output_size')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
if encoder_type == 'unet':
|
||||
self.encoder = UNet(output_size)
|
||||
if encoder_type == 'unet-pro':
|
||||
self.encoder = UNetPro(output_size)
|
||||
self.encoder = UNet(output_size)
|
||||
self.generator = AAD(identity_channels, output_channels, output_size, num_blocks)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
@@ -169,7 +169,7 @@ class GazeLoss(nn.Module):
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
scale_factor = CONFIG.getint('training.losses', 'gaze_scale_factor')
|
||||
scale_factor = CONFIG.getfloat('training.losses', 'gaze_scale_factor')
|
||||
y_min = int(60 * scale_factor)
|
||||
y_max = int(224 * scale_factor)
|
||||
x_min = int(16 * scale_factor)
|
||||
|
||||
@@ -2,9 +2,6 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torchvision import models
|
||||
from torchvision.models import ResNet34_Weights
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, output_size : int) -> None:
|
||||
@@ -88,38 +85,6 @@ class UNet(nn.Module):
|
||||
return bottleneck_tensor, *up_features, output_tensor
|
||||
|
||||
|
||||
class UNetPro(UNet):
|
||||
def __init__(self, output_size : int) -> None:
|
||||
super().__init__(output_size)
|
||||
self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT)
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
def create_down_samples(self) -> nn.ModuleList:
|
||||
down_samples = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
self.resnet.conv1,
|
||||
self.resnet.bn1,
|
||||
self.resnet.relu,
|
||||
nn.Conv2d(64, 32, kernel_size = 1, bias = False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.LeakyReLU(0.1, inplace = True)
|
||||
),
|
||||
DownSample(32, 64),
|
||||
self.resnet.layer2,
|
||||
self.resnet.layer3,
|
||||
self.resnet.layer4,
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.output_size == 512:
|
||||
down_samples.append(DownSample(2048, 2048))
|
||||
|
||||
return down_samples
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -80,7 +80,7 @@ class FaceSwapperTrainer(lightning.LightningModule):
|
||||
return generator_config, discriminator_config
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
preview_frequency = CONFIG.getfloat('training.trainer', 'preview_frequency')
|
||||
preview_frequency = CONFIG.getint('training.trainer', 'preview_frequency')
|
||||
|
||||
source_tensor, target_tensor = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user