diff --git a/cyclegan/test.py b/cyclegan/test.py index f44bb60..56522ef 100644 --- a/cyclegan/test.py +++ b/cyclegan/test.py @@ -68,7 +68,9 @@ if __name__ == '__main__': model.set_input(data) # unpack data from data loader with torch.no_grad(): model.forward_noattack() - input_adv, perturb = model.attack() + if i == 0: + input_adv, perturb = model.attack() + # input_adv, perturb = model.attack() with torch.no_grad(): model.forward_attack(perturb) model.compute_visuals()