GANimation conditional attacks
This commit is contained in:
@@ -30,8 +30,8 @@ class LinfPGDAttack(object):
|
||||
out = imFromAttReg(output_att, output_img, X)
|
||||
|
||||
self.model.zero_grad()
|
||||
# loss = -self.loss_fn(output_att, y) + self.loss_fn(output_img, y)
|
||||
loss = -self.loss_fn(out, y)
|
||||
loss = -self.loss_fn(output_att, y)
|
||||
# loss = -self.loss_fn(out, y)
|
||||
loss.backward()
|
||||
grad = X.grad
|
||||
|
||||
@@ -96,8 +96,8 @@ class LinfPGDAttack(object):
|
||||
|
||||
self.model.zero_grad()
|
||||
|
||||
# loss = self.loss_fn(output_att, y)
|
||||
loss = -self.loss_fn(out, y)
|
||||
loss = -self.loss_fn(output_att, y)
|
||||
# loss = -self.loss_fn(out, y)
|
||||
loss.backward()
|
||||
grad = X.grad
|
||||
|
||||
@@ -131,8 +131,8 @@ class LinfPGDAttack(object):
|
||||
|
||||
out = imFromAttReg(output_att, output_img, X)
|
||||
|
||||
# loss = -self.loss_fn(output_att, y)
|
||||
loss = -self.loss_fn(out, y)
|
||||
loss = -self.loss_fn(output_att, y)
|
||||
# loss = -self.loss_fn(out, y)
|
||||
full_loss += loss
|
||||
|
||||
full_loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user