Graduation / pipelines /bm25s /tests /comparison /test_bm25s_indexing.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
import json
import os
import time
import unittest
from pathlib import Path
import warnings
import logging
import numpy as np
from beir.datasets.data_loader import GenericDataLoader
from beir.util import download_and_unzip
import Stemmer
import bm25s
def check_scores_all_close(score1, score2, **kwargs):
for key in score1.keys():
matrix1 = score1[key]
matrix2 = score2[key]
if matrix1.shape != matrix2.shape:
return False
if not np.allclose(matrix1, matrix2, **kwargs):
return False
return True
class BM25SIndexing(unittest.TestCase):
def test_indexing_by_corpus_type(self):
warnings.filterwarnings("ignore", category=ResourceWarning)
class Tokenized:
def __init__(self, ids, vocab):
self.ids = ids
self.vocab = vocab
dataset = "scifact"
rel_save_dir = "datasets"
# Download and prepare dataset
base_url = (
"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip"
)
url = base_url.format(dataset)
out_dir = Path(__file__).parent / rel_save_dir
data_path = download_and_unzip(url, str(out_dir))
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(
split="test"
)
corpus_ids, corpus_lst = [], []
for key, val in corpus.items():
corpus_ids.append(key)
corpus_lst.append(val["title"] + " " + val["text"])
stemmer = Stemmer.Stemmer("english")
corpus_tokens_lst = bm25s.tokenize(
corpus_lst,
stopwords="en",
stemmer=stemmer,
leave=False,
return_ids=False,
)
corpus_tokenized = bm25s.tokenize(
corpus_lst,
stopwords="en",
stemmer=stemmer,
leave=False,
return_ids=True,
)
bm25_tokens = bm25s.BM25(k1=0.9, b=0.4)
bm25_tokens.index(corpus_tokens_lst)
bm25_tuples = bm25s.BM25(k1=0.9, b=0.4)
bm25_tuples.index((corpus_tokenized.ids, corpus_tokenized.vocab))
bm25_objects = bm25s.BM25(k1=0.9, b=0.4)
bm25_objects.index(
Tokenized(ids=corpus_tokenized.ids, vocab=corpus_tokenized.vocab)
)
bm25_namedtuple = bm25s.BM25(k1=0.9, b=0.4)
named_tuple = bm25s.tokenization.Tokenized(
ids=corpus_tokenized.ids, vocab=corpus_tokenized.vocab
)
bm25_namedtuple.index(named_tuple)
# now, verify that the sparse matrix matches
self.assertTrue(
check_scores_all_close(
bm25_tokens.scores, bm25_tuples.scores
),
"Tokenized and Tuple indexing do not match",
)
self.assertTrue(
check_scores_all_close(
bm25_tokens.scores, bm25_objects.scores
),
"Tokenized and Object indexing do not match",
)
self.assertTrue(
check_scores_all_close(
bm25_tokens.scores, bm25_namedtuple.scores
),
"Tokenized and NamedTuple indexing do not match",
)