17 lines
628 B
Python
17 lines
628 B
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: cos.py
|
|
# Created Date: Monday February 7th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Monday, 7th February 2022 6:26:23 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import torch
|
|
|
|
def cosin_metric(x1, x2):
|
|
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
|
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1)) |