This commit is contained in:
Nataniel Ruiz
2019-12-23 17:16:55 -04:00
parent 598606ffa1
commit 2da37248f2

View File

@@ -574,75 +574,76 @@ class Solver(object):
layer_dict = {0: 2, 1: 5, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12, 7: 13, 8: 14, 9: 17, 10: 20}
for layer_num_orig in range(11):
# Load the trained generator.
self.restore_model(self.test_iters)
# for layer_num_orig in range(11):
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
perceptual_error = 0.0
n_samples = 0
# 11 layers
# layer_num_orig = 1
# print('Layer ', layer_num_orig)
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
black = np.zeros((1,3,256,256))
black = torch.FloatTensor(black).to(self.device)
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# layer_num = layer_dict[layer_num_orig]
layer_num = None
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
# Translate images.
x_fake_list = [x_real]
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
for c_trg in c_trg_list:
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# x_adv = x_real + perturb
# x_adv = self.blur_tensor(x_adv)
# Metrics
with torch.no_grad():
# gen, preproc_x = self.G(x_adv, c_trg)
gen, gen_feats = self.G(x_adv, c_trg)
# Add to lists
# x_fake_list.append(preproc_x)
x_fake_list.append(gen)
# No Attack
# gen_noattack, _ = self.G(x_real, c_trg)
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
l1_error += F.l1_loss(gen, gen_noattack)
l2_error += F.mse_loss(gen, gen_noattack)
n_samples += 1
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
perceptual_error = 0.0
n_samples = 0
# 11 layers
# layer_num_orig = 1
print('Layer ', layer_num_orig)
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
black = np.zeros((1,3,256,256))
black = torch.FloatTensor(black).to(self.device)
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
layer_num = layer_dict[layer_num_orig]
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
# Translate images.
x_fake_list = [x_real]
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
for c_trg in c_trg_list:
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# x_adv = x_real + perturb
# x_adv = self.blur_tensor(x_adv)
# Metrics
with torch.no_grad():
# gen, preproc_x = self.G(x_adv, c_trg)
gen, gen_feats = self.G(x_adv, c_trg)
# Add to lists
# x_fake_list.append(preproc_x)
x_fake_list.append(gen)
# No Attack
# gen_noattack, _ = self.G(x_real, c_trg)
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
l1_error += F.l1_loss(gen, gen_noattack)
l2_error += F.mse_loss(gen, gen_noattack)
n_samples += 1
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))
if i == 199:
break
if i == 199:
break
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,