This commit is contained in:
Nataniel Ruiz
2019-12-26 09:57:29 -04:00
parent 466d927a7c
commit 982d09fe84
3 changed files with 6 additions and 4 deletions
+1 -1
View File
@@ -77,7 +77,7 @@ class TestModel(BaseModel):
pgd_attack = attacks.LinfPGDAttack(model=self.netG)
black = np.zeros((1, 3, image.size(2), image.size(3)))
black = torch.FloatTensor(black).cuda()
input_adv, perturb = pgd_attack.perturb(image, image)
input_adv, perturb = pgd_attack.perturb(image, black)
return input_adv, perturb
+1 -1
View File
@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
class LinfPGDAttack(object):
def __init__(self, model=None, epsilon=0.15, k=40, a=0.01):
def __init__(self, model=None, epsilon=0.05, k=1, a=0.01):
self.model = model
self.epsilon = epsilon
self.k = k
+4 -2
View File
@@ -609,8 +609,8 @@ class Solver(object):
# Translate images.
x_fake_list = [x_real]
# if i == 0:
# x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
if i == 0:
x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0])
for c_trg in c_trg_list:
with torch.no_grad():
@@ -639,6 +639,8 @@ class Solver(object):
l0_error += (gen - gen_noattack).norm(0)
min_dist += (gen - gen_noattack).norm(float('-inf'))
n_samples += 1
break
# Save the translated images.
x_concat = torch.cat(x_fake_list, dim=3)