diff --git a/embedding_converter/src/dataset.py b/embedding_converter/src/dataset.py index 77b5ab1..79c25a4 100644 --- a/embedding_converter/src/dataset.py +++ b/embedding_converter/src/dataset.py @@ -1,15 +1,19 @@ import glob +from configparser import ConfigParser from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms -from .types import Batch, Config +from .types import Batch class StaticDataset(Dataset[Tensor]): - def __init__(self, config : Config) -> None: - self.config = config + def __init__(self, config : ConfigParser) -> None: + self.config =\ + { + 'file_pattern': config.get('training.dataset', 'file_pattern') + } self.file_paths = glob.glob(self.config.get('file_pattern')) self.transforms = self.compose_transforms() diff --git a/embedding_converter/src/exporting.py b/embedding_converter/src/exporting.py index 0a39d69..4222295 100644 --- a/embedding_converter/src/exporting.py +++ b/embedding_converter/src/exporting.py @@ -5,18 +5,18 @@ import torch from .training import EmbeddingConverterTrainer -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') +CONFIG_PARSER = configparser.ConfigParser() +CONFIG_PARSER.read('config.ini') def export() -> None: config =\ { - 'directory_path': CONFIG.get('exporting', 'directory_path'), - 'source_path': CONFIG.get('exporting', 'source_path'), - 'target_path': CONFIG.get('exporting', 'target_path'), - 'ir_version': CONFIG.getint('exporting', 'ir_version'), - 'opset_version': CONFIG.getint('exporting', 'opset_version') + 'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'), + 'source_path': CONFIG_PARSER.get('exporting', 'source_path'), + 'target_path': CONFIG_PARSER.get('exporting', 'target_path'), + 'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'), + 'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version') } makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type] diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 91aea51..4d3024d 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -3,6 +3,7 @@ import os from typing import Tuple import torch +from configparser import ConfigParser from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger @@ -12,16 +13,20 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter -from .types import Batch, Config, ConfigSet, Embedding, OptimizerSet - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') +from .types import Batch, Embedding, OptimizerSet +CONFIG_PARSER = ConfigParser() +CONFIG_PARSER.read('config.ini') class EmbeddingConverterTrainer(LightningModule): - def __init__(self, config : Config) -> None: + def __init__(self, config_parser : ConfigParser) -> None: super().__init__() - self.config = config + self.config_parser =\ + { + 'source_path': config_parser.get('training.model', 'source_path'), + 'target_path': config_parser.get('training.model', 'target_path'), + 'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate') + } self.embedding_converter = EmbeddingConverter() self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call] self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call] @@ -68,8 +73,8 @@ class EmbeddingConverterTrainer(LightningModule): def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: config =\ { - 'batch_size': CONFIG.getint('training.loader', 'batch_size'), - 'num_workers': CONFIG.getint('training.loader', 'num_workers') + 'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'), + 'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers') } training_dataset, validate_dataset = split_dataset(dataset) @@ -81,7 +86,7 @@ def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: config =\ { - 'split_ratio': CONFIG.getfloat('training.loader', 'split_ratio') + 'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio') } dataset_size = len(dataset) # type:ignore[arg-type] @@ -94,10 +99,10 @@ def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[T def create_trainer() -> Trainer: config =\ { - 'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'), - 'precision': CONFIG.get('training.trainer', 'precision'), - 'directory_path': CONFIG.get('training.output', 'directory_path'), - 'file_pattern': CONFIG.get('training.output', 'file_pattern') + 'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'), + 'precision': CONFIG_PARSER.get('training.trainer', 'precision'), + 'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'), + 'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern') } logger = TensorBoardLogger('.logs', name = 'embedding_converter') @@ -121,33 +126,20 @@ def create_trainer() -> Trainer: def train() -> None: - config_set : ConfigSet =\ + config =\ { - 'dataset': - { - 'file_pattern': CONFIG.get('training.dataset', 'file_pattern') - }, - 'trainer': - { - 'source_path': CONFIG.get('training.model', 'source_path'), - 'target_path': CONFIG.get('training.model', 'target_path'), - 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') - }, - 'output': - { - 'resume_path': CONFIG.get('training.output', 'resume_path') - } + 'resume_path': CONFIG_PARSER.get('training.output', 'resume_path') } if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = StaticDataset(config_set.get('dataset')) + dataset = StaticDataset(CONFIG_PARSER) training_loader, validation_loader = create_loaders(dataset) - embedding_converter_trainer = EmbeddingConverterTrainer(config_set.get('trainer')) + embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER) trainer = create_trainer() - if os.path.exists(config_set.get('output').get('resume_path')): - trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_set.get('output').get('resume_path')) + if os.path.exists(config.get('resume_path')): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path')) else: trainer.fit(embedding_converter_trainer, training_loader, validation_loader) diff --git a/embedding_converter/src/types.py b/embedding_converter/src/types.py index e16bd5e..0522b39 100644 --- a/embedding_converter/src/types.py +++ b/embedding_converter/src/types.py @@ -1,11 +1,8 @@ -from typing import Any, Dict, TypeAlias +from typing import Any, TypeAlias from torch import Tensor Batch : TypeAlias = Tensor Embedding : TypeAlias = Tensor -Config : TypeAlias = Dict[str, Any] -ConfigSet : TypeAlias = Dict[str, Config] - OptimizerSet : TypeAlias = Any