bhulston commited on
Commit
5dc1bc3
·
1 Parent(s): 655400f

Update app.py

Browse files

Add metadata filters for pinecone

Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -48,6 +48,11 @@ units = st.slider(
48
  value = (1, 4)
49
  )
50
 
 
 
 
 
 
51
 
52
  assistant = st.chat_message("assistant")
53
  initial_message = "How can I help you today?"
@@ -58,23 +63,39 @@ def get_rag_results(prompt):
58
  2. Query the Pinecone DB and return the top 25 results based on cosine similarity
59
  3. Rerank the results from vector DB using a BERT-based cross encoder
60
  '''
61
- query = prompt
62
- response = filter_agent(prompt, OPENAI_API)
 
 
63
  query_filter = {
64
- "Units": str(int(units)) + ".0 units",
65
- "start": ${"gte": str(class_time[0])},
66
- "end": ${"lte": str(class_time[1])}
67
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  response = index.query(
69
  vector = embeddings.embed_query(query),
70
  top_k = 25,
71
- # filter = query_filter,
72
  include_metadata = True
73
  )
74
  response = reranker(query, response) # BERT cross encoder for ranking
75
 
76
  return response
77
 
 
 
78
  if "messages" not in st.session_state:
79
  st.session_state.messages = []
80
  with st.chat_message("assistant"):
 
48
  value = (1, 4)
49
  )
50
 
51
+ days = st.checkbox("What days are you free?",
52
+ options = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat"],
53
+ default = None,
54
+ placeholder = "Any day"
55
+ )
56
 
57
  assistant = st.chat_message("assistant")
58
  initial_message = "How can I help you today?"
 
63
  2. Query the Pinecone DB and return the top 25 results based on cosine similarity
64
  3. Rerank the results from vector DB using a BERT-based cross encoder
65
  '''
66
+ query = filter_agent(prompt, OPENAI_API)
67
+
68
+ ##Get metadata filters
69
+ days_filter = list()
70
  query_filter = {
71
+ "start": {"$gte": str(class_time[0])},
72
+ "end": {"$lte": str(class_time[1])}
 
73
  }
74
+
75
+ if units != "any":
76
+ query_filter["units"] = str(int(units)) + ".0 units"
77
+
78
+ if len(days) > 0:
79
+ for i in range(len(days)):
80
+ days_filter.append(days[i])
81
+ for j in range(i+1, len(days)):
82
+ two_day = days[i] + ", " + days[j]
83
+ days_filter.append(two_day)
84
+ query_filter["days"] = {"$in": days_filter}
85
+
86
+ ## Query the pinecone database
87
  response = index.query(
88
  vector = embeddings.embed_query(query),
89
  top_k = 25,
90
+ filter = query_filter,
91
  include_metadata = True
92
  )
93
  response = reranker(query, response) # BERT cross encoder for ranking
94
 
95
  return response
96
 
97
+
98
+
99
  if "messages" not in st.session_state:
100
  st.session_state.messages = []
101
  with st.chat_message("assistant"):