varun500 commited on
Commit
9f763c8
1 Parent(s): 1fa55ba

Create connection.py

Browse files
Files changed (1) hide show
  1. connection.py +49 -0
connection.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import pinecone
3
+ import streamlit as st
4
+ from streamlit.connections import ExperimentalBaseConnection
5
+
6
+
7
+ class PineconeConnection(ExperimentalBaseConnection):
8
+ def __init__(
9
+ self,
10
+ connection_name: str,
11
+ environment=None,
12
+ api_key=None,
13
+ **kwargs,
14
+ ) -> None:
15
+ self.environment = environment
16
+ self.api_key = api_key
17
+ super().__init__(connection_name, **kwargs)
18
+
19
+ def _connect(self):
20
+ api_key = self.api_key or self._secrets.get("Pinecone_API_KEY")
21
+ environment = self.environment
22
+ return pinecone.init(api_key=api_key, environment=environment)
23
+
24
+ def list_indexes(self):
25
+ self._connect()
26
+ self.indexes = pinecone.list_indexes()
27
+ return self.indexes
28
+
29
+ def _connect_index(self, index_name):
30
+ self._connect()
31
+ self.index_name = index_name
32
+ self.index = pinecone.Index(index_name)
33
+ return self.index
34
+
35
+ def query(
36
+ self, index_name: str, query_vector, top_k: int = 5, ttl: int = 3600, **kwargs
37
+ ) -> dict:
38
+ @st.cache_resource(ttl=ttl)
39
+ def _query(index_name: str, query_vector, top_k: int = 5, **kwargs):
40
+ index = self._connect_index(index_name)
41
+ query_results = index.query(query_vector, top_k=top_k, **kwargs)
42
+ results = list(query_results["matches"])
43
+ return results
44
+
45
+ results = _query(index_name, query_vector, top_k, **kwargs)
46
+ return results
47
+
48
+ def cursor(self):
49
+ return self._connect()