diff --git a/hyperswap/src/dataset.py b/hyperswap/src/dataset.py index 77df6cc..3d9bdb4 100644 --- a/hyperswap/src/dataset.py +++ b/hyperswap/src/dataset.py @@ -1,4 +1,3 @@ -import glob import os import random from configparser import ConfigParser @@ -9,7 +8,7 @@ from torch import Tensor from torch.utils.data import Dataset from torchvision import io, transforms -from .helper import convert_tensor +from .helper import convert_tensor, resolve_static_file_pattern from .types import Batch, BatchMode, ConvertTemplate, UsageMode @@ -23,10 +22,9 @@ class DynamicDataset(Dataset[Tensor]): self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio') self.config_parser = config_parser self.transforms = self.compose_transforms() - self.file_paths = glob.glob(self.config_file_pattern) def __getitem__(self, index : int) -> Batch: - file_path = self.file_paths[index] + file_path = resolve_static_file_pattern(self.config_file_pattern)[index] if random.random() < self.config_batch_ratio: if self.config_batch_mode == 'equal': @@ -43,7 +41,7 @@ class DynamicDataset(Dataset[Tensor]): return self.prepare_different_batch(file_path) def __len__(self) -> int: - return len(self.file_paths) + return len(resolve_static_file_pattern(self.config_file_pattern)) def compose_transforms(self) -> transforms: return transforms.Compose( @@ -79,7 +77,7 @@ class DynamicDataset(Dataset[Tensor]): config_file_pattern = config_parser.get(config_section, 'file_pattern') config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) - target_path = random.choice(glob.glob(config_file_pattern)) + target_path = random.choice(resolve_static_file_pattern(config_file_pattern)) source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) @@ -94,7 +92,7 @@ class DynamicDataset(Dataset[Tensor]): config_file_pattern = config_parser.get(config_section, 'file_pattern') config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template')) - source_path = random.choice(glob.glob(config_file_pattern)) + source_path = random.choice(resolve_static_file_pattern(config_file_pattern)) source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template) @@ -104,7 +102,7 @@ class DynamicDataset(Dataset[Tensor]): return source_tensor, target_tensor def prepare_different_batch(self, source_path : str) -> Batch: - target_path = random.choice(self.file_paths) + target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern)) source_tensor = io.read_image(source_path) source_tensor = self.transforms(source_tensor) source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template) diff --git a/hyperswap/src/helper.py b/hyperswap/src/helper.py index a14ea72..7d99e03 100644 --- a/hyperswap/src/helper.py +++ b/hyperswap/src/helper.py @@ -1,3 +1,7 @@ +import glob +from functools import lru_cache +from typing import List + import torch from torch import Tensor, nn @@ -55,3 +59,8 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor: noise_tensor = torch.randn_like(input_tensor) * factor output_tensor = input_tensor + noise_tensor return output_tensor + + +@lru_cache(maxsize = None) +def resolve_static_file_pattern(file_pattern : str) -> List[str]: + return sorted(glob.glob(file_pattern))