This commit is contained in:
Nataniel Ruiz
2019-12-23 16:59:04 -04:00
parent 44095641b1
commit b741dc2827
+21 -20
View File
@@ -571,25 +571,26 @@ class Solver(object):
def test_attack(self):
"""Translate images using StarGAN trained on a single dataset."""
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
perceptual_error = 0.0
n_samples = 0
# 11 layers
# layer_num_orig = 1
for layer_num_orig in range(11):
# Load the trained generator.
self.restore_model(self.test_iters)
# Set data loader.
if self.dataset == 'CelebA':
data_loader = self.celeba_loader
elif self.dataset == 'RaFD':
data_loader = self.rafd_loader
# Initialize Metrics
l1_error = 0.0
l2_error = 0.0
perceptual_error = 0.0
n_samples = 0
# 11 layers
# layer_num_orig = 1
print('Layer ', layer_num_orig)
for i, (x_real, c_org) in enumerate(data_loader):
# Black image
@@ -641,9 +642,9 @@ class Solver(object):
if i == 199:
break
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
perceptual_error / n_samples))
# Print metrics
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
perceptual_error / n_samples))
def test_multi(self):
"""Translate images using StarGAN trained on multiple datasets."""