Improve imports

This commit is contained in:
henryruhs
2025-03-06 14:24:55 +01:00
parent d215b6f98b
commit a0c42bedbe
2 changed files with 6 additions and 6 deletions
+2 -2
View File
@@ -9,10 +9,10 @@ from .types import Batch
class StaticDataset(Dataset[Tensor]):
def __init__(self, config : ConfigParser) -> None:
def __init__(self, config_parser : ConfigParser) -> None:
self.config =\
{
'file_pattern': config.get('training.dataset', 'file_pattern')
'file_pattern': config_parser.get('training.dataset', 'file_pattern')
}
self.file_paths = glob.glob(self.config.get('file_pattern'))
self.transforms = self.compose_transforms()
+4 -4
View File
@@ -1,11 +1,11 @@
import configparser
from os import makedirs
import os
from configparser import ConfigParser
import torch
from .training import EmbeddingConverterTrainer
CONFIG_PARSER = configparser.ConfigParser()
CONFIG_PARSER = ConfigParser()
CONFIG_PARSER.read('config.ini')
@@ -19,7 +19,7 @@ def export() -> None:
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
}
makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
os.makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(config.get('ir_version'))