Spaces:
Runtime error
Runtime error
File size: 4,837 Bytes
5602ea6 b26f908 5602ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import pandas as pd
import os
import json
import re
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
import time
import textwrap
model_bi_encoder = "msmarco-distilbert-base-tas-b"
model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"
bi_encoder = SentenceTransformer(model_bi_encoder)
bi_encoder.max_seq_length = 512
cross_encoder = CrossEncoder(model_cross_encoder)
def collect_data(data_lis,meta_count):
new_files = data_lis['file_name'][meta_count:]
new_links = data_lis['link'][meta_count:]
return new_files,new_links
def merge_text(text_list):
i = 0;j = 1
k = len(text_list)
while j < k:
if len(text_list[i].split()) <= 30:
text_list[j] = text_list[i] + " " + text_list[j]
text_list[i] = " "
i += 1;j += 1
return [accepted for accepted in text_list if accepted != " "]
def make_data(new_files,new_links,local_path):
text = [];links = []
for doc in range(len(new_files)):
sub_text = [];sub_link = []
with open(os.path.join(local_path, new_files[doc]), encoding='utf-8') as f:
for line in f.readlines():
temp_text = re.sub("\\n", "", line)
if temp_text != "":
sub_text.append(temp_text)
sub_text = merge_text(sub_text)
sub_link = [new_links[doc] for i in range(len(sub_text))]
text.extend(sub_text)
links.extend(sub_link)
return text,links
def get_final_data():
#Define all the paths
meta_path = "meta_data.json"
data_lis_path = "data_url.csv"
local_path = "Chitti_ver1/Data_final"
data_path = "Responses.csv"
corpus_path = "corpus.pt"
# Load the list of data files
data_lis = pd.read_csv(data_lis_path)
# Load the responses.csv file
if not(os.path.exists(data_path)):
fresh_text = []
fresh_link = []
fresh_data = {
"text": fresh_text,
"links": fresh_link
}
fresh_data = pd.DataFrame(fresh_data)
fresh_data.to_csv(data_path)
data = pd.read_csv(data_path)
# Check for any new files; If present add those to responses.csv file
# Make changes to corpus.pt accordingly
act_count = len(data_lis['file_name'])
with open(meta_path, "r") as jsonFile:
meta_data = json.load(jsonFile)
meta_count = meta_data["data"]["count"]
if meta_count!=act_count:
meta_data["data"]["count"] = act_count
with open(meta_path, "w") as jsonFile:
json.dump(meta_data, jsonFile)
new_files,new_links = collect_data(data_lis,meta_count)
text,links = make_data(new_files,new_links,local_path)
df = {
"text": text,
"links":links
}
df = pd.DataFrame(df)
data = pd.concat([data,df])
data.to_csv("Responses.csv")
if not(os.path.exists(corpus_path)):
corpus_embeddings = bi_encoder.encode(data["text"], convert_to_tensor=True, show_progress_bar=True)
torch.save(corpus_embeddings, corpus_path)
else:
corpus_embeddings = torch.load(corpus_path)
new_embeddings = bi_encoder.encode(df["text"], convert_to_tensor=True, show_progress_bar=True)
corpus_embeddings = torch.cat((corpus_embeddings,new_embeddings),0)
torch.save(corpus_embeddings, corpus_path)
corpus_embeddings = torch.load(corpus_path)
return corpus_embeddings,data
def search(query):
corpus_embeddings,data = get_final_data()
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
top_k = 20
#be = time.process_time()
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
#print("Time taken by Bi-encoder:" + str(time.process_time() - be))
hits = hits[0]
cross_inp = [[query, data['text'][hit['corpus_id']]] for hit in hits]
#ce = time.process_time()
cross_scores = cross_encoder.predict(cross_inp)
#print("Time taken by Cross-encoder:" + str(time.process_time() - ce))
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
result_table = list()
for hit in hits[0:5]:
ans = "{}".format(data['text'][hit['corpus_id']].replace("\n", " "))
#print(ans)
cs = "{}".format(hit['cross-score'])
#print(cs)
sc = "{}".format(hit['score'])
#print(sc)
corr_link = "{}".format(data['links'][hit['corpus_id']])
wrapper = textwrap.TextWrapper(width=50)
ans = wrapper.fill(text=ans)
result_table.append([ans,str(cs),str(sc),str(corr_link)])
return result_table
|