diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 9c72530..d95035f 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -113,7 +113,7 @@ def create_trainer() -> Trainer: def train() -> None: trainer = create_trainer() training_loader, validation_loader = create_loaders() - embedding_converter = EmbeddingConverterTrainer() + embedding_converter_trainer = EmbeddingConverterTrainer() tuner = Tuner(trainer) - tuner.lr_find(embedding_converter, training_loader, validation_loader) - trainer.fit(embedding_converter, training_loader, validation_loader) + tuner.lr_find(embedding_converter_trainer, training_loader, validation_loader) + trainer.fit(embedding_converter_trainer, training_loader, validation_loader)