Weedoo's picture
add async
dacd607 verified
import pandas as pd
import arxiv
import requests
from pinecone import Pinecone, ServerlessSpec
import logging
import os
import asyncio
from dotenv import load_dotenv
load_dotenv(".env")
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
def get_zotero_ids(api_key, library_id, tag):
base_url = "https://api.zotero.org"
suffix = "/users/" + library_id + "/items?tag=" + tag
header = {"Authorization": "Bearer " + api_key}
request = requests.get(base_url + suffix, headers=header)
return [data["data"]["archiveID"].replace("arXiv:", "") for data in request.json()]
def get_arxiv_papers(ids=None, category=None, comment=None):
logging.getLogger("arxiv").setLevel(logging.WARNING)
client = arxiv.Client()
if category is None:
search = arxiv.Search(
id_list=ids,
max_results=len(ids),
)
else:
if comment is None:
custom_query = f"cat:{category}"
else:
custom_query = f"cat:{category} AND co:{comment}"
search = arxiv.Search(
query=custom_query,
max_results=15,
sort_by=arxiv.SortCriterion.SubmittedDate,
)
if ids is None and category is None:
raise ValueError("not a valid query")
df = pd.DataFrame(
{
"Title": [result.title for result in client.results(search)],
"Abstract": [
result.summary.replace("\n", " ") for result in client.results(search)
],
"Date": [
result.published.date().strftime("%Y-%m-%d")
for result in client.results(search)
],
"id": [result.entry_id for result in client.results(search)],
}
)
if ids:
df.to_csv("arxiv-scrape.csv", index=False)
return df
def get_hf_embeddings(api_key, df):
title_abs = [
title + "[SEP]" + abstract
for title, abstract in zip(df["Title"], df["Abstract"])
]
API_URL = "https://api-inference.huggingface.co/models/malteos/scincl"
headers = {"Authorization": f"Bearer {api_key}"}
response = requests.post(
API_URL, headers=headers, json={"inputs": title_abs, "wait_for_model": False}
)
print(str(response.status_code) + "This part needs an update, causing KeyError 0")
if response.status_code == 503:
response = asyncio.run(
asyncio.to_thread(
requests.post,
API_URL,
headers=headers,
json={"inputs": title_abs, "wait_for_model": True},
)
)
# response = requests.post(
# API_URL, headers=headers, json={"inputs": title_abs, "wait_for_model": True}
# )
embeddings = response.json()
return embeddings, len(embeddings[0])
def upload_to_pinecone(api_key, index, namespace, embeddings, dim, df):
input = [
{"id": df["id"][i], "values": embeddings[i]} for i in range(len(embeddings))
]
pc = Pinecone(api_key=api_key)
if index in pc.list_indexes().names():
while True:
logging.warning(f"Index name : {index} already exists.")
return f"Index name : {index} already exists"
pc.create_index(
name=index,
dimension=dim,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
)
index = pc.Index(index)
return index.upsert(vectors=input, namespace=namespace)
def main():
script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
logging.basicConfig(
filename="logs/logfile.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logging.getLogger("arxiv").setLevel(logging.WARNING)
logging.info("Project Initialization Script Started (Serverless)")
ids = get_zotero_ids(
os.getenv("ZOTERO_API_KEY"),
os.getenv("ZOTERO_LIBRARY_ID"),
os.getenv("ZOTERO_TAG"),
)
print(ids)
df = get_arxiv_papers(ids=ids)
embeddings, dim = get_hf_embeddings(os.getenv("HF_API_KEY"), df)
feedback = upload_to_pinecone(
api_key=os.getenv("PINECONE_API_KEY"),
index=os.getenv("INDEX_NAME"),
namespace=os.getenv("NAMESPACE_NAME"),
embeddings=embeddings,
dim=dim,
df=df,
)
logging.info(feedback)
if feedback is dict:
return f"Retrieved {len(ids)} papers from Zotero. Successfully upserted {feedback['upserted_count']} embeddings in {os.getenv('NAMESPACE_NAME')} namespace."
else:
return feedback
def get_new_papers(df):
df_main = pd.read_csv("arxiv-scrape.csv")
df.reset_index(inplace=True)
df.drop(columns=["index"], inplace=True)
union_df = df.merge(df_main, how="left", indicator=True)
df = union_df[union_df["_merge"] == "left_only"].drop(columns=["_merge"])
if df.empty:
return "No New Papers Found"
else:
df_main = pd.concat([df_main, df], ignore_index=True)
df_main.drop_duplicates(inplace=True)
df_main.to_csv("arxiv-scrape.csv", index=False)
return df
def recommend_papers(api_key, index, namespace, embeddings, df, threshold):
pc = Pinecone(api_key=api_key)
if index in pc.list_indexes().names():
index = pc.Index(index)
else:
raise ValueError(f"{index} doesnt exist. Project isnt initialized properly")
results = []
score_threshold = threshold
for i, embedding in enumerate(embeddings):
query = embedding
result = index.query(
namespace=namespace, vector=query, top_k=3, include_values=False
)
sum_score = sum(match["score"] for match in result["matches"])
if sum_score > score_threshold:
results.append(
f"Paper-URL : [{df['id'][i]}]({df['id'][i]}) with score: {sum_score / 3} <br />"
)
if results:
return "\n".join(results)
else:
return "No Interesting Paper"
def recs(threshold):
logging.info("Weekly Script Started (Serverless)")
df = get_arxiv_papers(
category=os.getenv("ARXIV_CATEGORY_NAME"),
comment=os.getenv("ARXIV_COMMENT_QUERY"),
)
df = get_new_papers(df)
if not isinstance(df, pd.DataFrame):
return df
embeddings, _ = get_hf_embeddings(os.getenv("HF_API_KEY"), df)
results = recommend_papers(
os.getenv("PINECONE_API_KEY"),
os.getenv("INDEX_NAME"),
os.getenv("NAMESPACE_NAME"),
embeddings,
df,
threshold,
)
return results
if __name__ == "__main__":
choice = int(input("1. Initialize\n2. Recommend Papers\n"))
if choice == 1:
print(main())
elif choice == 2:
threshold = float(input("Enter Similarity Threshold"))
print(recs(threshold))
else:
raise ValueError("Invalid Input")