diff --git a/embedding_converter/src/dataset.py b/embedding_converter/src/dataset.py index 5680d59..77b5ab1 100644 --- a/embedding_converter/src/dataset.py +++ b/embedding_converter/src/dataset.py @@ -4,12 +4,13 @@ from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms -from .types import Batch +from .types import Batch, Config class StaticDataset(Dataset[Tensor]): - def __init__(self, file_pattern : str) -> None: - self.file_paths = glob.glob(file_pattern) + def __init__(self, config : Config) -> None: + self.config = config + self.file_paths = glob.glob(self.config.get('file_pattern')) self.transforms = self.compose_transforms() def __getitem__(self, index : int) -> Batch: diff --git a/embedding_converter/src/exporting.py b/embedding_converter/src/exporting.py index 9102c71..b9f048d 100644 --- a/embedding_converter/src/exporting.py +++ b/embedding_converter/src/exporting.py @@ -10,15 +10,18 @@ CONFIG.read('config.ini') def export() -> None: - directory_path = CONFIG.get('exporting', 'directory_path') - source_path = CONFIG.get('exporting', 'source_path') - target_path = CONFIG.get('exporting', 'target_path') - ir_version = CONFIG.getint('exporting', 'ir_version') - opset_version = CONFIG.getint('exporting', 'opset_version') + config =\ + { + 'directory_path': CONFIG.get('exporting', 'directory_path'), + 'source_path': CONFIG.get('exporting', 'source_path'), + 'target_path': CONFIG.get('exporting', 'target_path'), + 'ir_version': CONFIG.getint('exporting', 'ir_version'), + 'opset_version': CONFIG.getint('exporting', 'opset_version') + } - makedirs(directory_path, exist_ok = True) - model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu') + makedirs(config.get('directory_path'), exist_ok = True) + model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu') model.eval() - model.ir_version = torch.tensor(ir_version) + model.ir_version = torch.tensor(config.get('ir_version')) input_tensor = torch.randn(1, 512) - torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version) + torch.onnx.export(model, input_tensor, config.get('target_path'), input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config.get('opset_version')) diff --git a/embedding_converter/src/training.py b/embedding_converter/src/training.py index 1fa0009..f30c22f 100644 --- a/embedding_converter/src/training.py +++ b/embedding_converter/src/training.py @@ -22,12 +22,16 @@ CONFIG.read('config.ini') class EmbeddingConverterTrainer(lightning.LightningModule): def __init__(self) -> None: super(EmbeddingConverterTrainer, self).__init__() - source_path = CONFIG.get('training.model', 'source_path') - target_path = CONFIG.get('training.model', 'target_path') + self.config =\ + { + 'source_path': CONFIG.get('training.model', 'source_path'), + 'target_path': CONFIG.get('training.model', 'target_path'), + 'learning_rate': CONFIG.getfloat('training.trainer', 'learning_rate') + } self.embedding_converter = EmbeddingConverter() - self.source_embedder = torch.jit.load(source_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.target_embedder = torch.jit.load(target_path, map_location = 'cpu') # type:ignore[no-untyped-call] + self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call] + self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() def forward(self, source_embedding : Embedding) -> Embedding: @@ -51,8 +55,7 @@ class EmbeddingConverterTrainer(lightning.LightningModule): return validation_score def configure_optimizers(self) -> OptimizerConfig: - learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate') - optimizer = torch.optim.Adam(self.parameters(), lr = learning_rate) + optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate')) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) config =\ { @@ -70,42 +73,52 @@ class EmbeddingConverterTrainer(lightning.LightningModule): def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]: - batch_size = CONFIG.getint('training.loader', 'batch_size') - num_workers = CONFIG.getint('training.loader', 'num_workers') + config =\ + { + 'batch_size': CONFIG.getint('training.loader', 'batch_size'), + 'num_workers': CONFIG.getint('training.loader', 'num_workers') + } training_dataset, validate_dataset = split_dataset(dataset) - training_loader = StatefulDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) - validation_loader = StatefulDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True, persistent_workers = True) + training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True) + validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True) return training_loader, validation_loader def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]: - split_ratio = CONFIG.getfloat('training.loader', 'split_ratio') + config =\ + { + 'split_ratio': CONFIG.getfloat('training.loader', 'split_ratio') + } + dataset_size = len(dataset) # type:ignore[arg-type] - training_size = int(dataset_size * split_ratio) + training_size = int(dataset_size * config.get('split_ratio')) validation_size = int(dataset_size - training_size) training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ]) return training_dataset, validate_dataset def create_trainer() -> Trainer: - trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') - output_directory_path = CONFIG.get('training.output', 'directory_path') - output_file_pattern = CONFIG.get('training.output', 'file_pattern') - trainer_precision = CONFIG.get('training.trainer', 'precision') + config =\ + { + 'max_epochs': CONFIG.getint('training.trainer', 'max_epochs'), + 'directory_path': CONFIG.get('training.output', 'directory_path'), + 'file_pattern': CONFIG.get('training.output', 'file_pattern'), + 'precision': CONFIG.get('training.trainer', 'precision') + } logger = TensorBoardLogger('.logs', name = 'embedding_converter') return Trainer( logger = logger, log_every_n_steps = 10, - max_epochs = trainer_max_epochs, - precision = trainer_precision, # type:ignore[arg-type] + max_epochs = config.get('max_epochs'), + precision = config.get('precision'), # type:ignore[arg-type] callbacks = [ ModelCheckpoint( monitor = 'training_loss', - dirpath = output_directory_path, - filename = output_file_pattern, + dirpath = config.get('directory_path'), + filename = config.get('file_pattern'), every_n_epochs = 1, save_top_k = 3, save_last = True @@ -115,18 +128,21 @@ def create_trainer() -> Trainer: def train() -> None: - dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') - output_resume_path = CONFIG.get('training.output', 'resume_path') + config =\ + { + 'file_pattern': CONFIG.get('training.dataset', 'file_pattern'), + 'resume_path': CONFIG.get('training.output', 'resume_path') + } if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = StaticDataset(dataset_file_pattern) + dataset = StaticDataset(config) training_loader, validation_loader = create_loaders(dataset) embedding_converter_trainer = EmbeddingConverterTrainer() trainer = create_trainer() - if os.path.exists(output_resume_path): - trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = output_resume_path) + if os.path.exists(config.get('resume_path')): + trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path')) else: trainer.fit(embedding_converter_trainer, training_loader, validation_loader) diff --git a/embedding_converter/src/types.py b/embedding_converter/src/types.py index 7f41a40..0c89451 100644 --- a/embedding_converter/src/types.py +++ b/embedding_converter/src/types.py @@ -1,8 +1,9 @@ -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Dict from torch import Tensor Batch : TypeAlias = Tensor Embedding : TypeAlias = Tensor +Config : TypeAlias = Dict[str, Any] OptimizerConfig : TypeAlias = Any