118 lines
3.9 KiB
Python
118 lines
3.9 KiB
Python
import os.path
|
|
import torchvision.transforms as transforms
|
|
from data.dataset import DatasetBase
|
|
from PIL import Image
|
|
import random
|
|
import numpy as np
|
|
import pickle
|
|
from utils import cv_utils
|
|
|
|
|
|
class AusDataset(DatasetBase):
|
|
def __init__(self, opt, is_for_train):
|
|
super(AusDataset, self).__init__(opt, is_for_train)
|
|
self._name = 'AusDataset'
|
|
|
|
# read dataset
|
|
self._read_dataset_paths()
|
|
|
|
def __getitem__(self, index):
|
|
assert (index < self._dataset_size)
|
|
|
|
# start_time = time.time()
|
|
real_img = None
|
|
real_cond = None
|
|
while real_img is None or real_cond is None:
|
|
# if sample randomly: overwrite index
|
|
if not self._opt.serial_batches:
|
|
index = random.randint(0, self._dataset_size - 1)
|
|
|
|
# get sample data
|
|
sample_id = self._ids[index]
|
|
|
|
real_img, real_img_path = self._get_img_by_id(sample_id)
|
|
real_cond = self._get_cond_by_id(sample_id)
|
|
|
|
if real_img is None:
|
|
print 'error reading image %s, skipping sample' % sample_id
|
|
if real_cond is None:
|
|
print 'error reading aus %s, skipping sample' % sample_id
|
|
|
|
desired_cond = self._generate_random_cond()
|
|
|
|
# transform data
|
|
img = self._transform(Image.fromarray(real_img))
|
|
|
|
# pack data
|
|
sample = {'real_img': img,
|
|
'real_cond': real_cond,
|
|
'desired_cond': desired_cond,
|
|
'sample_id': sample_id,
|
|
'real_img_path': real_img_path
|
|
}
|
|
|
|
# print (time.time() - start_time)
|
|
|
|
return sample
|
|
|
|
def __len__(self):
|
|
return self._dataset_size
|
|
|
|
def _read_dataset_paths(self):
|
|
self._root = self._opt.data_dir
|
|
self._imgs_dir = os.path.join(self._root, self._opt.images_folder)
|
|
|
|
# read ids
|
|
use_ids_filename = self._opt.train_ids_file if self._is_for_train else self._opt.test_ids_file
|
|
use_ids_filepath = os.path.join(self._root, use_ids_filename)
|
|
self._ids = self._read_ids(use_ids_filepath)
|
|
|
|
# read aus
|
|
conds_filepath = os.path.join(self._root, self._opt.aus_file)
|
|
self._conds = self._read_conds(conds_filepath)
|
|
|
|
self._ids = list(set(self._ids).intersection(set(self._conds.keys())))
|
|
|
|
# dataset size
|
|
self._dataset_size = len(self._ids)
|
|
|
|
def _create_transform(self):
|
|
if self._is_for_train:
|
|
transform_list = [transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
|
std=[0.5, 0.5, 0.5]),
|
|
]
|
|
else:
|
|
transform_list = [transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
|
std=[0.5, 0.5, 0.5]),
|
|
]
|
|
self._transform = transforms.Compose(transform_list)
|
|
|
|
def _read_ids(self, file_path):
|
|
ids = np.loadtxt(file_path, delimiter='\t', dtype=np.str)
|
|
return [id[:-4] for id in ids]
|
|
|
|
def _read_conds(self, file_path):
|
|
with open(file_path, 'rb') as f:
|
|
return pickle.load(f)
|
|
|
|
def _get_cond_by_id(self, id):
|
|
if id in self._conds:
|
|
return self._conds[id]/5.0
|
|
else:
|
|
return None
|
|
|
|
def _get_img_by_id(self, id):
|
|
filepath = os.path.join(self._imgs_dir, id+'.jpg')
|
|
return cv_utils.read_cv2_img(filepath), filepath
|
|
|
|
def _generate_random_cond(self):
|
|
cond = None
|
|
while cond is None:
|
|
rand_sample_id = self._ids[random.randint(0, self._dataset_size - 1)]
|
|
cond = self._get_cond_by_id(rand_sample_id)
|
|
cond += np.random.uniform(-0.1, 0.1, cond.shape)
|
|
return cond
|