diff --git a/stargan/solver.py b/stargan/solver.py index 4c1c54e..610e547 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -571,7 +571,9 @@ class Solver(object): def test_attack(self): """Translate images using StarGAN trained on a single dataset.""" - + + layer_dict = {0: 2, 1: 5, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12, 7: 13, 8: 14, 9: 17, 10: 20} + for layer_num_orig in range(11): # Load the trained generator. self.restore_model(self.test_iters) @@ -601,7 +603,7 @@ class Solver(object): x_real = x_real.to(self.device) c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) - layer_num = (layer_num_orig + 1) * 3 - 1 + layer_num = layer_dict[layer_num_orig] pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num) # Translate images.