File size: 2,188 Bytes
ac03f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import httpx
from pydantic import BaseModel
import tqdm
'''
gold_sample_index = set()
with open('gold_sample_index.txt', 'r') as f:
    for line in f:
        gold_sample_index.add(line.strip())
'''

class SilverDataset(BaseModel):
    '''
    Classe para retornar o dataset silver para ser utilizado pela estratégia do Augmented SBERT
    ref: https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation/train_sts_indomain_bm25.py
    '''
    query_params: dict
    duplicated: set = set()
    solr_url: str = "http://localhost:8983/solr/sei_similaridade_augmented_sbert"
    
    def get_ids_list(self):
        query = f'{self.solr_url}/select?q=*&fl=id&rows=999999999'
        r = httpx.get(query).json()['response']['docs']
        return [doc['id'] for doc in r]

    def get_data(self,id):
        q_id = f"id:{id}"
        self.query_params['q'] = q_id
        r = httpx.post(f'{self.solr_url}/mlt', data=self.query_params).json()
        maxscore = r['response']['maxScore']
        response_docs =  r['response']['docs']
        response_docs = self.remove_duplicated(response_docs)
        return {'query_id': id,
                'query_doc': r['match']['docs'][0][self.query_params['mlt.qf']],
                'docs': response_docs,
                'maxscore': maxscore}
    
    def remove_duplicated(self,docs):
        '''
        remove os documentos que são iguais aos documentos do dataset Gold
        '''
        return [doc for doc in docs if doc['id'] not in self.duplicated]


    @staticmethod
    def create_sentence_pairs(queries):
        '''
        cria os pares de frases para o dataset silver
        '''
        pairs = set()
        for query in queries:
            for doc in query['docs']:
                pairs.add(
                    (query['query_doc'],
                    doc['assunto_text'],
                    doc['score']/query['maxscore']))
        return pairs
    
    def run(self):
        queries = []
        list_ids = self.get_ids_list()
        for id in tqdm.tqdm(list_ids):
            queries.append(self.get_data(id))
        pairs = self.create_sentence_pairs(queries)
        return pairs