next
This commit is contained in:
+21
-20
@@ -571,25 +571,26 @@ class Solver(object):
|
||||
|
||||
def test_attack(self):
|
||||
"""Translate images using StarGAN trained on a single dataset."""
|
||||
# Load the trained generator.
|
||||
self.restore_model(self.test_iters)
|
||||
|
||||
# Set data loader.
|
||||
if self.dataset == 'CelebA':
|
||||
data_loader = self.celeba_loader
|
||||
elif self.dataset == 'RaFD':
|
||||
data_loader = self.rafd_loader
|
||||
|
||||
# Initialize Metrics
|
||||
l1_error = 0.0
|
||||
l2_error = 0.0
|
||||
perceptual_error = 0.0
|
||||
n_samples = 0
|
||||
|
||||
# 11 layers
|
||||
# layer_num_orig = 1
|
||||
|
||||
for layer_num_orig in range(11):
|
||||
# Load the trained generator.
|
||||
self.restore_model(self.test_iters)
|
||||
|
||||
# Set data loader.
|
||||
if self.dataset == 'CelebA':
|
||||
data_loader = self.celeba_loader
|
||||
elif self.dataset == 'RaFD':
|
||||
data_loader = self.rafd_loader
|
||||
|
||||
# Initialize Metrics
|
||||
l1_error = 0.0
|
||||
l2_error = 0.0
|
||||
perceptual_error = 0.0
|
||||
n_samples = 0
|
||||
|
||||
# 11 layers
|
||||
# layer_num_orig = 1
|
||||
|
||||
print('Layer ', layer_num_orig)
|
||||
for i, (x_real, c_org) in enumerate(data_loader):
|
||||
# Black image
|
||||
@@ -641,9 +642,9 @@ class Solver(object):
|
||||
if i == 199:
|
||||
break
|
||||
|
||||
# Print metrics
|
||||
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
|
||||
perceptual_error / n_samples))
|
||||
# Print metrics
|
||||
print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,
|
||||
perceptual_error / n_samples))
|
||||
|
||||
def test_multi(self):
|
||||
"""Translate images using StarGAN trained on multiple datasets."""
|
||||
|
||||
Reference in New Issue
Block a user