|
import streamlit as st |
|
import re |
|
import os |
|
from langchain_chroma import Chroma |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from sentence_transformers import SentenceTransformer |
|
import bs4 |
|
import torch |
|
from transformers import pipeline |
|
|
|
|
|
class SentenceTransformerEmbedding: |
|
def __init__(self, model_name): |
|
self.model = SentenceTransformer(model_name) |
|
|
|
def embed_documents(self, texts): |
|
embeddings = self.model.encode(texts, convert_to_tensor=True) |
|
if isinstance(embeddings, torch.Tensor): |
|
return embeddings.cpu().detach().numpy().tolist() |
|
return embeddings |
|
|
|
def embed_query(self, query): |
|
embedding = self.model.encode([query], convert_to_tensor=True) |
|
if isinstance(embedding, torch.Tensor): |
|
return embedding.cpu().detach().numpy().tolist()[0] |
|
return embedding[0] |
|
|
|
|
|
st.title("🤖 Chatbot with URL-based Document Retrieval") |
|
|
|
|
|
sidebar_bg_style = """ |
|
<style> |
|
[data-testid="stSidebar"] { |
|
background: linear-gradient(135deg, #ffafbd, #ffc3a0, #2193b0, #6dd5ed); |
|
} |
|
</style> |
|
""" |
|
st.markdown(sidebar_bg_style, unsafe_allow_html=True) |
|
|
|
|
|
main_bg_style = """ |
|
<style> |
|
.main .block-container { |
|
background: linear-gradient(135deg, #ff9a9e, #fad0c4, #fbc2eb, #a18cd1); |
|
padding: 2rem; |
|
} |
|
.css-18e3th9 { |
|
background: linear-gradient(135deg, #ff9a9e, #fad0c4, #fbc2eb, #a18cd1); |
|
} |
|
</style> |
|
""" |
|
st.markdown(main_bg_style, unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title("Settings") |
|
|
|
|
|
url_input = st.sidebar.text_input("Enter Blog Post URL", placeholder="e.g., https://example.com/blog", help="Paste the full URL of the blog post you want to retrieve data from") |
|
|
|
|
|
if url_input: |
|
if re.match(r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+", url_input): |
|
st.sidebar.markdown('<p style="color:green; font-weight:bold;">URL is correctly entered</p>', unsafe_allow_html=True) |
|
else: |
|
st.sidebar.markdown('<p style="color:red; font-weight:bold;">Invalid URL, please enter a valid one</p>', unsafe_allow_html=True) |
|
|
|
|
|
use_preprovided_keys = st.sidebar.checkbox("Use pre-provided API keys") |
|
|
|
|
|
if not use_preprovided_keys: |
|
api_key_1 = st.sidebar.text_input("Enter LangChain API Key", type="password", placeholder="Enter your LangChain API Key", help="Please enter a valid LangChain API key here") |
|
api_key_2 = st.sidebar.text_input("Enter Groq API Key", type="password", placeholder="Enter your Groq API Key", help="Please enter your Groq API key here") |
|
else: |
|
api_key_1 = "your-preprovided-langchain-api-key" |
|
api_key_2 = "your-preprovided-groq-api-key" |
|
st.sidebar.markdown('<p style="color:blue; font-weight:bold;">Using pre-provided API keys</p>', unsafe_allow_html=True) |
|
|
|
|
|
if st.sidebar.button("Submit API Keys"): |
|
if use_preprovided_keys or (api_key_1 and api_key_2): |
|
os.environ["LANGCHAIN_API_KEY"] = api_key_1 |
|
os.environ["GROQ_API_KEY"] = api_key_2 |
|
st.sidebar.markdown('<p style="color:green; font-weight:bold;">API keys are set</p>', unsafe_allow_html=True) |
|
else: |
|
st.sidebar.markdown('<p style="color:red; font-weight:bold;">Please fill in both API keys or select the option to use pre-provided keys</p>', unsafe_allow_html=True) |
|
|
|
|
|
st.markdown(""" |
|
<marquee behavior="scroll" direction="left" scrollamount="10"> |
|
<p style='font-size:24px; color:#FF5733; font-weight:bold;'> |
|
Created by: <a href="https://www.linkedin.com/in/datascientisthameshraj/" target="_blank" style="color:#1E90FF; text-decoration:none;">Engr. Hamesh Raj</a> |
|
</p> |
|
</marquee> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown('<h1 style="color:#4CAF50; font-weight:bold;">🤖 Chatbot with URL-based Document Retrieval</h1>', unsafe_allow_html=True) |
|
|
|
|
|
query = st.text_input("Ask a question based on the blog post", placeholder="Type your question here...", help="Enter a question related to the content of the blog post") |
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
|
|
class CustomLanguageModel: |
|
def __init__(self): |
|
self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
|
def generate(self, prompt, context): |
|
summary = self.summarize_context(context) |
|
return f"Generated response: '{prompt}'. Summary: '{summary}'." |
|
|
|
def summarize_context(self, context): |
|
summarized = self.summarizer(context, max_length=200, min_length=100, do_sample=False) |
|
return summarized[0]['summary_text'] |
|
|
|
|
|
class RAGPrompt: |
|
def __call__(self, data): |
|
return {"question": data["question"], "context": data["context"]} |
|
|
|
|
|
if st.button("Submit Query"): |
|
if not query: |
|
st.warning("Please enter a query before submitting!") |
|
elif not url_input: |
|
st.warning("Please enter a valid URL in the sidebar.") |
|
else: |
|
try: |
|
|
|
loader = WebBaseLoader( |
|
web_paths=(url_input,), |
|
bs_kwargs=dict( |
|
parse_only=bs4.SoupStrainer() |
|
), |
|
) |
|
docs = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=300) |
|
splits = text_splitter.split_documents(docs) |
|
|
|
|
|
embedding_model = SentenceTransformerEmbedding('all-MiniLM-L6-v2') |
|
|
|
|
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embedding_model) |
|
|
|
|
|
retriever = vectorstore.as_retriever() |
|
|
|
|
|
retrieved_docs = retriever.get_relevant_documents(query) |
|
|
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
context = format_docs(retrieved_docs) |
|
|
|
|
|
custom_llm = CustomLanguageModel() |
|
|
|
|
|
prompt = RAGPrompt() |
|
|
|
|
|
prompt_data = prompt({"question": query, "context": context}) |
|
|
|
|
|
result = custom_llm.generate(prompt_data["question"], prompt_data["context"]) |
|
|
|
|
|
st.session_state['chat_history'].append((query, result)) |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
for q, r in st.session_state['chat_history']: |
|
st.write(f"**User:** {q}") |
|
st.write(f"**Bot:** {r}") |