Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
from init import ( | |
get_secrets, initialize_data, | |
update_dataframe, initialize_repos | |
) | |
from gen.openllm import GradioMistralChatPPManager | |
from gen.gemini_chat import GradioGeminiChatPPManager | |
from constants.js import ( | |
UPDATE_SEARCH_RESULTS, OPEN_CHAT_IF, | |
CLOSE_CHAT_IF, UPDATE_CHAT_HISTORY | |
) | |
from datetime import datetime, timedelta | |
from background import process_arxiv_ids | |
from apscheduler.schedulers.background import BackgroundScheduler | |
gemini_api_key, hf_token, dataset_repo_id, request_arxiv_repo_id, restart_repo_id = get_secrets() | |
initialize_repos(dataset_repo_id, request_arxiv_repo_id, hf_token) | |
titles, date_dict, requested_arxiv_ids_df, arxivid2data = initialize_data(dataset_repo_id, request_arxiv_repo_id) | |
from ui import ( | |
get_paper_by_year, get_paper_by_month, get_paper_by_day, | |
set_papers, set_paper, set_date, change_exp_type, add_arxiv_ids_to_queue, | |
before_chat_begin, chat_stream, chat_reset | |
) | |
if len(date_dict.keys()) > 0: | |
sorted_year = sorted(date_dict.keys()) | |
last_year = sorted_year[-1] if len(sorted_year) > 0 else "" | |
sorted_month = sorted(date_dict[last_year].keys()) | |
last_month = sorted_month[-1] if len(sorted_year) > 0 else "" | |
sorted_day = sorted(date_dict[last_year][last_month].keys()) | |
last_day = sorted_day[-1] if len(sorted_year) > 0 else "" | |
last_papers = date_dict[last_year][last_month][last_day] if len(sorted_year) > 0 else [""] | |
selected_paper = last_papers[0] | |
visible = True | |
else: | |
sorted_year = ["2024"] | |
last_year = sorted_year[-1] | |
sorted_month = ["01"] | |
last_month = sorted_month[-1] | |
sorted_day = ["01"] | |
last_day = sorted_day[-1] | |
selected_paper = {} | |
selected_paper["title"] = "" | |
selected_paper["summary"] = "" | |
selected_paper["arxiv_id"] = "" | |
selected_paper["target_date"] = "2024-01-01" | |
for idx in range(10): | |
selected_paper[f"{idx}_question"] = "" | |
selected_paper[f"{idx}_answers:eli5"] = "" | |
selected_paper[f"{idx}_answers:expert"] = "" | |
selected_paper[f"{idx}_additional_depth_q:follow up question"] = "" | |
selected_paper[f"{idx}_additional_depth_q:answers:eli5"] = "" | |
selected_paper[f"{idx}_additional_depth_q:answers:expert"] = "" | |
selected_paper[f"{idx}_additional_breath_q:follow up question"] = "" | |
selected_paper[f"{idx}_additional_breath_q:answers:eli5"] = "" | |
selected_paper[f"{idx}_additional_breath_q:answers:expert"] = "" | |
last_papers = [selected_paper] | |
visible = False | |
with gr.Blocks(css="constants/styles.css", theme=gr.themes.Soft()) as demo: | |
cur_arxiv_id = gr.Textbox(selected_paper['arxiv_id'], visible=False) | |
local_data = gr.JSON({}, visible=False) | |
chat_state = gr.State({ | |
"ppmanager_type": GradioGeminiChatPPManager # GradioMistralChatPPManager # GradioLLaMA2ChatPPManager | |
}) | |
with gr.Column(elem_id="chatbot-back"): | |
with gr.Column(elem_id="chatbot", elem_classes=["hover-opacity"]): | |
close = gr.Button("π", elem_id="chatbot-right-button") #elem_id="chatbot-right-button") | |
chatbot = gr.Chatbot( | |
label="Gemini 1.0 Pro", show_label=True, | |
show_copy_button=True, show_share_button=True, | |
visible=True, elem_id="chatbot-inside" | |
) | |
with gr.Row(elem_id="chatbot-bottm"): | |
reset = gr.Button("ποΈ Reset") | |
regen = gr.Button("π Regenerate", visible=False) | |
prompt_txtbox = gr.Textbox(placeholder="Ask anything.....", elem_id="chatbot-txtbox", elem_classes=["textbox-no-label"]) | |
gr.Markdown("# Let's explore papers with auto generated Q&As") | |
with gr.Column(elem_id="control-panel", elem_classes=["group"], visible=visible): | |
with gr.Column(): | |
with gr.Row(): | |
year_dd = gr.Dropdown(sorted_year, value=last_year, label="Year", interactive=True, filterable=False) | |
month_dd = gr.Dropdown(sorted_month, value=last_month, label="Month", interactive=True, filterable=False) | |
day_dd = gr.Dropdown(sorted_day, value=last_day, label="Day", interactive=True, filterable=False) | |
papers_dd = gr.Dropdown( | |
list(set([paper["title"] for paper in last_papers])), | |
value=selected_paper["title"], | |
label="Select paper title", | |
interactive=True, | |
filterable=False | |
) | |
with gr.Column(elem_classes=["no-gap"]): | |
search_in = gr.Textbox("", placeholder="Enter keywords to search...", elem_classes=["textbox-no-label"]) | |
search_r1 = gr.Button(visible=False, elem_id="search_r1", elem_classes=["no-radius"]) | |
search_r2 = gr.Button(visible=False, elem_id="search_r2", elem_classes=["no-radius"]) | |
search_r3 = gr.Button(visible=False, elem_id="search_r3", elem_classes=["no-radius"]) | |
search_r4 = gr.Button(visible=False, elem_id="search_r4", elem_classes=["no-radius"]) | |
search_r5 = gr.Button(visible=False, elem_id="search_r5", elem_classes=["no-radius"]) | |
search_r6 = gr.Button(visible=False, elem_id="search_r6", elem_classes=["no-radius"]) | |
search_r7 = gr.Button(visible=False, elem_id="search_r7", elem_classes=["no-radius"]) | |
search_r8 = gr.Button(visible=False, elem_id="search_r8", elem_classes=["no-radius"]) | |
search_r9 = gr.Button(visible=False, elem_id="search_r9", elem_classes=["no-radius"]) | |
search_r10 = gr.Button(visible=False, elem_id="search_r10", elem_classes=["no-radius"]) | |
with gr.Column(scale=7, visible=visible): | |
title = gr.Markdown(f"# {selected_paper['title']}", elem_classes=["markdown-center"]) | |
# with gr.Row(): | |
with gr.Row(): | |
arxiv_link = gr.Markdown( | |
"[![arXiv](https://img.shields.io/badge/arXiv-%s-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/%s)" % (selected_paper['arxiv_id'], selected_paper['arxiv_id']) + " " | |
"[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-lg.svg)](https://huggingface.co/papers/%s)" % selected_paper['arxiv_id'] + " ", | |
elem_id="link-md", | |
) | |
chat_button = gr.Button("Chat about any custom questions", interactive=True, elem_id="chat-button") | |
summary = gr.Markdown(f"{selected_paper['summary']}", elem_classes=["small-font"]) | |
with gr.Column(elem_id="qna_block", visible=True): | |
with gr.Row(): | |
with gr.Column(scale=7): | |
gr.Markdown("## Auto generated Questions & Answers") | |
exp_type = gr.Radio(choices=["ELI5", "Technical"], value="ELI5", elem_classes=["exp-type"], scale=3) | |
# 1 | |
with gr.Column(elem_classes=["group"], visible=True) as q_0: | |
basic_q_0 = gr.Markdown(f"### π {selected_paper['0_question']}") | |
basic_q_eli5_0 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['0_answers:eli5']}", elem_classes=["small-font"]) | |
basic_q_expert_0 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['0_answers:expert']}", visible=False, elem_classes=["small-font"]) | |
with gr.Accordion("Additional question #1", open=False, elem_classes=["accordion"]) as aq_0_0: | |
depth_q_0 = gr.Markdown(f"### ππ {selected_paper['0_additional_depth_q:follow up question']}") | |
depth_q_eli5_0 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['0_additional_depth_q:answers:eli5']}", elem_classes=["small-font"]) | |
depth_q_expert_0 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['0_additional_depth_q:answers:expert']}", visible=False, elem_classes=["small-font"]) | |
with gr.Accordion("Additional question #2", open=False, elem_classes=["accordion"]) as aq_0_1: | |
breath_q_0 = gr.Markdown(f"### ππ {selected_paper['0_additional_breath_q:follow up question']}") | |
breath_q_eli5_0 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['0_additional_breath_q:answers:eli5']}", elem_classes=["small-font"]) | |
breath_q_expert_0 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['0_additional_breath_q:answers:expert']}", visible=False, elem_classes=["small-font"]) | |
# 2 | |
with gr.Column(elem_classes=["group"], visible=True) as q_1: | |
basic_q_1 = gr.Markdown(f"### π {selected_paper['1_question']}") | |
basic_q_eli5_1 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['1_answers:eli5']}", elem_classes=["small-font"]) | |
basic_q_expert_1 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['1_answers:expert']}", visible=False, elem_classes=["small-font"]) | |
with gr.Accordion("Additional question #1", open=False, elem_classes=["accordion"]) as aq_1_0: | |
depth_q_1 = gr.Markdown(f"### ππ {selected_paper['1_additional_depth_q:follow up question']}") | |
depth_q_eli5_1 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['1_additional_depth_q:answers:eli5']}", elem_classes=["small-font"]) | |
depth_q_expert_1 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['1_additional_depth_q:answers:expert']}", visible=False, elem_classes=["small-font"]) | |
with gr.Accordion("Additional question #2", open=False, elem_classes=["accordion"]) as aq_1_1: | |
breath_q_1 = gr.Markdown(f"### ππ {selected_paper['1_additional_breath_q:follow up question']}") | |
breath_q_eli5_1 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['1_additional_breath_q:answers:eli5']}", elem_classes=["small-font"]) | |
breath_q_expert_1 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['1_additional_breath_q:answers:expert']}", visible=False, elem_classes=["small-font"]) | |
# 3 | |
with gr.Column(elem_classes=["group"], visible=True) as q_2: | |
basic_q_2 = gr.Markdown(f"### π {selected_paper['2_question']}") | |
basic_q_eli5_2 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['2_answers:eli5']}", elem_classes=["small-font"]) | |
basic_q_expert_2 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['2_answers:expert']}", visible=False, elem_classes=["small-font"]) | |
with gr.Accordion("Additional question #1", open=False, elem_classes=["accordion"]) as aq_2_0: | |
depth_q_2 = gr.Markdown(f"### ππ {selected_paper['2_additional_depth_q:follow up question']}") | |
depth_q_eli5_2 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['2_additional_depth_q:answers:eli5']}", elem_classes=["small-font"]) | |
depth_q_expert_2 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['2_additional_depth_q:answers:expert']}", visible=False, elem_classes=["small-font"]) | |
with gr.Accordion("Additional question #2", open=False, elem_classes=["accordion"]) as aq_2_1: | |
breath_q_2 = gr.Markdown(f"### ππ {selected_paper['2_additional_breath_q:follow up question']}") | |
breath_q_eli5_2 = gr.Markdown(f"βͺ **(ELI5)** {selected_paper['2_additional_breath_q:answers:eli5']}", elem_classes=["small-font"]) | |
breath_q_expert_2 = gr.Markdown(f"βͺ **(Technical)** {selected_paper['2_additional_breath_q:answers:expert']}", visible=False, elem_classes=["small-font"]) | |
gr.Markdown("## Request any arXiv ids") | |
arxiv_queue = gr.Dataframe( | |
headers=["Requested arXiv IDs"], col_count=(1, "fixed"), | |
value=update_dataframe, | |
every=180, | |
datatype=["str"], | |
interactive=False, | |
) | |
arxiv_id_enter = gr.Textbox(placeholder="Enter comma separated arXiv IDs...", elem_classes=["textbox-no-label"]) | |
arxiv_id_enter.submit( | |
add_arxiv_ids_to_queue, | |
[arxiv_queue, arxiv_id_enter], | |
[arxiv_queue, arxiv_id_enter], | |
concurrency_limit=20, | |
) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") | |
gr.Markdown("The target papers are collected from [Hugging Face π€ Daily Papers](https://huggingface.co/papers) on a daily basis. " | |
"The entire data is generated by [Google's Gemini 1.0](https://deepmind.google/technologies/gemini/) Pro. " | |
"If you are curious how it is done, visit the [Auto Paper Q&A Generation project repository](https://github.com/deep-diver/auto-paper-analysis) " | |
"Also, the generated dataset is hosted on Hugging Face π€ Dataset repository as well([Link](https://huggingface.co/datasets/chansung/auto-paper-qa2)). ") | |
search_r1.click(set_date, search_r1, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r1], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r2.click(set_date, search_r2, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r2], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r3.click(set_date, search_r3, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r3], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r4.click(set_date, search_r4, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r4], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r5.click(set_date, search_r5, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r5], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r6.click(set_date, search_r6, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r6], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r7.click(set_date, search_r7, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r7], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r8.click(set_date, search_r8, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r8], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r9.click(set_date, search_r9, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r9], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
search_r10.click(set_date, search_r10, [year_dd, month_dd, day_dd]).then( | |
set_papers, | |
inputs=[year_dd, month_dd, day_dd, search_r10], | |
outputs=[cur_arxiv_id, papers_dd, search_in], | |
concurrency_limit=20, | |
) | |
year_dd.input(get_paper_by_year, inputs=[year_dd], outputs=[month_dd, day_dd, papers_dd]).then( | |
set_paper, [year_dd, month_dd, day_dd, papers_dd], | |
[ | |
cur_arxiv_id, | |
title, arxiv_link, summary, | |
basic_q_0, basic_q_eli5_0, basic_q_expert_0, | |
depth_q_0, depth_q_eli5_0, depth_q_expert_0, | |
breath_q_0, breath_q_eli5_0, breath_q_expert_0, | |
basic_q_1, basic_q_eli5_1, basic_q_expert_1, | |
depth_q_1, depth_q_eli5_1, depth_q_expert_1, | |
breath_q_1, breath_q_eli5_1, breath_q_expert_1, | |
basic_q_2, basic_q_eli5_2, basic_q_expert_2, | |
depth_q_2, depth_q_eli5_2, depth_q_expert_2, | |
breath_q_2, breath_q_eli5_2, breath_q_expert_2 | |
], | |
concurrency_limit=20, | |
) | |
month_dd.input(get_paper_by_month, inputs=[year_dd, month_dd], outputs=[day_dd, papers_dd]).then( | |
set_paper, [year_dd, month_dd, day_dd, papers_dd], | |
[ | |
cur_arxiv_id, | |
title, arxiv_link, summary, | |
basic_q_0, basic_q_eli5_0, basic_q_expert_0, | |
depth_q_0, depth_q_eli5_0, depth_q_expert_0, | |
breath_q_0, breath_q_eli5_0, breath_q_expert_0, | |
basic_q_1, basic_q_eli5_1, basic_q_expert_1, | |
depth_q_1, depth_q_eli5_1, depth_q_expert_1, | |
breath_q_1, breath_q_eli5_1, breath_q_expert_1, | |
basic_q_2, basic_q_eli5_2, basic_q_expert_2, | |
depth_q_2, depth_q_eli5_2, depth_q_expert_2, | |
breath_q_2, breath_q_eli5_2, breath_q_expert_2 | |
], | |
concurrency_limit=20, | |
) | |
day_dd.input(get_paper_by_day, inputs=[year_dd, month_dd, day_dd], outputs=[papers_dd]).then( | |
set_paper, [year_dd, month_dd, day_dd, papers_dd], | |
[ | |
cur_arxiv_id, | |
title, arxiv_link, summary, | |
basic_q_0, basic_q_eli5_0, basic_q_expert_0, | |
depth_q_0, depth_q_eli5_0, depth_q_expert_0, | |
breath_q_0, breath_q_eli5_0, breath_q_expert_0, | |
basic_q_1, basic_q_eli5_1, basic_q_expert_1, | |
depth_q_1, depth_q_eli5_1, depth_q_expert_1, | |
breath_q_1, breath_q_eli5_1, breath_q_expert_1, | |
basic_q_2, basic_q_eli5_2, basic_q_expert_2, | |
depth_q_2, depth_q_eli5_2, depth_q_expert_2, | |
breath_q_2, breath_q_eli5_2, breath_q_expert_2 | |
], | |
concurrency_limit=20, | |
) | |
papers_dd.change(set_paper, [year_dd, month_dd, day_dd, papers_dd], | |
[ | |
cur_arxiv_id, | |
title, arxiv_link, summary, | |
basic_q_0, basic_q_eli5_0, basic_q_expert_0, | |
depth_q_0, depth_q_eli5_0, depth_q_expert_0, | |
breath_q_0, breath_q_eli5_0, breath_q_expert_0, | |
basic_q_1, basic_q_eli5_1, basic_q_expert_1, | |
depth_q_1, depth_q_eli5_1, depth_q_expert_1, | |
breath_q_1, breath_q_eli5_1, breath_q_expert_1, | |
basic_q_2, basic_q_eli5_2, basic_q_expert_2, | |
depth_q_2, depth_q_eli5_2, depth_q_expert_2, | |
breath_q_2, breath_q_eli5_2, breath_q_expert_2 | |
], | |
concurrency_limit=20, | |
) | |
search_in.change( | |
inputs=[search_in], | |
outputs=[ | |
search_r1, search_r2, search_r3, search_r4, search_r5, | |
search_r6, search_r7, search_r8, search_r9, search_r10 | |
], | |
js=UPDATE_SEARCH_RESULTS % str(list(titles)), | |
fn=None | |
) | |
exp_type.select( | |
change_exp_type, | |
exp_type, | |
[ | |
basic_q_eli5_0, basic_q_expert_0, depth_q_eli5_0, depth_q_expert_0, breath_q_eli5_0, breath_q_expert_0, | |
basic_q_eli5_1, basic_q_expert_1, depth_q_eli5_1, depth_q_expert_1, breath_q_eli5_1, breath_q_expert_1, | |
basic_q_eli5_2, basic_q_expert_2, depth_q_eli5_2, depth_q_expert_2, breath_q_eli5_2, breath_q_expert_2 | |
], | |
concurrency_limit=20, | |
) | |
chat_button.click(None, [cur_arxiv_id], [local_data, chatbot], js=OPEN_CHAT_IF) | |
chat_event1 = prompt_txtbox.submit( | |
before_chat_begin, None, [reset, regen], | |
concurrency_limit=20, | |
) | |
chat_event2 = chat_event1.then( | |
chat_stream, | |
[cur_arxiv_id, local_data, prompt_txtbox, chat_state], | |
[prompt_txtbox, chatbot, local_data, reset, regen], | |
concurrency_limit=20, queue=True | |
) | |
chat_event2.then( | |
None, [cur_arxiv_id, local_data], None, | |
js=UPDATE_CHAT_HISTORY | |
) | |
close.click( | |
None, None, None,js=CLOSE_CHAT_IF | |
) | |
close.click( | |
None, None, None, cancels=[chat_event1, chat_event2] | |
) | |
reset.click( | |
before_chat_begin, None, [reset, regen], | |
concurrency_limit=20, | |
).then( | |
chat_reset, | |
[local_data, chat_state], | |
[prompt_txtbox, chatbot, local_data, reset, regen], | |
concurrency_limit=20, | |
).then( | |
None, [cur_arxiv_id, local_data], None, | |
js=UPDATE_CHAT_HISTORY | |
) | |
# demo.load(lambda: update_dataframe(request_arxiv_repo_id), None, arxiv_queue, every=180) | |
# demo.load(None, None, [chatbot, local_data], js=GET_LOCAL_STORAGE % idx.value) | |
start_date = datetime.now() + timedelta(minutes=1) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job( | |
process_arxiv_ids, | |
trigger='interval', | |
seconds=300, | |
args=[ | |
gemini_api_key, | |
dataset_repo_id, | |
request_arxiv_repo_id, | |
hf_token, | |
restart_repo_id | |
], | |
start_date=start_date | |
) | |
scheduler.start() | |
demo.queue( | |
default_concurrency_limit=20, | |
max_size=256 | |
).launch( | |
share=True, debug=True | |
) |