From 5ede7105ffc36340500b4b2175a4417d78d5bb99 Mon Sep 17 00:00:00 2001 From: Nataniel Ruiz Date: Thu, 9 Jan 2020 18:06:27 -0400 Subject: [PATCH] GANimation conditional attacks --- ganimation/attacks.py | 4 ++-- ganimation/solver.py | 27 ++++++++++++++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/ganimation/attacks.py b/ganimation/attacks.py index cd35251..040e9cf 100644 --- a/ganimation/attacks.py +++ b/ganimation/attacks.py @@ -96,8 +96,8 @@ class LinfPGDAttack(object): self.model.zero_grad() - loss = self.loss_fn(output_att, y) - # loss = -self.loss_fn(out, y) + # loss = self.loss_fn(output_att, y) + loss = -self.loss_fn(out, y) loss.backward() grad = X.grad diff --git a/ganimation/solver.py b/ganimation/solver.py index f47e1bf..238c95b 100644 --- a/ganimation/solver.py +++ b/ganimation/solver.py @@ -451,11 +451,11 @@ class Solver(Utils): resulting_image = self.imFromAttReg( resulting_images_att, resulting_images_reg, x_adv).cuda() - # with torch.no_grad(): - # resulting_images_att_noattack, resulting_images_reg_noattack = self.G( - # image_to_animate, targets_au) - # resulting_image_noattack = self.imFromAttReg( - # resulting_images_att_noattack, resulting_images_reg_noattack, image_to_animate).cuda() + with torch.no_grad(): + resulting_images_att_noattack, resulting_images_reg_noattack = self.G( + image_to_animate, targets_au) + resulting_image_noattack = self.imFromAttReg( + resulting_images_att_noattack, resulting_images_reg_noattack, image_to_animate).cuda() save_image((resulting_image+1)/2, os.path.join(self.animation_results_dir, image_path.split('/')[-1].split('.')[0] @@ -465,16 +465,17 @@ class Solver(Utils): image_path.split('/')[-1].split('.')[0] + '_ref.jpg')) - # l1_error += F.l1_loss(resulting_image, resulting_image_noattack) - # l2_error += F.mse_loss(resulting_image, resulting_image_noattack) - # l0_error += (resulting_image - resulting_image_noattack).norm(0) - # min_dist += (resulting_image - resulting_image_noattack).norm(float('-inf')) + l1_error += F.l1_loss(resulting_image, resulting_image_noattack) + l2_error += F.mse_loss(resulting_image, resulting_image_noattack) + l0_error += (resulting_image - resulting_image_noattack).norm(0) + min_dist += (resulting_image - resulting_image_noattack).norm(float('-inf')) # Compare to input image - l1_error += F.l1_loss(resulting_image, image_to_animate) - l2_error += F.mse_loss(resulting_image, image_to_animate) - l0_error += (resulting_image - image_to_animate).norm(0) - min_dist += (resulting_image - image_to_animate).norm(float('-inf')) + # l1_error += F.l1_loss(resulting_image, image_to_animate) + # l2_error += F.mse_loss(resulting_image, image_to_animate) + # l0_error += (resulting_image - image_to_animate).norm(0) + # min_dist += (resulting_image - image_to_animate).norm(float('-inf')) + n_samples += 1 # Print metrics