File size: 11,101 Bytes
a37b18d
 
 
 
 
 
f4ce675
 
 
 
 
 
a37b18d
 
 
 
 
 
 
 
f4ce675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37b18d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4ce675
a37b18d
 
 
 
 
 
 
 
 
 
f4ce675
 
 
 
 
 
 
 
 
 
 
a37b18d
f4ce675
a37b18d
f4ce675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37b18d
f4ce675
 
a37b18d
f4ce675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37b18d
f4ce675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a37b18d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import streamlit as st
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from litellm import completion
import os
from dotenv import load_dotenv
from llama_index.core.settings import Settings
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.llms.groq import Groq

# Page config
st.set_page_config(
    page_title="Freud Works Search",
    page_icon="πŸ“š",
    layout="wide"
)

# Load environment variables
load_dotenv()

# Initialize session state
if "messages" not in st.session_state:
    st.session_state.messages = []
if "sources" not in st.session_state:
    st.session_state.sources = {}
if "system_prompt" not in st.session_state:
    st.session_state.system_prompt = """You are Sigmund Freud, speaking from your historical context and perspective. As the founder of psychoanalysis, you should:

1. Only engage with topics related to:
   - Psychoanalysis and its theories
   - Dreams and their interpretation
   - The unconscious mind
   - Human sexuality and development
   - Your published works and case studies
   - Your historical context and contemporaries

2. Politely decline to answer:
   - Questions about events after your death in 1939
   - Medical advice or diagnosis
   - Topics outside your expertise or historical context
   - Personal matters unrelated to psychoanalysis

3. Maintain your characteristic style:
   - Speak with authority on psychoanalytic matters
   - Use psychoanalytic terminology when appropriate
   - Reference your own works and theories
   - Interpret questions through a psychoanalytic lens

If a question is inappropriate or outside your scope, explain why you cannot answer it from your perspective as Freud."""

# Configure LlamaIndex settings
Settings.llm = Groq(
    model="llama3-8b-8192",
    api_key=os.getenv("GROQ_API_KEY"),
    max_tokens=6000,
    context_window=6000
)

@st.cache_resource
def load_indices():
    """Load the index and create retrievers"""
    # Load embeddings
    embed_model = HuggingFaceEmbedding(model_name="multi-qa-MiniLM-L6-cos-v1")
    Settings.embed_model = embed_model
    
    # Load index
    storage_context = StorageContext.from_defaults(persist_dir="freud_index")
    index = load_index_from_storage(storage_context=storage_context)
    
    # Create retrievers
    vector_retriever = index.as_retriever(similarity_top_k=10)
    bm25_retriever = BM25Retriever.from_defaults(
        index, similarity_top_k=10
    )
    
    # Create hybrid retriever
    hybrid_retriever = QueryFusionRetriever(
        [vector_retriever, bm25_retriever],
        similarity_top_k=10,
        num_queries=1,
        mode="reciprocal_rerank",
        use_async=True,
        verbose=True,
    )
    
    return index, vector_retriever, bm25_retriever, hybrid_retriever

# Load indices
index, vector_retriever, bm25_retriever, hybrid_retriever = load_indices()

# Function to process chat with RAG
def chat_with_rag(message, history, retriever):
    # Get context from the index if RAG is enabled
    if st.session_state.get('use_rag', True):
        nodes = retriever.retrieve(message)
        # sort nodes by score
        nodes = sorted(nodes, key=lambda x: x.score, reverse=True)
        # nodes up to slider value
        nodes = nodes[:st.session_state.get('num_chunks', 1)]
        context = "\n\n".join([node.text for node in nodes])
        system_prompt = f"""{st.session_state.system_prompt}

        Use the following passages from my writings to inform your response:
        
        {context}
        """
        
        # Store sources in session state for this message
        # Calculate the correct message index (total number of messages)
        message_index = len(st.session_state.messages)
        st.session_state.sources[message_index] = nodes
    else:
        system_prompt = st.session_state.system_prompt
        nodes = []

    # Prepare messages for the API call
    messages = [ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)]
    for h in history:
        role = MessageRole.ASSISTANT if h["role"] == "assistant" else MessageRole.USER
        messages.append(ChatMessage(role=role, content=h["content"]))
    messages.append(ChatMessage(role=MessageRole.USER, content=message))
    
    # Call Groq via LiteLLM (replace with LlamaIndex's Groq)
    response = Settings.llm.chat(messages)
    assistant_response = response.message.content
    
    return assistant_response

# Move the title to the top, before tabs
st.title("Freud Explorer")

# Add tab selection
tab2, tab1 = st.tabs(["Chat", "Search"])

with tab1:
    st.title("Freud Works Hybrid Search")
    st.markdown("""
    This demo allows you to search through Freud's complete works using a hybrid approach combining:
    - BM25 (keyword-based search)
    - Vector search (semantic similarity)
    """)
    
    # Search interface
    search_query = st.text_input("Enter your search query:", placeholder="e.g. Oedipus complex")

    # Add top_k selector
    top_k = st.slider("Number of results to return:", min_value=1, max_value=20, value=10)

    # Update retrievers with new top_k
    vector_retriever.similarity_top_k = top_k
    bm25_retriever.similarity_top_k = top_k
    hybrid_retriever.similarity_top_k = top_k

    # Search type selector
    search_type = st.radio(
        "Select search method:",
        ["Hybrid", "Vector", "BM25"],
        horizontal=True,
        help="""
        - **BM25**: Keyword-based search that works best for exact matches and specific terms.
        - **Vector**: Semantic search that understands the meaning of your query.
        - **Hybrid**: Combines both approaches for better overall results.
        """
    )

    if search_query:
        with st.spinner('Searching...'):
            if search_type == "Hybrid":
                nodes = hybrid_retriever.retrieve(search_query)
            elif search_type == "Vector":
                nodes = vector_retriever.retrieve(search_query)
            else:  # BM25
                nodes = bm25_retriever.retrieve(search_query)
            
            # Display results
            st.subheader(f"Search Results")
            
            for i, node in enumerate(nodes, 1):
                preview = node.text[:200] + "..." if len(node.text) > 200 else node.text
                score = f"{node.score:.3f}" if hasattr(node, 'score') else "N/A"
                
                with st.expander(f"Result {i} (score: {score})\n\n{preview}", expanded=False):
                    st.markdown(node.text)
                    if node.metadata:
                        st.markdown("---")
                        st.markdown("**Source:**")
                        st.json(node.metadata)

    # Add sidebar with information
    with st.sidebar:
        st.header("About")
        st.markdown("""
        This demo searches through Freud's complete works using:
        
        - **BM25**: Traditional keyword-based search
        - **Vector Search**: Semantic similarity using embeddings
        - **Hybrid**: Combines both approaches
        """)

with tab2:
    st.header("Chat with Freud's Works")
    
    # Create containers in the right order
    chat_container = st.container()
    input_container = st.container()
    options_container = st.container()
    
    # System prompt editor in an expander with help text above
    with options_container:
        st.info("πŸ’‘ The system prompt defines the AI's persona and behavior. It's like giving stage directions to an actor.")
        with st.expander("System Prompt"):
            st.text_area(
                "Edit System Prompt",
                value=st.session_state.system_prompt,
                height=100,
                help="This prompt sets the AI's personality and behavior. When RAG is enabled, relevant passages will be automatically added after this prompt.",
                key="system_prompt",
                on_change=lambda: setattr(st.session_state, 'system_prompt', st.session_state.system_prompt)
            )
        
        # Put the RAG toggle, chunks slider, and clear button in a horizontal layout
        col1, col2, col3 = st.columns([2, 2, 1])
        with col1:
            st.session_state.use_rag = st.toggle(
                "Enable RAG (Retrieval Augmented Generation)", 
                value=st.session_state.get('use_rag', True),
                key='rag_toggle'
            )
        with col2:
            if st.session_state.use_rag:
                num_chunks = st.slider(
                    "Number of chunks to retrieve",
                    min_value=1,
                    max_value=3,
                    value=st.session_state.get('num_chunks', 1),
                    key='num_chunks_slider'
                )
        with col3:
            if st.button("Clear Chat", use_container_width=True):
                st.session_state.messages = []
                st.session_state.sources = {}
                st.rerun()
    
    # Display chat messages in the chat container
    with chat_container:
        for i, message in enumerate(st.session_state.messages):
            with st.chat_message(message["role"]):
                st.markdown(message["content"])
                if (message["role"] == "assistant" and 
                    i in st.session_state.sources and 
                    st.session_state.sources[i]):
                    with st.expander("View Sources"):
                        nodes = st.session_state.sources[i]
                        for j, node in enumerate(nodes, 1):
                            st.markdown(f"**Source {j}:**")
                            st.markdown(node.text)
                            if node.metadata:
                                st.markdown("---")
                                st.markdown("**Metadata:**")
                                st.json(node.metadata)
    
    # Chat input at the bottom
    with input_container:
        if prompt := st.chat_input("What would you like to know about Freud's works?", key="chat_input"):
            st.session_state.messages.append({"role": "user", "content": prompt})
            with chat_container:
                with st.chat_message("user"):
                    st.markdown(prompt)

            with chat_container:
                with st.chat_message("assistant"):
                    with st.spinner("Thinking..."):
                        response = chat_with_rag(
                            prompt, 
                            st.session_state.messages[:-1],
                            hybrid_retriever if st.session_state.use_rag else None
                        )
                        st.markdown(response)
                        st.session_state.messages.append({"role": "assistant", "content": response})
            
            st.rerun()

if __name__ == "__main__":
    pass  # Remove the duplicate title