diff --git a/embedding_converter/src/dataset.py b/embedding_converter/src/dataset.py index 79c25a4..21057a4 100644 --- a/embedding_converter/src/dataset.py +++ b/embedding_converter/src/dataset.py @@ -9,10 +9,10 @@ from .types import Batch class StaticDataset(Dataset[Tensor]): - def __init__(self, config : ConfigParser) -> None: + def __init__(self, config_parser : ConfigParser) -> None: self.config =\ { - 'file_pattern': config.get('training.dataset', 'file_pattern') + 'file_pattern': config_parser.get('training.dataset', 'file_pattern') } self.file_paths = glob.glob(self.config.get('file_pattern')) self.transforms = self.compose_transforms() diff --git a/embedding_converter/src/exporting.py b/embedding_converter/src/exporting.py index 4222295..3c76a4f 100644 --- a/embedding_converter/src/exporting.py +++ b/embedding_converter/src/exporting.py @@ -1,11 +1,11 @@ -import configparser -from os import makedirs +import os +from configparser import ConfigParser import torch from .training import EmbeddingConverterTrainer -CONFIG_PARSER = configparser.ConfigParser() +CONFIG_PARSER = ConfigParser() CONFIG_PARSER.read('config.ini') @@ -19,7 +19,7 @@ def export() -> None: 'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version') } - makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type] + os.makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type] model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu') model.eval() model.ir_version = torch.tensor(config.get('ir_version'))