diff --git a/stargan/attacks.py b/stargan/attacks.py index cbf3f79..49a29ee 100644 --- a/stargan/attacks.py +++ b/stargan/attacks.py @@ -35,7 +35,7 @@ class LinfPGDAttack(object): y = torch.FloatTensor(y).to(self.device) self.model.zero_grad() - loss = self.loss_fn(output, y) + loss = -self.loss_fn(output, y) loss.backward() grad = X.grad diff --git a/stargan/solver.py b/stargan/solver.py index 6d38c41..10e8d07 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -610,8 +610,10 @@ class Solver(object): # x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0]) for c_trg in c_trg_list: + with torch.no_grad(): + gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) # Attack - x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) + x_adv, perturb = pgd_attack.perturb(x_real, gen_noattack, c_trg) # x_adv = x_real + perturb # x_adv = self.blur_tensor(x_adv) @@ -627,7 +629,7 @@ class Solver(object): # No Attack # gen_noattack, _ = self.G(x_real, c_trg) - gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) + # gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) l1_error += F.l1_loss(gen, gen_noattack) l2_error += F.mse_loss(gen, gen_noattack)