import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms import os import os.path class DatasetFactory: def __init__(self): pass @staticmethod def get_by_name(dataset_name, opt, is_for_train): if dataset_name == 'aus': from data.dataset_aus import AusDataset dataset = AusDataset(opt, is_for_train) else: raise ValueError("Dataset [%s] not recognized." % dataset_name) print('Dataset {} was created'.format(dataset.name)) return dataset class DatasetBase(data.Dataset): def __init__(self, opt, is_for_train): super(DatasetBase, self).__init__() self._name = 'BaseDataset' self._root = None self._opt = opt self._is_for_train = is_for_train self._create_transform() self._IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] @property def name(self): return self._name @property def path(self): return self._root def _create_transform(self): self._transform = transforms.Compose([]) def get_transform(self): return self._transform def _is_image_file(self, filename): return any(filename.endswith(extension) for extension in self._IMG_EXTENSIONS) def _is_csv_file(self, filename): return filename.endswith('.csv') def _get_all_files_in_subfolders(self, dir, is_file): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_file(fname): path = os.path.join(root, fname) images.append(path) return images