diff --git a/stargan/solver.py b/stargan/solver.py index 2863d1e..4c1c54e 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -571,25 +571,26 @@ class Solver(object): def test_attack(self): """Translate images using StarGAN trained on a single dataset.""" - # Load the trained generator. - self.restore_model(self.test_iters) - # Set data loader. - if self.dataset == 'CelebA': - data_loader = self.celeba_loader - elif self.dataset == 'RaFD': - data_loader = self.rafd_loader - - # Initialize Metrics - l1_error = 0.0 - l2_error = 0.0 - perceptual_error = 0.0 - n_samples = 0 - - # 11 layers - # layer_num_orig = 1 - for layer_num_orig in range(11): + # Load the trained generator. + self.restore_model(self.test_iters) + + # Set data loader. + if self.dataset == 'CelebA': + data_loader = self.celeba_loader + elif self.dataset == 'RaFD': + data_loader = self.rafd_loader + + # Initialize Metrics + l1_error = 0.0 + l2_error = 0.0 + perceptual_error = 0.0 + n_samples = 0 + + # 11 layers + # layer_num_orig = 1 + print('Layer ', layer_num_orig) for i, (x_real, c_org) in enumerate(data_loader): # Black image @@ -641,9 +642,9 @@ class Solver(object): if i == 199: break - # Print metrics - print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples, - perceptual_error / n_samples)) + # Print metrics + print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples, + perceptual_error / n_samples)) def test_multi(self): """Translate images using StarGAN trained on multiple datasets."""