diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 02d093a..a492df9 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -12,7 +12,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter -from .types import Batch, Config, Embedding, OptimizerSet +from .types import Batch, Config, ConfigSet, Embedding, OptimizerSet CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -121,7 +121,7 @@ def create_trainer() -> Trainer: def train() -> None: - config : Config =\ + config : ConfigSet =\ { 'dataset': { diff --git a/embedding_converter/src/types.py b/embedding_converter/src/types.py index fd6235b..38a5708 100644 --- a/embedding_converter/src/types.py +++ b/embedding_converter/src/types.py @@ -1,9 +1,10 @@ -from typing import Any, TypeAlias +from typing import Any, Dict, TypeAlias from torch import Tensor Batch : TypeAlias = Tensor Embedding : TypeAlias = Tensor -Config : TypeAlias = Any +Config : TypeAlias = Dict[str, Any] +ConfigSet : TypeAlias = Dict[str, Config] OptimizerSet : TypeAlias = Any