diff --git a/embedding_converter/src/dataset.py b/embedding_converter/src/dataset.py index 62a8fae..6ed1a4f 100644 --- a/embedding_converter/src/dataset.py +++ b/embedding_converter/src/dataset.py @@ -9,13 +9,13 @@ from torchvision import transforms from .types import Batch -class DynamicDataset(Dataset[Tensor]): +class StaticDataset(Dataset[Tensor]): def __init__(self, file_pattern : str) -> None: self.file_paths = glob.glob(file_pattern) self.transforms = self.compose_transforms() def __getitem__(self, index : int) -> Batch: - file_path = random.choice(self.file_paths) + file_path = self.file_paths[index] vision_frame = cv2.imread(file_path) return self.transforms(vision_frame) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 14358fa..b43429b 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -11,7 +11,7 @@ from lightning.pytorch.tuner import Tuner from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset, random_split -from .dataset import DynamicDataset +from .dataset import StaticDataset from .models.embedding_converter import EmbeddingConverter from .types import Batch, Embedding, OptimizerConfig @@ -24,11 +24,13 @@ class EmbeddingConverterTrainer(lightning.LightningModule): super(EmbeddingConverterTrainer, self).__init__() source_path = CONFIG.get('training.model', 'source_path') target_path = CONFIG.get('training.model', 'target_path') + learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') self.embedding_converter = EmbeddingConverter() self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() + self.lr = learning_rate def forward(self, source_embedding : Embedding) -> Embedding: return self.embedding_converter(source_embedding) @@ -115,7 +117,7 @@ def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') output_resume_path = CONFIG.get('training.output', 'resume_path') - dataset = DynamicDataset(dataset_file_pattern) + dataset = StaticDataset(dataset_file_pattern) training_loader, validation_loader = create_loaders(dataset) embedding_converter_trainer = EmbeddingConverterTrainer() trainer = create_trainer()