Spaces:
Running
Running
Commit
·
b5a3ebb
1
Parent(s):
7f98acf
add: docs for BM25sRetriever
Browse files
medrag_multi_modal/retrieval/bm25s_retrieval.py
CHANGED
@@ -16,6 +16,17 @@ LANGUAGE_DICT = {
|
|
16 |
|
17 |
|
18 |
class BM25sRetriever(weave.Model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
language: str
|
20 |
use_stemmer: bool
|
21 |
_retriever: Optional[bm25s.BM25]
|
@@ -30,6 +41,34 @@ class BM25sRetriever(weave.Model):
|
|
30 |
self._retriever = retriever or bm25s.BM25()
|
31 |
|
32 |
def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
chunk_dataset = weave.ref(chunk_dataset_name).get().rows
|
34 |
corpus = [row["text"] for row in chunk_dataset]
|
35 |
corpus_tokens = bm25s.tokenize(
|
@@ -56,6 +95,23 @@ class BM25sRetriever(weave.Model):
|
|
56 |
|
57 |
@classmethod
|
58 |
def from_wandb_artifact(cls, index_artifact_address: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
if wandb.run:
|
60 |
artifact = wandb.run.use_artifact(
|
61 |
index_artifact_address, type="bm25s-index"
|
@@ -76,13 +132,48 @@ class BM25sRetriever(weave.Model):
|
|
76 |
|
77 |
@weave.op()
|
78 |
def retrieve(self, query: str, top_k: int = 2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
query_tokens = bm25s.tokenize(
|
80 |
query,
|
81 |
stopwords=LANGUAGE_DICT[self.language],
|
82 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
83 |
)
|
84 |
results = self._retriever.retrieve(query_tokens, k=top_k)
|
85 |
-
|
86 |
-
|
87 |
-
"
|
88 |
-
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class BM25sRetriever(weave.Model):
|
19 |
+
"""
|
20 |
+
`BM25sRetriever` is a class that provides functionality for indexing and
|
21 |
+
retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s).
|
22 |
+
|
23 |
+
Args:
|
24 |
+
language (str): The language of the documents to be indexed and retrieved.
|
25 |
+
use_stemmer (bool): A flag indicating whether to use stemming during tokenization.
|
26 |
+
retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided,
|
27 |
+
a new instance is created.
|
28 |
+
"""
|
29 |
+
|
30 |
language: str
|
31 |
use_stemmer: bool
|
32 |
_retriever: Optional[bm25s.BM25]
|
|
|
41 |
self._retriever = retriever or bm25s.BM25()
|
42 |
|
43 |
def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
|
44 |
+
"""
|
45 |
+
Indexes a dataset of text chunks using the BM25 algorithm.
|
46 |
+
|
47 |
+
This function takes a dataset of text chunks identified by `chunk_dataset_name`,
|
48 |
+
tokenizes the text using the BM25 tokenizer with optional stemming, and indexes
|
49 |
+
the tokenized text using the BM25 retriever. If an `index_name` is provided, the
|
50 |
+
index is saved to disk and logged as a Weights & Biases artifact.
|
51 |
+
|
52 |
+
!!! example "Example Usage"
|
53 |
+
```python
|
54 |
+
import weave
|
55 |
+
from dotenv import load_dotenv
|
56 |
+
|
57 |
+
import wandb
|
58 |
+
from medrag_multi_modal.retrieval import BM25sRetriever
|
59 |
+
|
60 |
+
load_dotenv()
|
61 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
62 |
+
wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index")
|
63 |
+
retriever = BM25sRetriever()
|
64 |
+
retriever.index(chunk_dataset_name="grays-anatomy-text:v13", index_name="grays-anatomy-bm25s")
|
65 |
+
```
|
66 |
+
|
67 |
+
Args:
|
68 |
+
chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
|
69 |
+
index_name (Optional[str]): The name to save the index under. If provided, the index
|
70 |
+
is saved to disk and logged as a Weights & Biases artifact.
|
71 |
+
"""
|
72 |
chunk_dataset = weave.ref(chunk_dataset_name).get().rows
|
73 |
corpus = [row["text"] for row in chunk_dataset]
|
74 |
corpus_tokens = bm25s.tokenize(
|
|
|
95 |
|
96 |
@classmethod
|
97 |
def from_wandb_artifact(cls, index_artifact_address: str):
|
98 |
+
"""
|
99 |
+
Creates an instance of the class from a Weights & Biases artifact.
|
100 |
+
|
101 |
+
This class method retrieves a BM25 index artifact from Weights & Biases,
|
102 |
+
downloads the artifact, and loads the BM25 retriever with the index and its
|
103 |
+
associated corpus. The method also extracts metadata from the artifact to
|
104 |
+
initialize the class instance with the appropriate language and stemming
|
105 |
+
settings.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
index_artifact_address (str): The address of the Weights & Biases artifact
|
109 |
+
containing the BM25 index.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
An instance of the class initialized with the BM25 retriever and metadata
|
113 |
+
from the artifact.
|
114 |
+
"""
|
115 |
if wandb.run:
|
116 |
artifact = wandb.run.use_artifact(
|
117 |
index_artifact_address, type="bm25s-index"
|
|
|
132 |
|
133 |
@weave.op()
|
134 |
def retrieve(self, query: str, top_k: int = 2):
|
135 |
+
"""
|
136 |
+
Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm.
|
137 |
+
|
138 |
+
This method tokenizes the input query using the BM25 tokenizer, which takes into
|
139 |
+
account the language-specific stopwords and optional stemming. It then retrieves
|
140 |
+
the top-k most relevant chunks from the BM25 index based on the tokenized query.
|
141 |
+
The results are returned as a list of dictionaries, each containing a chunk and
|
142 |
+
its corresponding relevance score.
|
143 |
+
|
144 |
+
!!! example "Example Usage"
|
145 |
+
```python
|
146 |
+
import weave
|
147 |
+
from dotenv import load_dotenv
|
148 |
+
|
149 |
+
from medrag_multi_modal.retrieval import BM25sRetriever
|
150 |
+
|
151 |
+
load_dotenv()
|
152 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
153 |
+
retriever = BM25sRetriever.from_wandb_artifact(
|
154 |
+
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:v2"
|
155 |
+
)
|
156 |
+
retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
|
157 |
+
```
|
158 |
+
|
159 |
+
Args:
|
160 |
+
query (str): The input query string to search for relevant chunks.
|
161 |
+
top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
list: A list of dictionaries, each containing a retrieved chunk and its
|
165 |
+
relevance score.
|
166 |
+
"""
|
167 |
query_tokens = bm25s.tokenize(
|
168 |
query,
|
169 |
stopwords=LANGUAGE_DICT[self.language],
|
170 |
stemmer=Stemmer(self.language) if self.use_stemmer else None,
|
171 |
)
|
172 |
results = self._retriever.retrieve(query_tokens, k=top_k)
|
173 |
+
retrieved_chunks = []
|
174 |
+
for chunk, score in zip(
|
175 |
+
results["results"].flatten().tolist(),
|
176 |
+
results["scores"].flatten().tolist(),
|
177 |
+
):
|
178 |
+
retrieved_chunks.append({"chunk": chunk, "score": score})
|
179 |
+
return retrieved_chunks
|