Rework on config

This commit is contained in:
henryruhs
2025-03-06 16:32:07 +01:00
parent a5d99c139e
commit ab3b699124
9 changed files with 169 additions and 146 deletions
+24 -15
View File
@@ -1,6 +1,8 @@
import glob
import os
import random
from configparser import ConfigParser
from typing import cast
import albumentations
from torch import Tensor
@@ -8,25 +10,29 @@ from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import warp_tensor
from .types import Batch, BatchMode, WarpTemplate
from .types import Batch, WarpTemplate, BatchMode
class DynamicDataset(Dataset[Tensor]):
def __init__(self, file_pattern : str, warp_template : WarpTemplate, transform_size : int, batch_mode : BatchMode, batch_ratio : float) -> None:
self.file_paths = glob.glob(file_pattern)
self.warp_template = warp_template
self.transform_size = transform_size
self.batch_mode = batch_mode
self.batch_ratio = batch_ratio
def __init__(self, config_parser : ConfigParser) -> None:
self.config =\
{
'file_pattern': config_parser.get('training.dataset', 'file_pattern'),
'transform_size': config_parser.get('training.dataset', 'transform_size'),
'batch_mode': cast(BatchMode, config_parser.get('training.dataset', 'batch_mode')),
'batch_ratio': config_parser.getfloat('training.dataset', 'batch_ratio'),
}
self.config_parser = config_parser
self.file_paths = glob.glob(self.config.get('file_pattern'))
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
if random.random() < self.batch_ratio:
if self.batch_mode == 'equal':
if random.random() < self.config.get('batch_ratio'):
if self.config.get('batch_mode') == 'equal':
return self.prepare_equal_batch(file_path)
if self.batch_mode == 'same':
if self.config.get('batch_mode') == 'same':
return self.prepare_same_batch(file_path)
return self.prepare_different_batch(file_path)
@@ -39,9 +45,9 @@ class DynamicDataset(Dataset[Tensor]):
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.transform_size, self.transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.Resize((self.config.get('transform_size'), self.config.get('transform_size')), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
WarpTransform(self.warp_template),
WarpTransform(self.config_parser),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
@@ -92,9 +98,12 @@ class AugmentTransform:
class WarpTransform:
def __init__(self, warp_template : WarpTemplate) -> None:
self.warp_template = warp_template
def __init__(self, config_parser : ConfigParser) -> None:
self.config =\
{
'warp_template': cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
}
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return warp_tensor(temp_tensor, self.warp_template).squeeze(0)
return warp_tensor(temp_tensor, self.config.get('warp_template')).squeeze(0)
+19 -15
View File
@@ -1,26 +1,30 @@
import configparser
from os import makedirs
import os
from configparser import ConfigParser
import torch
from .training import FaceSwapperTrainer
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
def export() -> None:
directory_path = CONFIG.get('exporting', 'directory_path')
source_path = CONFIG.get('exporting', 'source_path')
target_path = CONFIG.get('exporting', 'target_path')
target_size = CONFIG.getint('exporting', 'target_size')
ir_version = CONFIG.getint('exporting', 'ir_version')
opset_version = CONFIG.getint('exporting', 'opset_version')
config =\
{
'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'),
'source_path': CONFIG_PARSER.get('exporting', 'source_path'),
'target_path': CONFIG_PARSER.get('exporting', 'target_path'),
'target_size': CONFIG_PARSER.getint('exporting', 'target_size'),
'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'),
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
}
makedirs(directory_path, exist_ok = True)
model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
os.makedirs(config.get('directory_path'), exist_ok = True)
model = FaceSwapperTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(ir_version)
model.ir_version = torch.tensor(config.get('ir_version'))
source_tensor = torch.randn(1, 512)
target_tensor = torch.randn(1, 3, target_size, target_size)
torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version)
target_tensor = torch.randn(1, 3, config.get('target_size'), config.get('target_size'))
torch.onnx.export(model, (source_tensor, target_tensor), config.get('target_path'), input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = config.get('opset_version'))
+15 -12
View File
@@ -6,26 +6,29 @@ from torchvision import io
from .helper import calc_embedding
from .models.generator import Generator
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
CONFIG_PARSER = configparser.ConfigParser()
CONFIG_PARSER.read('config.ini')
def infer() -> None:
generator_path = CONFIG.get('inferencing', 'generator_path')
embedder_path = CONFIG.get('inferencing', 'embedder_path')
source_path = CONFIG.get('inferencing', 'source_path')
target_path = CONFIG.get('inferencing', 'target_path')
output_path = CONFIG.get('inferencing', 'output_path')
config =\
{
'generator_path': CONFIG_PARSER.get('inferencing', 'generator_path'),
'embedder_path': CONFIG_PARSER.get('inferencing', 'embedder_path'),
'source_path': CONFIG_PARSER.get('inferencing', 'source_path'),
'target_path': CONFIG_PARSER.get('inferencing', 'target_path'),
'output_path': CONFIG_PARSER.get('inferencing', 'output_path')
}
state_dict = torch.load(generator_path).get('state_dict').get('generator')
generator = Generator()
state_dict = torch.load(config.get('generator_path')).get('state_dict').get('generator')
generator = Generator(CONFIG_PARSER)
generator.load_state_dict(state_dict)
generator.eval()
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
embedder.eval()
source_tensor = io.read_image(source_path)
target_tensor = io.read_image(target_path)
source_tensor = io.read_image(config.get('source_path'))
target_tensor = io.read_image(config.get('target_path'))
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor = generator(source_embedding, target_tensor)[0]
io.write_jpeg(output_tensor, output_path)
io.write_jpeg(output_tensor, config.get('output_path'))
+13 -15
View File
@@ -1,31 +1,28 @@
import configparser
from configparser import ConfigParser
from typing import List
from torch import Tensor, nn
from ..networks.nld import NLD
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class Discriminator(nn.Module):
def __init__(self) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators')
}
self.config_parser = config_parser
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
@staticmethod
def create_discriminators() -> nn.ModuleList:
num_discriminators = CONFIG.getint('training.model.discriminator', 'num_discriminators')
input_channels = CONFIG.getint('training.model.discriminator', 'input_channels')
num_filters = CONFIG.getint('training.model.discriminator', 'num_filters')
kernel_size = CONFIG.getint('training.model.discriminator', 'kernel_size')
num_layers = CONFIG.getint('training.model.discriminator', 'num_layers')
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
for _ in range(num_discriminators):
discriminator = NLD(input_channels, num_filters, num_layers, kernel_size).sequences
for _ in range(self.config.get('num_discriminators')):
discriminator = NLD(self.config_parser).sequences
discriminators.append(discriminator)
return discriminators
@@ -35,7 +32,8 @@ class Discriminator(nn.Module):
output_tensors = []
for discriminator in self.discriminators:
output_tensors.append(discriminator(temp_tensor))
output_tensor = discriminator(temp_tensor)
output_tensors.append(output_tensor)
temp_tensor = self.avg_pool(temp_tensor)
return output_tensors
+4 -12
View File
@@ -1,4 +1,4 @@
import configparser
from configparser import ConfigParser
from torch import Tensor, nn
@@ -6,20 +6,12 @@ from ..networks.aad import AAD
from ..networks.unet import UNet
from ..types import Attributes, Embedding
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class Generator(nn.Module):
def __init__(self) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
identity_channels = CONFIG.getint('training.model.generator', 'identity_channels')
output_channels = CONFIG.getint('training.model.generator', 'output_channels')
output_size = CONFIG.getint('training.model.generator', 'output_size')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.encoder = UNet(output_size)
self.generator = AAD(identity_channels, output_channels, output_size, num_blocks)
self.encoder = UNet(config_parser)
self.generator = AAD(config_parser)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
+7 -7
View File
@@ -1,4 +1,4 @@
import configparser
from configparser import ConfigParser
from typing import List, Tuple
import torch
@@ -9,9 +9,6 @@ from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attributes, EmbedderModule, Gaze, GazerModule, MotionExtractorModule
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
@@ -36,11 +33,14 @@ class DiscriminatorLoss(nn.Module):
class AdversarialLoss(nn.Module):
def __init__(self) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'adversarial_weight': config_parser.getfloat('training.losses', 'adversarial_weight')
}
def forward(self, discriminator_output_tensors : List[Tensor]) -> Tuple[Tensor, Tensor]:
adversarial_weight = CONFIG.getfloat('training.losses', 'adversarial_weight')
temp_tensors = []
for discriminator_output_tensor in discriminator_output_tensors:
@@ -48,7 +48,7 @@ class AdversarialLoss(nn.Module):
temp_tensors.append(temp_tensor)
adversarial_loss = torch.stack(temp_tensors).mean()
weighted_adversarial_loss = adversarial_loss * adversarial_weight
weighted_adversarial_loss = adversarial_loss * self.config.get('adversarial_weight')
return adversarial_loss, weighted_adversarial_loss
+16 -12
View File
@@ -1,33 +1,37 @@
import math
from configparser import ConfigParser
from torch import Tensor, nn
class NLD(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, kernel_size : int) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.input_channels = input_channels
self.num_filters = num_filters
self.num_layers = num_layers
self.kernel_size = kernel_size
self.config =\
{
'input_channels': config_parser.getint('training.model.discriminator', 'input_channels'),
'num_filters': config_parser.getint('training.model.discriminator', 'num_filters'),
'kernel_size': config_parser.getint('training.model.discriminator', 'kernel_size'),
'num_layers': config_parser.getint('training.model.discriminator', 'num_layers')
}
self.layers = self.create_layers()
self.sequences = nn.Sequential(*self.layers)
def create_layers(self) -> nn.ModuleList:
padding = math.ceil((self.kernel_size - 1) / 2)
current_filters = self.num_filters
padding = math.ceil((self.config.get('kernel_size') - 1) / 2)
current_filters = self.config.get('num_filters')
layers = nn.ModuleList(
[
nn.Conv2d(self.input_channels, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding),
nn.Conv2d(self.config.get('input_channels'), current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
nn.LeakyReLU(0.2, True)
])
for _ in range(1, self.num_layers):
for _ in range(1, self.config.get('num_layers')):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, stride = 2, padding = padding),
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]
@@ -36,10 +40,10 @@ class NLD(nn.Module):
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.kernel_size, padding = padding),
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True),
nn.Conv2d(current_filters, 1, kernel_size = self.kernel_size, padding = padding)
nn.Conv2d(current_filters, 1, kernel_size = self.config.get('kernel_size'), padding = padding)
]
return layers
+12 -8
View File
@@ -1,3 +1,4 @@
from configparser import ConfigParser
from typing import Tuple
import torch
@@ -5,9 +6,12 @@ from torch import Tensor, nn
class UNet(nn.Module):
def __init__(self, output_size : int) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.output_size = output_size
self.config =\
{
'output_size': config_parser.getint('training.model.generator', 'output_size')
}
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()
@@ -21,20 +25,20 @@ class UNet(nn.Module):
DownSample(256, 512)
])
if self.output_size == 128:
if self.config.get('output_size') == 128:
down_samples.extend(
[
DownSample(512, 512)
])
if self.output_size == 256:
if self.config.get('output_size') == 256:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 1024)
])
if self.output_size == 512:
if self.config.get('output_size') == 512:
down_samples.extend(
[
DownSample(512, 1024),
@@ -47,20 +51,20 @@ class UNet(nn.Module):
def create_up_samples(self) -> nn.ModuleList:
up_samples = nn.ModuleList()
if self.output_size == 128:
if self.config.get('output_size') == 128:
up_samples.extend(
[
UpSample(512, 512)
])
if self.output_size == 256:
if self.config.get('output_size') == 256:
up_samples.extend(
[
UpSample(1024, 1024),
UpSample(2048, 512)
])
if self.output_size == 512:
if self.config.get('output_size') == 512:
up_samples.extend(
[
UpSample(2048, 2048),
+59 -50
View File
@@ -1,5 +1,5 @@
import configparser
import os
from configparser import ConfigParser
import warnings
from typing import Tuple, cast
@@ -21,30 +21,34 @@ from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class FaceSwapperTrainer(LightningModule):
def __init__(self) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
gazer_path = CONFIG.get('training.model', 'gazer_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.config =\
{
'embedder_path': config_parser.get('training.model', 'embedder_path'),
'gazer_path': config_parser.get('training.model', 'gazer_path'),
'motion_extractor_path': config_parser.get('training.model', 'motion_extractor_path'),
'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate'),
'preview_frequency': config_parser.getint('training.trainer', 'preview_frequency')
}
self.embedder = torch.jit.load(self.config.get('embedder_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.gazer = torch.jit.load(self.config.get('gazer_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(self.config.get('motion_extractor_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.gazer = torch.jit.load(gazer_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu').eval() # type:ignore[no-untyped-call]
self.generator = Generator()
self.discriminator = Discriminator()
self.discriminator_loss = DiscriminatorLoss()
self.adversarial_loss = AdversarialLoss()
self.attribute_loss = AttributeLoss()
self.reconstruction_loss = ReconstructionLoss(self.embedder)
self.identity_loss = IdentityLoss(self.embedder)
self.motion_loss = MotionLoss(self.motion_extractor)
self.gaze_loss = GazeLoss(self.gazer)
self.generator = Generator(config_parser)
self.discriminator = Discriminator(config_parser)
self.discriminator_loss = DiscriminatorLoss(config_parser)
self.adversarial_loss = AdversarialLoss(config_parser)
self.attribute_loss = AttributeLoss(config_parser)
self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder)
self.identity_loss = IdentityLoss(config_parser, self.embedder)
self.motion_loss = MotionLoss(config_parser, self.motion_extractor)
self.gaze_loss = GazeLoss(config_parser, self.gazer)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
@@ -52,9 +56,8 @@ class FaceSwapperTrainer(LightningModule):
return output_tensor
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet]:
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr = self.config.get('learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(generator_optimizer, T_0 = 300, T_mult = 2)
discriminator_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(discriminator_optimizer, T_0 = 300, T_mult = 2)
@@ -79,8 +82,6 @@ class FaceSwapperTrainer(LightningModule):
return generator_config, discriminator_config
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
preview_frequency = CONFIG.getint('training.trainer', 'preview_frequency')
source_tensor, target_tensor = batch
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
@@ -113,7 +114,7 @@ class FaceSwapperTrainer(LightningModule):
discriminator_optimizer.step()
self.untoggle_optimizer(discriminator_optimizer)
if self.global_step % preview_frequency == 0:
if self.global_step % self.config.get('preview_frequency') == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.log('generator_loss', generator_loss, prog_bar = True)
@@ -149,42 +150,52 @@ class FaceSwapperTrainer(LightningModule):
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
config =\
{
'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'),
'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers')
}
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True)
training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
config =\
{
'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
}
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * split_ratio)
training_size = int(dataset_size * config.get('split_ratio'))
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs')
output_directory_path = CONFIG.get('training.output', 'directory_path')
output_file_pattern = CONFIG.get('training.output', 'file_pattern')
trainer_precision = CONFIG.get('training.trainer', 'precision')
config =\
{
'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'),
'precision': CONFIG_PARSER.get('training.trainer', 'precision'),
'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'),
'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern')
}
logger = TensorBoardLogger('.logs', name = 'face_swapper')
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = trainer_max_epochs,
precision = trainer_precision, # type:ignore[arg-type]
max_epochs = config.get('max_epochs'),
precision = config.get('precision'),
callbacks =
[
ModelCheckpoint(
monitor = 'generator_loss',
dirpath = output_directory_path,
filename = output_file_pattern,
dirpath = config.get('directory_path'),
filename = config.get('file_pattern'),
every_n_train_steps = 1000,
save_top_k = 3,
save_last = True
@@ -195,22 +206,20 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
output_resume_path = CONFIG.get('training.output', 'resume_path')
output_size = CONFIG.getint('training.model.generator', 'output_size')
config =\
{
'resume_path': CONFIG_PARSER.get('training.output', 'resume_path')
}
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, output_size, dataset_batch_mode, dataset_batch_ratio)
dataset = DynamicDataset(CONFIG_PARSER)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()
face_swapper_trainer = FaceSwapperTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.isfile(output_resume_path):
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = output_resume_path)
if os.path.isfile(config.get('resume_path')):
trainer.fit(face_swapper_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
else:
trainer.fit(face_swapper_trainer, training_loader, validation_loader)