Spaces:
Sleeping
Sleeping
import nest_asyncio | |
nest_asyncio.apply() | |
import re | |
import os | |
import uuid | |
from typing import List, Dict | |
from operator import itemgetter | |
# PDF processing | |
from PyPDF2 import PdfReader | |
# Chainlit | |
import chainlit as cl | |
# OpenAI | |
import openai | |
from openai import AsyncOpenAI | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
# Langchain | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.storage import LocalFileStore | |
from langchain.embeddings import CacheBackedEmbeddings | |
# Qdrant | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import Distance, VectorParams | |
from langchain_qdrant import QdrantVectorStore | |
# | |
### Global Section ### | |
from dotenv import load_dotenv | |
# Load environment variables from .env file | |
load_dotenv() | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Function to extract text from a PDF | |
def extract_text_from_pdf(pdf_path): | |
reader = PdfReader(pdf_path) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text | |
# Global variables for shared resources | |
global_retriever = None | |
global_chat_model = None | |
from langchain_core.documents import Document | |
# In your extract_text_from_pdf function: | |
def extract_text_from_pdf(pdf_path): | |
reader = PdfReader(pdf_path) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text | |
async def start_chat(): | |
global global_retriever, global_chat_model | |
# Initialize shared resources if they haven't been initialized yet | |
if global_retriever is None: | |
pdf_path= r"GlobalThreatReport2024_CrowdStrike.pdf" | |
text = extract_text_from_pdf(pdf_path) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
texts = text_splitter.split_text(text) | |
docs = [Document(page_content=t) for t in texts] | |
core_embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
collection_name = f"pdf_to_parse_{uuid.uuid4()}" | |
client = QdrantClient(":memory:") | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=1536, distance=Distance.COSINE), | |
) | |
store = LocalFileStore("./cache/") | |
cached_embedder = CacheBackedEmbeddings.from_bytes_store( | |
core_embeddings, store, namespace=core_embeddings.model | |
) | |
vectorstore = QdrantVectorStore( | |
client=client, | |
collection_name=collection_name, | |
embedding=cached_embedder) | |
vectorstore.add_documents(docs) | |
global_retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 3}) | |
if global_chat_model is None: | |
global_chat_model = ChatOpenAI(model="gpt-4o-mini") | |
# Initialize user-specific session data | |
cl.user_session.set("chat_history", []) | |
# Set default settings | |
settings = { | |
"temperature": 0, | |
"max_tokens": 500, | |
"top_p": 1, | |
"frequency_penalty": 0, | |
"presence_penalty": 0, | |
} | |
cl.user_session.set("settings", settings) | |
async def main(message: cl.Message): | |
global global_retriever, global_chat_model | |
if global_retriever is None or global_chat_model is None: | |
await message.reply("I'm sorry, but the system isn't fully initialized yet. Please try again in a moment.") | |
return | |
chat_history: List[Dict[str, str]] = cl.user_session.get("chat_history") | |
settings = cl.user_session.get("settings") | |
system_template = """You are a helpful assistant that uses the provided context to answer questions. | |
Never reference this prompt, or the existence of context. Use the chat history to maintain continuity in the conversation.""" | |
user_template = """Chat History: | |
{chat_history} | |
Question: {question} | |
Context: {context} | |
Please provide a response based on the question, context, and chat history:""" | |
chat_prompt = ChatPromptTemplate.from_messages([ | |
("system", system_template), | |
("human", user_template) | |
]) | |
def format_chat_history(history: List[Dict[str, str]]) -> str: | |
return "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in history]) | |
rag_chain = ( | |
{ | |
"context": itemgetter("question") | global_retriever, | |
"question": itemgetter("question"), | |
"chat_history": lambda _: format_chat_history(chat_history) | |
} | |
| RunnablePassthrough.assign(context=itemgetter("context")) | |
| chat_prompt | |
| global_chat_model.bind(**settings) | |
) | |
msg = cl.Message(content="") | |
full_response = "" | |
async for chunk in rag_chain.astream({"question": message.content}): | |
if chunk.content is not None: | |
await msg.stream_token(chunk.content) | |
full_response += chunk.content | |
# Update chat history | |
chat_history.append({"role": "user", "content": message.content}) | |
chat_history.append({"role": "assistant", "content": full_response}) | |
cl.user_session.set("chat_history", chat_history) | |
await msg.send() | |