mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Forward config parser
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
import glob
|
||||
from configparser import ConfigParser
|
||||
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .types import Batch, Config
|
||||
from .types import Batch
|
||||
|
||||
|
||||
class StaticDataset(Dataset[Tensor]):
|
||||
def __init__(self, config : Config) -> None:
|
||||
self.config = config
|
||||
def __init__(self, config : ConfigParser) -> None:
|
||||
self.config =\
|
||||
{
|
||||
'file_pattern': config.get('training.dataset', 'file_pattern')
|
||||
}
|
||||
self.file_paths = glob.glob(self.config.get('file_pattern'))
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
|
||||
@@ -5,18 +5,18 @@ import torch
|
||||
|
||||
from .training import EmbeddingConverterTrainer
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
|
||||
def export() -> None:
|
||||
config =\
|
||||
{
|
||||
'directory_path': CONFIG.get('exporting', 'directory_path'),
|
||||
'source_path': CONFIG.get('exporting', 'source_path'),
|
||||
'target_path': CONFIG.get('exporting', 'target_path'),
|
||||
'ir_version': CONFIG.getint('exporting', 'ir_version'),
|
||||
'opset_version': CONFIG.getint('exporting', 'opset_version')
|
||||
'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'),
|
||||
'source_path': CONFIG_PARSER.get('exporting', 'source_path'),
|
||||
'target_path': CONFIG_PARSER.get('exporting', 'target_path'),
|
||||
'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'),
|
||||
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
|
||||
}
|
||||
|
||||
makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from configparser import ConfigParser
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
@@ -12,16 +13,20 @@ from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
|
||||
from .dataset import StaticDataset
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
from .types import Batch, Config, ConfigSet, Embedding, OptimizerSet
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
|
||||
CONFIG_PARSER = ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
|
||||
class EmbeddingConverterTrainer(LightningModule):
|
||||
def __init__(self, config : Config) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.config_parser =\
|
||||
{
|
||||
'source_path': config_parser.get('training.model', 'source_path'),
|
||||
'target_path': config_parser.get('training.model', 'target_path'),
|
||||
'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate')
|
||||
}
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
@@ -68,8 +73,8 @@ class EmbeddingConverterTrainer(LightningModule):
|
||||
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
|
||||
config =\
|
||||
{
|
||||
'batch_size': CONFIG.getint('training.loader', 'batch_size'),
|
||||
'num_workers': CONFIG.getint('training.loader', 'num_workers')
|
||||
'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'),
|
||||
'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers')
|
||||
}
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
@@ -81,7 +86,7 @@ def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor
|
||||
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
|
||||
config =\
|
||||
{
|
||||
'split_ratio': CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
|
||||
}
|
||||
|
||||
dataset_size = len(dataset) # type:ignore[arg-type]
|
||||
@@ -94,10 +99,10 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T
|
||||
def create_trainer() -> Trainer:
|
||||
config =\
|
||||
{
|
||||
'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'),
|
||||
'precision': CONFIG.get('training.trainer', 'precision'),
|
||||
'directory_path': CONFIG.get('training.output', 'directory_path'),
|
||||
'file_pattern': CONFIG.get('training.output', 'file_pattern')
|
||||
'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'),
|
||||
'precision': CONFIG_PARSER.get('training.trainer', 'precision'),
|
||||
'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'),
|
||||
'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern')
|
||||
}
|
||||
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
|
||||
|
||||
@@ -121,33 +126,20 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
config_set : ConfigSet =\
|
||||
config =\
|
||||
{
|
||||
'dataset':
|
||||
{
|
||||
'file_pattern': CONFIG.get('training.dataset', 'file_pattern')
|
||||
},
|
||||
'trainer':
|
||||
{
|
||||
'source_path': CONFIG.get('training.model', 'source_path'),
|
||||
'target_path': CONFIG.get('training.model', 'target_path'),
|
||||
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
|
||||
},
|
||||
'output':
|
||||
{
|
||||
'resume_path': CONFIG.get('training.output', 'resume_path')
|
||||
}
|
||||
'resume_path': CONFIG_PARSER.get('training.output', 'resume_path')
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = StaticDataset(config_set.get('dataset'))
|
||||
dataset = StaticDataset(CONFIG_PARSER)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(config_set.get('trainer'))
|
||||
embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER)
|
||||
trainer = create_trainer()
|
||||
|
||||
if os.path.exists(config_set.get('output').get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_set.get('output').get('resume_path'))
|
||||
if os.path.exists(config.get('resume_path')):
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
|
||||
else:
|
||||
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from typing import Any, Dict, TypeAlias
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
Batch : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
|
||||
Config : TypeAlias = Dict[str, Any]
|
||||
ConfigSet : TypeAlias = Dict[str, Config]
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
Reference in New Issue
Block a user