next
This commit is contained in:
@@ -574,75 +574,76 @@ class Solver(object):
|
||||
|
||||
layer_dict = {0: 2, 1: 5, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12, 7: 13, 8: 14, 9: 17, 10: 20}
|
||||
|
||||
for layer_num_orig in range(11):
|
||||
# Load the trained generator.
|
||||
self.restore_model(self.test_iters)
|
||||
# for layer_num_orig in range(11):
|
||||
# Load the trained generator.
|
||||
self.restore_model(self.test_iters)
|
||||
|
||||
# Set data loader.
|
||||
if self.dataset == 'CelebA':
|
||||
data_loader = self.celeba_loader
|
||||
elif self.dataset == 'RaFD':
|
||||
data_loader = self.rafd_loader
|
||||
|
||||
# Initialize Metrics
|
||||
l1_error = 0.0
|
||||
l2_error = 0.0
|
||||
perceptual_error = 0.0
|
||||
n_samples = 0
|
||||
|
||||
# 11 layers
|
||||
# layer_num_orig = 1
|
||||
|
||||
# print('Layer ', layer_num_orig)
|
||||
for i, (x_real, c_org) in enumerate(data_loader):
|
||||
# Black image
|
||||
black = np.zeros((1,3,256,256))
|
||||
black = torch.FloatTensor(black).to(self.device)
|
||||
|
||||
# Prepare input images and target domain labels.
|
||||
x_real = x_real.to(self.device)
|
||||
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
|
||||
|
||||
# layer_num = layer_dict[layer_num_orig]
|
||||
layer_num = None
|
||||
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
|
||||
|
||||
# Translate images.
|
||||
x_fake_list = [x_real]
|
||||
|
||||
# if i == 0:
|
||||
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
|
||||
|
||||
for c_trg in c_trg_list:
|
||||
# Attack
|
||||
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
|
||||
# x_adv = x_real + perturb
|
||||
# x_adv = self.blur_tensor(x_adv)
|
||||
|
||||
# Metrics
|
||||
with torch.no_grad():
|
||||
# gen, preproc_x = self.G(x_adv, c_trg)
|
||||
gen, gen_feats = self.G(x_adv, c_trg)
|
||||
|
||||
# Add to lists
|
||||
# x_fake_list.append(preproc_x)
|
||||
x_fake_list.append(gen)
|
||||
|
||||
# No Attack
|
||||
# gen_noattack, _ = self.G(x_real, c_trg)
|
||||
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
|
||||
|
||||
l1_error += F.l1_loss(gen, gen_noattack)
|
||||
l2_error += F.mse_loss(gen, gen_noattack)
|
||||
n_samples += 1
|
||||
|
||||
# Save the translated images.
|
||||
x_concat = torch.cat(x_fake_list, dim=3)
|
||||
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
|
||||
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
|
||||
# print('Saved real and fake images into {}...'.format(result_path))
|
||||
|
||||
# Set data loader.
|
||||
if self.dataset == 'CelebA':
|
||||
data_loader = self.celeba_loader
|
||||
elif self.dataset == 'RaFD':
|
||||
data_loader = self.rafd_loader
|
||||
|
||||
# Initialize Metrics
|
||||
l1_error = 0.0
|
||||
l2_error = 0.0
|
||||
perceptual_error = 0.0
|
||||
n_samples = 0
|
||||
|
||||
# 11 layers
|
||||
# layer_num_orig = 1
|
||||
|
||||
print('Layer ', layer_num_orig)
|
||||
for i, (x_real, c_org) in enumerate(data_loader):
|
||||
# Black image
|
||||
black = np.zeros((1,3,256,256))
|
||||
black = torch.FloatTensor(black).to(self.device)
|
||||
|
||||
# Prepare input images and target domain labels.
|
||||
x_real = x_real.to(self.device)
|
||||
c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)
|
||||
|
||||
layer_num = layer_dict[layer_num_orig]
|
||||
pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num)
|
||||
|
||||
# Translate images.
|
||||
x_fake_list = [x_real]
|
||||
|
||||
# if i == 0:
|
||||
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
|
||||
|
||||
for c_trg in c_trg_list:
|
||||
# Attack
|
||||
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
|
||||
# x_adv = x_real + perturb
|
||||
# x_adv = self.blur_tensor(x_adv)
|
||||
|
||||
# Metrics
|
||||
with torch.no_grad():
|
||||
# gen, preproc_x = self.G(x_adv, c_trg)
|
||||
gen, gen_feats = self.G(x_adv, c_trg)
|
||||
|
||||
# Add to lists
|
||||
# x_fake_list.append(preproc_x)
|
||||
x_fake_list.append(gen)
|
||||
|
||||
# No Attack
|
||||
# gen_noattack, _ = self.G(x_real, c_trg)
|
||||
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
|
||||
|
||||
l1_error += F.l1_loss(gen, gen_noattack)
|
||||
l2_error += F.mse_loss(gen, gen_noattack)
|
||||
n_samples += 1
|
||||
|
||||
# Save the translated images.
|
||||
x_concat = torch.cat(x_fake_list, dim=3)
|
||||
result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
|
||||
save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
|
||||
# print('Saved real and fake images into {}...'.format(result_path))
|
||||
|
||||
if i == 199:
|
||||
break
|
||||
if i == 199:
|
||||
break
|
||||
|
||||
# Print metrics
|
||||
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
|
||||
|
||||
Reference in New Issue
Block a user