Cache the usage of glob.glob (#80)

This commit is contained in:
Henry Ruhs
2025-06-05 09:30:04 +02:00
committed by GitHub
parent a602bbd474
commit 94cbcb68f0
2 changed files with 15 additions and 8 deletions
+6 -8
View File
@@ -1,4 +1,3 @@
import glob
import os
import random
from configparser import ConfigParser
@@ -9,7 +8,7 @@ from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import convert_tensor
from .helper import convert_tensor, resolve_static_file_pattern
from .types import Batch, BatchMode, ConvertTemplate, UsageMode
@@ -23,10 +22,9 @@ class DynamicDataset(Dataset[Tensor]):
self.config_batch_ratio = config_parser.getfloat('training.dataset.current', 'batch_ratio')
self.config_parser = config_parser
self.transforms = self.compose_transforms()
self.file_paths = glob.glob(self.config_file_pattern)
def __getitem__(self, index : int) -> Batch:
file_path = self.file_paths[index]
file_path = resolve_static_file_pattern(self.config_file_pattern)[index]
if random.random() < self.config_batch_ratio:
if self.config_batch_mode == 'equal':
@@ -43,7 +41,7 @@ class DynamicDataset(Dataset[Tensor]):
return self.prepare_different_batch(file_path)
def __len__(self) -> int:
return len(self.file_paths)
return len(resolve_static_file_pattern(self.config_file_pattern))
def compose_transforms(self) -> transforms:
return transforms.Compose(
@@ -79,7 +77,7 @@ class DynamicDataset(Dataset[Tensor]):
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
target_path = random.choice(glob.glob(config_file_pattern))
target_path = random.choice(resolve_static_file_pattern(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
@@ -94,7 +92,7 @@ class DynamicDataset(Dataset[Tensor]):
config_file_pattern = config_parser.get(config_section, 'file_pattern')
config_convert_template = cast(ConvertTemplate, config_parser.get(config_section, 'convert_template'))
source_path = random.choice(glob.glob(config_file_pattern))
source_path = random.choice(resolve_static_file_pattern(config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, config_convert_template)
@@ -104,7 +102,7 @@ class DynamicDataset(Dataset[Tensor]):
return source_tensor, target_tensor
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
target_path = random.choice(resolve_static_file_pattern(self.config_file_pattern))
source_tensor = io.read_image(source_path)
source_tensor = self.transforms(source_tensor)
source_tensor = self.conditional_convert_tensor(source_tensor, self.config_convert_template)
+9
View File
@@ -1,3 +1,7 @@
import glob
from functools import lru_cache
from typing import List
import torch
from torch import Tensor, nn
@@ -55,3 +59,8 @@ def apply_noise(input_tensor : Tensor, factor : float) -> Tensor:
noise_tensor = torch.randn_like(input_tensor) * factor
output_tensor = input_tensor + noise_tensor
return output_tensor
@lru_cache(maxsize = None)
def resolve_static_file_pattern(file_pattern : str) -> List[str]:
return sorted(glob.glob(file_pattern))