|
import json |
|
import os |
|
import time |
|
import unittest |
|
from pathlib import Path |
|
import warnings |
|
import logging |
|
|
|
import numpy as np |
|
from beir.datasets.data_loader import GenericDataLoader |
|
from beir.util import download_and_unzip |
|
import Stemmer |
|
|
|
import bm25s |
|
|
|
def check_scores_all_close(score1, score2, **kwargs): |
|
for key in score1.keys(): |
|
matrix1 = score1[key] |
|
matrix2 = score2[key] |
|
|
|
if matrix1.shape != matrix2.shape: |
|
return False |
|
if not np.allclose(matrix1, matrix2, **kwargs): |
|
return False |
|
|
|
return True |
|
|
|
class BM25SIndexing(unittest.TestCase): |
|
def test_indexing_by_corpus_type(self): |
|
warnings.filterwarnings("ignore", category=ResourceWarning) |
|
class Tokenized: |
|
def __init__(self, ids, vocab): |
|
self.ids = ids |
|
self.vocab = vocab |
|
|
|
dataset = "scifact" |
|
rel_save_dir = "datasets" |
|
|
|
base_url = ( |
|
"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip" |
|
) |
|
url = base_url.format(dataset) |
|
out_dir = Path(__file__).parent / rel_save_dir |
|
data_path = download_and_unzip(url, str(out_dir)) |
|
|
|
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load( |
|
split="test" |
|
) |
|
|
|
corpus_ids, corpus_lst = [], [] |
|
for key, val in corpus.items(): |
|
corpus_ids.append(key) |
|
corpus_lst.append(val["title"] + " " + val["text"]) |
|
|
|
stemmer = Stemmer.Stemmer("english") |
|
corpus_tokens_lst = bm25s.tokenize( |
|
corpus_lst, |
|
stopwords="en", |
|
stemmer=stemmer, |
|
leave=False, |
|
return_ids=False, |
|
) |
|
|
|
corpus_tokenized = bm25s.tokenize( |
|
corpus_lst, |
|
stopwords="en", |
|
stemmer=stemmer, |
|
leave=False, |
|
return_ids=True, |
|
) |
|
|
|
bm25_tokens = bm25s.BM25(k1=0.9, b=0.4) |
|
bm25_tokens.index(corpus_tokens_lst) |
|
|
|
bm25_tuples = bm25s.BM25(k1=0.9, b=0.4) |
|
bm25_tuples.index((corpus_tokenized.ids, corpus_tokenized.vocab)) |
|
|
|
bm25_objects = bm25s.BM25(k1=0.9, b=0.4) |
|
bm25_objects.index( |
|
Tokenized(ids=corpus_tokenized.ids, vocab=corpus_tokenized.vocab) |
|
) |
|
|
|
bm25_namedtuple = bm25s.BM25(k1=0.9, b=0.4) |
|
named_tuple = bm25s.tokenization.Tokenized( |
|
ids=corpus_tokenized.ids, vocab=corpus_tokenized.vocab |
|
) |
|
bm25_namedtuple.index(named_tuple) |
|
|
|
|
|
self.assertTrue( |
|
check_scores_all_close( |
|
bm25_tokens.scores, bm25_tuples.scores |
|
), |
|
"Tokenized and Tuple indexing do not match", |
|
) |
|
self.assertTrue( |
|
check_scores_all_close( |
|
bm25_tokens.scores, bm25_objects.scores |
|
), |
|
"Tokenized and Object indexing do not match", |
|
) |
|
self.assertTrue( |
|
check_scores_all_close( |
|
bm25_tokens.scores, bm25_namedtuple.scores |
|
), |
|
"Tokenized and NamedTuple indexing do not match", |
|
) |
|
|