diff --git a/stargan/solver.py b/stargan/solver.py index 8f6a1bf..25e4b25 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -587,6 +587,8 @@ class Solver(object): # Initialize Metrics l1_error = 0.0 l2_error = 0.0 + min_dist = 0.0 + l0_error = 0.0 perceptual_error = 0.0 n_samples = 0 @@ -633,6 +635,8 @@ class Solver(object): l1_error += F.l1_loss(gen, gen_noattack) l2_error += F.mse_loss(gen, gen_noattack) + l0_error += np.linalg.norm((gen - gen_noattack).data.cpu(), ord=0) + min_dist += attacks.min_dist((gen - gen_noattack).data.cpu(), ord='-inf') n_samples += 1 # Save the translated images. @@ -645,8 +649,8 @@ class Solver(object): break # Print metrics - print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples, - perceptual_error / n_samples)) + print('{} images. L1 error: {}. L2 error: {}. L0 error: {}. L_-inf error: {}. Perceptual error: {}.'.format(n_samples, + l1_error / n_samples, l2_error / n_samples, l0_error / n_samples, min_dist / n_samples, perceptual_error / n_samples)) def test_multi(self): """Translate images using StarGAN trained on multiple datasets."""