|
import streamlit as st |
|
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.core import Settings |
|
from llama_index.retrievers.bm25 import BM25Retriever |
|
from llama_index.core.retrievers import QueryFusionRetriever |
|
from litellm import completion |
|
import os |
|
from dotenv import load_dotenv |
|
from llama_index.core.settings import Settings |
|
from llama_index.core.llms import ChatMessage, MessageRole |
|
from llama_index.llms.groq import Groq |
|
|
|
|
|
st.set_page_config( |
|
page_title="Freud Works Search", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "sources" not in st.session_state: |
|
st.session_state.sources = {} |
|
if "system_prompt" not in st.session_state: |
|
st.session_state.system_prompt = """You are Sigmund Freud, speaking from your historical context and perspective. As the founder of psychoanalysis, you should: |
|
|
|
1. Only engage with topics related to: |
|
- Psychoanalysis and its theories |
|
- Dreams and their interpretation |
|
- The unconscious mind |
|
- Human sexuality and development |
|
- Your published works and case studies |
|
- Your historical context and contemporaries |
|
|
|
2. Politely decline to answer: |
|
- Questions about events after your death in 1939 |
|
- Medical advice or diagnosis |
|
- Topics outside your expertise or historical context |
|
- Personal matters unrelated to psychoanalysis |
|
|
|
3. Maintain your characteristic style: |
|
- Speak with authority on psychoanalytic matters |
|
- Use psychoanalytic terminology when appropriate |
|
- Reference your own works and theories |
|
- Interpret questions through a psychoanalytic lens |
|
|
|
If a question is inappropriate or outside your scope, explain why you cannot answer it from your perspective as Freud.""" |
|
|
|
|
|
Settings.llm = Groq( |
|
model="llama3-8b-8192", |
|
api_key=os.getenv("GROQ_API_KEY"), |
|
max_tokens=6000, |
|
context_window=6000 |
|
) |
|
|
|
@st.cache_resource |
|
def load_indices(): |
|
"""Load the index and create retrievers""" |
|
|
|
embed_model = HuggingFaceEmbedding(model_name="multi-qa-MiniLM-L6-cos-v1") |
|
Settings.embed_model = embed_model |
|
|
|
|
|
storage_context = StorageContext.from_defaults(persist_dir="freud_index") |
|
index = load_index_from_storage(storage_context=storage_context) |
|
|
|
|
|
vector_retriever = index.as_retriever(similarity_top_k=10) |
|
bm25_retriever = BM25Retriever.from_defaults( |
|
index, similarity_top_k=10 |
|
) |
|
|
|
|
|
hybrid_retriever = QueryFusionRetriever( |
|
[vector_retriever, bm25_retriever], |
|
similarity_top_k=10, |
|
num_queries=1, |
|
mode="reciprocal_rerank", |
|
use_async=True, |
|
verbose=True, |
|
) |
|
|
|
return index, vector_retriever, bm25_retriever, hybrid_retriever |
|
|
|
|
|
index, vector_retriever, bm25_retriever, hybrid_retriever = load_indices() |
|
|
|
|
|
def chat_with_rag(message, history, retriever): |
|
|
|
if st.session_state.get('use_rag', True): |
|
nodes = retriever.retrieve(message) |
|
|
|
nodes = sorted(nodes, key=lambda x: x.score, reverse=True) |
|
|
|
nodes = nodes[:st.session_state.get('num_chunks', 1)] |
|
context = "\n\n".join([node.text for node in nodes]) |
|
system_prompt = f"""{st.session_state.system_prompt} |
|
|
|
Use the following passages from my writings to inform your response: |
|
|
|
{context} |
|
""" |
|
|
|
|
|
|
|
message_index = len(st.session_state.messages) |
|
st.session_state.sources[message_index] = nodes |
|
else: |
|
system_prompt = st.session_state.system_prompt |
|
nodes = [] |
|
|
|
|
|
messages = [ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)] |
|
for h in history: |
|
role = MessageRole.ASSISTANT if h["role"] == "assistant" else MessageRole.USER |
|
messages.append(ChatMessage(role=role, content=h["content"])) |
|
messages.append(ChatMessage(role=MessageRole.USER, content=message)) |
|
|
|
|
|
response = Settings.llm.chat(messages) |
|
assistant_response = response.message.content |
|
|
|
return assistant_response |
|
|
|
|
|
st.title("Freud Explorer") |
|
|
|
|
|
tab2, tab1 = st.tabs(["Chat", "Search"]) |
|
|
|
with tab1: |
|
st.title("Freud Works Hybrid Search") |
|
st.markdown(""" |
|
This demo allows you to search through Freud's complete works using a hybrid approach combining: |
|
- BM25 (keyword-based search) |
|
- Vector search (semantic similarity) |
|
""") |
|
|
|
|
|
search_query = st.text_input("Enter your search query:", placeholder="e.g. Oedipus complex") |
|
|
|
|
|
top_k = st.slider("Number of results to return:", min_value=1, max_value=20, value=10) |
|
|
|
|
|
vector_retriever.similarity_top_k = top_k |
|
bm25_retriever.similarity_top_k = top_k |
|
hybrid_retriever.similarity_top_k = top_k |
|
|
|
|
|
search_type = st.radio( |
|
"Select search method:", |
|
["Hybrid", "Vector", "BM25"], |
|
horizontal=True, |
|
help=""" |
|
- **BM25**: Keyword-based search that works best for exact matches and specific terms. |
|
- **Vector**: Semantic search that understands the meaning of your query. |
|
- **Hybrid**: Combines both approaches for better overall results. |
|
""" |
|
) |
|
|
|
if search_query: |
|
with st.spinner('Searching...'): |
|
if search_type == "Hybrid": |
|
nodes = hybrid_retriever.retrieve(search_query) |
|
elif search_type == "Vector": |
|
nodes = vector_retriever.retrieve(search_query) |
|
else: |
|
nodes = bm25_retriever.retrieve(search_query) |
|
|
|
|
|
st.subheader(f"Search Results") |
|
|
|
for i, node in enumerate(nodes, 1): |
|
preview = node.text[:200] + "..." if len(node.text) > 200 else node.text |
|
score = f"{node.score:.3f}" if hasattr(node, 'score') else "N/A" |
|
|
|
with st.expander(f"Result {i} (score: {score})\n\n{preview}", expanded=False): |
|
st.markdown(node.text) |
|
if node.metadata: |
|
st.markdown("---") |
|
st.markdown("**Source:**") |
|
st.json(node.metadata) |
|
|
|
|
|
with st.sidebar: |
|
st.header("About") |
|
st.markdown(""" |
|
This demo searches through Freud's complete works using: |
|
|
|
- **BM25**: Traditional keyword-based search |
|
- **Vector Search**: Semantic similarity using embeddings |
|
- **Hybrid**: Combines both approaches |
|
""") |
|
|
|
with tab2: |
|
st.header("Chat with Freud's Works") |
|
|
|
|
|
chat_container = st.container() |
|
input_container = st.container() |
|
options_container = st.container() |
|
|
|
|
|
with options_container: |
|
st.info("π‘ The system prompt defines the AI's persona and behavior. It's like giving stage directions to an actor.") |
|
with st.expander("System Prompt"): |
|
st.text_area( |
|
"Edit System Prompt", |
|
value=st.session_state.system_prompt, |
|
height=100, |
|
help="This prompt sets the AI's personality and behavior. When RAG is enabled, relevant passages will be automatically added after this prompt.", |
|
key="system_prompt", |
|
on_change=lambda: setattr(st.session_state, 'system_prompt', st.session_state.system_prompt) |
|
) |
|
|
|
|
|
col1, col2, col3 = st.columns([2, 2, 1]) |
|
with col1: |
|
st.session_state.use_rag = st.toggle( |
|
"Enable RAG (Retrieval Augmented Generation)", |
|
value=st.session_state.get('use_rag', True), |
|
key='rag_toggle' |
|
) |
|
with col2: |
|
if st.session_state.use_rag: |
|
num_chunks = st.slider( |
|
"Number of chunks to retrieve", |
|
min_value=1, |
|
max_value=3, |
|
value=st.session_state.get('num_chunks', 1), |
|
key='num_chunks_slider' |
|
) |
|
with col3: |
|
if st.button("Clear Chat", use_container_width=True): |
|
st.session_state.messages = [] |
|
st.session_state.sources = {} |
|
st.rerun() |
|
|
|
|
|
with chat_container: |
|
for i, message in enumerate(st.session_state.messages): |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
if (message["role"] == "assistant" and |
|
i in st.session_state.sources and |
|
st.session_state.sources[i]): |
|
with st.expander("View Sources"): |
|
nodes = st.session_state.sources[i] |
|
for j, node in enumerate(nodes, 1): |
|
st.markdown(f"**Source {j}:**") |
|
st.markdown(node.text) |
|
if node.metadata: |
|
st.markdown("---") |
|
st.markdown("**Metadata:**") |
|
st.json(node.metadata) |
|
|
|
|
|
with input_container: |
|
if prompt := st.chat_input("What would you like to know about Freud's works?", key="chat_input"): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with chat_container: |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
with chat_container: |
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
response = chat_with_rag( |
|
prompt, |
|
st.session_state.messages[:-1], |
|
hybrid_retriever if st.session_state.use_rag else None |
|
) |
|
st.markdown(response) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
st.rerun() |
|
|
|
if __name__ == "__main__": |
|
pass |
|
|