Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import operator | |
import datasets | |
import pandas as pd | |
from huggingface_hub import HfApi | |
from ragatouille import RAGPretrainedModel | |
api = HfApi() | |
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/ICLR2024-papers-abstract-index/" | |
api.snapshot_download( | |
repo_id="ICLR2024/ICLR2024-papers-abstract-index", | |
repo_type="dataset", | |
local_dir=INDEX_DIR_PATH, | |
) | |
ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH) | |
# Run once to initialize the retriever | |
ABSTRACT_RETRIEVER.search("LLM") | |
class PaperList: | |
COLUMN_INFO = [ | |
["Title", "str"], | |
["Authors", "str"], | |
["Type", "str"], | |
["Paper page", "markdown"], | |
["π", "number"], | |
["π¬", "number"], | |
["OpenReview", "markdown"], | |
["Project page", "markdown"], | |
["GitHub", "markdown"], | |
["Spaces", "markdown"], | |
["Models", "markdown"], | |
["Datasets", "markdown"], | |
["claimed", "markdown"], | |
] | |
def __init__(self): | |
self.df_raw = self.get_df() | |
self.df_prettified = self.prettify(self.df_raw) | |
def get_df() -> pd.DataFrame: | |
df = pd.merge( | |
left=datasets.load_dataset("ICLR2024/ICLR2024-papers", split="train").to_pandas(), | |
right=datasets.load_dataset("ICLR2024/ICLR2024-paper-stats", split="train").to_pandas(), | |
on="id", | |
how="left", | |
) | |
keys = ["n_authors", "n_linked_authors", "upvotes", "num_comments"] | |
df[keys] = df[keys].fillna(-1).astype(int) | |
df["paper_page"] = df["arxiv_id"].apply( | |
lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else "" | |
) | |
return df | |
def create_link(text: str, url: str) -> str: | |
return f'<a href="{url}" target="_blank">{text}</a>' | |
def prettify(df: pd.DataFrame) -> pd.DataFrame: | |
rows = [] | |
for _, row in df.iterrows(): | |
author_linked = "β " if row.n_linked_authors > 0 else "" | |
n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors | |
n_authors = "" if row.n_authors == -1 else row.n_authors | |
claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}" | |
upvotes = "" if row.upvotes == -1 else row.upvotes | |
num_comments = "" if row.num_comments == -1 else row.num_comments | |
new_row = { | |
"Title": row["title"], | |
"Authors": ", ".join(row["authors"]), | |
"Type": row["type"], | |
"Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]), | |
"Project page": ( | |
PaperList.create_link("Project page", row["project_page"]) if row["project_page"] else "" | |
), | |
"π": upvotes, | |
"π¬": num_comments, | |
"OpenReview": PaperList.create_link("OpenReview", row["OpenReview"]), | |
"GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]]), | |
"Spaces": "\n".join( | |
[ | |
PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}") | |
for repo_id in row["Space"] | |
] | |
), | |
"Models": "\n".join( | |
[PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]] | |
), | |
"Datasets": "\n".join( | |
[ | |
PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}") | |
for repo_id in row["Dataset"] | |
] | |
), | |
"claimed": claimed_paper, | |
} | |
rows.append(new_row) | |
return pd.DataFrame(rows, columns=PaperList.get_column_names()) | |
def get_column_names(): | |
return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO)) | |
def get_column_datatypes(self, column_names: list[str]) -> list[str]: | |
mapping = dict(self.COLUMN_INFO) | |
return [mapping[name] for name in column_names] | |
def search( | |
self, | |
title_search_query: str, | |
abstract_search_query: str, | |
max_num_to_retrieve: int, | |
filter_names: list[str], | |
presentation_type: str, | |
columns_names: list[str], | |
) -> pd.DataFrame: | |
df = self.df_raw.copy() | |
# As ragatouille uses str for document_id | |
df["id"] = df["id"].astype(str) | |
# Filter by title | |
df = df[df["title"].str.contains(title_search_query, case=False)] | |
# Filter by presentation type | |
if presentation_type != "(ALL)": | |
df = df[df["type"] == presentation_type] | |
if "Paper page" in filter_names: | |
df = df[df["paper_page"] != ""] | |
if "GitHub" in filter_names: | |
df = df[df["GitHub"].apply(len) > 0] | |
if "Space" in filter_names: | |
df = df[df["Space"].apply(len) > 0] | |
if "Model" in filter_names: | |
df = df[df["Model"].apply(len) > 0] | |
if "Dataset" in filter_names: | |
df = df[df["Dataset"].apply(len) > 0] | |
# Filter by abstract | |
if abstract_search_query: | |
results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve) | |
remaining_ids = set(map(str, df["id"])) | |
found_id_set = set() | |
found_ids = [] | |
for x in results: | |
paper_id = x["document_id"] | |
if paper_id not in remaining_ids: | |
continue | |
if paper_id in found_id_set: | |
continue | |
found_id_set.add(paper_id) | |
found_ids.append(paper_id) | |
df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index() | |
df_prettified = self.prettify(df) | |
return df_prettified.loc[:, columns_names] | |