diff --git a/ganimation/solver.py b/ganimation/solver.py index d2fbca2..2e6611b 100644 --- a/ganimation/solver.py +++ b/ganimation/solver.py @@ -411,7 +411,7 @@ class Solver(Utils): if idx == 0: for target_idx in range(targets.size(0)): - x_adv, perturb = pgd_attack.perturb_iter_class(image_to_animate, black, targets[target_idx, :].cuda()) + x_adv, perturb = pgd_attack.perturb_iter_class(image_to_animate, black, targets[target_idx, :].unsqueeze(0).cuda()) x_advs.append(x_adv, perturb) for target_idx in range(targets.size(0)):