diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index f30c22f..a6ca04f 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -13,7 +13,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter -from .types import Batch, Embedding, OptimizerConfig +from .types import Batch, Embedding, OptimizerSet CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -54,7 +54,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule): self.log('validation_score', validation_score, prog_bar = True) return validation_score - def configure_optimizers(self) -> OptimizerConfig: + def configure_optimizers(self) -> OptimizerSet: optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate')) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) config =\ diff --git a/embedding_converter/src/types.py b/embedding_converter/src/types.py index 0c89451..a893b5f 100644 --- a/embedding_converter/src/types.py +++ b/embedding_converter/src/types.py @@ -1,4 +1,4 @@ -from typing import Any, TypeAlias, Dict +from typing import Any, Dict, TypeAlias from torch import Tensor @@ -6,4 +6,4 @@ Batch : TypeAlias = Tensor Embedding : TypeAlias = Tensor Config : TypeAlias = Dict[str, Any] -OptimizerConfig : TypeAlias = Any +OptimizerSet : TypeAlias = Any