#%% from tqdm import tqdm from request_solr import SilverDataset from sentence_transformers.cross_encoder import CrossEncoder import joblib from solr_query_params import params ############################################################################ # # https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation/train_sts_indomain_bm25.py # Step 2: Label BM25 sampled STSb (silver dataset) using cross-encoder model # ############################################################################ cross_encoder_path = 'anatel/cross-encoder-pt-anatel-metadados-assunto' gold_sample_index = set() with open('gold_sample_index.txt', 'r') as f: for line in f: gold_sample_index.add(line.strip()) 7 try: joblib.load('silver_data_v2.pkl') except: print('Creating silver data...') silver_data = SilverDataset(query_params=params,duplicated=gold_sample_index).run() joblib.dump(silver_data, 'silver_data_v2.pkl') print('Done!') sentences = [(sent_1,sent_2) for sent_1, sent_2, _ in silver_data] cross_encoder = CrossEncoder(cross_encoder_path,max_length=512) cross_silver_scores = [] for i in tqdm(sentences): cross_silver_scores.append(cross_encoder.predict(i)) import numpy as np cross_silver_data = np.c_[np.array(silver_data),np.array(cross_silver_scores)] # All model predictions should be between [0,1] assert all(0.0 <= score <= 1.0 for score in cross_silver_scores) joblib.dump(cross_silver_data, 'cross_silver_scores_2.pkl')