|
import os |
|
import shutil |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.storage import LocalFileStore |
|
from langchain.embeddings import CacheBackedEmbeddings |
|
from langchain_groq import ChatGroq |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from streamlit_chat import message |
|
|
|
|
|
load_dotenv() |
|
os.environ['GROQ_API_KEY'] = os.getenv('GROQ_API') |
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_API_KEY"] = os.getenv('LANGSMITH_API') |
|
|
|
UPLOAD_DIR = "uploaded_files" |
|
|
|
|
|
def cleanup_files(): |
|
if os.path.isdir(UPLOAD_DIR): |
|
shutil.rmtree(UPLOAD_DIR, ignore_errors=True) |
|
if 'file_handle' in st.session_state: |
|
st.session_state.file_handle.close() |
|
|
|
|
|
if 'cleanup_done' not in st.session_state: |
|
st.session_state.cleanup_done = False |
|
|
|
if not st.session_state.cleanup_done: |
|
cleanup_files() |
|
|
|
if not os.path.exists(UPLOAD_DIR): |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
body { |
|
background-color: #FFF7F0; |
|
color: #333333; |
|
font-family: 'Helvetica Neue', sans-serif; |
|
background-image: url('https://drive.google.com/uc?export=view&id=17Vg5hM0-X7fUy2BdYCFqSAQtJVDqYErU'); |
|
background-size: cover; |
|
background-position: top center; |
|
} |
|
.stButton button { |
|
background-color: #FF5000; |
|
color: white; |
|
border-radius: 12px; |
|
border: none; |
|
padding: 10px 20px; |
|
font-weight: bold; |
|
} |
|
.stButton button:hover { |
|
background-color: #E64500; |
|
} |
|
.stTextInput > div > input { |
|
border: 1px solid #FF5000; |
|
border-radius: 10px; |
|
padding: 10px; |
|
} |
|
.stFileUploader > div { |
|
border: 2px dashed #FF5000; |
|
border-radius: 10px; |
|
padding: 10px; |
|
} |
|
.header { |
|
display: flex; |
|
align-items: center; |
|
gap: 10px; |
|
padding-top: 50px; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
st.markdown( |
|
""" |
|
<div class="header" style="display: flex; align-items: center; gap: 10px;"> |
|
<h1 style="color: #FF5000; font-weight: bold;">Hi, we're Wattpad.</h1> |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
st.write("<div style='height: 100px;'></div>", unsafe_allow_html=True) |
|
|
|
st.title("Chat with Your PDF!!") |
|
|
|
uploaded_file = st.file_uploader("Upload a file") |
|
|
|
if uploaded_file is not None: |
|
file_path = os.path.join(UPLOAD_DIR, uploaded_file.name) |
|
file_path = os.path.abspath(file_path) |
|
|
|
with open(file_path, 'wb') as f: |
|
f.write(uploaded_file.getbuffer()) |
|
st.write("You're Ready For a Chat with your PDF") |
|
|
|
docs = PyPDFLoader(file_path).load_and_split() |
|
|
|
embedding = HuggingFaceEmbeddings( |
|
model_name='BAAI/llm-embedder', |
|
) |
|
|
|
store = LocalFileStore("./cache/") |
|
cached_embedder = CacheBackedEmbeddings.from_bytes_store( |
|
embedding, store, namespace='embeddings' |
|
) |
|
|
|
vector_base = FAISS.from_documents( |
|
docs, |
|
embedding |
|
) |
|
|
|
template = '''You are WattBot, Wattpad's friendly chatbot assistant, designed to help readers and writers with insightful answers about stories, writing tips, and the Wattpad platform. Please answer the {question} based only on the given {context}. If the question is unrelated to the context or beyond your knowledge, respond with "I'm not sure about that, but feel free to explore more on Wattpad!" Keep your responses concise, using a maximum of three sentences.''' |
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
retriever = vector_base.as_retriever() |
|
|
|
llm = ChatGroq( |
|
model='mixtral-8x7b-32768', |
|
temperature=0, |
|
) |
|
|
|
if 'history' not in st.session_state: |
|
st.session_state.history = [] |
|
|
|
query = st.text_input("Enter your question", placeholder="Ask something interesting...") |
|
|
|
if st.button("Submit!", key="submit_button"): |
|
if query: |
|
chain = ( |
|
{'context': retriever, 'question': RunnablePassthrough()} |
|
| prompt | llm | StrOutputParser() |
|
) |
|
answer = chain.invoke(query) |
|
st.session_state.history.append({'question': query, 'answer': answer}) |
|
|
|
if st.session_state.history: |
|
st.write("### Previous Questions and Answers") |
|
for idx, entry in enumerate(st.session_state.history): |
|
st.markdown( |
|
f""" |
|
<div style="background-color: #FFFAF5; padding: 10px; border-radius: 10px; margin-bottom: 10px;"> |
|
<p style="font-weight: bold; color: #FF5000;">Q{idx + 1}: {entry['question']}</p> |
|
<p style="color: #333333;">A{idx + 1}: {entry['answer']}</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
if st.session_state.cleanup_done: |
|
cleanup_files() |
|
|