diff --git a/ganimation/attacks.py b/ganimation/attacks.py index ca3be29..c879e6a 100644 --- a/ganimation/attacks.py +++ b/ganimation/attacks.py @@ -96,8 +96,8 @@ class LinfPGDAttack(object): self.model.zero_grad() - # Away from black - loss = self.loss_fn(output_att, y) + # loss = self.loss_fn(output_att, y) + loss = -self.loss_fn(out, y) loss.backward() grad = X.grad @@ -131,7 +131,8 @@ class LinfPGDAttack(object): self.model.zero_grad() - loss = -self.loss_fn(output_att, y) + # loss = -self.loss_fn(output_att, y) + loss = -self.loss_fn(out, y) full_loss += loss full_loss.backward()