From ca800ef00b6b8c1016e6e0b7d5c394e26810216f Mon Sep 17 00:00:00 2001 From: NNNNAI <844294823@qq.com> Date: Mon, 19 Jul 2021 11:48:17 +0800 Subject: [PATCH] Update test_one_image.py --- test_one_image.py | 83 ++++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/test_one_image.py b/test_one_image.py index 0915550..a05032b 100644 --- a/test_one_image.py +++ b/test_one_image.py @@ -22,10 +22,10 @@ transformer_Arcface = transforms.Compose([ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) -# detransformer = transforms.Compose([ -# transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), -# transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) -# ]) +detransformer = transforms.Compose([ + transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), + transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) + ]) if __name__ == '__main__': opt = TestOptions().parse() @@ -35,51 +35,52 @@ if __name__ == '__main__': model = create_model(opt) model.eval() + with torch.no_grad(): + + pic_a = opt.pic_a_path + img_a = Image.open(pic_a).convert('RGB') + img_a = transformer_Arcface(img_a) + img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) - pic_a = opt.pic_a_path - img_a = Image.open(pic_a).convert('RGB') - img_a = transformer_Arcface(img_a) - img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) + pic_b = opt.pic_b_path - pic_b = opt.pic_b_path + img_b = Image.open(pic_b).convert('RGB') + img_b = transformer(img_b) + img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) - img_b = Image.open(pic_b).convert('RGB') - img_b = transformer(img_b) - img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) + # convert numpy to tensor + img_id = img_id.cuda() + img_att = img_att.cuda() - # convert numpy to tensor - img_id = img_id.cuda() - img_att = img_att.cuda() - - #create latent id - img_id_downsample = F.interpolate(img_id, scale_factor=0.5) - latend_id = model.netArc(img_id_downsample) - latend_id = latend_id.detach().to('cpu') - latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) - latend_id = latend_id.to('cuda') + #create latent id + img_id_downsample = F.interpolate(img_id, scale_factor=0.5) + latend_id = model.netArc(img_id_downsample) + latend_id = latend_id.detach().to('cpu') + latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) + latend_id = latend_id.to('cuda') - ############## Forward Pass ###################### - img_fake = model(img_id, img_att, latend_id, latend_id, True) + ############## Forward Pass ###################### + img_fake = model(img_id, img_att, latend_id, latend_id, True) - for i in range(img_id.shape[0]): - if i == 0: - row1 = img_id[i] - row2 = img_att[i] - row3 = img_fake[i] - else: - row1 = torch.cat([row1, img_id[i]], dim=2) - row2 = torch.cat([row2, img_att[i]], dim=2) - row3 = torch.cat([row3, img_fake[i]], dim=2) + for i in range(img_id.shape[0]): + if i == 0: + row1 = img_id[i] + row2 = img_att[i] + row3 = img_fake[i] + else: + row1 = torch.cat([row1, img_id[i]], dim=2) + row2 = torch.cat([row2, img_att[i]], dim=2) + row3 = torch.cat([row3, img_fake[i]], dim=2) - #full = torch.cat([row1, row2, row3], dim=1).detach() - full = row3.detach() - full = full.permute(1, 2, 0) - output = full.to('cpu') - output = np.array(output) - output = output[..., ::-1] + #full = torch.cat([row1, row2, row3], dim=1).detach() + full = row3.detach() + full = full.permute(1, 2, 0) + output = full.to('cpu') + output = np.array(output) + output = output[..., ::-1] - output = output*255 + output = output*255 - cv2.imwrite(opt.output_path + 'result.jpg',output) \ No newline at end of file + cv2.imwrite(opt.output_path + 'result.jpg',output) \ No newline at end of file