This commit is contained in:
harisreedhar
2025-02-20 18:09:37 +05:30
committed by henryruhs
parent dcf19634d1
commit b47c6b72ee
10 changed files with 112 additions and 170 deletions
+41
View File
@@ -0,0 +1,41 @@
import glob
import random
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from .helper import read_image
from .types import Batch, ImagePathList
class DataLoaderRecognition(Dataset[torch.Tensor]):
def __init__(self, dataset_path : str, dataset_image_pattern : str) -> None:
self.image_paths = self.prepare_image_paths(dataset_path, dataset_image_pattern)
self.dataset_total = len(self.image_paths)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
target_image_path = random.choice(self.image_paths)
target_vision_frame = read_image(target_image_path)
target_tensor = self.transforms(target_vision_frame)
return target_tensor
def __len__(self) -> int:
return self.dataset_total
def prepare_image_paths(self, dataset_path : str, dataset_image_pattern : str) -> ImagePathList:
image_paths = glob.glob(dataset_image_pattern.format(dataset_path))
return image_paths
def compose_transforms(self) -> transforms:
transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform