Update test_one_image.py

This commit is contained in:
Smiril
2022-07-12 12:19:20 +02:00
committed by GitHub
parent cd3d89638d
commit 8ee255f35f

View File

@@ -49,16 +49,31 @@ if __name__ == '__main__':
img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2])
# convert numpy to tensor
img_id = img_id.cuda()
img_att = img_att.cuda()
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
img_id = img_id.cuda()
img_att = img_att.cuda()
else:
print("ERROR")
else:
img_id = img_id.has_mps()
img_att = img_att.has_mps()
#create latent id
img_id_downsample = F.interpolate(img_id, size=(112,112))
latend_id = model.netArc(img_id_downsample)
latend_id = latend_id.detach().to('cpu')
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
latend_id = latend_id.to('cuda')
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
img_id_downsample = F.interpolate(img_id, size=(112,112))
latend_id = model.netArc(img_id_downsample)
latend_id = latend_id.detach().to('cpu')
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
latend_id = latend_id.to('cuda')
else:
print("ERROR")
else:
img_id_downsample = F.interpolate(img_id, size=(112,112))
latend_id = model.netArc(img_id_downsample)
latend_id = latend_id.detach().to('cpu')
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
latend_id = latend_id.to('has_mps')
############## Forward Pass ######################
img_fake = model(img_id, img_att, latend_id, latend_id, True)
@@ -83,4 +98,4 @@ if __name__ == '__main__':
output = output*255
cv2.imwrite(opt.output_path + 'result.jpg',output)
cv2.imwrite(opt.output_path + 'result.jpg',output)