Spaces:
Runtime error
Runtime error
# https://python.langchain.com/docs/tutorials/rag/ | |
import gradio as gr | |
from langchain import hub | |
from langchain_chroma import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_mistralai import MistralAIEmbeddings | |
from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_mistralai import ChatMistralAI | |
from langchain_community.document_loaders import PyPDFLoader | |
import requests | |
from pathlib import Path | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_community.retrievers import ArxivRetriever | |
import bs4 | |
from langchain_core.rate_limiters import InMemoryRateLimiter | |
from urllib.parse import urljoin | |
rate_limiter = InMemoryRateLimiter( | |
requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second | |
check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request, | |
max_bucket_size=10, # Controls the maximum burst size. | |
) | |
""" | |
# get data | |
urlsfile = open("urls.txt") | |
urls = urlsfile.readlines() | |
urls = [url.replace("\n","") for url in urls] | |
urlsfile.close() | |
# Load, chunk and index the contents of the blog. | |
loader = WebBaseLoader(urls) | |
docs = loader.load() | |
# load arxiv papers | |
arxivfile = open("arxiv.txt") | |
arxivs = arxivfile.readlines() | |
arxivs = [arxiv.replace("\n","") for arxiv in arxivs] | |
arxivfile.close() | |
retriever = ArxivRetriever( | |
load_max_docs=2, | |
get_ful_documents=True, | |
) | |
for arxiv in arxivs: | |
doc = retriever.invoke(arxiv) | |
doc[0].metadata['Published'] = str(doc[0].metadata['Published']) | |
docs.append(doc[0]) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
def RAG(llm, docs, embeddings): | |
# Split text | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
# Create vector store | |
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
# Retrieve and generate using the relevant snippets of the documents | |
retriever = vectorstore.as_retriever() | |
# Prompt basis example for RAG systems | |
prompt = hub.pull("rlm/rag-prompt") | |
# Create the chain | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
return rag_chain | |
# LLM model | |
llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter) | |
# Embeddings | |
embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1" | |
# embed_model = "nvidia/NV-Embed-v2" | |
embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model) | |
# embeddings = MistralAIEmbeddings() | |
# RAG chain | |
rag_chain = RAG(llm, docs, embeddings) | |
def handle_prompt(message, history): | |
try: | |
# Stream output | |
out="" | |
for chunk in rag_chain.stream(message): | |
out += chunk | |
yield out | |
except: | |
raise gr.Error("Requests rate limit exceeded") | |
""" | |
def handle_prompt(message, history): | |
print(message) | |
greetingsmessage = "Hi, I'm your personal arXiv reader. Input the arXiv number of the paper:" | |
demo = gr.ChatInterface(handle_prompt, type="messages", title="ChangBot", theme=gr.themes.Soft(), description=greetingsmessage) | |
demo.launch() | |
example_questions = [ | |
"Tell me more about SimBIG", | |
"How can you constrain neutrino mass with galaxies?", | |
"What is the DESI BGS?", | |
"What is SEDflow?", | |
"What are normalizing flows?" | |
] | |