Spaces:
Sleeping
Sleeping
import os | |
import time | |
from typing import Dict, List, Optional, TypeAlias | |
import gradio as gr | |
import torch | |
import weave | |
from transformers import pipeline | |
from papersai.utils import load_paper_as_context | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
HistoryType: TypeAlias = List[Dict[str, str]] | |
# Initialize the LLM and Weave client | |
client = weave.init("papersai") | |
checkpoint: str = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
pipe = pipeline( | |
model=checkpoint, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
class ChatState: | |
"""Utility class to store context and last response""" | |
def __init__(self): | |
self.context = None | |
self.last_response = None | |
def record_feedback(x: gr.LikeData) -> None: | |
""" | |
Logs user feedback on the assistant's response in the form of a | |
like/dislike reaction. | |
Reference: | |
* https://weave-docs.wandb.ai/guides/tracking/feedback | |
Args: | |
x (gr.LikeData): User feedback data | |
Returns: | |
None | |
""" | |
call = state.last_response | |
# Remove any existing feedback before adding new feedback | |
for existing_feedback in list(call.feedback): | |
call.feedback.purge(existing_feedback.id) | |
if x.liked: | |
call.feedback.add_reaction("π") | |
else: | |
call.feedback.add_reaction("π") | |
def invoke(history: HistoryType): | |
""" | |
Simple wrapper around llm inference wrapped in a weave op | |
Args: | |
history (HistoryType): Chat history | |
Returns: | |
BaseMessage: Response from the model | |
""" | |
input_text = pipe.tokenizer.apply_chat_template( | |
history, | |
tokenize=False, | |
) | |
response = pipe(input_text, do_sample=True, top_p=0.95, max_new_tokens=100)[0][ | |
"generated_text" | |
] | |
response = response.split("\nassistant\n")[-1] | |
return response | |
def update_state(history: HistoryType, message: Optional[Dict[str, str]]): | |
""" | |
Update history and app state with the latest user input. | |
Args: | |
history (HistoryType): Chat history | |
message (Optional[Dict[str, str]]): User input message | |
Returns: | |
Tuple[HistoryType, gr.MultimodalTextbox]: Updated history and chat input | |
""" | |
if message is None: | |
return history, gr.MultimodalTextbox(value=None, interactive=True) | |
# Initialize history if None | |
if history is None: | |
history = [] | |
# Handle file uploads without adding to visible history | |
if isinstance(message, dict) and "files" in message: | |
for file_path in message["files"]: | |
try: | |
text = load_paper_as_context(file_path=file_path) | |
doc_context = [x.get_content() for x in text] | |
state.context = " ".join(doc_context)[ | |
: pipe.model.config.max_position_embeddings | |
] | |
history.append( | |
{"role": "system", "content": f"Context: {state.context}\n"} | |
) | |
except Exception as e: | |
history.append( | |
{"role": "assistant", "content": f"Error loading file: {str(e)}"} | |
) | |
# Handle text input | |
if isinstance(message, dict) and message.get("text"): | |
history.append({"role": "user", "content": message["text"]}) | |
return history, gr.MultimodalTextbox(value=None, interactive=True) | |
def bot(history: HistoryType): | |
""" | |
Generate response from the LLM and stream it back to the user. | |
Args: | |
history (HistoryType): Chat history | |
Yields: | |
response from the LLM | |
""" | |
if not history: | |
return history | |
try: | |
# Get response from LLM | |
response, call = invoke.call(history) | |
state.last_response = call | |
# Add empty assistant message | |
history.append({"role": "assistant", "content": ""}) | |
# Stream the response | |
for character in response: | |
history[-1]["content"] += character | |
time.sleep(0.02) | |
yield history | |
except Exception as e: | |
history.append({"role": "assistant", "content": f"Error: {str(e)}"}) | |
yield history | |
def create_interface(): | |
with gr.Blocks() as demo: | |
global state | |
state = ChatState() | |
gr.Markdown( | |
""" | |
<a href="https://github.com/SauravMaheshkar/papersai"> | |
<div align="center"><h1>papers.ai</h1></div> | |
</a> | |
""", | |
) | |
chatbot = gr.Chatbot( | |
show_label=False, | |
height=600, | |
type="messages", | |
show_copy_all_button=True, | |
placeholder="Upload a research paper and ask questions!!", | |
) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_count="single", | |
placeholder="Upload a document or type your message...", | |
show_label=False, | |
) | |
chat_msg = chat_input.submit( | |
fn=update_state, | |
inputs=[chatbot, chat_input], | |
outputs=[chatbot, chat_input], | |
) | |
bot_msg = chat_msg.then( # noqa: F841 | |
fn=bot, inputs=[chatbot], outputs=chatbot, api_name="bot_response" | |
) | |
chatbot.like( | |
fn=record_feedback, | |
inputs=None, | |
outputs=None, | |
like_user_message=True, | |
) | |
return demo | |
def main(): | |
demo = create_interface() | |
demo.launch(share=False) | |
if __name__ == "__main__": | |
main() | |