Forward config parser

This commit is contained in:
henryruhs
2025-03-06 14:11:56 +01:00
parent a2a9b78dac
commit 01278d679f
4 changed files with 39 additions and 46 deletions
+7 -3
View File
@@ -1,15 +1,19 @@
import glob
from configparser import ConfigParser
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .types import Batch, Config
from .types import Batch
class StaticDataset(Dataset[Tensor]):
def __init__(self, config : Config) -> None:
self.config = config
def __init__(self, config : ConfigParser) -> None:
self.config =\
{
'file_pattern': config.get('training.dataset', 'file_pattern')
}
self.file_paths = glob.glob(self.config.get('file_pattern'))
self.transforms = self.compose_transforms()
+7 -7
View File
@@ -5,18 +5,18 @@ import torch
from .training import EmbeddingConverterTrainer
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
CONFIG_PARSER = configparser.ConfigParser()
CONFIG_PARSER.read('config.ini')
def export() -> None:
config =\
{
'directory_path': CONFIG.get('exporting', 'directory_path'),
'source_path': CONFIG.get('exporting', 'source_path'),
'target_path': CONFIG.get('exporting', 'target_path'),
'ir_version': CONFIG.getint('exporting', 'ir_version'),
'opset_version': CONFIG.getint('exporting', 'opset_version')
'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'),
'source_path': CONFIG_PARSER.get('exporting', 'source_path'),
'target_path': CONFIG_PARSER.get('exporting', 'target_path'),
'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'),
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
}
makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
+24 -32
View File
@@ -3,6 +3,7 @@ import os
from typing import Tuple
import torch
from configparser import ConfigParser
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
@@ -12,16 +13,20 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Config, ConfigSet, Embedding, OptimizerSet
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
from .types import Batch, Embedding, OptimizerSet
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
class EmbeddingConverterTrainer(LightningModule):
def __init__(self, config : Config) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config = config
self.config_parser =\
{
'source_path': config_parser.get('training.model', 'source_path'),
'target_path': config_parser.get('training.model', 'target_path'),
'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate')
}
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
@@ -68,8 +73,8 @@ class EmbeddingConverterTrainer(LightningModule):
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config =\
{
'batch_size': CONFIG.getint('training.loader', 'batch_size'),
'num_workers': CONFIG.getint('training.loader', 'num_workers')
'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'),
'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers')
}
training_dataset, validate_dataset = split_dataset(dataset)
@@ -81,7 +86,7 @@ def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
config =\
{
'split_ratio': CONFIG.getfloat('training.loader', 'split_ratio')
'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
}
dataset_size = len(dataset) # type:ignore[arg-type]
@@ -94,10 +99,10 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T
def create_trainer() -> Trainer:
config =\
{
'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'),
'precision': CONFIG.get('training.trainer', 'precision'),
'directory_path': CONFIG.get('training.output', 'directory_path'),
'file_pattern': CONFIG.get('training.output', 'file_pattern')
'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'),
'precision': CONFIG_PARSER.get('training.trainer', 'precision'),
'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'),
'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern')
}
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
@@ -121,33 +126,20 @@ def create_trainer() -> Trainer:
def train() -> None:
config_set : ConfigSet =\
config =\
{
'dataset':
{
'file_pattern': CONFIG.get('training.dataset', 'file_pattern')
},
'trainer':
{
'source_path': CONFIG.get('training.model', 'source_path'),
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
},
'output':
{
'resume_path': CONFIG.get('training.output', 'resume_path')
}
'resume_path': CONFIG_PARSER.get('training.output', 'resume_path')
}
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(config_set.get('dataset'))
dataset = StaticDataset(CONFIG_PARSER)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer(config_set.get('trainer'))
embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.exists(config_set.get('output').get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_set.get('output').get('resume_path'))
if os.path.exists(config.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
+1 -4
View File
@@ -1,11 +1,8 @@
from typing import Any, Dict, TypeAlias
from typing import Any, TypeAlias
from torch import Tensor
Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Config : TypeAlias = Dict[str, Any]
ConfigSet : TypeAlias = Dict[str, Config]
OptimizerSet : TypeAlias = Any