This commit is contained in:
henryruhs
2025-03-06 08:44:02 +01:00
parent c8953ce8a1
commit e61e470432
2 changed files with 4 additions and 4 deletions
+2 -2
View File
@@ -13,7 +13,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader
from .dataset import StaticDataset
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Embedding, OptimizerConfig
from .types import Batch, Embedding, OptimizerSet
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -54,7 +54,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
self.log('validation_score', validation_score, prog_bar = True)
return validation_score
def configure_optimizers(self) -> OptimizerConfig:
def configure_optimizers(self) -> OptimizerSet:
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate'))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
config =\
+2 -2
View File
@@ -1,4 +1,4 @@
from typing import Any, TypeAlias, Dict
from typing import Any, Dict, TypeAlias
from torch import Tensor
@@ -6,4 +6,4 @@ Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Config : TypeAlias = Dict[str, Any]
OptimizerConfig : TypeAlias = Any
OptimizerSet : TypeAlias = Any