geekyrakshit commited on
Commit
b5a3ebb
·
1 Parent(s): 7f98acf

add: docs for BM25sRetriever

Browse files
medrag_multi_modal/retrieval/bm25s_retrieval.py CHANGED
@@ -16,6 +16,17 @@ LANGUAGE_DICT = {
16
 
17
 
18
  class BM25sRetriever(weave.Model):
 
 
 
 
 
 
 
 
 
 
 
19
  language: str
20
  use_stemmer: bool
21
  _retriever: Optional[bm25s.BM25]
@@ -30,6 +41,34 @@ class BM25sRetriever(weave.Model):
30
  self._retriever = retriever or bm25s.BM25()
31
 
32
  def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  chunk_dataset = weave.ref(chunk_dataset_name).get().rows
34
  corpus = [row["text"] for row in chunk_dataset]
35
  corpus_tokens = bm25s.tokenize(
@@ -56,6 +95,23 @@ class BM25sRetriever(weave.Model):
56
 
57
  @classmethod
58
  def from_wandb_artifact(cls, index_artifact_address: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if wandb.run:
60
  artifact = wandb.run.use_artifact(
61
  index_artifact_address, type="bm25s-index"
@@ -76,13 +132,48 @@ class BM25sRetriever(weave.Model):
76
 
77
  @weave.op()
78
  def retrieve(self, query: str, top_k: int = 2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  query_tokens = bm25s.tokenize(
80
  query,
81
  stopwords=LANGUAGE_DICT[self.language],
82
  stemmer=Stemmer(self.language) if self.use_stemmer else None,
83
  )
84
  results = self._retriever.retrieve(query_tokens, k=top_k)
85
- return {
86
- "results": results.documents,
87
- "scores": results.scores,
88
- }
 
 
 
 
16
 
17
 
18
  class BM25sRetriever(weave.Model):
19
+ """
20
+ `BM25sRetriever` is a class that provides functionality for indexing and
21
+ retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s).
22
+
23
+ Args:
24
+ language (str): The language of the documents to be indexed and retrieved.
25
+ use_stemmer (bool): A flag indicating whether to use stemming during tokenization.
26
+ retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided,
27
+ a new instance is created.
28
+ """
29
+
30
  language: str
31
  use_stemmer: bool
32
  _retriever: Optional[bm25s.BM25]
 
41
  self._retriever = retriever or bm25s.BM25()
42
 
43
  def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
44
+ """
45
+ Indexes a dataset of text chunks using the BM25 algorithm.
46
+
47
+ This function takes a dataset of text chunks identified by `chunk_dataset_name`,
48
+ tokenizes the text using the BM25 tokenizer with optional stemming, and indexes
49
+ the tokenized text using the BM25 retriever. If an `index_name` is provided, the
50
+ index is saved to disk and logged as a Weights & Biases artifact.
51
+
52
+ !!! example "Example Usage"
53
+ ```python
54
+ import weave
55
+ from dotenv import load_dotenv
56
+
57
+ import wandb
58
+ from medrag_multi_modal.retrieval import BM25sRetriever
59
+
60
+ load_dotenv()
61
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
62
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index")
63
+ retriever = BM25sRetriever()
64
+ retriever.index(chunk_dataset_name="grays-anatomy-text:v13", index_name="grays-anatomy-bm25s")
65
+ ```
66
+
67
+ Args:
68
+ chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
69
+ index_name (Optional[str]): The name to save the index under. If provided, the index
70
+ is saved to disk and logged as a Weights & Biases artifact.
71
+ """
72
  chunk_dataset = weave.ref(chunk_dataset_name).get().rows
73
  corpus = [row["text"] for row in chunk_dataset]
74
  corpus_tokens = bm25s.tokenize(
 
95
 
96
  @classmethod
97
  def from_wandb_artifact(cls, index_artifact_address: str):
98
+ """
99
+ Creates an instance of the class from a Weights & Biases artifact.
100
+
101
+ This class method retrieves a BM25 index artifact from Weights & Biases,
102
+ downloads the artifact, and loads the BM25 retriever with the index and its
103
+ associated corpus. The method also extracts metadata from the artifact to
104
+ initialize the class instance with the appropriate language and stemming
105
+ settings.
106
+
107
+ Args:
108
+ index_artifact_address (str): The address of the Weights & Biases artifact
109
+ containing the BM25 index.
110
+
111
+ Returns:
112
+ An instance of the class initialized with the BM25 retriever and metadata
113
+ from the artifact.
114
+ """
115
  if wandb.run:
116
  artifact = wandb.run.use_artifact(
117
  index_artifact_address, type="bm25s-index"
 
132
 
133
  @weave.op()
134
  def retrieve(self, query: str, top_k: int = 2):
135
+ """
136
+ Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm.
137
+
138
+ This method tokenizes the input query using the BM25 tokenizer, which takes into
139
+ account the language-specific stopwords and optional stemming. It then retrieves
140
+ the top-k most relevant chunks from the BM25 index based on the tokenized query.
141
+ The results are returned as a list of dictionaries, each containing a chunk and
142
+ its corresponding relevance score.
143
+
144
+ !!! example "Example Usage"
145
+ ```python
146
+ import weave
147
+ from dotenv import load_dotenv
148
+
149
+ from medrag_multi_modal.retrieval import BM25sRetriever
150
+
151
+ load_dotenv()
152
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
153
+ retriever = BM25sRetriever.from_wandb_artifact(
154
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:v2"
155
+ )
156
+ retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
157
+ ```
158
+
159
+ Args:
160
+ query (str): The input query string to search for relevant chunks.
161
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
162
+
163
+ Returns:
164
+ list: A list of dictionaries, each containing a retrieved chunk and its
165
+ relevance score.
166
+ """
167
  query_tokens = bm25s.tokenize(
168
  query,
169
  stopwords=LANGUAGE_DICT[self.language],
170
  stemmer=Stemmer(self.language) if self.use_stemmer else None,
171
  )
172
  results = self._retriever.retrieve(query_tokens, k=top_k)
173
+ retrieved_chunks = []
174
+ for chunk, score in zip(
175
+ results["results"].flatten().tolist(),
176
+ results["scores"].flatten().tolist(),
177
+ ):
178
+ retrieved_chunks.append({"chunk": chunk, "score": score})
179
+ return retrieved_chunks