update
This commit is contained in:
@@ -1,94 +0,0 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
class FaceDataSet(Dataset):
|
||||
def __init__(self, dataset_path, batch_size):
|
||||
super(FaceDataSet, self).__init__()
|
||||
|
||||
|
||||
|
||||
'''picture_dir_list = []
|
||||
for i in range(self.people_num):
|
||||
picture_dir_list.append('/data/home/renwangchen/vgg_align_224/'+self.people_list[i])
|
||||
|
||||
self.people_pic_list = []
|
||||
for i in range(self.people_num):
|
||||
pic_list = os.listdir(picture_dir_list[i])
|
||||
person_pic_list = []
|
||||
for j in range(len(pic_list)):
|
||||
pic_dir = os.path.join(picture_dir_list[i], pic_list[j])
|
||||
person_pic_list.append(pic_dir)
|
||||
self.people_pic_list.append(person_pic_list)'''
|
||||
|
||||
pic_dir = '/data/home/renwangchen/CelebA_224/'
|
||||
latent_dir = '/data/home/renwangchen/CelebA_latent/'
|
||||
|
||||
tmp_list = os.listdir(pic_dir)
|
||||
self.pic_list = []
|
||||
self.latent_list = []
|
||||
for i in range(len(tmp_list)):
|
||||
self.pic_list.append(pic_dir + tmp_list[i])
|
||||
self.latent_list.append(latent_dir + tmp_list[i][:-3] + 'npy')
|
||||
|
||||
self.pic_list = self.pic_list[:29984]
|
||||
'''for i in range(29984):
|
||||
print(self.pic_list[i])'''
|
||||
self.latent_list = self.latent_list[:29984]
|
||||
|
||||
self.people_num = len(self.pic_list)
|
||||
|
||||
self.type = 1
|
||||
self.bs = batch_size
|
||||
self.count = 0
|
||||
|
||||
self.transformer = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def __getitem__(self, index):
|
||||
p1 = random.randint(0, self.people_num - 1)
|
||||
p2 = p1
|
||||
|
||||
if self.type == 0:
|
||||
# load pictures from the same folder
|
||||
pass
|
||||
else:
|
||||
# load pictures from different folders
|
||||
p2 = p1
|
||||
while p2 == p1:
|
||||
p2 = random.randint(0, self.people_num - 1)
|
||||
|
||||
pic_id_dir = self.pic_list[p1]
|
||||
pic_att_dir = self.pic_list[p2]
|
||||
latent_id_dir = self.latent_list[p1]
|
||||
latent_att_dir = self.latent_list[p2]
|
||||
|
||||
img_id = Image.open(pic_id_dir).convert('RGB')
|
||||
img_id = self.transformer(img_id)
|
||||
latent_id = np.load(latent_id_dir)
|
||||
latent_id = latent_id / np.linalg.norm(latent_id)
|
||||
latent_id = torch.from_numpy(latent_id)
|
||||
|
||||
img_att = Image.open(pic_att_dir).convert('RGB')
|
||||
img_att = self.transformer(img_att)
|
||||
latent_att = np.load(latent_att_dir)
|
||||
latent_att = latent_att / np.linalg.norm(latent_att)
|
||||
latent_att = torch.from_numpy(latent_att)
|
||||
|
||||
self.count += 1
|
||||
data_type = self.type
|
||||
if self.count == self.bs:
|
||||
self.type = 1 - self.type
|
||||
self.count = 0
|
||||
|
||||
return img_id, img_att, latent_id, latent_att, data_type
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pic_list)
|
||||
@@ -1,76 +0,0 @@
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
|
||||
class AlignedDataset(BaseDataset):
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
|
||||
### input A (label maps)
|
||||
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
|
||||
self.A_paths = sorted(make_dataset(self.dir_A))
|
||||
|
||||
### input B (real images)
|
||||
if opt.isTrain or opt.use_encoded_image:
|
||||
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
|
||||
self.B_paths = sorted(make_dataset(self.dir_B))
|
||||
|
||||
### instance maps
|
||||
if not opt.no_instance:
|
||||
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
|
||||
self.inst_paths = sorted(make_dataset(self.dir_inst))
|
||||
|
||||
### load precomputed instance-wise encoded features
|
||||
if opt.load_features:
|
||||
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
|
||||
print('----------- loading features from %s ----------' % self.dir_feat)
|
||||
self.feat_paths = sorted(make_dataset(self.dir_feat))
|
||||
|
||||
self.dataset_size = len(self.A_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
### input A (label maps)
|
||||
A_path = self.A_paths[index]
|
||||
A = Image.open(A_path)
|
||||
params = get_params(self.opt, A.size)
|
||||
if self.opt.label_nc == 0:
|
||||
transform_A = get_transform(self.opt, params)
|
||||
A_tensor = transform_A(A.convert('RGB'))
|
||||
else:
|
||||
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
||||
A_tensor = transform_A(A) * 255.0
|
||||
|
||||
B_tensor = inst_tensor = feat_tensor = 0
|
||||
### input B (real images)
|
||||
if self.opt.isTrain or self.opt.use_encoded_image:
|
||||
B_path = self.B_paths[index]
|
||||
B = Image.open(B_path).convert('RGB')
|
||||
transform_B = get_transform(self.opt, params)
|
||||
B_tensor = transform_B(B)
|
||||
|
||||
### if using instance maps
|
||||
if not self.opt.no_instance:
|
||||
inst_path = self.inst_paths[index]
|
||||
inst = Image.open(inst_path)
|
||||
inst_tensor = transform_A(inst)
|
||||
|
||||
if self.opt.load_features:
|
||||
feat_path = self.feat_paths[index]
|
||||
feat = Image.open(feat_path).convert('RGB')
|
||||
norm = normalize()
|
||||
feat_tensor = norm(transform_A(feat))
|
||||
|
||||
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
|
||||
'feat': feat_tensor, 'path': A_path}
|
||||
|
||||
return input_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
|
||||
|
||||
def name(self):
|
||||
return 'AlignedDataset'
|
||||
@@ -1,90 +0,0 @@
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
def name(self):
|
||||
return 'BaseDataset'
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.resize_or_crop == 'resize_and_crop':
|
||||
new_h = new_w = opt.loadSize
|
||||
elif opt.resize_or_crop == 'scale_width_and_crop':
|
||||
new_w = opt.loadSize
|
||||
new_h = opt.loadSize * h // w
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
|
||||
|
||||
flip = random.random() > 0.5
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
|
||||
transform_list = []
|
||||
if 'resize' in opt.resize_or_crop:
|
||||
osize = [opt.loadSize, opt.loadSize]
|
||||
transform_list.append(transforms.Scale(osize, method))
|
||||
elif 'scale_width' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
|
||||
|
||||
if 'crop' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
||||
|
||||
if opt.resize_or_crop == 'none':
|
||||
base = float(2 ** opt.n_downsample_global)
|
||||
if opt.netG == 'local':
|
||||
base *= (2 ** opt.n_local_enhancers)
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
return img
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
@@ -1,7 +0,0 @@
|
||||
|
||||
def CreateDataLoader(opt):
|
||||
from data.custom_dataset_data_loader import CustomDatasetDataLoader
|
||||
data_loader = CustomDatasetDataLoader()
|
||||
print(data_loader.name())
|
||||
data_loader.initialize(opt)
|
||||
return data_loader
|
||||
@@ -3,10 +3,8 @@ import glob
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import cv2
|
||||
import torch
|
||||
import fractions
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
Reference in New Issue
Block a user