Always config injection

This commit is contained in:
henryruhs
2025-03-06 09:22:14 +01:00
parent 368da824aa
commit b59e172fa3
+17 -18
View File
@@ -2,9 +2,8 @@ import configparser
import os
from typing import Tuple
import lightning
import torch
from lightning import Trainer
from lightning import Trainer, LightningModule
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch import Tensor, nn
@@ -13,22 +12,16 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Embedding, OptimizerSet
from .types import Batch, Embedding, OptimizerSet, Config
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class EmbeddingConverterTrainer(lightning.LightningModule):
def __init__(self) -> None:
super(EmbeddingConverterTrainer, self).__init__()
self.config =\
{
'source_path': CONFIG.get('training.model', 'source_path'),
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
}
class EmbeddingConverterTrainer(LightningModule):
def __init__(self, config : Config) -> None:
super().__init__()
self.config = config
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
@@ -128,21 +121,27 @@ def create_trainer() -> Trainer:
def train() -> None:
config =\
config_dataset =\
{
'file_pattern': CONFIG.get('training.dataset', 'file_pattern'),
'resume_path': CONFIG.get('training.output', 'resume_path')
}
config_trainer =\
{
'source_path': CONFIG.get('training.model', 'source_path'),
'target_path': CONFIG.get('training.model', 'target_path'),
'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate')
}
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = StaticDataset(config)
dataset = StaticDataset(config_dataset)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer()
embedding_converter_trainer = EmbeddingConverterTrainer(config_trainer)
trainer = create_trainer()
if os.path.exists(config.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
if os.path.exists(config_dataset.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_dataset.get('resume_path'))
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)