This commit is contained in:
Nataniel Ruiz
2020-01-17 14:36:24 -05:00
parent 9421ed626b
commit e7a447d416
+5 -3
View File
@@ -77,8 +77,8 @@ class Solver(object):
def build_model(self):
"""Create a generator and a discriminator."""
if self.dataset in ['CelebA', 'RaFD']:
# self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.G = AvgBlurGenerator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
# self.G = AvgBlurGenerator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)
elif self.dataset in ['Both']:
self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector.
@@ -618,7 +618,9 @@ class Solver(object):
for idx, c_trg in enumerate(c_trg_list):
with torch.no_grad():
gen_noattack, gen_noattack_feats = self.G(x_real, c_trg)
x_real_mod = x_real
x_real_mod = self.blur_tensor(x_real_mod)
gen_noattack, gen_noattack_feats = self.G(x_real_mod, c_trg)
# Attack
x_adv, perturb = pgd_attack.perturb(x_real, black, c_trg)
# _, perturb = x_advs[idx]