ppsingh commited on
Commit
50fbfdd
·
1 Parent(s): b603692

hybrid embeddings

Browse files
Files changed (3) hide show
  1. app.py +81 -43
  2. iati_files/data_giz_website.json +0 -3
  3. requirements.txt +2 -1
app.py CHANGED
@@ -9,7 +9,10 @@ from qdrant_client import QdrantClient
9
  from langchain.retrievers import ContextualCompressionRetriever
10
  from langchain.retrievers.document_compressors import CrossEncoderReranker
11
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
 
12
 
 
 
13
  device = 'cuda' if cuda.is_available() else 'cpu'
14
 
15
 
@@ -20,70 +23,97 @@ var=st.text_input("enter keyword")
20
 
21
  def create_chunks(text):
22
  """TAKES A TEXT AND CERATES CREATES CHUNKS"""
 
23
  text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
24
  texts = text_splitter.split_text(text)
25
  return texts
26
 
27
  def get_chunks():
28
- #orgas_df = pd.read_csv("iati_files/project_orgas.csv")
29
- #region_df = pd.read_csv("iati_files/project_region.csv")
30
- #sector_df = pd.read_csv("iati_files/project_sector.csv")
31
- #status_df = pd.read_csv("iati_files/project_status.csv")
32
- #texts_df = pd.read_csv("iati_files/project_texts.csv")
33
-
34
- #projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
35
- #projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
36
- #projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
37
- #projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
38
- #giz_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True)
39
-
40
- #giz_df.drop(columns= ['orga_abbreviation', 'client',
41
- # 'orga_full_name', 'country',
42
- # 'country_flag', 'crs_5_code', 'crs_3_code',
43
- # 'sgd_pred_code'], inplace=True)
44
-
45
- giz_df = pd.read_json('iati_files/data_giz_website.json')
46
- giz_df = giz_df.rename(columns={'content':'project_description'})
 
 
 
 
47
 
48
 
49
- giz_df['text_size'] = giz_df.apply(lambda x: len((x['project_name'] + x['project_description']).split()), axis=1)
50
- giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['project_name'] + x['project_description']),axis=1)
 
 
 
 
 
51
  giz_df = giz_df.explode(column=['chunks'], ignore_index=True)
52
 
53
-
54
  placeholder= []
55
  for i in range(len(giz_df)):
56
  placeholder.append(Document(page_content= giz_df.loc[i,'chunks'],
57
- metadata={
58
- "title_main":giz_df.loc[i,'title_main'],
59
- "country_name":str(giz_df.loc[i,'countries']),
60
- "client": giz_df_new.loc[i, 'client'],
61
- "language":giz_df_new.loc[i, 'language'],
62
- "political_sponsor":giz_df.loc[i, 'poli_trager'],
63
- "url": giz_df.loc[i, 'url']
64
- #"iati_id": giz_df.loc[i,'iati_id'],
65
- #"iati_orga_id":giz_df.loc[i,'iati_orga_id'],
66
- #"crs_5_name": giz_df.loc[i,'crs_5_name'],
67
- #"crs_3_name": giz_df.loc[i,'crs_3_name'],
68
- #"sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
69
- #"status":giz_df.loc[i,'status'],
70
- }))
71
  return placeholder
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def embed_chunks(chunks):
 
 
 
74
  embeddings = HuggingFaceEmbeddings(
75
  model_kwargs = {'device': device},
76
  encode_kwargs = {'normalize_embeddings': True},
77
  model_name='BAAI/bge-m3'
78
  )
 
79
  # placeholder for collection
80
  print("starting embedding")
81
  qdrant_collections = {}
82
- qdrant_collections['all'] = Qdrant.from_documents(
83
  chunks,
84
  embeddings,
 
85
  path="/data/local_qdrant",
86
- collection_name='all',
87
  )
88
 
89
  print(qdrant_collections)
@@ -122,11 +152,19 @@ def get_context(vectorstore,query):
122
  print(f"retrieved paragraphs:{len(context_retrieved)}")
123
 
124
  return context_retrieved
125
- #chunks = get_chunks()
126
- vectorstores = get_local_qdrant()
127
- vectorstore = vectorstores['all']
128
- button=st.button("search")
129
- results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
 
 
 
 
 
 
 
 
130
  if button:
131
  st.write(f"Found {len(results)} results for query:{var}")
132
 
 
9
  from langchain.retrievers import ContextualCompressionRetriever
10
  from langchain.retrievers.document_compressors import CrossEncoderReranker
11
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
12
+ from langchain_qdrant import FastEmbedSparse, RetrievalMode
13
 
14
+
15
+ # get the device to be used eithe gpu or cpu
16
  device = 'cuda' if cuda.is_available() else 'cpu'
17
 
18
 
 
23
 
24
  def create_chunks(text):
25
  """TAKES A TEXT AND CERATES CREATES CHUNKS"""
26
+ # chunk size in terms of token
27
  text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
28
  texts = text_splitter.split_text(text)
29
  return texts
30
 
31
  def get_chunks():
32
+ """
33
+ this will read the iati files and create the chunks
34
+ """
35
+ orgas_df = pd.read_csv("iati_files/project_orgas.csv")
36
+ region_df = pd.read_csv("iati_files/project_region.csv")
37
+ sector_df = pd.read_csv("iati_files/project_sector.csv")
38
+ status_df = pd.read_csv("iati_files/project_status.csv")
39
+ texts_df = pd.read_csv("iati_files/project_texts.csv")
40
+
41
+ projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
42
+ projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
43
+ projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
44
+ projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
45
+ giz_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True)
46
+
47
+ giz_df.drop(columns= ['orga_abbreviation', 'client',
48
+ 'orga_full_name', 'country',
49
+ 'country_flag', 'crs_5_code', 'crs_3_code',
50
+ 'sgd_pred_code'], inplace=True)
51
+
52
+ #### code for eading the giz_worldwide data
53
+ #giz_df = pd.read_json('iati_files/data_giz_website.json')
54
+ #giz_df = giz_df.rename(columns={'content':'project_description'})
55
 
56
 
57
+ #giz_df['text_size'] = giz_df.apply(lambda x: len((x['project_name'] + x['project_description']).split()), axis=1)
58
+ #giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['project_name'] + x['project_description']),axis=1)
59
+ #giz_df = giz_df.explode(column=['chunks'], ignore_index=True)
60
+
61
+
62
+ giz_df['text_size'] = giz_df.apply(lambda x: len((x['title_main'] + x['description_main']).split()), axis=1)
63
+ giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['title_main'] + x['description_main']),axis=1)
64
  giz_df = giz_df.explode(column=['chunks'], ignore_index=True)
65
 
 
66
  placeholder= []
67
  for i in range(len(giz_df)):
68
  placeholder.append(Document(page_content= giz_df.loc[i,'chunks'],
69
+ metadata={"iati_id": giz_df.loc[i,'iati_id'],
70
+ "iati_orga_id":giz_df.loc[i,'iati_orga_id'],
71
+ "country_name":str(giz_df.loc[i,'country_name']),
72
+ "crs_5_name": giz_df.loc[i,'crs_5_name'],
73
+ "crs_3_name": giz_df.loc[i,'crs_3_name'],
74
+ "sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
75
+ "status":giz_df.loc[i,'status'],
76
+ "title_main":giz_df.loc[i,'title_main'],}))
 
 
 
 
 
 
77
  return placeholder
78
 
79
+ # placeholder= []
80
+ # for i in range(len(giz_df)):
81
+ # placeholder.append(Document(page_content= giz_df.loc[i,'chunks'],
82
+ # metadata={
83
+ # "title_main":giz_df.loc[i,'title_main'],
84
+ # "country_name":str(giz_df.loc[i,'countries']),
85
+ # "client": giz_df_new.loc[i, 'client'],
86
+ # "language":giz_df_new.loc[i, 'language'],
87
+ # "political_sponsor":giz_df.loc[i, 'poli_trager'],
88
+ # "url": giz_df.loc[i, 'url']
89
+ # #"iati_id": giz_df.loc[i,'iati_id'],
90
+ # #"iati_orga_id":giz_df.loc[i,'iati_orga_id'],
91
+ # #"crs_5_name": giz_df.loc[i,'crs_5_name'],
92
+ # #"crs_3_name": giz_df.loc[i,'crs_3_name'],
93
+ # #"sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
94
+ # #"status":giz_df.loc[i,'status'],
95
+ # }))
96
+ # return placeholder
97
+
98
  def embed_chunks(chunks):
99
+ """
100
+ takes the chunks and does the hybrid embedding for the list of chunks
101
+ """
102
  embeddings = HuggingFaceEmbeddings(
103
  model_kwargs = {'device': device},
104
  encode_kwargs = {'normalize_embeddings': True},
105
  model_name='BAAI/bge-m3'
106
  )
107
+ sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
108
  # placeholder for collection
109
  print("starting embedding")
110
  qdrant_collections = {}
111
+ qdrant_collections['iati'] = Qdrant.from_documents(
112
  chunks,
113
  embeddings,
114
+ sparse_embeddings = sparse_embeddings,
115
  path="/data/local_qdrant",
116
+ collection_name='iati',
117
  )
118
 
119
  print(qdrant_collections)
 
152
  print(f"retrieved paragraphs:{len(context_retrieved)}")
153
 
154
  return context_retrieved
155
+
156
+ # first we create the chunks for iati documents
157
+ chunks = get_chunks()
158
+ print("chunking done")
159
+
160
+ # once the chunks are done, we perform hybrid emebddings
161
+ qdrant_collections = embed_chunks(chunks)
162
+ print(qdrant_collections.keys())
163
+
164
+ # vectorstores = get_local_qdrant()
165
+ # vectorstore = vectorstores['all']
166
+ # button=st.button("search")
167
+ # results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
168
  if button:
169
  st.write(f"Found {len(results)} results for query:{var}")
170
 
iati_files/data_giz_website.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:be70c4b250aad01e53543bdd07c1d9f9fdd8a23be65e4a1d8c64f2272f2bbf03
3
- size 13980616
 
 
 
 
requirements.txt CHANGED
@@ -7,4 +7,5 @@ langchain==0.1.20
7
  langsmith==0.1.99
8
  qdrant-client==1.10.1
9
  tiktoken
10
- torch==2.4.0
 
 
7
  langsmith==0.1.99
8
  qdrant-client==1.10.1
9
  tiktoken
10
+ torch==2.4.0
11
+ fastembed