This commit is contained in:
chenxuanhong
2022-04-21 00:58:32 +08:00
parent 5181115399
commit 1bbc1eff67
6 changed files with 0 additions and 270 deletions

View File

@@ -1,94 +0,0 @@
import torch
from torch.utils.data import Dataset
import os
import numpy as np
import random
from torchvision import transforms
from PIL import Image
import cv2
class FaceDataSet(Dataset):
def __init__(self, dataset_path, batch_size):
super(FaceDataSet, self).__init__()
'''picture_dir_list = []
for i in range(self.people_num):
picture_dir_list.append('/data/home/renwangchen/vgg_align_224/'+self.people_list[i])
self.people_pic_list = []
for i in range(self.people_num):
pic_list = os.listdir(picture_dir_list[i])
person_pic_list = []
for j in range(len(pic_list)):
pic_dir = os.path.join(picture_dir_list[i], pic_list[j])
person_pic_list.append(pic_dir)
self.people_pic_list.append(person_pic_list)'''
pic_dir = '/data/home/renwangchen/CelebA_224/'
latent_dir = '/data/home/renwangchen/CelebA_latent/'
tmp_list = os.listdir(pic_dir)
self.pic_list = []
self.latent_list = []
for i in range(len(tmp_list)):
self.pic_list.append(pic_dir + tmp_list[i])
self.latent_list.append(latent_dir + tmp_list[i][:-3] + 'npy')
self.pic_list = self.pic_list[:29984]
'''for i in range(29984):
print(self.pic_list[i])'''
self.latent_list = self.latent_list[:29984]
self.people_num = len(self.pic_list)
self.type = 1
self.bs = batch_size
self.count = 0
self.transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __getitem__(self, index):
p1 = random.randint(0, self.people_num - 1)
p2 = p1
if self.type == 0:
# load pictures from the same folder
pass
else:
# load pictures from different folders
p2 = p1
while p2 == p1:
p2 = random.randint(0, self.people_num - 1)
pic_id_dir = self.pic_list[p1]
pic_att_dir = self.pic_list[p2]
latent_id_dir = self.latent_list[p1]
latent_att_dir = self.latent_list[p2]
img_id = Image.open(pic_id_dir).convert('RGB')
img_id = self.transformer(img_id)
latent_id = np.load(latent_id_dir)
latent_id = latent_id / np.linalg.norm(latent_id)
latent_id = torch.from_numpy(latent_id)
img_att = Image.open(pic_att_dir).convert('RGB')
img_att = self.transformer(img_att)
latent_att = np.load(latent_att_dir)
latent_att = latent_att / np.linalg.norm(latent_att)
latent_att = torch.from_numpy(latent_att)
self.count += 1
data_type = self.type
if self.count == self.bs:
self.type = 1 - self.type
self.count = 0
return img_id, img_att, latent_id, latent_att, data_type
def __len__(self):
return len(self.pic_list)

View File

@@ -1,76 +0,0 @@
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))
### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))
### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))
self.dataset_size = len(self.A_paths)
def __getitem__(self, index):
### input A (label maps)
A_path = self.A_paths[index]
A = Image.open(A_path)
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)
### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)
if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}
return input_dict
def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
def name(self):
return 'AlignedDataset'

View File

@@ -1,90 +0,0 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img

View File

@@ -1,7 +0,0 @@
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader

View File

@@ -3,10 +3,8 @@ import glob
import torch
import random
from PIL import Image
from pathlib import Path
from torch.utils import data
from torchvision import transforms as T
# from StyleResize import StyleResize
class data_prefetcher():
def __init__(self, loader):

View File

@@ -2,7 +2,6 @@
import cv2
import torch
import fractions
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms