mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Rework on config
This commit is contained in:
+24
-15
@@ -1,6 +1,8 @@
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from configparser import ConfigParser
|
||||
from typing import cast
|
||||
|
||||
import albumentations
|
||||
from torch import Tensor
|
||||
@@ -8,25 +10,29 @@ from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch, BatchMode, WarpTemplate
|
||||
from .types import Batch, WarpTemplate, BatchMode
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str, warp_template : WarpTemplate, transform_size : int, batch_mode : BatchMode, batch_ratio : float) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.warp_template = warp_template
|
||||
self.transform_size = transform_size
|
||||
self.batch_mode = batch_mode
|
||||
self.batch_ratio = batch_ratio
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config =\
|
||||
{
|
||||
'file_pattern': config_parser.get('training.dataset', 'file_pattern'),
|
||||
'transform_size': config_parser.get('training.dataset', 'transform_size'),
|
||||
'batch_mode': cast(BatchMode, config_parser.get('training.dataset', 'batch_mode')),
|
||||
'batch_ratio': config_parser.getfloat('training.dataset', 'batch_ratio'),
|
||||
}
|
||||
self.config_parser = config_parser
|
||||
self.file_paths = glob.glob(self.config.get('file_pattern'))
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
file_path = self.file_paths[index]
|
||||
|
||||
if random.random() < self.batch_ratio:
|
||||
if self.batch_mode == 'equal':
|
||||
if random.random() < self.config.get('batch_ratio'):
|
||||
if self.config.get('batch_mode') == 'equal':
|
||||
return self.prepare_equal_batch(file_path)
|
||||
if self.batch_mode == 'same':
|
||||
if self.config.get('batch_mode') == 'same':
|
||||
return self.prepare_same_batch(file_path)
|
||||
|
||||
return self.prepare_different_batch(file_path)
|
||||
@@ -39,9 +45,9 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((self.transform_size, self.transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.Resize((self.config.get('transform_size'), self.config.get('transform_size')), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
WarpTransform(self.warp_template),
|
||||
WarpTransform(self.config_parser),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
@@ -92,9 +98,12 @@ class AugmentTransform:
|
||||
|
||||
|
||||
class WarpTransform:
|
||||
def __init__(self, warp_template : WarpTemplate) -> None:
|
||||
self.warp_template = warp_template
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config =\
|
||||
{
|
||||
'warp_template': cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
|
||||
}
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return warp_tensor(temp_tensor, self.warp_template).squeeze(0)
|
||||
return warp_tensor(temp_tensor, self.config.get('warp_template')).squeeze(0)
|
||||
|
||||
@@ -1,26 +1,30 @@
|
||||
import configparser
|
||||
from os import makedirs
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
directory_path = CONFIG.get('exporting', 'directory_path')
|
||||
source_path = CONFIG.get('exporting', 'source_path')
|
||||
target_path = CONFIG.get('exporting', 'target_path')
|
||||
target_size = CONFIG.getint('exporting', 'target_size')
|
||||
ir_version = CONFIG.getint('exporting', 'ir_version')
|
||||
opset_version = CONFIG.getint('exporting', 'opset_version')
|
||||
config =\
|
||||
{
|
||||
'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'),
|
||||
'source_path': CONFIG_PARSER.get('exporting', 'source_path'),
|
||||
'target_path': CONFIG_PARSER.get('exporting', 'target_path'),
|
||||
'target_size': CONFIG_PARSER.getint('exporting', 'target_size'),
|
||||
'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'),
|
||||
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
}
|
||||
|
||||
makedirs(directory_path, exist_ok = True)
|
||||
model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
|
||||
os.makedirs(config.get('directory_path'), exist_ok = True)
|
||||
model = FaceSwapperTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
|
||||
model.eval()
|
||||
model.ir_version = torch.tensor(ir_version)
|
||||
model.ir_version = torch.tensor(config.get('ir_version'))
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, target_size, target_size)
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
target_tensor = torch.randn(1, 3, config.get('target_size'), config.get('target_size'))
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), config.get('target_path'), input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = config.get('opset_version'))
|
||||
|
||||
@@ -6,26 +6,29 @@ from torchvision import io
|
||||
from .helper import calc_embedding
|
||||
from .models.generator import Generator
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def infer() -> None:
|
||||
generator_path = CONFIG.get('inferencing', 'generator_path')
|
||||
embedder_path = CONFIG.get('inferencing', 'embedder_path')
|
||||
source_path = CONFIG.get('inferencing', 'source_path')
|
||||
target_path = CONFIG.get('inferencing', 'target_path')
|
||||
output_path = CONFIG.get('inferencing', 'output_path')
|
||||
config =\
|
||||
{
|
||||
'generator_path': CONFIG_PARSER.get('inferencing', 'generator_path'),
|
||||
'embedder_path': CONFIG_PARSER.get('inferencing', 'embedder_path'),
|
||||
'source_path': CONFIG_PARSER.get('inferencing', 'source_path'),
|
||||
'target_path': CONFIG_PARSER.get('inferencing', 'target_path'),
|
||||
'output_path': CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
}
|
||||
|
||||
state_dict = torch.load(generator_path).get('state_dict').get('generator')
|
||||
generator = Generator()
|
||||
state_dict = torch.load(config.get('generator_path')).get('state_dict').get('generator')
|
||||
generator = Generator(CONFIG_PARSER)
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
embedder.eval()
|
||||
|
||||
source_tensor = io.read_image(source_path)
|
||||
target_tensor = io.read_image(target_path)
|
||||
source_tensor = io.read_image(config.get('source_path'))
|
||||
target_tensor = io.read_image(config.get('target_path'))
|
||||
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = generator(source_embedding, target_tensor)[0]
|
||||
io.write_jpeg(output_tensor, output_path)
|
||||
io.write_jpeg(output_tensor, config.get('output_path'))
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
import configparser
|
||||
from configparser import ConfigParser
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.nld import NLD
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators')
|
||||
}
|
||||
self.config_parser = config_parser
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
|
||||
@staticmethod
|
||||
def create_discriminators() -> nn.ModuleList:
|
||||
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
|
||||
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
|
||||
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
|
||||
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
|
||||
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
|
||||
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
for _ in range(num_discriminators):
|
||||
discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).sequences
|
||||
for _ in range(self.config.get('num_discriminators')):
|
||||
discriminator = NLD(self.config_parser).sequences
|
||||
discriminators.append(discriminator)
|
||||
|
||||
return discriminators
|
||||
@@ -35,7 +32,8 @@ class Discriminator(nn.Module):
|
||||
output_tensors = []
|
||||
|
||||
for discriminator in self.discriminators:
|
||||
output_tensors.append(discriminator(temp_tensor))
|
||||
output_tensor = discriminator(temp_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
temp_tensor = self.avg_pool(temp_tensor)
|
||||
|
||||
return output_tensors
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import configparser
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -6,20 +6,12 @@ from ..networks.aad import AAD
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
identity_channels = CONFIG.getint('training.model.generator', 'identity_channels')
|
||||
output_channels = CONFIG.getint('training.model.generator', 'output_channels')
|
||||
output_size = CONFIG.getint('training.model.generator', 'output_size')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
self.encoder = UNet(output_size)
|
||||
self.generator = AAD(identity_channels, output_channels, output_size, num_blocks)
|
||||
self.encoder = UNet(config_parser)
|
||||
self.generator = AAD(config_parser)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import configparser
|
||||
from configparser import ConfigParser
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
@@ -9,9 +9,6 @@ from torchvision import transforms
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@@ -36,11 +33,14 @@ class DiscriminatorLoss(nn.Module):
|
||||
|
||||
|
||||
class AdversarialLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'adversarial_weight': config_parser.getfloat('training.losses', 'adversarial_weight')
|
||||
}
|
||||
|
||||
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
|
||||
adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight')
|
||||
temp_tensors = []
|
||||
|
||||
for discriminator_output_tensor in discriminator_output_tensors:
|
||||
@@ -48,7 +48,7 @@ class AdversarialLoss(nn.Module):
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
adversarial_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_adversarial_loss = adversarial_loss * adversarial_weight
|
||||
weighted_adversarial_loss = adversarial_loss * self.config.get('adversarial_weight')
|
||||
return adversarial_loss, weighted_adversarial_loss
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,37 @@
|
||||
import math
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class NLD(nn.Module):
|
||||
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.num_filters = num_filters
|
||||
self.num_layers = num_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.config =\
|
||||
{
|
||||
'input_channels': config_parser.getint('training.model.discriminator', 'input_channels'),
|
||||
'num_filters': config_parser.getint('training.model.discriminator', 'num_filters'),
|
||||
'kernel_size': config_parser.getint('training.model.discriminator', 'kernel_size'),
|
||||
'num_layers': config_parser.getint('training.model.discriminator', 'num_layers')
|
||||
}
|
||||
self.layers = self.create_layers()
|
||||
self.sequences = nn.Sequential(*self.layers)
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
padding = math.ceil((self.kernel_size - 1) / 2)
|
||||
current_filters = self.num_filters
|
||||
padding = math.ceil((self.config.get('kernel_size') - 1) / 2)
|
||||
current_filters = self.config.get('num_filters')
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Conv2d(self.input_channels, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding),
|
||||
nn.Conv2d(self.config.get('input_channels'), current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
])
|
||||
|
||||
for _ in range(1, self.num_layers):
|
||||
for _ in range(1, self.config.get('num_layers')):
|
||||
previous_filters = current_filters
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding),
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
@@ -36,10 +40,10 @@ class NLD(nn.Module):
|
||||
current_filters = min(current_filters * 2, 512)
|
||||
layers +=\
|
||||
[
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, padding = padding),
|
||||
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), padding = padding),
|
||||
nn.InstanceNorm2d(current_filters),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.kernel_size, padding = padding)
|
||||
nn.Conv2d(current_filters, 1, kernel_size = self.config.get('kernel_size'), padding = padding)
|
||||
]
|
||||
return layers
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@@ -5,9 +6,12 @@ from torch import Tensor, nn
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, output_size : int) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.output_size = output_size
|
||||
self.config =\
|
||||
{
|
||||
'output_size': config_parser.getint('training.model.generator', 'output_size')
|
||||
}
|
||||
self.down_samples = self.create_down_samples()
|
||||
self.up_samples = self.create_up_samples()
|
||||
|
||||
@@ -21,20 +25,20 @@ class UNet(nn.Module):
|
||||
DownSample(256, 512)
|
||||
])
|
||||
|
||||
if self.output_size == 128:
|
||||
if self.config.get('output_size') == 128:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 512)
|
||||
])
|
||||
|
||||
if self.output_size == 256:
|
||||
if self.config.get('output_size') == 256:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
DownSample(1024, 1024)
|
||||
])
|
||||
|
||||
if self.output_size == 512:
|
||||
if self.config.get('output_size') == 512:
|
||||
down_samples.extend(
|
||||
[
|
||||
DownSample(512, 1024),
|
||||
@@ -47,20 +51,20 @@ class UNet(nn.Module):
|
||||
def create_up_samples(self) -> nn.ModuleList:
|
||||
up_samples = nn.ModuleList()
|
||||
|
||||
if self.output_size == 128:
|
||||
if self.config.get('output_size') == 128:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(512, 512)
|
||||
])
|
||||
|
||||
if self.output_size == 256:
|
||||
if self.config.get('output_size') == 256:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(1024, 1024),
|
||||
UpSample(2048, 512)
|
||||
])
|
||||
|
||||
if self.output_size == 512:
|
||||
if self.config.get('output_size') == 512:
|
||||
up_samples.extend(
|
||||
[
|
||||
UpSample(2048, 2048),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import configparser
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
import warnings
|
||||
from typing import Tuple, cast
|
||||
|
||||
@@ -21,30 +21,34 @@ from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperTrainer(LightningModule):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
gazer_path = CONFIG.get('training.model', 'gazer_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.config =\
|
||||
{
|
||||
'embedder_path': config_parser.get('training.model', 'embedder_path'),
|
||||
'gazer_path': config_parser.get('training.model', 'gazer_path'),
|
||||
'motion_extractor_path': config_parser.get('training.model', 'motion_extractor_path'),
|
||||
'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate'),
|
||||
'preview_frequency': config_parser.getint('training.trainer', 'preview_frequency')
|
||||
}
|
||||
self.embedder = torch.jit.load(self.config.get('embedder_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.gazer = torch.jit.load(self.config.get('gazer_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(self.config.get('motion_extractor_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.gazer = torch.jit.load(gazer_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss()
|
||||
self.attribute_loss = AttributeLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss(self.embedder)
|
||||
self.identity_loss = IdentityLoss(self.embedder)
|
||||
self.motion_loss = MotionLoss(self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(self.gazer)
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss(config_parser)
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.attribute_loss = AttributeLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder)
|
||||
self.identity_loss = IdentityLoss(config_parser, self.embedder)
|
||||
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
|
||||
self.gaze_loss = GazeLoss(config_parser, self.gazer)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
@@ -52,9 +56,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
return output_tensor
|
||||
|
||||
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
|
||||
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
|
||||
|
||||
@@ -79,8 +82,6 @@ class FaceSwapperTrainer(LightningModule):
|
||||
return generator_config, discriminator_config
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
preview_frequency = CONFIG.getint('training.trainer', 'preview_frequency')
|
||||
|
||||
source_tensor, target_tensor = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
@@ -113,7 +114,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
discriminator_optimizer.step()
|
||||
self.untoggle_optimizer(discriminator_optimizer)
|
||||
|
||||
if self.global_step % preview_frequency == 0:
|
||||
if self.global_step % self.config.get('preview_frequency') == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
@@ -149,42 +150,52 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
config =\
|
||||
{
|
||||
'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'),
|
||||
'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
}
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
|
||||
training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
config =\
|
||||
{
|
||||
'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
}
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
training_size = int(dataset_size * split_ratio)
|
||||
training_size = int(dataset_size * config.get('split_ratio'))
|
||||
validation_size = int(dataset_size - training_size)
|
||||
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
|
||||
output_directory_path = CONFIG.get('training.output', 'directory_path')
|
||||
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
|
||||
trainer_precision = CONFIG.get('training.trainer', 'precision')
|
||||
config =\
|
||||
{
|
||||
'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'),
|
||||
'precision': CONFIG_PARSER.get('training.trainer', 'precision'),
|
||||
'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'),
|
||||
'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
}
|
||||
logger = TensorBoardLogger('.logs', name = 'face_swapper')
|
||||
|
||||
return Trainer(
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = trainer_precision, # type:ignore[arg-type]
|
||||
max_epochs = config.get('max_epochs'),
|
||||
precision = config.get('precision'),
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
monitor = 'generator_loss',
|
||||
dirpath = output_directory_path,
|
||||
filename = output_file_pattern,
|
||||
dirpath = config.get('directory_path'),
|
||||
filename = config.get('file_pattern'),
|
||||
every_n_train_steps = 1000,
|
||||
save_top_k = 3,
|
||||
save_last = True
|
||||
@@ -195,22 +206,20 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
|
||||
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
|
||||
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
output_size = CONFIG.getint('training.model.generator', 'output_size')
|
||||
config =\
|
||||
{
|
||||
'resume_path': CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, output_size, dataset_batch_mode, dataset_batch_ratio)
|
||||
dataset = DynamicDataset(CONFIG_PARSER)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.isfile(output_resume_path):
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = output_resume_path)
|
||||
if os.path.isfile(config.get('resume_path')):
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
|
||||
else:
|
||||
trainer.fit(face_swapper_trainer, training_loader, validation_loader)
|
||||
|
||||
Reference in New Issue
Block a user