diff --git a/stargan/solver.py b/stargan/solver.py index 610e547..eda5505 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -574,75 +574,76 @@ class Solver(object): layer_dict = {0: 2, 1: 5, 2: 8, 3: 9, 4: 10, 5: 11, 6: 12, 7: 13, 8: 14, 9: 17, 10: 20} - for layer_num_orig in range(11): - # Load the trained generator. - self.restore_model(self.test_iters) + # for layer_num_orig in range(11): + # Load the trained generator. + self.restore_model(self.test_iters) + + # Set data loader. + if self.dataset == 'CelebA': + data_loader = self.celeba_loader + elif self.dataset == 'RaFD': + data_loader = self.rafd_loader + + # Initialize Metrics + l1_error = 0.0 + l2_error = 0.0 + perceptual_error = 0.0 + n_samples = 0 + + # 11 layers + # layer_num_orig = 1 + + # print('Layer ', layer_num_orig) + for i, (x_real, c_org) in enumerate(data_loader): + # Black image + black = np.zeros((1,3,256,256)) + black = torch.FloatTensor(black).to(self.device) + + # Prepare input images and target domain labels. + x_real = x_real.to(self.device) + c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) + + # layer_num = layer_dict[layer_num_orig] + layer_num = None + pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num) + + # Translate images. + x_fake_list = [x_real] + + # if i == 0: + # x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0]) + + for c_trg in c_trg_list: + # Attack + x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) + # x_adv = x_real + perturb + # x_adv = self.blur_tensor(x_adv) + + # Metrics + with torch.no_grad(): + # gen, preproc_x = self.G(x_adv, c_trg) + gen, gen_feats = self.G(x_adv, c_trg) + + # Add to lists + # x_fake_list.append(preproc_x) + x_fake_list.append(gen) + + # No Attack + # gen_noattack, _ = self.G(x_real, c_trg) + gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) + + l1_error += F.l1_loss(gen, gen_noattack) + l2_error += F.mse_loss(gen, gen_noattack) + n_samples += 1 + + # Save the translated images. + x_concat = torch.cat(x_fake_list, dim=3) + result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) + save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) + # print('Saved real and fake images into {}...'.format(result_path)) - # Set data loader. - if self.dataset == 'CelebA': - data_loader = self.celeba_loader - elif self.dataset == 'RaFD': - data_loader = self.rafd_loader - - # Initialize Metrics - l1_error = 0.0 - l2_error = 0.0 - perceptual_error = 0.0 - n_samples = 0 - - # 11 layers - # layer_num_orig = 1 - - print('Layer ', layer_num_orig) - for i, (x_real, c_org) in enumerate(data_loader): - # Black image - black = np.zeros((1,3,256,256)) - black = torch.FloatTensor(black).to(self.device) - - # Prepare input images and target domain labels. - x_real = x_real.to(self.device) - c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) - - layer_num = layer_dict[layer_num_orig] - pgd_attack = attacks.LinfPGDAttack(model=self.G, device=self.device, feat=layer_num) - - # Translate images. - x_fake_list = [x_real] - - # if i == 0: - # x_adv, perturb = pgd_attack.perturb(x_real, x_real, c_trg_list[0]) - - for c_trg in c_trg_list: - # Attack - x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) - # x_adv = x_real + perturb - # x_adv = self.blur_tensor(x_adv) - - # Metrics - with torch.no_grad(): - # gen, preproc_x = self.G(x_adv, c_trg) - gen, gen_feats = self.G(x_adv, c_trg) - - # Add to lists - # x_fake_list.append(preproc_x) - x_fake_list.append(gen) - - # No Attack - # gen_noattack, _ = self.G(x_real, c_trg) - gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) - - l1_error += F.l1_loss(gen, gen_noattack) - l2_error += F.mse_loss(gen, gen_noattack) - n_samples += 1 - - # Save the translated images. - x_concat = torch.cat(x_fake_list, dim=3) - result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) - # print('Saved real and fake images into {}...'.format(result_path)) - - if i == 199: - break + if i == 199: + break # Print metrics print('{} images. L1 error: {}. L2 error: {}. Perceptual error: {}.'.format(n_samples, l1_error / n_samples, l2_error / n_samples,