diff --git a/ganimation/solver.py b/ganimation/solver.py index 76af574..fcd9f22 100644 --- a/ganimation/solver.py +++ b/ganimation/solver.py @@ -418,15 +418,17 @@ class Solver(Utils): x_adv, perturb = pgd_attack.perturb(image_to_animate, black, targets_au) # x_adv = image_to_animate # print(image_to_animate.shape, x_adv.shape) - resulting_images_att, resulting_images_reg = self.G( - x_adv, targets_au) - resulting_image = self.imFromAttReg( - resulting_images_att, resulting_images_reg, x_adv).cuda() + with torch.no_grad(): + resulting_images_att, resulting_images_reg = self.G( + x_adv, targets_au) + resulting_image = self.imFromAttReg( + resulting_images_att, resulting_images_reg, x_adv).cuda() - resulting_images_att_noattack, resulting_images_reg_noattack = self.G( - image_to_animate, targets_au) - resulting_image_noattack = self.imFromAttReg( - resulting_images_att_noattack, resulting_images_reg_noattack, image_to_animate).cuda() + with torch.no_grad(): + resulting_images_att_noattack, resulting_images_reg_noattack = self.G( + image_to_animate, targets_au) + resulting_image_noattack = self.imFromAttReg( + resulting_images_att_noattack, resulting_images_reg_noattack, image_to_animate).cuda() save_image((resulting_image+1)/2, os.path.join(self.animation_results_dir, image_path.split('/')[-1].split('.')[0]