Update test_one_image.py
This commit is contained in:
@@ -49,16 +49,31 @@ if __name__ == '__main__':
|
||||
img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2])
|
||||
|
||||
# convert numpy to tensor
|
||||
img_id = img_id.cuda()
|
||||
img_att = img_att.cuda()
|
||||
|
||||
if not torch.backends.mps.is_available():
|
||||
if not torch.backends.mps.is_built():
|
||||
img_id = img_id.cuda()
|
||||
img_att = img_att.cuda()
|
||||
else:
|
||||
print("ERROR")
|
||||
else:
|
||||
img_id = img_id.has_mps()
|
||||
img_att = img_att.has_mps()
|
||||
#create latent id
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = latend_id.detach().to('cpu')
|
||||
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
|
||||
latend_id = latend_id.to('cuda')
|
||||
|
||||
if not torch.backends.mps.is_available():
|
||||
if not torch.backends.mps.is_built():
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = latend_id.detach().to('cpu')
|
||||
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
|
||||
latend_id = latend_id.to('cuda')
|
||||
else:
|
||||
print("ERROR")
|
||||
else:
|
||||
img_id_downsample = F.interpolate(img_id, size=(112,112))
|
||||
latend_id = model.netArc(img_id_downsample)
|
||||
latend_id = latend_id.detach().to('cpu')
|
||||
latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True)
|
||||
latend_id = latend_id.to('has_mps')
|
||||
|
||||
############## Forward Pass ######################
|
||||
img_fake = model(img_id, img_att, latend_id, latend_id, True)
|
||||
@@ -83,4 +98,4 @@ if __name__ == '__main__':
|
||||
|
||||
output = output*255
|
||||
|
||||
cv2.imwrite(opt.output_path + 'result.jpg',output)
|
||||
cv2.imwrite(opt.output_path + 'result.jpg',output)
|
||||
|
||||
Reference in New Issue
Block a user