From 8ee255f35fd56bfdc8c33b324809802fbf0b3867 Mon Sep 17 00:00:00 2001 From: Smiril Date: Tue, 12 Jul 2022 12:19:20 +0200 Subject: [PATCH] Update test_one_image.py --- test_one_image.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/test_one_image.py b/test_one_image.py index 4eabd8e..8713428 100644 --- a/test_one_image.py +++ b/test_one_image.py @@ -49,16 +49,31 @@ if __name__ == '__main__': img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) # convert numpy to tensor - img_id = img_id.cuda() - img_att = img_att.cuda() - + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + img_id = img_id.cuda() + img_att = img_att.cuda() + else: + print("ERROR") + else: + img_id = img_id.has_mps() + img_att = img_att.has_mps() #create latent id - img_id_downsample = F.interpolate(img_id, size=(112,112)) - latend_id = model.netArc(img_id_downsample) - latend_id = latend_id.detach().to('cpu') - latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) - latend_id = latend_id.to('cuda') - + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + img_id_downsample = F.interpolate(img_id, size=(112,112)) + latend_id = model.netArc(img_id_downsample) + latend_id = latend_id.detach().to('cpu') + latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) + latend_id = latend_id.to('cuda') + else: + print("ERROR") + else: + img_id_downsample = F.interpolate(img_id, size=(112,112)) + latend_id = model.netArc(img_id_downsample) + latend_id = latend_id.detach().to('cpu') + latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) + latend_id = latend_id.to('has_mps') ############## Forward Pass ###################### img_fake = model(img_id, img_att, latend_id, latend_id, True) @@ -83,4 +98,4 @@ if __name__ == '__main__': output = output*255 - cv2.imwrite(opt.output_path + 'result.jpg',output) \ No newline at end of file + cv2.imwrite(opt.output_path + 'result.jpg',output)