AgentReview / agentreview /utility /metrics_utils.py
Yiqiao Jin
Update demo
53709ed
raw
history blame
556 Bytes
import numpy as np
import torch
def pairwise_cos_sim(A: torch.Tensor, B: torch.Tensor, device="cuda:0"):
if isinstance(A, np.ndarray):
A = torch.from_numpy(A).to(device)
if isinstance(B, np.ndarray):
B = torch.from_numpy(B).to(device)
from torch.nn.functional import normalize
A_norm = normalize(A, dim=1) # Normalize the rows of A
B_norm = normalize(B, dim=1) # Normalize the rows of B
cos_sim = torch.matmul(A_norm,
B_norm.t()) # Calculate the cosine similarity
return cos_sim