|
import httpx |
|
from pydantic import BaseModel |
|
import tqdm |
|
''' |
|
gold_sample_index = set() |
|
with open('gold_sample_index.txt', 'r') as f: |
|
for line in f: |
|
gold_sample_index.add(line.strip()) |
|
''' |
|
|
|
class SilverDataset(BaseModel): |
|
''' |
|
Classe para retornar o dataset silver para ser utilizado pela estratégia do Augmented SBERT |
|
ref: https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation/train_sts_indomain_bm25.py |
|
''' |
|
query_params: dict |
|
duplicated: set = set() |
|
solr_url: str = "http://localhost:8983/solr/sei_similaridade_augmented_sbert" |
|
|
|
def get_ids_list(self): |
|
query = f'{self.solr_url}/select?q=*&fl=id&rows=999999999' |
|
r = httpx.get(query).json()['response']['docs'] |
|
return [doc['id'] for doc in r] |
|
|
|
def get_data(self,id): |
|
q_id = f"id:{id}" |
|
self.query_params['q'] = q_id |
|
r = httpx.post(f'{self.solr_url}/mlt', data=self.query_params).json() |
|
maxscore = r['response']['maxScore'] |
|
response_docs = r['response']['docs'] |
|
response_docs = self.remove_duplicated(response_docs) |
|
return {'query_id': id, |
|
'query_doc': r['match']['docs'][0][self.query_params['mlt.qf']], |
|
'docs': response_docs, |
|
'maxscore': maxscore} |
|
|
|
def remove_duplicated(self,docs): |
|
''' |
|
remove os documentos que são iguais aos documentos do dataset Gold |
|
''' |
|
return [doc for doc in docs if doc['id'] not in self.duplicated] |
|
|
|
|
|
@staticmethod |
|
def create_sentence_pairs(queries): |
|
''' |
|
cria os pares de frases para o dataset silver |
|
''' |
|
pairs = set() |
|
for query in queries: |
|
for doc in query['docs']: |
|
pairs.add( |
|
(query['query_doc'], |
|
doc['assunto_text'], |
|
doc['score']/query['maxscore'])) |
|
return pairs |
|
|
|
def run(self): |
|
queries = [] |
|
list_ids = self.get_ids_list() |
|
for id in tqdm.tqdm(list_ids): |
|
queries.append(self.get_data(id)) |
|
pairs = self.create_sentence_pairs(queries) |
|
return pairs |