This commit is contained in:
Nataniel Ruiz
2019-12-23 16:56:46 -04:00
parent 97ad0ac45e
commit 44095641b1
+43 -41
View File
@@ -587,57 +587,59 @@ class Solver(object):
n_samples = 0
# 11 layers
layer_num_orig = 1
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
black = np.zeros((1,3,256,256))
black = torch.FloatTensor(black).to(self.device)
# layer_num_orig = 1
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
for layer_num_orig in range(11):
print('Layer ', layer_num_orig)
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
black = np.zeros((1,3,256,256))
black = torch.FloatTensor(black).to(self.device)
layer_num = (layer_num_orig + 1) * 3 - 1
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
# Translate images.
x_fake_list = [x_real]
layer_num = (layer_num_orig + 1) * 3 - 1
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
# Translate images.
x_fake_list = [x_real]
for c_trg in c_trg_list:
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# x_adv = x_real + perturb
# x_adv = self.blur_tensor(x_adv)
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
# Metrics
with torch.no_grad():
# gen, preproc_x = self.G(x_adv, c_trg)
gen, gen_feats = self.G(x_adv, c_trg)
for c_trg in c_trg_list:
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# x_adv = x_real + perturb
# x_adv = self.blur_tensor(x_adv)
# Add to lists
# x_fake_list.append(preproc_x)
x_fake_list.append(gen)
# Metrics
with torch.no_grad():
# gen, preproc_x = self.G(x_adv, c_trg)
gen, gen_feats = self.G(x_adv, c_trg)
# No Attack
# gen_noattack, _ = self.G(x_real, c_trg)
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
# Add to lists
# x_fake_list.append(preproc_x)
x_fake_list.append(gen)
l1_error += F.l1_loss(gen, gen_noattack)
l2_error += F.mse_loss(gen, gen_noattack)
n_samples += 1
# No Attack
# gen_noattack, _ = self.G(x_real, c_trg)
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))
if i == 199:
break
l1_error += F.l1_loss(gen, gen_noattack)
l2_error += F.mse_loss(gen, gen_noattack)
n_samples += 1
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))
if i == 199:
break
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,