From e61e470432c3caa2b2087fff737cb091b959a461 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 6 Mar 2025 08:44:02 +0100 Subject: [PATCH] Fix CI --- embedding_converter/src/training.py | 4 ++-- embedding_converter/src/types.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index f30c22f..a6ca04f 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -13,7 +13,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter -from .types import Batch, Embedding, OptimizerConfig +from .types import Batch, Embedding, OptimizerSet CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -54,7 +54,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule): self.log('validation_score', validation_score, prog_bar = True) return validation_score - def configure_optimizers(self) -> OptimizerConfig: + def configure_optimizers(self) -> OptimizerSet: optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate')) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) config =\ diff --git a/embedding_converter/src/types.py b/embedding_converter/src/types.py index 0c89451..a893b5f 100644 --- a/embedding_converter/src/types.py +++ b/embedding_converter/src/types.py @@ -1,4 +1,4 @@ -from typing import Any, TypeAlias, Dict +from typing import Any, Dict, TypeAlias from torch import Tensor @@ -6,4 +6,4 @@ Batch : TypeAlias = Tensor Embedding : TypeAlias = Tensor Config : TypeAlias = Dict[str, Any] -OptimizerConfig : TypeAlias = Any +OptimizerSet : TypeAlias = Any