diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 4d3024d..a9e945a 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -1,9 +1,8 @@ -import configparser import os +from configparser import ConfigParser from typing import Tuple import torch -from configparser import ConfigParser from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger @@ -18,6 +17,7 @@ from .types import Batch, Embedding, OptimizerSet CONFIG_PARSER = ConfigParser() CONFIG_PARSER.read('config.ini') + class EmbeddingConverterTrainer(LightningModule): def __init__(self, config_parser : ConfigParser) -> None: super().__init__()