diff --git a/ganimation/solver.py b/ganimation/solver.py index 824e7ed..dfffe46 100644 --- a/ganimation/solver.py +++ b/ganimation/solver.py @@ -388,6 +388,14 @@ class Solver(Utils): black = np.zeros((1,3,128,128)) black = torch.FloatTensor(black).to(self.device) + # Initialize Metrics + l1_error = 0.0 + l2_error = 0.0 + min_dist = 0.0 + l0_error = 0.0 + perceptual_error = 0.0 + n_samples = 0 + pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device) images_to_animate_path = sorted(glob.glob( @@ -400,21 +408,26 @@ class Solver(Utils): all_images = torch.cat([regular_image_transform(Image.open(path)).unsqueeze(0) for path in images_to_animate_path], dim=0).cuda() for target_idx in range(targets.size(0)): - if target_idx == 0: - img = regular_image_transform(Image.open(images_to_animate_path[idx])).unsqueeze(0).cuda() - # x_adv, perturb = pgd_attack.perturb(img, black, targets[0, :].unsqueeze(0).cuda()) - x_adv, perturb = pgd_attack.perturb_iter_class(image_to_animate, black, targets[:, :].cuda()) - # _, perturb = pgd_attack.perturb_iter_data(image_to_animate, all_images, black, targets[68, :].unsqueeze(0).cuda()) + # if target_idx == 0: + # img = regular_image_transform(Image.open(images_to_animate_path[idx])).unsqueeze(0).cuda() + # # x_adv, perturb = pgd_attack.perturb(img, black, targets[0, :].unsqueeze(0).cuda()) + # x_adv, perturb = pgd_attack.perturb_iter_class(image_to_animate, black, targets[:, :].cuda()) + # # _, perturb = pgd_attack.perturb_iter_data(image_to_animate, all_images, black, targets[68, :].unsqueeze(0).cuda()) targets_au = targets[target_idx, :].unsqueeze(0).cuda() - # x_adv, perturb = pgd_attack.perturb(image_to_animate, black, targets_au) - x_adv = image_to_animate + x_adv, perturb = pgd_attack.perturb(image_to_animate, black, targets_au) + # x_adv = image_to_animate # print(image_to_animate.shape, x_adv.shape) resulting_images_att, resulting_images_reg = self.G( x_adv, targets_au) resulting_image = self.imFromAttReg( resulting_images_att, resulting_images_reg, x_adv).cuda() + resulting_images_att_noattack, resulting_images_reg_noattack = self.G( + image_to_animate, targets_au) + resulting_image_noattack = self.imFromAttReg( + resulting_images_att_noattack, resulting_images_reg_noattack, image_to_animate).cuda() + save_image((resulting_image+1)/2, os.path.join(self.animation_results_dir, image_path.split('/')[-1].split('.')[0] + '_' + reference_expression_images[target_idx])) @@ -423,6 +436,16 @@ class Solver(Utils): image_path.split('/')[-1].split('.')[0] + '_ref.jpg')) + l1_error += F.l1_loss(resulting_image, gen_noattack) + l2_error += F.mse_loss(resulting_image, gen_noattack) + l0_error += (resulting_image - gen_noattack).norm(0) + min_dist += (resulting_image - gen_noattack).norm(float('-inf')) + n_samples += 1 + + # Print metrics + print('{} images. L1 error: {}. L2 error: {}. L0 error: {}. L_-inf error: {}. Perceptual error: {}.'.format(n_samples, + l1_error / n_samples, l2_error / n_samples, l0_error / n_samples, min_dist / n_samples, perceptual_error / n_samples)) + # """ Code to modify single Action Units """ # Set data loader.