ennioferreirab's picture
add model
ac03f85
import httpx
from pydantic import BaseModel
import tqdm
'''
gold_sample_index = set()
with open('gold_sample_index.txt', 'r') as f:
for line in f:
gold_sample_index.add(line.strip())
'''
class SilverDataset(BaseModel):
'''
Classe para retornar o dataset silver para ser utilizado pela estratégia do Augmented SBERT
ref: https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation/train_sts_indomain_bm25.py
'''
query_params: dict
duplicated: set = set()
solr_url: str = "http://localhost:8983/solr/sei_similaridade_augmented_sbert"
def get_ids_list(self):
query = f'{self.solr_url}/select?q=*&fl=id&rows=999999999'
r = httpx.get(query).json()['response']['docs']
return [doc['id'] for doc in r]
def get_data(self,id):
q_id = f"id:{id}"
self.query_params['q'] = q_id
r = httpx.post(f'{self.solr_url}/mlt', data=self.query_params).json()
maxscore = r['response']['maxScore']
response_docs = r['response']['docs']
response_docs = self.remove_duplicated(response_docs)
return {'query_id': id,
'query_doc': r['match']['docs'][0][self.query_params['mlt.qf']],
'docs': response_docs,
'maxscore': maxscore}
def remove_duplicated(self,docs):
'''
remove os documentos que são iguais aos documentos do dataset Gold
'''
return [doc for doc in docs if doc['id'] not in self.duplicated]
@staticmethod
def create_sentence_pairs(queries):
'''
cria os pares de frases para o dataset silver
'''
pairs = set()
for query in queries:
for doc in query['docs']:
pairs.add(
(query['query_doc'],
doc['assunto_text'],
doc['score']/query['maxscore']))
return pairs
def run(self):
queries = []
list_ids = self.get_ids_list()
for id in tqdm.tqdm(list_ids):
queries.append(self.get_data(id))
pairs = self.create_sentence_pairs(queries)
return pairs