87 lines
2.6 KiB
Python
87 lines
2.6 KiB
Python
from torch.utils import data
|
|
from torchvision import transforms as T
|
|
from torchvision.datasets import ImageFolder
|
|
from PIL import Image
|
|
import torch
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
|
|
|
|
class CelebA(data.Dataset):
|
|
|
|
def __init__(self, image_dir, attr_path, transform, mode, c_dim):
|
|
|
|
self.image_dir = image_dir
|
|
self.attr_path = attr_path
|
|
self.transform = transform
|
|
self.mode = mode
|
|
self.c_dim = c_dim
|
|
|
|
self.train_dataset = []
|
|
self.test_dataset = []
|
|
|
|
# Fills train_dataset and test_dataset --> [filename, boolean attribute vector]
|
|
self.preprocess()
|
|
|
|
if mode == 'train':
|
|
self.num_images = len(self.train_dataset)
|
|
else:
|
|
self.num_images = len(self.test_dataset)
|
|
|
|
print("------------------------------------------------")
|
|
print("Training images: ", len(self.train_dataset))
|
|
print("Testing images: ", len(self.test_dataset))
|
|
|
|
def preprocess(self):
|
|
lines = [line.rstrip() for line in open(self.attr_path, 'r')]
|
|
lines = lines[2:]
|
|
|
|
random.seed(1234)
|
|
random.shuffle(lines)
|
|
|
|
# Extract the info from each line
|
|
for idx, line in enumerate(lines):
|
|
split = line.split()
|
|
filename = split[0]
|
|
values = split[1:]
|
|
label = [] # Vector representing the presence of each attribute in each image
|
|
|
|
for n in range(self.c_dim):
|
|
label.append(float(values[n])/5.)
|
|
|
|
if idx < 100:
|
|
self.test_dataset.append([filename, label])
|
|
else:
|
|
self.train_dataset.append([filename, label])
|
|
|
|
print('Dataset ready!...')
|
|
|
|
def __getitem__(self, index):
|
|
dataset = self.train_dataset if self.mode == 'train' else self.test_dataset
|
|
filename, label = dataset[index]
|
|
image = Image.open(os.path.join(self.image_dir, filename))
|
|
return self.transform(image), torch.FloatTensor(label)
|
|
|
|
def __len__(self):
|
|
return self.num_images
|
|
|
|
|
|
def get_loader(image_dir, attr_path, c_dim, image_size=128,
|
|
batch_size=25, mode='train', num_workers=1):
|
|
|
|
transform = []
|
|
transform.append(T.ToTensor())
|
|
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
|
|
transform = T.Compose(transform)
|
|
|
|
dataset = CelebA(image_dir, attr_path, transform, mode, c_dim)
|
|
|
|
data_loader = data.DataLoader(dataset=dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
drop_last=True)
|
|
|
|
return data_loader
|