sivan22 commited on
Commit
64dd69a
1 Parent(s): 956d65c

using semantic_search

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -7,6 +7,7 @@ from langchain_openai import ChatOpenAI
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.messages import HumanMessage, SystemMessage
9
  from sentence_transformers import util
 
10
 
11
 
12
 
@@ -39,9 +40,9 @@ def get_chat_api(api_key:str):
39
 
40
  def get_results(embeddings_model,input,df,num_of_results) -> pd.DataFrame:
41
  embeddings = embeddings_model.embed_query('query: '+ input)
42
- df['similarity'] = df['embeddings'].apply(lambda x: util.dot_score(x,embeddings))
43
- results = df.sort_values(by='similarity', ascending=False)
44
- return results.head(num_of_results)
45
 
46
  def get_llm_results(query,chat,results):
47
 
 
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.messages import HumanMessage, SystemMessage
9
  from sentence_transformers import util
10
+ from torch import tensor
11
 
12
 
13
 
 
40
 
41
  def get_results(embeddings_model,input,df,num_of_results) -> pd.DataFrame:
42
  embeddings = embeddings_model.embed_query('query: '+ input)
43
+ hits = util.semantic_search(tensor(embeddings), tensor(df['embeddings'].tolist()), top_k=num_of_results)
44
+ hit_list = [hit['corpus_id'] for hit in hits[0]]
45
+ return df.iloc[hit_list]
46
 
47
  def get_llm_results(query,chat,results):
48