#!/usr/bin/env python3 # -*- coding:utf-8 -*- ############################################################# # File: cos.py # Created Date: Monday February 7th 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com # Last Modified: Monday, 7th February 2022 6:26:23 pm # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# import torch def cosin_metric(x1, x2): #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))