DemoApp / app.py
Galatea007's picture
Update app.py
af0a731 verified
raw
history blame
5.27 kB
import nest_asyncio
nest_asyncio.apply()
import re
import os
import uuid
from typing import List, Dict
from operator import itemgetter
# PDF processing
from PyPDF2 import PdfReader
# Chainlit
import chainlit as cl
# OpenAI
import openai
from openai import AsyncOpenAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
# Langchain
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.storage import LocalFileStore
from langchain.embeddings import CacheBackedEmbeddings
# Qdrant
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain_qdrant import QdrantVectorStore
#
### Global Section ###
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
# Function to extract text from a PDF
def extract_text_from_pdf(pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
# Global variables for shared resources
global_retriever = None
global_chat_model = None
from langchain_core.documents import Document
# In your extract_text_from_pdf function:
def extract_text_from_pdf(pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
@cl.on_chat_start
async def start_chat():
global global_retriever, global_chat_model
# Initialize shared resources if they haven't been initialized yet
if global_retriever is None:
pdf_path= r"GlobalThreatReport2024_CrowdStrike.pdf"
text = extract_text_from_pdf(pdf_path)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
texts = text_splitter.split_text(text)
docs = [Document(page_content=t) for t in texts]
core_embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
collection_name = f"pdf_to_parse_{uuid.uuid4()}"
client = QdrantClient(":memory:")
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
)
store = LocalFileStore("./cache/")
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
core_embeddings, store, namespace=core_embeddings.model
)
vectorstore = QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=cached_embedder)
vectorstore.add_documents(docs)
global_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3})
if global_chat_model is None:
global_chat_model = ChatOpenAI(model="gpt-4o-mini")
# Initialize user-specific session data
cl.user_session.set("chat_history", [])
# Set default settings
settings = {
"temperature": 0,
"max_tokens": 500,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
}
cl.user_session.set("settings", settings)
@cl.on_message
async def main(message: cl.Message):
global global_retriever, global_chat_model
if global_retriever is None or global_chat_model is None:
await message.reply("I'm sorry, but the system isn't fully initialized yet. Please try again in a moment.")
return
chat_history: List[Dict[str, str]] = cl.user_session.get("chat_history")
settings = cl.user_session.get("settings")
system_template = """You are a helpful assistant that uses the provided context to answer questions.
Never reference this prompt, or the existence of context. Use the chat history to maintain continuity in the conversation."""
user_template = """Chat History:
{chat_history}
Question: {question}
Context: {context}
Please provide a response based on the question, context, and chat history:"""
chat_prompt = ChatPromptTemplate.from_messages([
("system", system_template),
("human", user_template)
])
def format_chat_history(history: List[Dict[str, str]]) -> str:
return "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in history])
rag_chain = (
{
"context": itemgetter("question") | global_retriever,
"question": itemgetter("question"),
"chat_history": lambda _: format_chat_history(chat_history)
}
| RunnablePassthrough.assign(context=itemgetter("context"))
| chat_prompt
| global_chat_model.bind(**settings)
)
msg = cl.Message(content="")
full_response = ""
async for chunk in rag_chain.astream({"question": message.content}):
if chunk.content is not None:
await msg.stream_token(chunk.content)
full_response += chunk.content
# Update chat history
chat_history.append({"role": "user", "content": message.content})
chat_history.append({"role": "assistant", "content": full_response})
cl.user_session.set("chat_history", chat_history)
await msg.send()