Files
SimSwapPlus/data_tools/test_dataloader_dir.py
T
chenxuanhong 3783ef0e75 init
2022-01-10 15:03:58 +08:00

81 lines
2.9 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: eval_dataloader_DIV2K.py
# Created Date: Tuesday January 12th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 12th October 2021 8:29:51 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import os
import cv2
import glob
import torch
class TestDataset:
def __init__( self,
path,
batch_size = 16,
subffix=['png','jpg']):
"""Initialize and preprocess the setX dataset."""
self.path = path
self.subffix = subffix
self.dataset = []
self.pointer = 0
self.batch_size = batch_size
self.__preprocess__()
self.num_images = len(self.dataset)
def __preprocess__(self):
"""Preprocess the SetX dataset."""
print("processing content images...")
for i_suf in self.subffix:
temp_path = os.path.join(self.path,'*.%s'%(i_suf))
images = glob.glob(temp_path)
for item in images:
file_name = os.path.basename(item)
file_name = os.path.splitext(file_name)
file_name = file_name[0]
# lr_name = os.path.join(set5lr_path, file_name)
self.dataset.append([item,file_name])
# self.dataset = images
print('Finished preprocessing the content dataset, total image number: %d...'%len(self.dataset))
def __call__(self):
"""Return one batch images."""
if self.pointer>=self.num_images:
self.pointer = 0
a = "The end of the story!"
raise StopIteration(print(a))
elif (self.pointer+self.batch_size) > self.num_images:
end = self.num_images
else:
end = self.pointer+self.batch_size
for i in range(self.pointer, end):
filename = self.dataset[i][0]
hr_img = cv2.imread(filename)
hr_img = cv2.cvtColor(hr_img,cv2.COLOR_BGR2RGB)
hr_img = hr_img.transpose((2,0,1))#.astype(np.float)
hr_img = torch.from_numpy(hr_img)
hr_img = hr_img/255.0
hr_img = 2 * (hr_img - 0.5)
if (i-self.pointer) == 0:
hr_ls = hr_img.unsqueeze(0)
nm_ls = [self.dataset[i][1],]
else:
hr_ls = torch.cat((hr_ls,hr_img.unsqueeze(0)),0)
nm_ls += [self.dataset[i][1],]
self.pointer = end
return hr_ls, nm_ls
def __len__(self):
return self.num_images
def __repr__(self):
return self.__class__.__name__ + ' (' + self.path + ')'