From ab3b69912420558483615ae0bd529b02bf2daed3 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 6 Mar 2025 16:32:07 +0100 Subject: [PATCH] Rework on config --- face_swapper/src/dataset.py | 39 ++++---- face_swapper/src/exporting.py | 34 +++---- face_swapper/src/inferencing.py | 27 +++--- face_swapper/src/models/discriminator.py | 28 +++--- face_swapper/src/models/generator.py | 16 +--- face_swapper/src/models/loss.py | 14 +-- face_swapper/src/networks/nld.py | 28 +++--- face_swapper/src/networks/unet.py | 20 +++-- face_swapper/src/training.py | 109 ++++++++++++----------- 9 files changed, 169 insertions(+), 146 deletions(-) diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index 393a934..99232be 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -1,6 +1,8 @@ import glob import os import random +from configparser import ConfigParser +from typing import cast import albumentations from torch import Tensor @@ -8,25 +10,29 @@ from torch.utils.data import Dataset from torchvision import io, transforms from .helper import warp_tensor -from .types import Batch, BatchMode, WarpTemplate +from .types import Batch, WarpTemplate, BatchMode class DynamicDataset(Dataset[Tensor]): - def __init__(self, file_pattern : str, warp_template : WarpTemplate, transform_size : int, batch_mode : BatchMode, batch_ratio : float) -> None: - self.file_paths = glob.glob(file_pattern) - self.warp_template = warp_template - self.transform_size = transform_size - self.batch_mode = batch_mode - self.batch_ratio = batch_ratio + def __init__(self, config_parser : ConfigParser) -> None: + self.config =\ + { + 'file_pattern': config_parser.get('training.dataset', 'file_pattern'), + 'transform_size': config_parser.get('training.dataset', 'transform_size'), + 'batch_mode': cast(BatchMode, config_parser.get('training.dataset', 'batch_mode')), + 'batch_ratio': config_parser.getfloat('training.dataset', 'batch_ratio'), + } + self.config_parser = config_parser + self.file_paths = glob.glob(self.config.get('file_pattern')) self.transforms = self.compose_transforms() def __getitem__(self, index : int) -> Batch: file_path = self.file_paths[index] - if random.random() < self.batch_ratio: - if self.batch_mode == 'equal': + if random.random() < self.config.get('batch_ratio'): + if self.config.get('batch_mode') == 'equal': return self.prepare_equal_batch(file_path) - if self.batch_mode == 'same': + if self.config.get('batch_mode') == 'same': return self.prepare_same_batch(file_path) return self.prepare_different_batch(file_path) @@ -39,9 +45,9 @@ class DynamicDataset(Dataset[Tensor]): [ AugmentTransform(), transforms.ToPILImage(), - transforms.Resize((self.transform_size, self.transform_size), interpolation = transforms.InterpolationMode.BICUBIC), + transforms.Resize((self.config.get('transform_size'), self.config.get('transform_size')), interpolation = transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), - WarpTransform(self.warp_template), + WarpTransform(self.config_parser), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) @@ -92,9 +98,12 @@ class AugmentTransform: class WarpTransform: - def __init__(self, warp_template : WarpTemplate) -> None: - self.warp_template = warp_template + def __init__(self, config_parser : ConfigParser) -> None: + self.config =\ + { + 'warp_template': cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template')) + } def __call__(self, input_tensor : Tensor) -> Tensor: temp_tensor = input_tensor.unsqueeze(0) - return warp_tensor(temp_tensor, self.warp_template).squeeze(0) + return warp_tensor(temp_tensor, self.config.get('warp_template')).squeeze(0) diff --git a/face_swapper/src/exporting.py b/face_swapper/src/exporting.py index 8e64c8b..a24aa0d 100644 --- a/face_swapper/src/exporting.py +++ b/face_swapper/src/exporting.py @@ -1,26 +1,30 @@ -import configparser -from os import makedirs +import os +from configparser import ConfigParser + import torch from .training import FaceSwapperTrainer -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') def export() -> None: - directory_path = CONFIG.get('exporting', 'directory_path') - source_path = CONFIG.get('exporting', 'source_path') - target_path = CONFIG.get('exporting', 'target_path') - target_size = CONFIG.getint('exporting', 'target_size') - ir_version = CONFIG.getint('exporting', 'ir_version') - opset_version = CONFIG.getint('exporting', 'opset_version') + config =\ + { + 'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'), + 'source_path': CONFIG_PARSER.get('exporting', 'source_path'), + 'target_path': CONFIG_PARSER.get('exporting', 'target_path'), + 'target_size': CONFIG_PARSER.getint('exporting', 'target_size'), + 'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'), + 'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version') + } - makedirs(directory_path, exist_ok = True) - model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu') + os.makedirs(config.get('directory_path'), exist_ok = True) + model = FaceSwapperTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu') model.eval() - model.ir_version = torch.tensor(ir_version) + model.ir_version = torch.tensor(config.get('ir_version')) source_tensor = torch.randn(1, 512) - target_tensor = torch.randn(1, 3, target_size, target_size) - torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version) + target_tensor = torch.randn(1, 3, config.get('target_size'), config.get('target_size')) + torch.onnx.export(model, (source_tensor, target_tensor), config.get('target_path'), input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = config.get('opset_version')) diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index 52c5f80..3f718b3 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -6,26 +6,29 @@ from torchvision import io from .helper import calc_embedding from .models.generator import Generator -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') +CONFIG_PARSER = configparser.ConfigParser() +CONFIG_PARSER.read('config.ini') def infer() -> None: - generator_path = CONFIG.get('inferencing', 'generator_path') - embedder_path = CONFIG.get('inferencing', 'embedder_path') - source_path = CONFIG.get('inferencing', 'source_path') - target_path = CONFIG.get('inferencing', 'target_path') - output_path = CONFIG.get('inferencing', 'output_path') + config =\ + { + 'generator_path': CONFIG_PARSER.get('inferencing', 'generator_path'), + 'embedder_path': CONFIG_PARSER.get('inferencing', 'embedder_path'), + 'source_path': CONFIG_PARSER.get('inferencing', 'source_path'), + 'target_path': CONFIG_PARSER.get('inferencing', 'target_path'), + 'output_path': CONFIG_PARSER.get('inferencing', 'output_path') + } - state_dict = torch.load(generator_path).get('state_dict').get('generator') - generator = Generator() + state_dict = torch.load(config.get('generator_path')).get('state_dict').get('generator') + generator = Generator(CONFIG_PARSER) generator.load_state_dict(state_dict) generator.eval() embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] embedder.eval() - source_tensor = io.read_image(source_path) - target_tensor = io.read_image(target_path) + source_tensor = io.read_image(config.get('source_path')) + target_tensor = io.read_image(config.get('target_path')) source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0)) output_tensor = generator(source_embedding, target_tensor)[0] - io.write_jpeg(output_tensor, output_path) + io.write_jpeg(output_tensor, config.get('output_path')) diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index 9f7bf49..4ccb88f 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -1,31 +1,28 @@ -import configparser +from configparser import ConfigParser from typing import List from torch import Tensor, nn from ..networks.nld import NLD -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - class Discriminator(nn.Module): - def __init__(self) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() + self.config =\ + { + 'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators') + } + self.config_parser = config_parser self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False) self.discriminators = self.create_discriminators() - @staticmethod - def create_discriminators() -> nn.ModuleList: - num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators') - input_channels = CONFIG.getint('training.model.discriminator', 'input_channels') - num_filters = CONFIG.getint('training.model.discriminator', 'num_filters') - kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size') - num_layers = CONFIG.getint('training.model.discriminator', 'num_layers') + + def create_discriminators(self) -> nn.ModuleList: discriminators = nn.ModuleList() - for _ in range(num_discriminators): - discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).sequences + for _ in range(self.config.get('num_discriminators')): + discriminator = NLD(self.config_parser).sequences discriminators.append(discriminator) return discriminators @@ -35,7 +32,8 @@ class Discriminator(nn.Module): output_tensors = [] for discriminator in self.discriminators: - output_tensors.append(discriminator(temp_tensor)) + output_tensor = discriminator(temp_tensor) + output_tensors.append(output_tensor) temp_tensor = self.avg_pool(temp_tensor) return output_tensors diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 7e1ec08..d02be0b 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -1,4 +1,4 @@ -import configparser +from configparser import ConfigParser from torch import Tensor, nn @@ -6,20 +6,12 @@ from ..networks.aad import AAD from ..networks.unet import UNet from ..types import Attributes, Embedding -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - class Generator(nn.Module): - def __init__(self) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - identity_channels = CONFIG.getint('training.model.generator', 'identity_channels') - output_channels = CONFIG.getint('training.model.generator', 'output_channels') - output_size = CONFIG.getint('training.model.generator', 'output_size') - num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - - self.encoder = UNet(output_size) - self.generator = AAD(identity_channels, output_channels, output_size, num_blocks) + self.encoder = UNet(config_parser) + self.generator = AAD(config_parser) self.encoder.apply(init_weight) self.generator.apply(init_weight) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 71a20ea..85e85e1 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -1,4 +1,4 @@ -import configparser +from configparser import ConfigParser from typing import List, Tuple import torch @@ -9,9 +9,6 @@ from torchvision import transforms from ..helper import calc_embedding from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - class DiscriminatorLoss(nn.Module): def __init__(self) -> None: @@ -36,11 +33,14 @@ class DiscriminatorLoss(nn.Module): class AdversarialLoss(nn.Module): - def __init__(self) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() + self.config =\ + { + 'adversarial_weight': config_parser.getfloat('training.losses', 'adversarial_weight') + } def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]: - adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight') temp_tensors = [] for discriminator_output_tensor in discriminator_output_tensors: @@ -48,7 +48,7 @@ class AdversarialLoss(nn.Module): temp_tensors.append(temp_tensor) adversarial_loss = torch.stack(temp_tensors).mean() - weighted_adversarial_loss = adversarial_loss * adversarial_weight + weighted_adversarial_loss = adversarial_loss * self.config.get('adversarial_weight') return adversarial_loss, weighted_adversarial_loss diff --git a/face_swapper/src/networks/nld.py b/face_swapper/src/networks/nld.py index 2ef6865..d0347f8 100644 --- a/face_swapper/src/networks/nld.py +++ b/face_swapper/src/networks/nld.py @@ -1,33 +1,37 @@ import math +from configparser import ConfigParser from torch import Tensor, nn class NLD(nn.Module): - def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.input_channels = input_channels - self.num_filters = num_filters - self.num_layers = num_layers - self.kernel_size = kernel_size + self.config =\ + { + 'input_channels': config_parser.getint('training.model.discriminator', 'input_channels'), + 'num_filters': config_parser.getint('training.model.discriminator', 'num_filters'), + 'kernel_size': config_parser.getint('training.model.discriminator', 'kernel_size'), + 'num_layers': config_parser.getint('training.model.discriminator', 'num_layers') + } self.layers = self.create_layers() self.sequences = nn.Sequential(*self.layers) def create_layers(self) -> nn.ModuleList: - padding = math.ceil((self.kernel_size - 1) / 2) - current_filters = self.num_filters + padding = math.ceil((self.config.get('kernel_size') - 1) / 2) + current_filters = self.config.get('num_filters') layers = nn.ModuleList( [ - nn.Conv2d(self.input_channels, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding), + nn.Conv2d(self.config.get('input_channels'), current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding), nn.LeakyReLU(0.2, True) ]) - for _ in range(1, self.num_layers): + for _ in range(1, self.config.get('num_layers')): previous_filters = current_filters current_filters = min(current_filters * 2, 512) layers +=\ [ - nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding), + nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding), nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) ] @@ -36,10 +40,10 @@ class NLD(nn.Module): current_filters = min(current_filters * 2, 512) layers +=\ [ - nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, padding = padding), + nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), padding = padding), nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True), - nn.Conv2d(current_filters, 1, kernel_size = self.kernel_size, padding = padding) + nn.Conv2d(current_filters, 1, kernel_size = self.config.get('kernel_size'), padding = padding) ] return layers diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 1e0494c..250c7b0 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -1,3 +1,4 @@ +from configparser import ConfigParser from typing import Tuple import torch @@ -5,9 +6,12 @@ from torch import Tensor, nn class UNet(nn.Module): - def __init__(self, output_size : int) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.output_size = output_size + self.config =\ + { + 'output_size': config_parser.getint('training.model.generator', 'output_size') + } self.down_samples = self.create_down_samples() self.up_samples = self.create_up_samples() @@ -21,20 +25,20 @@ class UNet(nn.Module): DownSample(256, 512) ]) - if self.output_size == 128: + if self.config.get('output_size') == 128: down_samples.extend( [ DownSample(512, 512) ]) - if self.output_size == 256: + if self.config.get('output_size') == 256: down_samples.extend( [ DownSample(512, 1024), DownSample(1024, 1024) ]) - if self.output_size == 512: + if self.config.get('output_size') == 512: down_samples.extend( [ DownSample(512, 1024), @@ -47,20 +51,20 @@ class UNet(nn.Module): def create_up_samples(self) -> nn.ModuleList: up_samples = nn.ModuleList() - if self.output_size == 128: + if self.config.get('output_size') == 128: up_samples.extend( [ UpSample(512, 512) ]) - if self.output_size == 256: + if self.config.get('output_size') == 256: up_samples.extend( [ UpSample(1024, 1024), UpSample(2048, 512) ]) - if self.output_size == 512: + if self.config.get('output_size') == 512: up_samples.extend( [ UpSample(2048, 2048), diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index a99f89f..a375dad 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,5 +1,5 @@ -import configparser import os +from configparser import ConfigParser import warnings from typing import Tuple, cast @@ -21,30 +21,34 @@ from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate warnings.filterwarnings('ignore', category = UserWarning, module = 'torch') -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') class FaceSwapperTrainer(LightningModule): - def __init__(self) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - embedder_path = CONFIG.get('training.model', 'embedder_path') - gazer_path = CONFIG.get('training.model', 'gazer_path') - motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') + self.config =\ + { + 'embedder_path': config_parser.get('training.model', 'embedder_path'), + 'gazer_path': config_parser.get('training.model', 'gazer_path'), + 'motion_extractor_path': config_parser.get('training.model', 'motion_extractor_path'), + 'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate'), + 'preview_frequency': config_parser.getint('training.trainer', 'preview_frequency') + } + self.embedder = torch.jit.load(self.config.get('embedder_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call] + self.gazer = torch.jit.load(self.config.get('gazer_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call] + self.motion_extractor = torch.jit.load(self.config.get('motion_extractor_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call] - self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] - self.gazer = torch.jit.load(gazer_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] - self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call] - - self.generator = Generator() - self.discriminator = Discriminator() - self.discriminator_loss = DiscriminatorLoss() - self.adversarial_loss = AdversarialLoss() - self.attribute_loss = AttributeLoss() - self.reconstruction_loss = ReconstructionLoss(self.embedder) - self.identity_loss = IdentityLoss(self.embedder) - self.motion_loss = MotionLoss(self.motion_extractor) - self.gaze_loss = GazeLoss(self.gazer) + self.generator = Generator(config_parser) + self.discriminator = Discriminator(config_parser) + self.discriminator_loss = DiscriminatorLoss(config_parser) + self.adversarial_loss = AdversarialLoss(config_parser) + self.attribute_loss = AttributeLoss(config_parser) + self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder) + self.identity_loss = IdentityLoss(config_parser, self.embedder) + self.motion_loss = MotionLoss(config_parser, self.motion_extractor) + self.gaze_loss = GazeLoss(config_parser, self.gazer) self.automatic_optimization = False def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: @@ -52,9 +56,8 @@ class FaceSwapperTrainer(LightningModule): return output_tensor def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]: - learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') - generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) - discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4) + generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) + discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2) discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2) @@ -79,8 +82,6 @@ class FaceSwapperTrainer(LightningModule): return generator_config, discriminator_config def training_step(self, batch : Batch, batch_index : int) -> Tensor: - preview_frequency = CONFIG.getint('training.trainer', 'preview_frequency') - source_tensor, target_tensor = batch generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined] source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0)) @@ -113,7 +114,7 @@ class FaceSwapperTrainer(LightningModule): discriminator_optimizer.step() self.untoggle_optimizer(discriminator_optimizer) - if self.global_step % preview_frequency == 0: + if self.global_step % self.config.get('preview_frequency') == 0: self.generate_preview(source_tensor, target_tensor, generator_output_tensor) self.log('generator_loss', generator_loss, prog_bar = True) @@ -149,42 +150,52 @@ class FaceSwapperTrainer(LightningModule): def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: - batch_size = CONFIG.getint('training.loader', 'batch_size') - num_workers = CONFIG.getint('training.loader', 'num_workers') + config =\ + { + 'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'), + 'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers') + } training_dataset, validate_dataset = split_dataset(dataset) - training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) + training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True) return training_loader, validation_loader def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: - split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') + config =\ + { + 'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio') + } + dataset_size = len(dataset) # type:ignore[arg-type] - training_size = int(dataset_size * split_ratio) + training_size = int(dataset_size * config.get('split_ratio')) validation_size = int(dataset_size - training_size) training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) return training_dataset, validate_dataset def create_trainer() -> Trainer: - trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') - output_directory_path = CONFIG.get('training.output', 'directory_path') - output_file_pattern = CONFIG.get('training.output', 'file_pattern') - trainer_precision = CONFIG.get('training.trainer', 'precision') + config =\ + { + 'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'), + 'precision': CONFIG_PARSER.get('training.trainer', 'precision'), + 'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'), + 'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern') + } logger = TensorBoardLogger('.logs', name = 'face_swapper') return Trainer( logger = logger, log_every_n_steps = 10, - max_epochs = trainer_max_epochs, - precision = trainer_precision, # type:ignore[arg-type] + max_epochs = config.get('max_epochs'), + precision = config.get('precision'), callbacks = [ ModelCheckpoint( monitor = 'generator_loss', - dirpath = output_directory_path, - filename = output_file_pattern, + dirpath = config.get('directory_path'), + filename = config.get('file_pattern'), every_n_train_steps = 1000, save_top_k = 3, save_last = True @@ -195,22 +206,20 @@ def create_trainer() -> Trainer: def train() -> None: - dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') - dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template')) - dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode')) - dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') - output_resume_path = CONFIG.get('training.output', 'resume_path') - output_size = CONFIG.getint('training.model.generator', 'output_size') + config =\ + { + 'resume_path': CONFIG_PARSER.get('training.output', 'resume_path') + } if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, output_size, dataset_batch_mode, dataset_batch_ratio) + dataset = DynamicDataset(CONFIG_PARSER) training_loader, validation_loader = create_loaders(dataset) - face_swapper_trainer = FaceSwapperTrainer() + face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER) trainer = create_trainer() - if os.path.isfile(output_resume_path): - trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = output_resume_path) + if os.path.isfile(config.get('resume_path')): + trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path')) else: trainer.fit(face_swapper_trainer, training_loader, validation_loader)