diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index a6ca04f..b52057a 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -2,9 +2,8 @@ import configparser import os from typing import Tuple -import lightning import torch -from lightning import Trainer +from lightning import Trainer, LightningModule from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import Tensor, nn @@ -13,22 +12,16 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter -from .types import Batch, Embedding, OptimizerSet +from .types import Batch, Embedding, OptimizerSet, Config CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class EmbeddingConverterTrainer(lightning.LightningModule): - def __init__(self) -> None: - super(EmbeddingConverterTrainer, self).__init__() - self.config =\ - { - 'source_path': CONFIG.get('training.model', 'source_path'), - 'target_path': CONFIG.get('training.model', 'target_path'), - 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') - } - +class EmbeddingConverterTrainer(LightningModule): + def __init__(self, config : Config) -> None: + super().__init__() + self.config = config self.embedding_converter = EmbeddingConverter() self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call] self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call] @@ -128,21 +121,27 @@ def create_trainer() -> Trainer: def train() -> None: - config =\ + config_dataset =\ { 'file_pattern': CONFIG.get('training.dataset', 'file_pattern'), 'resume_path': CONFIG.get('training.output', 'resume_path') } + config_trainer =\ + { + 'source_path': CONFIG.get('training.model', 'source_path'), + 'target_path': CONFIG.get('training.model', 'target_path'), + 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') + } if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = StaticDataset(config) + dataset = StaticDataset(config_dataset) training_loader, validation_loader = create_loaders(dataset) - embedding_converter_trainer = EmbeddingConverterTrainer() + embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer) trainer = create_trainer() - if os.path.exists(config.get('resume_path')): - trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path')) + if os.path.exists(config_dataset.get('resume_path')): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_dataset.get('resume_path')) else: trainer.fit(embedding_converter_trainer, training_loader, validation_loader)