diff --git a/test_one_image.py b/test_one_image.py index a3a0cc6..ea930dd 100644 --- a/test_one_image.py +++ b/test_one_image.py @@ -55,7 +55,7 @@ img_att = img_att.cuda() img_id_downsample = F.interpolate(img_id, scale_factor=0.5) latend_id = model.netArc(img_id_downsample) latend_id = latend_id.detach().to('cpu') -latend_id = latend_id/np.linalg.norm(latend_id) +latend_id = latend_id/np.linalg.norm(latend_id,axis=1,keepdims=True) latend_id = latend_id.to('cuda')