File size: 2,214 Bytes
9b744c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import json
import pprint

import numpy as np
from sentence_transformers import SentenceTransformer

def cosine_similarity(a, b):
    if a.ndim == 1:
        a = a.reshape(1, -1)

    if b.ndim == 1:
        b = b.reshape(1, -1)

    return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b, axis=1))


def retrieve_issue_rankings(
    query: str,
    model_id: str,
    input_embedding_filename: str,
):
    """
    Given a query returns the list of issues sorted by similarity to the query
    according to their embedding index
    """
    model = SentenceTransformer(model_id)

    embeddings = np.load(input_embedding_filename)

    query_embedding = model.encode(query)

    # Calculate the cosine similarity between the query and all the issues
    cosine_similarities = cosine_similarity(query_embedding, embeddings)

    # Get the index of the most similar issue
    most_similar_indices = np.argsort(cosine_similarities)
    most_similar_indices = most_similar_indices[0][::-1]
    return most_similar_indices


def print_issue(issues, issue_id):
    # Get the issue id of the most similar issue
    issue_info = issues[issue_id]

    print(f"#{issue_id}", issue_info["title"])
    print(issue_info["body"])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("query", type=str)
    parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2")
    parser.add_argument("--input_embedding_filename", type=str, default="issue_embeddings.npy")
    parser.add_argument("--input_index_filename", type=str, default="embedding_index_to_issue.json")

    args = parser.parse_args()

    issue_rankings = retrieve_issue_rankings(
        query=args.query,
        model_id=args.model_id,
        input_embedding_filename=args.input_embedding_filename,
    )

    with open("issues_dict.json", "r") as f:
        issues = json.load(f)

    with open(args.input_index_filename, "r") as f:
        embedding_index_to_issue = json.load(f)

    issue_ids = [embedding_index_to_issue[str(i)] for i in issue_rankings]

    for issue_id in issue_ids[:3]:
        print(issue_id)
        print_issue(issues, issue_id)
        print("\n\n\n")