From 9c2cb063beec716cfcc3aa19748377cb6bc66c90 Mon Sep 17 00:00:00 2001 From: Nataniel Ruiz Date: Tue, 24 Dec 2019 14:07:25 -0400 Subject: [PATCH] next --- stargan/attacks.py | 2 +- stargan/solver.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/stargan/attacks.py b/stargan/attacks.py index cbf3f79..49a29ee 100644 --- a/stargan/attacks.py +++ b/stargan/attacks.py @@ -35,7 +35,7 @@ class LinfPGDAttack(object): y = torch.FloatTensor(y).to(self.device) self.model.zero_grad() - loss = self.loss_fn(output, y) + loss = -self.loss_fn(output, y) loss.backward() grad = X.grad diff --git a/stargan/solver.py b/stargan/solver.py index 6d38c41..10e8d07 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -610,8 +610,10 @@ class Solver(object): # x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0]) for c_trg in c_trg_list: + with torch.no_grad(): + gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) # Attack - x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) + x_adv, perturb = pgd_attack.perturb(x_real, gen_noattack, c_trg) # x_adv = x_real + perturb # x_adv = self.blur_tensor(x_adv) @@ -627,7 +629,7 @@ class Solver(object): # No Attack # gen_noattack, _ = self.G(x_real, c_trg) - gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) + # gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) l1_error += F.l1_loss(gen, gen_noattack) l2_error += F.mse_loss(gen, gen_noattack)