diff --git a/stargan/solver.py b/stargan/solver.py index 60bad1e..7966142 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -609,15 +609,19 @@ class Solver(object): # Translate images. x_fake_list = [x_real] + x_advs = {} if i == 0: - x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0]) + for c_trg in c_trg_list: + x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) + x_advs[c_trg] = x_adv, perturb for c_trg in c_trg_list: with torch.no_grad(): gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) # Attack - x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) - # x_adv = x_real + perturb + # x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) + _, perturb = x_advs[c_trg] + x_adv = x_real + perturb # x_adv = self.blur_tensor(x_adv) # Metrics