freud_rag / app.py
ruggsea's picture
chat with groq
f4ce675
import streamlit as st
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from litellm import completion
import os
from dotenv import load_dotenv
from llama_index.core.settings import Settings
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.groq import Groq
# Page config
st.set_page_config(
page_title="Freud Works Search",
page_icon="πŸ“š",
layout="wide"
)
# Load environment variables
load_dotenv()
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = []
if "sources" not in st.session_state:
st.session_state.sources = {}
if "system_prompt" not in st.session_state:
st.session_state.system_prompt = """You are Sigmund Freud, speaking from your historical context and perspective. As the founder of psychoanalysis, you should:
1. Only engage with topics related to:
- Psychoanalysis and its theories
- Dreams and their interpretation
- The unconscious mind
- Human sexuality and development
- Your published works and case studies
- Your historical context and contemporaries
2. Politely decline to answer:
- Questions about events after your death in 1939
- Medical advice or diagnosis
- Topics outside your expertise or historical context
- Personal matters unrelated to psychoanalysis
3. Maintain your characteristic style:
- Speak with authority on psychoanalytic matters
- Use psychoanalytic terminology when appropriate
- Reference your own works and theories
- Interpret questions through a psychoanalytic lens
If a question is inappropriate or outside your scope, explain why you cannot answer it from your perspective as Freud."""
# Configure LlamaIndex settings
Settings.llm = Groq(
model="llama3-8b-8192",
api_key=os.getenv("GROQ_API_KEY"),
max_tokens=6000,
context_window=6000
)
@st.cache_resource
def load_indices():
"""Load the index and create retrievers"""
# Load embeddings
embed_model = HuggingFaceEmbedding(model_name="multi-qa-MiniLM-L6-cos-v1")
Settings.embed_model = embed_model
# Load index
storage_context = StorageContext.from_defaults(persist_dir="freud_index")
index = load_index_from_storage(storage_context=storage_context)
# Create retrievers
vector_retriever = index.as_retriever(similarity_top_k=10)
bm25_retriever = BM25Retriever.from_defaults(
index, similarity_top_k=10
)
# Create hybrid retriever
hybrid_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=10,
num_queries=1,
mode="reciprocal_rerank",
use_async=True,
verbose=True,
)
return index, vector_retriever, bm25_retriever, hybrid_retriever
# Load indices
index, vector_retriever, bm25_retriever, hybrid_retriever = load_indices()
# Function to process chat with RAG
def chat_with_rag(message, history, retriever):
# Get context from the index if RAG is enabled
if st.session_state.get('use_rag', True):
nodes = retriever.retrieve(message)
# sort nodes by score
nodes = sorted(nodes, key=lambda x: x.score, reverse=True)
# nodes up to slider value
nodes = nodes[:st.session_state.get('num_chunks', 1)]
context = "\n\n".join([node.text for node in nodes])
system_prompt = f"""{st.session_state.system_prompt}
Use the following passages from my writings to inform your response:
{context}
"""
# Store sources in session state for this message
# Calculate the correct message index (total number of messages)
message_index = len(st.session_state.messages)
st.session_state.sources[message_index] = nodes
else:
system_prompt = st.session_state.system_prompt
nodes = []
# Prepare messages for the API call
messages = [ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)]
for h in history:
role = MessageRole.ASSISTANT if h["role"] == "assistant" else MessageRole.USER
messages.append(ChatMessage(role=role, content=h["content"]))
messages.append(ChatMessage(role=MessageRole.USER, content=message))
# Call Groq via LiteLLM (replace with LlamaIndex's Groq)
response = Settings.llm.chat(messages)
assistant_response = response.message.content
return assistant_response
# Move the title to the top, before tabs
st.title("Freud Explorer")
# Add tab selection
tab2, tab1 = st.tabs(["Chat", "Search"])
with tab1:
st.title("Freud Works Hybrid Search")
st.markdown("""
This demo allows you to search through Freud's complete works using a hybrid approach combining:
- BM25 (keyword-based search)
- Vector search (semantic similarity)
""")
# Search interface
search_query = st.text_input("Enter your search query:", placeholder="e.g. Oedipus complex")
# Add top_k selector
top_k = st.slider("Number of results to return:", min_value=1, max_value=20, value=10)
# Update retrievers with new top_k
vector_retriever.similarity_top_k = top_k
bm25_retriever.similarity_top_k = top_k
hybrid_retriever.similarity_top_k = top_k
# Search type selector
search_type = st.radio(
"Select search method:",
["Hybrid", "Vector", "BM25"],
horizontal=True,
help="""
- **BM25**: Keyword-based search that works best for exact matches and specific terms.
- **Vector**: Semantic search that understands the meaning of your query.
- **Hybrid**: Combines both approaches for better overall results.
"""
)
if search_query:
with st.spinner('Searching...'):
if search_type == "Hybrid":
nodes = hybrid_retriever.retrieve(search_query)
elif search_type == "Vector":
nodes = vector_retriever.retrieve(search_query)
else: # BM25
nodes = bm25_retriever.retrieve(search_query)
# Display results
st.subheader(f"Search Results")
for i, node in enumerate(nodes, 1):
preview = node.text[:200] + "..." if len(node.text) > 200 else node.text
score = f"{node.score:.3f}" if hasattr(node, 'score') else "N/A"
with st.expander(f"Result {i} (score: {score})\n\n{preview}", expanded=False):
st.markdown(node.text)
if node.metadata:
st.markdown("---")
st.markdown("**Source:**")
st.json(node.metadata)
# Add sidebar with information
with st.sidebar:
st.header("About")
st.markdown("""
This demo searches through Freud's complete works using:
- **BM25**: Traditional keyword-based search
- **Vector Search**: Semantic similarity using embeddings
- **Hybrid**: Combines both approaches
""")
with tab2:
st.header("Chat with Freud's Works")
# Create containers in the right order
chat_container = st.container()
input_container = st.container()
options_container = st.container()
# System prompt editor in an expander with help text above
with options_container:
st.info("πŸ’‘ The system prompt defines the AI's persona and behavior. It's like giving stage directions to an actor.")
with st.expander("System Prompt"):
st.text_area(
"Edit System Prompt",
value=st.session_state.system_prompt,
height=100,
help="This prompt sets the AI's personality and behavior. When RAG is enabled, relevant passages will be automatically added after this prompt.",
key="system_prompt",
on_change=lambda: setattr(st.session_state, 'system_prompt', st.session_state.system_prompt)
)
# Put the RAG toggle, chunks slider, and clear button in a horizontal layout
col1, col2, col3 = st.columns([2, 2, 1])
with col1:
st.session_state.use_rag = st.toggle(
"Enable RAG (Retrieval Augmented Generation)",
value=st.session_state.get('use_rag', True),
key='rag_toggle'
)
with col2:
if st.session_state.use_rag:
num_chunks = st.slider(
"Number of chunks to retrieve",
min_value=1,
max_value=3,
value=st.session_state.get('num_chunks', 1),
key='num_chunks_slider'
)
with col3:
if st.button("Clear Chat", use_container_width=True):
st.session_state.messages = []
st.session_state.sources = {}
st.rerun()
# Display chat messages in the chat container
with chat_container:
for i, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
st.markdown(message["content"])
if (message["role"] == "assistant" and
i in st.session_state.sources and
st.session_state.sources[i]):
with st.expander("View Sources"):
nodes = st.session_state.sources[i]
for j, node in enumerate(nodes, 1):
st.markdown(f"**Source {j}:**")
st.markdown(node.text)
if node.metadata:
st.markdown("---")
st.markdown("**Metadata:**")
st.json(node.metadata)
# Chat input at the bottom
with input_container:
if prompt := st.chat_input("What would you like to know about Freud's works?", key="chat_input"):
st.session_state.messages.append({"role": "user", "content": prompt})
with chat_container:
with st.chat_message("user"):
st.markdown(prompt)
with chat_container:
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = chat_with_rag(
prompt,
st.session_state.messages[:-1],
hybrid_retriever if st.session_state.use_rag else None
)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
st.rerun()
if __name__ == "__main__":
pass # Remove the duplicate title