diff --git a/stargan/solver.py b/stargan/solver.py index 6fff77c..2d14b85 100644 --- a/stargan/solver.py +++ b/stargan/solver.py @@ -77,8 +77,8 @@ class Solver(object): def build_model(self): """Create a generator and a discriminator.""" if self.dataset in ['CelebA', 'RaFD']: - # self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) - self.G = AvgBlurGenerator(self.g_conv_dim, self.c_dim, self.g_repeat_num) + self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) + # self.G = AvgBlurGenerator(self.g_conv_dim, self.c_dim, self.g_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) elif self.dataset in ['Both']: self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. @@ -618,7 +618,9 @@ class Solver(object): for idx, c_trg in enumerate(c_trg_list): with torch.no_grad(): - gen_noattack, gen_noattack_feats = self.G(x_real, c_trg) + x_real_mod = x_real + x_real_mod = self.blur_tensor(x_real_mod) + gen_noattack, gen_noattack_feats = self.G(x_real_mod, c_trg) # Attack x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg) # _, perturb = x_advs[idx]