This commit is contained in:
Nataniel Ruiz
2019-12-23 17:22:38 -04:00
parent 0b90dda881
commit f6bc1c6cb1
2 changed files with 73 additions and 73 deletions
+1 -1
View File
@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
class LinfPGDAttack(object):
def __init__(self, model=None, device=None, epsilon=0.05, k=1, a=0.05, feat = None):
def __init__(self, model=None, device=None, epsilon=0.05, k=20, a=0.01, feat = None):
self.model = model
self.epsilon = epsilon
self.k = k
+72 -72
View File
@@ -574,79 +574,79 @@ class Solver(object):
layer_dict = {0: 2, 1: 5, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12, 7: 13, 8: 14, 9: 17, 10: 20, 11: None}
# for layer_num_orig in range(12):
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
perceptual_error = 0.0
n_samples = 0
# 11 layers + output
layer_num_orig = 11
print('Layer ', layer_num_orig)
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
black = np.zeros((1,3,256,256))
black = torch.FloatTensor(black).to(self.device)
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
layer_num = layer_dict[layer_num_orig]
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
# Translate images.
x_fake_list = [x_real]
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
for c_trg in c_trg_list:
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# x_adv = x_real + perturb
# x_adv = self.blur_tensor(x_adv)
# Metrics
with torch.no_grad():
# gen, preproc_x = self.G(x_adv, c_trg)
gen, gen_feats = self.G(x_adv, c_trg)
# Add to lists
# x_fake_list.append(preproc_x)
x_fake_list.append(gen)
# No Attack
# gen_noattack, _ = self.G(x_real, c_trg)
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
l1_error += F.l1_loss(gen, gen_noattack)
l2_error += F.mse_loss(gen, gen_noattack)
n_samples += 1
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))
for layer_num_orig in range(12):
# Load the trained generator.
self.restore_model(self.test_iters)
if i == 199:
break
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
perceptual_error / n_samples))
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
perceptual_error = 0.0
n_samples = 0
# 11 layers + output
# layer_num_orig = 11
print('Layer', layer_num_orig)
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
black = np.zeros((1,3,256,256))
black = torch.FloatTensor(black).to(self.device)
# Prepare input images and target domain labels.
x_real = x_real.to(self.device)
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
layer_num = layer_dict[layer_num_orig]
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
# Translate images.
x_fake_list = [x_real]
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
for c_trg in c_trg_list:
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# x_adv = x_real + perturb
# x_adv = self.blur_tensor(x_adv)
# Metrics
with torch.no_grad():
# gen, preproc_x = self.G(x_adv, c_trg)
gen, gen_feats = self.G(x_adv, c_trg)
# Add to lists
# x_fake_list.append(preproc_x)
x_fake_list.append(gen)
# No Attack
# gen_noattack, _ = self.G(x_real, c_trg)
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
l1_error += F.l1_loss(gen, gen_noattack)
l2_error += F.mse_loss(gen, gen_noattack)
n_samples += 1
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
# print('Saved real and fake images into {}...'.format(result_path))
if i == 199:
break
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
perceptual_error / n_samples))
def test_multi(self):
"""Translate images using StarGAN trained on multiple datasets."""