ppsingh's picture
hybrid embeddings
50fbfdd
raw
history blame
7.9 kB
import streamlit as st
import pandas as pd
from langchain_text_splitters import TokenTextSplitter
from langchain.docstore.document import Document
from torch import cuda
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_qdrant import FastEmbedSparse, RetrievalMode
# get the device to be used eithe gpu or cpu
device = 'cuda' if cuda.is_available() else 'cpu'
st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("SEARCH IATI Database")
var=st.text_input("enter keyword")
def create_chunks(text):
"""TAKES A TEXT AND CERATES CREATES CHUNKS"""
# chunk size in terms of token
text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0)
texts = text_splitter.split_text(text)
return texts
def get_chunks():
"""
this will read the iati files and create the chunks
"""
orgas_df = pd.read_csv("iati_files/project_orgas.csv")
region_df = pd.read_csv("iati_files/project_region.csv")
sector_df = pd.read_csv("iati_files/project_sector.csv")
status_df = pd.read_csv("iati_files/project_status.csv")
texts_df = pd.read_csv("iati_files/project_texts.csv")
projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')
giz_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True)
giz_df.drop(columns= ['orga_abbreviation', 'client',
'orga_full_name', 'country',
'country_flag', 'crs_5_code', 'crs_3_code',
'sgd_pred_code'], inplace=True)
#### code for eading the giz_worldwide data
#giz_df = pd.read_json('iati_files/data_giz_website.json')
#giz_df = giz_df.rename(columns={'content':'project_description'})
#giz_df['text_size'] = giz_df.apply(lambda x: len((x['project_name'] + x['project_description']).split()), axis=1)
#giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['project_name'] + x['project_description']),axis=1)
#giz_df = giz_df.explode(column=['chunks'], ignore_index=True)
giz_df['text_size'] = giz_df.apply(lambda x: len((x['title_main'] + x['description_main']).split()), axis=1)
giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['title_main'] + x['description_main']),axis=1)
giz_df = giz_df.explode(column=['chunks'], ignore_index=True)
placeholder= []
for i in range(len(giz_df)):
placeholder.append(Document(page_content= giz_df.loc[i,'chunks'],
metadata={"iati_id": giz_df.loc[i,'iati_id'],
"iati_orga_id":giz_df.loc[i,'iati_orga_id'],
"country_name":str(giz_df.loc[i,'country_name']),
"crs_5_name": giz_df.loc[i,'crs_5_name'],
"crs_3_name": giz_df.loc[i,'crs_3_name'],
"sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
"status":giz_df.loc[i,'status'],
"title_main":giz_df.loc[i,'title_main'],}))
return placeholder
# placeholder= []
# for i in range(len(giz_df)):
# placeholder.append(Document(page_content= giz_df.loc[i,'chunks'],
# metadata={
# "title_main":giz_df.loc[i,'title_main'],
# "country_name":str(giz_df.loc[i,'countries']),
# "client": giz_df_new.loc[i, 'client'],
# "language":giz_df_new.loc[i, 'language'],
# "political_sponsor":giz_df.loc[i, 'poli_trager'],
# "url": giz_df.loc[i, 'url']
# #"iati_id": giz_df.loc[i,'iati_id'],
# #"iati_orga_id":giz_df.loc[i,'iati_orga_id'],
# #"crs_5_name": giz_df.loc[i,'crs_5_name'],
# #"crs_3_name": giz_df.loc[i,'crs_3_name'],
# #"sgd_pred_str":giz_df.loc[i,'sgd_pred_str'],
# #"status":giz_df.loc[i,'status'],
# }))
# return placeholder
def embed_chunks(chunks):
"""
takes the chunks and does the hybrid embedding for the list of chunks
"""
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name='BAAI/bge-m3'
)
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
# placeholder for collection
print("starting embedding")
qdrant_collections = {}
qdrant_collections['iati'] = Qdrant.from_documents(
chunks,
embeddings,
sparse_embeddings = sparse_embeddings,
path="/data/local_qdrant",
collection_name='iati',
)
print(qdrant_collections)
print("vector embeddings done")
return qdrant_collections
@st.cache_resource
def get_local_qdrant():
"""once the local qdrant server is created this is used to make the connection to exisitng server"""
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name='BAAI/bge-m3')
client = QdrantClient(path="/data/local_qdrant")
print("Collections in local Qdrant:",client.get_collections())
qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, )
return qdrant_collections
def get_context(vectorstore,query):
# create metadata filter
# getting context
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.5,
"k": 10,})
# # re-ranking the retrieved results
# model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL'))
# compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K')))
# compression_retriever = ContextualCompressionRetriever(
# base_compressor=compressor, base_retriever=retriever
# )
context_retrieved = retriever.invoke(query)
print(f"retrieved paragraphs:{len(context_retrieved)}")
return context_retrieved
# first we create the chunks for iati documents
chunks = get_chunks()
print("chunking done")
# once the chunks are done, we perform hybrid emebddings
qdrant_collections = embed_chunks(chunks)
print(qdrant_collections.keys())
# vectorstores = get_local_qdrant()
# vectorstore = vectorstores['all']
# button=st.button("search")
# results= get_context(vectorstore, f"find the relvant paragraphs for: {var}")
if button:
st.write(f"Found {len(results)} results for query:{var}")
for i in results:
st.subheader(i.metadata['iati_id']+":"+i.metadata['title_main'])
st.caption(f"Status:{i.metadata['status']}, Country:{i.metadata['country_name']}")
st.write(i.page_content)
st.divider()