geekyrakshit commited on
Commit
170d9a9
·
1 Parent(s): e6f968c

update: app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +89 -72
  2. docs/app.md +0 -61
  3. docs/assistant/figure_annotation.md +0 -3
  4. docs/assistant/llm_client.md +0 -3
  5. docs/assistant/medqa_assistant.md +0 -3
  6. docs/chunking.md +0 -3
  7. docs/document_loader/image_loader/base_img_loader.md +0 -3
  8. docs/document_loader/image_loader/fitzpil_img_loader.md +0 -22
  9. docs/document_loader/image_loader/marker_img_loader.md +0 -21
  10. docs/document_loader/image_loader/pdf2image_img_loader.md +0 -26
  11. docs/document_loader/image_loader/pdfplumber_img_loader.md +0 -22
  12. docs/document_loader/image_loader/pymupdf_img_loader.md +0 -23
  13. docs/document_loader/text_loader/base_text_loader.md +0 -3
  14. docs/document_loader/text_loader/marker_text_loader.md +0 -23
  15. docs/document_loader/text_loader/pdfplumber_text_loader.md +0 -22
  16. docs/document_loader/text_loader/pymupdf4llm_text_loader.md +0 -23
  17. docs/document_loader/text_loader/pypdf2_text_loader.md +0 -23
  18. docs/index.md +0 -40
  19. docs/installation/development.md +0 -40
  20. docs/installation/install.md +0 -9
  21. docs/retreival/bm25s.md +0 -3
  22. docs/retreival/colpali.md +0 -3
  23. docs/retreival/contriever.md +0 -3
  24. docs/retreival/medcpt.md +0 -3
  25. docs/retreival/nv_embed_2.md +0 -3
  26. install.sh +0 -30
  27. medrag_multi_modal/assistant/figure_annotation.py +4 -13
  28. medrag_multi_modal/assistant/llm_client.py +19 -11
  29. medrag_multi_modal/assistant/medqa_assistant.py +94 -28
  30. medrag_multi_modal/assistant/schema.py +27 -0
  31. medrag_multi_modal/cli.py +54 -3
  32. medrag_multi_modal/document_loader/image_loader/base_img_loader.py +80 -29
  33. medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py +16 -16
  34. medrag_multi_modal/document_loader/image_loader/marker_img_loader.py +15 -26
  35. medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py +7 -16
  36. medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py +16 -16
  37. medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py +16 -16
  38. medrag_multi_modal/document_loader/text_loader/base_text_loader.py +58 -20
  39. medrag_multi_modal/document_loader/text_loader/marker_text_loader.py +8 -15
  40. medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py +7 -13
  41. medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py +7 -15
  42. medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py +7 -13
  43. medrag_multi_modal/metrics/__init__.py +3 -0
  44. medrag_multi_modal/metrics/base.py +108 -0
  45. medrag_multi_modal/metrics/mmlu.py +24 -0
  46. medrag_multi_modal/retrieval/__init__.py +1 -13
  47. medrag_multi_modal/retrieval/colpali_retrieval.py +1 -1
  48. medrag_multi_modal/retrieval/common.py +0 -23
  49. medrag_multi_modal/retrieval/text_retrieval/__init__.py +11 -0
  50. medrag_multi_modal/retrieval/{bm25s_retrieval.py → text_retrieval/bm25s_retrieval.py} +87 -61
app.py CHANGED
@@ -1,26 +1,20 @@
1
- import os
2
- import wandb
3
-
4
- wandb.login(relogin=True, key=os.getenv("WANDB_API_KEY"))
5
-
6
-
7
  import streamlit as st
8
- import weave
9
 
10
- from medrag_multi_modal.assistant import (
11
- FigureAnnotatorFromPageImage,
12
- LLMClient,
13
- MedQAAssistant,
14
- )
15
- from medrag_multi_modal.assistant.llm_client import (
16
- GOOGLE_MODELS,
17
- MISTRAL_MODELS,
18
- OPENAI_MODELS,
19
  )
20
- from medrag_multi_modal.retrieval import MedCPTRetriever
21
 
22
  # Define constants
23
- ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS
 
 
 
 
 
24
 
25
  # Sidebar for configuration settings
26
  st.sidebar.title("Configuration Settings")
@@ -30,68 +24,91 @@ project_name = st.sidebar.text_input(
30
  placeholder="wandb project name",
31
  help="format: wandb_username/wandb_project_name",
32
  )
33
- chunk_dataset_name = st.sidebar.text_input(
34
- label="Text Chunk WandB Dataset Name",
35
- value="grays-anatomy-chunks:v0",
36
- placeholder="wandb dataset name",
37
- help="format: wandb_dataset_name:version",
38
  )
39
- index_artifact_address = st.sidebar.text_input(
40
- label="WandB Index Artifact Address",
41
- value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
42
- placeholder="wandb artifact address",
43
- help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
44
  )
45
- image_artifact_address = st.sidebar.text_input(
46
- label="WandB Image Artifact Address",
47
- value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
48
- placeholder="wandb artifact address",
49
- help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
50
  )
51
- llm_client_model_name = st.sidebar.selectbox(
52
- label="LLM Client Model Name",
53
- options=ALL_AVAILABLE_MODELS,
54
- index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"),
55
- help="select a model from the list",
56
  )
57
- figure_extraction_model_name = st.sidebar.selectbox(
58
- label="Figure Extraction Model Name",
59
- options=ALL_AVAILABLE_MODELS,
60
- index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"),
61
- help="select a model from the list",
62
  )
63
- structured_output_model_name = st.sidebar.selectbox(
64
- label="Structured Output Model Name",
65
- options=ALL_AVAILABLE_MODELS,
66
- index=ALL_AVAILABLE_MODELS.index("gpt-4o"),
67
- help="select a model from the list",
 
 
 
 
68
  )
69
 
70
- # Streamlit app layout
71
- st.title("MedQA Assistant App")
72
 
73
- # Initialize Weave
74
- weave.init(project_name=project_name)
75
 
76
- # Initialize clients and assistants
77
- llm_client = LLMClient(model_name=llm_client_model_name)
78
- retriever = MedCPTRetriever.from_wandb_artifact(
79
- chunk_dataset_name=chunk_dataset_name,
80
- index_artifact_address=index_artifact_address,
81
- )
82
- figure_annotator = FigureAnnotatorFromPageImage(
83
- figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
84
- structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
85
- image_artifact_address=image_artifact_address,
86
- )
87
- medqa_assistant = MedQAAssistant(
88
- llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- query = st.chat_input("Enter your question here")
92
- if query:
93
- with st.chat_message("user"):
94
- st.markdown(query)
95
- response = medqa_assistant.predict(query=query)
96
  with st.chat_message("assistant"):
97
- st.markdown(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
2
 
3
+ from medrag_multi_modal.assistant import LLMClient, MedQAAssistant
4
+ from medrag_multi_modal.retrieval.text_retrieval import (
5
+ BM25sRetriever,
6
+ ContrieverRetriever,
7
+ MedCPTRetriever,
8
+ NVEmbed2Retriever,
 
 
 
9
  )
 
10
 
11
  # Define constants
12
+ ALL_AVAILABLE_MODELS = [
13
+ "gemini-1.5-flash-latest",
14
+ "gemini-1.5-pro-latest",
15
+ "gpt-4o",
16
+ "gpt-4o-mini",
17
+ ]
18
 
19
  # Sidebar for configuration settings
20
  st.sidebar.title("Configuration Settings")
 
24
  placeholder="wandb project name",
25
  help="format: wandb_username/wandb_project_name",
26
  )
27
+ chunk_dataset_id = st.sidebar.selectbox(
28
+ label="Chunk Dataset ID",
29
+ options=["ashwiniai/medrag-text-corpus-chunks"],
 
 
30
  )
31
+ llm_model = st.sidebar.selectbox(
32
+ label="LLM Model",
33
+ options=ALL_AVAILABLE_MODELS,
 
 
34
  )
35
+ top_k_chunks_for_query = st.sidebar.slider(
36
+ label="Top K Chunks for Query",
37
+ min_value=1,
38
+ max_value=20,
39
+ value=5,
40
  )
41
+ top_k_chunks_for_options = st.sidebar.slider(
42
+ label="Top K Chunks for Options",
43
+ min_value=1,
44
+ max_value=20,
45
+ value=3,
46
  )
47
+ rely_only_on_context = st.sidebar.checkbox(
48
+ label="Rely Only on Context",
49
+ value=False,
 
 
50
  )
51
+ retriever_type = st.sidebar.selectbox(
52
+ label="Retriever Type",
53
+ options=[
54
+ "",
55
+ "BM25S",
56
+ "Contriever",
57
+ "MedCPT",
58
+ "NV-Embed-v2",
59
+ ],
60
  )
61
 
62
+ if retriever_type != "":
 
63
 
64
+ llm_model = LLMClient(model_name=llm_model)
 
65
 
66
+ retriever = None
67
+
68
+ if retriever_type == "BM25S":
69
+ retriever = BM25sRetriever.from_index(
70
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s"
71
+ )
72
+ elif retriever_type == "Contriever":
73
+ retriever = ContrieverRetriever.from_index(
74
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever",
75
+ chunk_dataset_id=chunk_dataset_id,
76
+ )
77
+ elif retriever_type == "MedCPT":
78
+ retriever = MedCPTRetriever.from_index(
79
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt",
80
+ chunk_dataset_id=chunk_dataset_id,
81
+ )
82
+ elif retriever_type == "NV-Embed-v2":
83
+ retriever = NVEmbed2Retriever.from_index(
84
+ index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2",
85
+ chunk_dataset_id=chunk_dataset_id,
86
+ )
87
+
88
+ medqa_assistant = MedQAAssistant(
89
+ llm_client=llm_model,
90
+ retriever=retriever,
91
+ top_k_chunks_for_query=top_k_chunks_for_query,
92
+ top_k_chunks_for_options=top_k_chunks_for_options,
93
+ )
94
 
 
 
 
 
 
95
  with st.chat_message("assistant"):
96
+ st.markdown(
97
+ """
98
+ Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences.
99
+ I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge.
100
+
101
+ **Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions.
102
+ Please consult a medical professional for any medical advice.
103
+
104
+ In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal).
105
+ """,
106
+ unsafe_allow_html=True,
107
+ )
108
+ query = st.chat_input("Enter your question here")
109
+ if query:
110
+ with st.chat_message("user"):
111
+ st.markdown(query)
112
+ response = medqa_assistant.predict(query=query)
113
+ with st.chat_message("assistant"):
114
+ st.markdown(response.response)
docs/app.md DELETED
@@ -1,61 +0,0 @@
1
- # MedQA Assistant App
2
-
3
- The MedQA Assistant App is a Streamlit-based application designed to provide a chat interface for medical question answering. It leverages advanced language models (LLMs) and retrieval augmented generation (RAG) techniques to deliver accurate and informative responses to medical queries.
4
-
5
- ## Features
6
-
7
- - **Interactive Chat Interface**: Engage with the app through a user-friendly chat interface.
8
- - **Configurable Settings**: Customize model selection and data sources via the sidebar.
9
- - **Retrieval-Augmented Generation**: Ensures precise and contextually relevant responses.
10
- - **Figure Annotation Capabilities**: Extracts and annotates figures from medical texts.
11
-
12
- ## Usage
13
-
14
- 1. Install the package using:
15
- ```bash
16
- uv pip install .
17
- ```
18
- 1. **Launch the App**: Start the application using Streamlit:
19
- ```bash
20
- medrag run
21
- ```
22
- 2. **Configure Settings**: Adjust configuration settings in the sidebar to suit your needs.
23
- 3. **Ask a Question**: Enter your medical question in the chat input field.
24
- 4. **Receive a Response**: Get a detailed answer from the MedQA Assistant.
25
-
26
- ## Configuration
27
-
28
- The app allows users to customize various settings through the sidebar:
29
-
30
- - **Project Name**: Specify the WandB project name.
31
- - **Text Chunk WandB Dataset Name**: Define the dataset containing text chunks.
32
- - **WandB Index Artifact Address**: Provide the address of the index artifact.
33
- - **WandB Image Artifact Address**: Provide the address of the image artifact.
34
- - **LLM Client Model Name**: Choose a language model for generating responses.
35
- - **Figure Extraction Model Name**: Select a model for extracting figures from images.
36
- - **Structured Output Model Name**: Choose a model for generating structured outputs.
37
-
38
- ## Technical Details
39
-
40
- The app is built using the following components:
41
-
42
- - **Streamlit**: For the user interface.
43
- - **Weave**: For project initialization and artifact management.
44
- - **MedQAAssistant**: For processing queries and generating responses.
45
- - **LLMClient**: For interacting with language models.
46
- - **MedCPTRetriever**: For retrieving relevant text chunks.
47
- - **FigureAnnotatorFromPageImage**: For annotating figures in medical texts.
48
-
49
- ## Development and Deployment
50
-
51
- - **Environment Setup**: Ensure all dependencies are installed as per the `pyproject.toml`.
52
- - **Running the App**: Use Streamlit to run the app locally.
53
- - **Deployment**: coming soon...
54
-
55
- ## Additional Resources
56
-
57
- For more detailed information on the components and their usage, refer to the following documentation sections:
58
-
59
- - [MedQA Assistant](/assistant/medqa_assistant)
60
- - [LLM Client](/assistant/llm_client)
61
- - [Figure Annotation](/assistant/figure_annotation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/assistant/figure_annotation.md DELETED
@@ -1,3 +0,0 @@
1
- # Figure Annotation
2
-
3
- ::: medrag_multi_modal.assistant.figure_annotation
 
 
 
 
docs/assistant/llm_client.md DELETED
@@ -1,3 +0,0 @@
1
- # LLM Client
2
-
3
- ::: medrag_multi_modal.assistant.llm_client
 
 
 
 
docs/assistant/medqa_assistant.md DELETED
@@ -1,3 +0,0 @@
1
- # MedQA Assistant
2
-
3
- ::: medrag_multi_modal.assistant.medqa_assistant
 
 
 
 
docs/chunking.md DELETED
@@ -1,3 +0,0 @@
1
- # Chunking
2
-
3
- ::: medrag_multi_modal.semantic_chunking
 
 
 
 
docs/document_loader/image_loader/base_img_loader.md DELETED
@@ -1,3 +0,0 @@
1
- ## Load images from PDF files
2
-
3
- ::: medrag_multi_modal.document_loader.image_loader.base_img_loader
 
 
 
 
docs/document_loader/image_loader/fitzpil_img_loader.md DELETED
@@ -1,22 +0,0 @@
1
- # Load images from PDF files (using Fitz & PIL)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `fitz` & `pillow`
5
-
6
- Extract images from PDF files using `fitz` and `pillow`.
7
-
8
- Use it in our library with:
9
- ```python
10
- from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
11
- ```
12
-
13
- For more details, please refer to the sources below.
14
-
15
- **Sources:**
16
-
17
- - [Docs](https://pymupdf.readthedocs.io/en/latest/intro.html)
18
- - [GitHub](https://github.com/kastman/fitz)
19
- - [PyPI](https://pypi.org/project/fitz/)
20
- - [PyPI](https://pypi.org/project/pillow/)
21
-
22
- ::: medrag_multi_modal.document_loader.image_loader.fitzpil_img_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/image_loader/marker_img_loader.md DELETED
@@ -1,21 +0,0 @@
1
- # Load images from PDF files (using Marker)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `marker-pdf`
5
-
6
- Extract images from PDF files using `marker-pdf`.
7
-
8
- Use it in our library with:
9
- ```python
10
- from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
11
- ```
12
-
13
- For details, please refer to the sources below.
14
-
15
- **Sources:**
16
-
17
- - [DataLab](https://www.datalab.to)
18
- - [GitHub](https://github.com/VikParuchuri/marker)
19
- - [PyPI](https://pypi.org/project/marker-pdf/)
20
-
21
- ::: medrag_multi_modal.document_loader.image_loader.marker_img_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/image_loader/pdf2image_img_loader.md DELETED
@@ -1,26 +0,0 @@
1
- # Load images from PDF files (using PDF2Image)
2
-
3
- !!! danger "Warning"
4
- Unlike other image extraction methods in `document_loader.image_loader`, this loader does not extract embedded images from the PDF.
5
- Instead, it creates a snapshot image version of each selected page from the PDF.
6
-
7
- ??? note "Note"
8
- **Underlying Library:** `pdf2image`
9
-
10
- Extract images from PDF files using `pdf2image`.
11
-
12
-
13
- Use it in our library with:
14
- ```python
15
- from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
16
- ```
17
-
18
- For details and available `**kwargs`, please refer to the sources below.
19
-
20
- **Sources:**
21
-
22
- - [DataLab](https://www.datalab.to)
23
- - [GitHub](https://github.com/VikParuchuri/marker)
24
- - [PyPI](https://pypi.org/project/marker-pdf/)
25
-
26
- ::: medrag_multi_modal.document_loader.image_loader.pdf2image_img_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/image_loader/pdfplumber_img_loader.md DELETED
@@ -1,22 +0,0 @@
1
- # Load images from PDF files (using PDFPlumber)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `pdfplumber`
5
-
6
- Extract images from PDF files using `pdfplumber`.
7
-
8
- You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
9
-
10
- Use it in our library with:
11
- ```python
12
- from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
13
- ```
14
-
15
- For details, please refer to the sources below.
16
-
17
- **Sources:**
18
-
19
- - [GitHub](https://github.com/jsvine/pdfplumber)
20
- - [PyPI](https://pypi.org/project/pdfplumber/)
21
-
22
- ::: medrag_multi_modal.document_loader.image_loader.pdfplumber_img_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/image_loader/pymupdf_img_loader.md DELETED
@@ -1,23 +0,0 @@
1
- # Load images from PDF files (using PyMuPDF)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `pymupdf`
5
-
6
- PyMuPDF is a high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents.
7
-
8
- You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
9
-
10
- Use it in our library with:
11
- ```python
12
- from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
13
- ```
14
-
15
- For details, please refer to the sources below.
16
-
17
- **Sources:**
18
-
19
- - [Docs](https://pymupdf.readthedocs.io/en/latest/)
20
- - [GitHub](https://github.com/pymupdf/PyMuPDF)
21
- - [PyPI](https://pypi.org/project/PyMuPDF/)
22
-
23
- ::: medrag_multi_modal.document_loader.image_loader.pymupdf_img_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/text_loader/base_text_loader.md DELETED
@@ -1,3 +0,0 @@
1
- ## Load text from PDF files
2
-
3
- ::: medrag_multi_modal.document_loader.text_loader.base_text_loader
 
 
 
 
docs/document_loader/text_loader/marker_text_loader.md DELETED
@@ -1,23 +0,0 @@
1
- ## Load text from PDF files (using Marker)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `marker-pdf`
5
-
6
- Convert PDF to markdown quickly and accurately using a pipeline of deep learning models.
7
-
8
- You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
9
-
10
- Use it in our library with:
11
- ```python
12
- from medrag_multi_modal.document_loader.text_loader import MarkerTextLoader
13
- ```
14
-
15
- For details and available `**kwargs`, please refer to the sources below.
16
-
17
- **Sources:**
18
-
19
- - [DataLab](https://www.datalab.to)
20
- - [GitHub](https://github.com/VikParuchuri/marker)
21
- - [PyPI](https://pypi.org/project/marker-pdf/)
22
-
23
- ::: medrag_multi_modal.document_loader.text_loader.marker_text_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/text_loader/pdfplumber_text_loader.md DELETED
@@ -1,22 +0,0 @@
1
- ## Load text from PDF files (using PDFPlumber)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `pdfplumber`
5
-
6
- Plumb a PDF for detailed information about each char, rectangle, line, et cetera — and easily extract text and tables.
7
-
8
- You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
9
-
10
- Use it in our library with:
11
- ```python
12
- from medrag_multi_modal.document_loader.text_loader import PDFPlumberTextLoader
13
- ```
14
-
15
- For details and available `**kwargs`, please refer to the sources below.
16
-
17
- **Sources:**
18
-
19
- - [GitHub](https://github.com/jsvine/pdfplumber)
20
- - [PyPI](https://pypi.org/project/pdfplumber/)
21
-
22
- ::: medrag_multi_modal.document_loader.text_loader.pdfplumber_text_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/text_loader/pymupdf4llm_text_loader.md DELETED
@@ -1,23 +0,0 @@
1
- ## Load text from PDF files (using PyMuPDF4LLM)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `pymupdf4llm`
5
-
6
- PyMuPDF is a high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents.
7
-
8
- You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
9
-
10
- Use it in our library with:
11
- ```python
12
- from medrag_multi_modal.document_loader.text_loader import PyMuPDF4LLMTextLoader
13
- ```
14
-
15
- For details and available `**kwargs`, please refer to the sources below.
16
-
17
- **Sources:**
18
-
19
- - [Docs](https://pymupdf.readthedocs.io/en/latest/pymupdf4llm/)
20
- - [GitHub](https://github.com/pymupdf/PyMuPDF)
21
- - [PyPI](https://pypi.org/project/pymupdf4llm/)
22
-
23
- ::: medrag_multi_modal.document_loader.text_loader.pymupdf4llm_text_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/document_loader/text_loader/pypdf2_text_loader.md DELETED
@@ -1,23 +0,0 @@
1
- ## Load text from PDF files (using PyPDF2)
2
-
3
- ??? note "Note"
4
- **Underlying Library:** `pypdf2`
5
-
6
- A pure-python PDF library capable of splitting, merging, cropping, and transforming the pages of PDF files
7
-
8
- You can interact with the underlying library and fine-tune the outputs via `**kwargs`.
9
-
10
- Use it in our library with:
11
- ```python
12
- from medrag_multi_modal.document_loader.text_loader import PyPDF2TextLoader
13
- ```
14
-
15
- For details and available `**kwargs`, please refer to the sources below.
16
-
17
- **Sources:**
18
-
19
- - [Docs](https://pypdf2.readthedocs.io/en/3.x/)
20
- - [GitHub](https://github.com/py-pdf/pypdf)
21
- - [PyPI](https://pypi.org/project/PyPDF2/)
22
-
23
- ::: medrag_multi_modal.document_loader.text_loader.pypdf2_text_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/index.md DELETED
@@ -1,40 +0,0 @@
1
- # MedRAG Multi-Modal
2
-
3
- Multi-modal RAG for medical docmain.
4
-
5
- ## Installation
6
-
7
- ### For Development
8
-
9
- For MacOS, you need to run
10
-
11
- ```bash
12
- brew install poppler
13
- ```
14
-
15
- For Debian/Ubuntu, you need to run
16
-
17
- ```bash
18
- sudo apt-get install -y poppler-utils
19
- ```
20
-
21
- Then, you can install the dependencies using uv in the virtual environment `.venv` using
22
-
23
- ```bash
24
- git clone https://github.com/soumik12345/medrag-multi-modal
25
- cd medrag-multi-modal
26
- pip install -U pip uv
27
- uv sync
28
- ```
29
-
30
- After this, you need to activate the virtual environment using
31
-
32
- ```bash
33
- source .venv/bin/activate
34
- ```
35
-
36
- In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using
37
-
38
- ```bash
39
- uv pip install flash-attn --no-build-isolation
40
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/installation/development.md DELETED
@@ -1,40 +0,0 @@
1
- # Setting up the development environment
2
-
3
- ## Install Poppler
4
-
5
- For MacOS, you need to run
6
-
7
- ```bash
8
- brew install poppler
9
- ```
10
-
11
- For Debian/Ubuntu, you need to run
12
-
13
- ```bash
14
- sudo apt-get install -y poppler-utils
15
- ```
16
-
17
- ## Install the dependencies
18
-
19
- Then, you can install the dependencies using uv in the virtual environment `.venv` using
20
-
21
- ```bash
22
- git clone https://github.com/soumik12345/medrag-multi-modal
23
- cd medrag-multi-modal
24
- pip install -U pip uv
25
- uv sync
26
- ```
27
-
28
- After this, you need to activate the virtual environment using
29
-
30
- ```bash
31
- source .venv/bin/activate
32
- ```
33
-
34
- ## [Optional] Install Flash Attention
35
-
36
- In the activated virtual environment, you can optionally install Flash Attention (required for ColPali) using
37
-
38
- ```bash
39
- uv pip install flash-attn --no-build-isolation
40
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/installation/install.md DELETED
@@ -1,9 +0,0 @@
1
- # Installation
2
-
3
- You just need to clone the repository and run the install.sh script
4
-
5
- ```bash
6
- git clone https://github.com/soumik12345/medrag-multi-modal
7
- cd medrag-multi-modal
8
- sh install.sh
9
- ```
 
 
 
 
 
 
 
 
 
 
docs/retreival/bm25s.md DELETED
@@ -1,3 +0,0 @@
1
- # BM25-Sparse Retrieval
2
-
3
- ::: medrag_multi_modal.retrieval.bm25s_retrieval
 
 
 
 
docs/retreival/colpali.md DELETED
@@ -1,3 +0,0 @@
1
- # ColPali Retrieval
2
-
3
- ::: medrag_multi_modal.retrieval.colpali_retrieval
 
 
 
 
docs/retreival/contriever.md DELETED
@@ -1,3 +0,0 @@
1
- # Contriever Retrieval
2
-
3
- ::: medrag_multi_modal.retrieval.contriever_retrieval
 
 
 
 
docs/retreival/medcpt.md DELETED
@@ -1,3 +0,0 @@
1
- # MedCPT Retrieval
2
-
3
- ::: medrag_multi_modal.retrieval.medcpt_retrieval
 
 
 
 
docs/retreival/nv_embed_2.md DELETED
@@ -1,3 +0,0 @@
1
- # NV-Embed-v2 Retrieval
2
-
3
- ::: medrag_multi_modal.retrieval.nv_embed_2
 
 
 
 
install.sh DELETED
@@ -1,30 +0,0 @@
1
- #!/bin/bash
2
-
3
- OS_TYPE=$(uname -s)
4
-
5
- if [ "$OS_TYPE" = "Darwin" ]; then
6
- echo "Detected macOS."
7
- brew install poppler
8
- elif [ "$OS_TYPE" = "Linux" ]; then
9
- if [ -f /etc/os-release ]; then
10
- . /etc/os-release
11
- if [ "$ID" = "ubuntu" ] || [ "$ID" = "debian" ]; then
12
- echo "Detected Ubuntu/Debian."
13
- sudo apt-get update
14
- sudo apt-get install -y poppler-utils
15
- else
16
- echo "Unsupported Linux distribution: $ID"
17
- exit 1
18
- fi
19
- else
20
- echo "Cannot detect Linux distribution."
21
- exit 1
22
- fi
23
- else
24
- echo "Unsupported OS: $OS_TYPE"
25
- exit 1
26
- fi
27
-
28
- git clone https://github.com/soumik12345/medrag-multi-modal
29
- cd medrag-multi-modal
30
- pip install -U .[core]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
medrag_multi_modal/assistant/figure_annotation.py CHANGED
@@ -5,19 +5,10 @@ from typing import Optional, Union
5
  import cv2
6
  import weave
7
  from PIL import Image
8
- from pydantic import BaseModel
9
 
10
- from ..utils import get_wandb_artifact, read_jsonl_file
11
- from .llm_client import LLMClient
12
-
13
-
14
- class FigureAnnotation(BaseModel):
15
- figure_id: str
16
- figure_description: str
17
-
18
-
19
- class FigureAnnotations(BaseModel):
20
- annotations: list[FigureAnnotation]
21
 
22
 
23
  class FigureAnnotatorFromPageImage(weave.Model):
@@ -108,7 +99,7 @@ Here are some clues you need to follow:
108
  )
109
 
110
  @weave.op()
111
- def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]:
112
  """
113
  Predicts figure annotations for a specific page in a document.
114
 
 
5
  import cv2
6
  import weave
7
  from PIL import Image
 
8
 
9
+ from medrag_multi_modal.assistant.llm_client import LLMClient
10
+ from medrag_multi_modal.assistant.schema import FigureAnnotations
11
+ from medrag_multi_modal.utils import get_wandb_artifact, read_jsonl_file
 
 
 
 
 
 
 
 
12
 
13
 
14
  class FigureAnnotatorFromPageImage(weave.Model):
 
99
  )
100
 
101
  @weave.op()
102
+ def predict(self, page_idx: int) -> dict[int, list[FigureAnnotations]]:
103
  """
104
  Predicts figure annotations for a specific page in a document.
105
 
medrag_multi_modal/assistant/llm_client.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from enum import Enum
3
  from typing import Any, Optional, Union
@@ -93,6 +94,7 @@ class LLMClient(weave.Model):
93
  schema: Optional[Any] = None,
94
  ) -> Union[str, Any]:
95
  import google.generativeai as genai
 
96
 
97
  system_prompt = (
98
  [system_prompt] if isinstance(system_prompt, str) else system_prompt
@@ -100,18 +102,25 @@ class LLMClient(weave.Model):
100
  user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
101
 
102
  genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
103
- model = genai.GenerativeModel(self.model_name)
104
  generation_config = (
105
  None
106
  if schema is None
107
  else genai.GenerationConfig(
108
- response_mime_type="application/json", response_schema=list[schema]
109
  )
110
  )
111
  response = model.generate_content(
112
- system_prompt + user_prompt, generation_config=generation_config
 
 
 
 
 
 
 
113
  )
114
- return response.text if schema is None else response
115
 
116
  @weave.op()
117
  def execute_mistral_sdk(
@@ -146,14 +155,13 @@ class LLMClient(weave.Model):
146
  client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
147
  client = instructor.from_mistral(client) if schema is not None else client
148
 
149
- response = (
150
- client.chat.complete(model=self.model_name, messages=messages)
151
- if schema is None
152
- else client.messages.create(
153
- response_model=schema, messages=messages, temperature=0
154
  )
155
- )
156
- return response.choices[0].message.content
 
157
 
158
  @weave.op()
159
  def execute_openai_sdk(
 
1
+ import json
2
  import os
3
  from enum import Enum
4
  from typing import Any, Optional, Union
 
94
  schema: Optional[Any] = None,
95
  ) -> Union[str, Any]:
96
  import google.generativeai as genai
97
+ from google.generativeai.types import HarmBlockThreshold, HarmCategory
98
 
99
  system_prompt = (
100
  [system_prompt] if isinstance(system_prompt, str) else system_prompt
 
102
  user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
103
 
104
  genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
105
+ model = genai.GenerativeModel(self.model_name, system_instruction=system_prompt)
106
  generation_config = (
107
  None
108
  if schema is None
109
  else genai.GenerationConfig(
110
+ response_mime_type="application/json", response_schema=schema
111
  )
112
  )
113
  response = model.generate_content(
114
+ user_prompt,
115
+ generation_config=generation_config,
116
+ # This is necessary in order to answer questions about anatomy, sexual diseases,
117
+ # medical devices, medicines, etc.
118
+ safety_settings={
119
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
120
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
121
+ },
122
  )
123
+ return response.text if schema is None else json.loads(response.text)
124
 
125
  @weave.op()
126
  def execute_mistral_sdk(
 
155
  client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
156
  client = instructor.from_mistral(client) if schema is not None else client
157
 
158
+ if schema is None:
159
+ raise NotImplementedError(
160
+ "Mistral does not support structured output using a schema"
 
 
161
  )
162
+ else:
163
+ response = client.chat.complete(model=self.model_name, messages=messages)
164
+ return response.choices[0].message.content
165
 
166
  @weave.op()
167
  def execute_openai_sdk(
medrag_multi_modal/assistant/medqa_assistant.py CHANGED
@@ -1,8 +1,16 @@
 
 
1
  import weave
2
 
3
- from ..retrieval import SimilarityMetric
4
- from .figure_annotation import FigureAnnotatorFromPageImage
5
- from .llm_client import LLMClient
 
 
 
 
 
 
6
 
7
 
8
  class MedQAAssistant(weave.Model):
@@ -47,39 +55,68 @@ class MedQAAssistant(weave.Model):
47
  llm_client (LLMClient): The language model client used to generate responses.
48
  retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
49
  figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
50
- top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
 
51
  retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
52
  """
53
 
54
  llm_client: LLMClient
55
  retriever: weave.Model
56
- figure_annotator: FigureAnnotatorFromPageImage
57
- top_k_chunks: int = 2
 
 
58
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
59
 
60
  @weave.op()
61
- def predict(self, query: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """
63
  Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
64
  from a medical document and using a language model to generate the final response.
65
 
66
  This function performs the following steps:
67
- 1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
 
68
  2. Extracts the text and page indices from the retrieved chunks.
69
  3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
70
- 4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
71
- 5. Uses the language model client to generate a response based on the constructed prompts.
72
- 6. Appends the source information (page numbers) to the generated response.
 
 
 
 
 
 
73
 
74
  Args:
75
  query (str): The medical query to be answered.
 
 
76
 
77
  Returns:
78
- str: The generated response to the query, including source information.
79
  """
80
- retrieved_chunks = self.retriever.predict(
81
- query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
82
- )
83
 
84
  retrieved_chunk_texts = []
85
  page_indices = set()
@@ -88,21 +125,50 @@ class MedQAAssistant(weave.Model):
88
  page_indices.add(int(chunk["page_idx"]))
89
 
90
  figure_descriptions = []
91
- for page_idx in page_indices:
92
- figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
93
- page_idx
94
- ]
95
- figure_descriptions += [
96
- item["figure_description"] for item in figure_annotations
97
- ]
98
-
99
- system_prompt = """
100
- You are an expert in medical science. You are given a query and a list of chunks from a medical document.
 
 
 
 
101
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  response = self.llm_client.predict(
103
  system_prompt=system_prompt,
104
  user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
 
105
  )
106
- page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
107
- response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
108
- return response
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
  import weave
4
 
5
+ from medrag_multi_modal.assistant.figure_annotation import FigureAnnotatorFromPageImage
6
+ from medrag_multi_modal.assistant.llm_client import LLMClient
7
+ from medrag_multi_modal.assistant.schema import (
8
+ MedQACitation,
9
+ MedQAMCQResponse,
10
+ MedQAResponse,
11
+ )
12
+ from medrag_multi_modal.retrieval.common import SimilarityMetric
13
+ from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
14
 
15
 
16
  class MedQAAssistant(weave.Model):
 
55
  llm_client (LLMClient): The language model client used to generate responses.
56
  retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
57
  figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
58
+ top_k_chunks_for_query (int): The number of top chunks to retrieve based on similarity metric for the query.
59
+ top_k_chunks_for_options (int): The number of top chunks to retrieve based on similarity metric for the options.
60
  retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
61
  """
62
 
63
  llm_client: LLMClient
64
  retriever: weave.Model
65
+ figure_annotator: Optional[FigureAnnotatorFromPageImage] = None
66
+ top_k_chunks_for_query: int = 2
67
+ top_k_chunks_for_options: int = 2
68
+ rely_only_on_context: bool = True
69
  retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
70
 
71
  @weave.op()
72
+ def retrieve_chunks_for_query(self, query: str) -> list[dict]:
73
+ retriever_kwargs = {"top_k": self.top_k_chunks_for_query}
74
+ if not isinstance(self.retriever, BM25sRetriever):
75
+ retriever_kwargs["metric"] = self.retrieval_similarity_metric
76
+ return self.retriever.predict(query, **retriever_kwargs)
77
+
78
+ @weave.op()
79
+ def retrieve_chunks_for_options(self, options: list[str]) -> list[dict]:
80
+ retriever_kwargs = {"top_k": self.top_k_chunks_for_options}
81
+ if not isinstance(self.retriever, BM25sRetriever):
82
+ retriever_kwargs["metric"] = self.retrieval_similarity_metric
83
+ retrieved_chunks = []
84
+ for option in options:
85
+ retrieved_chunks += self.retriever.predict(query=option, **retriever_kwargs)
86
+ return retrieved_chunks
87
+
88
+ @weave.op()
89
+ def predict(self, query: str, options: Optional[list[str]] = None) -> MedQAResponse:
90
  """
91
  Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
92
  from a medical document and using a language model to generate the final response.
93
 
94
  This function performs the following steps:
95
+ 1. Retrieves relevant text chunks from the medical document based on the query and any provided options
96
+ using the retriever model.
97
  2. Extracts the text and page indices from the retrieved chunks.
98
  3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
99
+ 4. Constructs a system prompt and user prompt combining the query, options (if provided), retrieved text chunks,
100
+ and figure descriptions.
101
+ 5. Uses the language model client to generate a response based on the constructed prompts, either choosing
102
+ from provided options or generating a free-form response.
103
+ 6. Returns the generated response, which includes the answer and explanation if options were provided.
104
+
105
+ The function can operate in two modes:
106
+ - Multiple choice: When options are provided, it selects the best answer from the options and explains the choice
107
+ - Free response: When no options are provided, it generates a comprehensive response based on the context
108
 
109
  Args:
110
  query (str): The medical query to be answered.
111
+ options (Optional[list[str]]): The list of options to choose from.
112
+ rely_only_on_context (bool): Whether to rely only on the context provided or not during response generation.
113
 
114
  Returns:
115
+ MedQAResponse: The generated response to the query, including source information.
116
  """
117
+ retrieved_chunks = self.retrieve_chunks_for_query(query)
118
+ options = options or []
119
+ retrieved_chunks += self.retrieve_chunks_for_options(options)
120
 
121
  retrieved_chunk_texts = []
122
  page_indices = set()
 
125
  page_indices.add(int(chunk["page_idx"]))
126
 
127
  figure_descriptions = []
128
+ if self.figure_annotator is not None:
129
+ for page_idx in page_indices:
130
+ figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
131
+ page_idx
132
+ ]
133
+ figure_descriptions += [
134
+ item["figure_description"] for item in figure_annotations
135
+ ]
136
+
137
+ system_prompt = """You are an expert in medical science. You are given a question
138
+ and a list of excerpts from various medical documents.
139
+ """
140
+ query = f"""# Question
141
+ {query}
142
  """
143
+
144
+ if len(options) > 0:
145
+ system_prompt += """\nYou are also given a list of options to choose your answer from.
146
+ You are supposed to choose the best possible option based on the context provided. You should also
147
+ explain your answer to justify why you chose that option.
148
+ """
149
+ query += "## Options\n"
150
+ for option in options:
151
+ query += f"- {option}\n"
152
+ else:
153
+ system_prompt += "\nYou are supposed to answer the question based on the context provided."
154
+
155
+ if self.rely_only_on_context:
156
+ system_prompt += """\n\nYou are only allowed to use the context provided to answer the question.
157
+ You are not allowed to use any external knowledge to answer the question.
158
+ """
159
+
160
  response = self.llm_client.predict(
161
  system_prompt=system_prompt,
162
  user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
163
+ schema=MedQAMCQResponse if len(options) > 0 else None,
164
  )
165
+
166
+ # TODO: Add figure citations
167
+ # TODO: Add source document name from retrieved chunks as citations
168
+ citations = []
169
+ for page_idx in page_indices:
170
+ citations.append(
171
+ MedQACitation(page_number=page_idx + 1, document_name="Gray's Anatomy")
172
+ )
173
+
174
+ return MedQAResponse(response=response, citations=citations)
medrag_multi_modal/assistant/schema.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class FigureAnnotation(BaseModel):
7
+ figure_id: str
8
+ figure_description: str
9
+
10
+
11
+ class FigureAnnotations(BaseModel):
12
+ annotations: list[FigureAnnotation]
13
+
14
+
15
+ class MedQAMCQResponse(BaseModel):
16
+ answer: str
17
+ explanation: str
18
+
19
+
20
+ class MedQACitation(BaseModel):
21
+ page_number: int
22
+ document_name: str
23
+
24
+
25
+ class MedQAResponse(BaseModel):
26
+ response: Union[str, MedQAMCQResponse]
27
+ citations: list[MedQACitation]
medrag_multi_modal/cli.py CHANGED
@@ -1,16 +1,67 @@
1
  import argparse
 
2
  import subprocess
3
  import sys
4
 
5
 
6
  def main():
7
  parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
8
- parser.add_argument("command", choices=["run"], help="Command to execute")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  args = parser.parse_args()
10
 
11
  if args.command == "run":
12
- # Assuming your Streamlit app is in app.py
13
- subprocess.run([sys.executable, "-m", "streamlit", "run", "app.py"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  if __name__ == "__main__":
 
1
  import argparse
2
+ import os
3
  import subprocess
4
  import sys
5
 
6
 
7
  def main():
8
  parser = argparse.ArgumentParser(description="MedRAG Multi-Modal CLI")
9
+ subparsers = parser.add_subparsers(dest="command", required=True)
10
+
11
+ # Run subcommand
12
+ run_parser = subparsers.add_parser("run", help="Run the Streamlit application")
13
+ run_parser.add_argument(
14
+ "--port", type=int, default=8501, help="Port to run Streamlit on"
15
+ )
16
+
17
+ # Evaluate subcommand
18
+ eval_parser = subparsers.add_parser("evaluate", help="Run evaluation tests")
19
+ eval_parser.add_argument(
20
+ "--test-file",
21
+ default=os.path.join("tests", "evals", "test_assistant_mmlu_anatomy.py"),
22
+ help="Path to test file",
23
+ )
24
+ eval_parser.add_argument(
25
+ "--test-case",
26
+ type=str,
27
+ help="Only run tests which match the given substring expression",
28
+ )
29
+ eval_parser.add_argument(
30
+ "--model-name",
31
+ type=str,
32
+ default="gemini-1.5-flash",
33
+ help="Model name to use for evaluation",
34
+ )
35
+
36
  args = parser.parse_args()
37
 
38
  if args.command == "run":
39
+ subprocess.run(
40
+ [
41
+ sys.executable,
42
+ "-m",
43
+ "streamlit",
44
+ "run",
45
+ "app.py",
46
+ "--server.port",
47
+ str(args.port),
48
+ ]
49
+ )
50
+
51
+ elif args.command == "evaluate":
52
+ test_file = (
53
+ args.test_file + "::" + args.test_case if args.test_case else args.test_file
54
+ )
55
+ cmd = [
56
+ sys.executable,
57
+ "-m",
58
+ "pytest",
59
+ "-s",
60
+ test_file,
61
+ "-v",
62
+ f"--model-name={args.model_name}",
63
+ ]
64
+ subprocess.run(cmd)
65
 
66
 
67
  if __name__ == "__main__":
medrag_multi_modal/document_loader/image_loader/base_img_loader.py CHANGED
@@ -1,11 +1,21 @@
1
  import asyncio
2
  import os
3
  from abc import abstractmethod
 
4
  from typing import Dict, List, Optional
5
 
 
6
  import jsonlines
7
  import rich
8
- import wandb
 
 
 
 
 
 
 
 
9
 
10
  from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
11
  BaseTextLoader,
@@ -36,14 +46,72 @@ class BaseImageLoader(BaseTextLoader):
36
  """
37
  pass
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  async def load_data(
40
  self,
41
  start_page: Optional[int] = None,
42
  end_page: Optional[int] = None,
43
- wandb_artifact_name: Optional[str] = None,
 
44
  image_save_dir: str = "./images",
45
  exclude_file_extensions: list[str] = [],
46
- cleanup: bool = False,
47
  **kwargs,
48
  ) -> List[Dict[str, str]]:
49
  """
@@ -65,21 +133,15 @@ class BaseImageLoader(BaseTextLoader):
65
  Args:
66
  start_page (Optional[int]): The starting page index (0-based) to process.
67
  end_page (Optional[int]): The ending page index (0-based) to process.
68
- wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
 
69
  image_save_dir (str): The directory to save the extracted images.
70
  exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
71
- cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
72
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
73
 
74
  Returns:
75
- List[Dict[str, Any]]: A list of dictionaries, each containing the image and metadata for a processed page.
76
- Each dictionary will have the following keys and values:
77
-
78
- - "page_idx": (int) the index of the page.
79
- - "document_name": (str) the name of the document.
80
- - "file_path": (str) the local file path where the PDF is stored.
81
- - "file_url": (str) the URL of the PDF file.
82
- - "image_file_path" or "image_file_paths": (str) the local file path where the image/images are stored.
83
  Raises:
84
  ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
85
  """
@@ -111,19 +173,8 @@ class BaseImageLoader(BaseTextLoader):
111
  if file.endswith(tuple(exclude_file_extensions)):
112
  os.remove(os.path.join(image_save_dir, file))
113
 
114
- if wandb_artifact_name:
115
- artifact = wandb.Artifact(
116
- name=wandb_artifact_name,
117
- type="dataset",
118
- metadata={"loader_name": self.__class__.__name__},
119
- )
120
- artifact.add_dir(local_path=image_save_dir)
121
- artifact.save()
122
- rich.print("Artifact saved and uploaded to wandb!")
123
-
124
- if cleanup:
125
- for file in os.listdir(image_save_dir):
126
- file_path = os.path.join(image_save_dir, file)
127
- if os.path.isfile(file_path):
128
- os.remove(file_path)
129
- return pages
 
1
  import asyncio
2
  import os
3
  from abc import abstractmethod
4
+ from glob import glob
5
  from typing import Dict, List, Optional
6
 
7
+ import huggingface_hub
8
  import jsonlines
9
  import rich
10
+ from datasets import (
11
+ Dataset,
12
+ Features,
13
+ Image,
14
+ Sequence,
15
+ Value,
16
+ concatenate_datasets,
17
+ load_dataset,
18
+ )
19
 
20
  from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
21
  BaseTextLoader,
 
46
  """
47
  pass
48
 
49
+ def save_as_dataset(
50
+ self,
51
+ start_page: int,
52
+ end_page: int,
53
+ image_save_dir: str,
54
+ dataset_repo_id: Optional[str] = None,
55
+ overwrite_dataset: bool = False,
56
+ ):
57
+ features = Features(
58
+ {
59
+ "page_image": Image(decode=True),
60
+ "page_figure_images": Sequence(Image(decode=True)),
61
+ "document_name": Value(dtype="string"),
62
+ "page_idx": Value(dtype="int32"),
63
+ }
64
+ )
65
+
66
+ all_examples = []
67
+ for page_idx in range(start_page, end_page):
68
+ page_image_file_paths = glob(
69
+ os.path.join(image_save_dir, f"page{page_idx}*.png")
70
+ )
71
+ if len(page_image_file_paths) > 0:
72
+ page_image_path = page_image_file_paths[0]
73
+ figure_image_paths = [
74
+ image_file_path
75
+ for image_file_path in glob(
76
+ os.path.join(image_save_dir, f"page{page_idx}*_fig*.png")
77
+ )
78
+ ]
79
+
80
+ example = {
81
+ "page_image": page_image_path,
82
+ "page_figure_images": figure_image_paths,
83
+ "document_name": self.document_name,
84
+ "page_idx": page_idx,
85
+ }
86
+ all_examples.append(example)
87
+
88
+ dataset = Dataset.from_list(all_examples, features=features)
89
+
90
+ if dataset_repo_id:
91
+ if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
92
+ if not overwrite_dataset:
93
+ dataset = concatenate_datasets(
94
+ [dataset, load_dataset(dataset_repo_id)["corpus"]]
95
+ )
96
+
97
+ dataset.push_to_hub(dataset_repo_id, split="corpus")
98
+
99
+ return dataset
100
+
101
+ def cleanup_image_dir(self, image_save_dir: str = "./images"):
102
+ for file in os.listdir(image_save_dir):
103
+ file_path = os.path.join(image_save_dir, file)
104
+ if os.path.isfile(file_path):
105
+ os.remove(file_path)
106
+
107
  async def load_data(
108
  self,
109
  start_page: Optional[int] = None,
110
  end_page: Optional[int] = None,
111
+ dataset_repo_id: Optional[str] = None,
112
+ overwrite_dataset: bool = False,
113
  image_save_dir: str = "./images",
114
  exclude_file_extensions: list[str] = [],
 
115
  **kwargs,
116
  ) -> List[Dict[str, str]]:
117
  """
 
133
  Args:
134
  start_page (Optional[int]): The starting page index (0-based) to process.
135
  end_page (Optional[int]): The ending page index (0-based) to process.
136
+ dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
137
+ overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
138
  image_save_dir (str): The directory to save the extracted images.
139
  exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
 
140
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
141
 
142
  Returns:
143
+ Dataset: A HuggingFace dataset containing the processed pages.
144
+
 
 
 
 
 
 
145
  Raises:
146
  ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
147
  """
 
173
  if file.endswith(tuple(exclude_file_extensions)):
174
  os.remove(os.path.join(image_save_dir, file))
175
 
176
+ dataset = self.save_as_dataset(
177
+ start_page, end_page, image_save_dir, dataset_repo_id, overwrite_dataset
178
+ )
179
+
180
+ return dataset
 
 
 
 
 
 
 
 
 
 
 
medrag_multi_modal/document_loader/image_loader/fitzpil_img_loader.py CHANGED
@@ -3,9 +3,12 @@ import os
3
  from typing import Any, Dict
4
 
5
  import fitz
 
6
  from PIL import Image, ImageOps, UnidentifiedImageError
7
 
8
- from .base_img_loader import BaseImageLoader
 
 
9
 
10
 
11
  class FitzPILImageLoader(BaseImageLoader):
@@ -20,27 +23,16 @@ class FitzPILImageLoader(BaseImageLoader):
20
  ```python
21
  import asyncio
22
 
23
- import weave
24
-
25
- import wandb
26
  from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
27
 
28
- weave.init(project_name="ml-colabs/medrag-multi-modal")
29
- wandb.init(project="medrag-multi-modal", entity="ml-colabs")
30
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
31
  loader = FitzPILImageLoader(
32
- url=url,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
- asyncio.run(
37
- loader.load_data(
38
- start_page=32,
39
- end_page=37,
40
- wandb_artifact_name="grays-anatomy-images-fitzpil",
41
- cleanup=False,
42
- )
43
- )
44
  ```
45
 
46
  Args:
@@ -118,6 +110,14 @@ class FitzPILImageLoader(BaseImageLoader):
118
 
119
  pdf_document.close()
120
 
 
 
 
 
 
 
 
 
121
  return {
122
  "page_idx": page_idx,
123
  "document_name": self.document_name,
 
3
  from typing import Any, Dict
4
 
5
  import fitz
6
+ from pdf2image.pdf2image import convert_from_path
7
  from PIL import Image, ImageOps, UnidentifiedImageError
8
 
9
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
10
+ BaseImageLoader,
11
+ )
12
 
13
 
14
  class FitzPILImageLoader(BaseImageLoader):
 
23
  ```python
24
  import asyncio
25
 
 
 
 
26
  from medrag_multi_modal.document_loader.image_loader import FitzPILImageLoader
27
 
28
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
29
+
 
30
  loader = FitzPILImageLoader(
31
+ url=URL,
32
  document_name="Gray's Anatomy",
33
  document_file_path="grays_anatomy.pdf",
34
  )
35
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
 
 
 
 
 
 
 
36
  ```
37
 
38
  Args:
 
110
 
111
  pdf_document.close()
112
 
113
+ page_image = convert_from_path(
114
+ self.document_file_path,
115
+ first_page=page_idx + 1,
116
+ last_page=page_idx + 1,
117
+ **kwargs,
118
+ )[0]
119
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
120
+
121
  return {
122
  "page_idx": page_idx,
123
  "document_name": self.document_name,
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py CHANGED
@@ -5,7 +5,9 @@ from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
6
  from pdf2image.pdf2image import convert_from_path
7
 
8
- from .base_img_loader import BaseImageLoader
 
 
9
 
10
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11
 
@@ -22,27 +24,16 @@ class MarkerImageLoader(BaseImageLoader):
22
  ```python
23
  import asyncio
24
 
25
- import weave
26
-
27
- import wandb
28
  from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
29
 
30
- weave.init(project_name="ml-colabs/medrag-multi-modal")
31
- wandb.init(project="medrag-multi-modal", entity="ml-colabs")
32
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
33
  loader = MarkerImageLoader(
34
- url=url,
35
  document_name="Gray's Anatomy",
36
  document_file_path="grays_anatomy.pdf",
37
  )
38
- asyncio.run(
39
- loader.load_data(
40
- start_page=31,
41
- end_page=36,
42
- wandb_artifact_name="grays-anatomy-images-marker",
43
- cleanup=False,
44
- )
45
- )
46
  ```
47
 
48
  Args:
@@ -84,7 +75,7 @@ class MarkerImageLoader(BaseImageLoader):
84
  - "file_url": (str) the URL of the PDF file.
85
  - "image_file_path": (str) the local file path where the image is stored.
86
  """
87
- _, images, out_meta = convert_single_pdf(
88
  self.document_file_path,
89
  self.model_lst,
90
  max_pages=1,
@@ -101,14 +92,13 @@ class MarkerImageLoader(BaseImageLoader):
101
  image.save(image_file_path, "png")
102
  image_file_paths.append(image_file_path)
103
 
104
- if self.save_page_image:
105
- page_image = convert_from_path(
106
- self.document_file_path,
107
- first_page=page_idx + 1,
108
- last_page=page_idx + 1,
109
- **kwargs,
110
- )[0]
111
- page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
112
 
113
  return {
114
  "page_idx": page_idx,
@@ -116,7 +106,6 @@ class MarkerImageLoader(BaseImageLoader):
116
  "file_path": self.document_file_path,
117
  "file_url": self.url,
118
  "image_file_paths": os.path.join(image_save_dir, "*.png"),
119
- "meta": out_meta,
120
  }
121
 
122
  def load_data(
 
5
  from marker.models import load_all_models
6
  from pdf2image.pdf2image import convert_from_path
7
 
8
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
9
+ BaseImageLoader,
10
+ )
11
 
12
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
13
 
 
24
  ```python
25
  import asyncio
26
 
 
 
 
27
  from medrag_multi_modal.document_loader.image_loader import MarkerImageLoader
28
 
29
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
+
 
31
  loader = MarkerImageLoader(
32
+ url=URL,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
 
 
 
 
 
 
 
37
  ```
38
 
39
  Args:
 
75
  - "file_url": (str) the URL of the PDF file.
76
  - "image_file_path": (str) the local file path where the image is stored.
77
  """
78
+ _, images, _ = convert_single_pdf(
79
  self.document_file_path,
80
  self.model_lst,
81
  max_pages=1,
 
92
  image.save(image_file_path, "png")
93
  image_file_paths.append(image_file_path)
94
 
95
+ page_image = convert_from_path(
96
+ self.document_file_path,
97
+ first_page=page_idx + 1,
98
+ last_page=page_idx + 1,
99
+ **kwargs,
100
+ )[0]
101
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
 
102
 
103
  return {
104
  "page_idx": page_idx,
 
106
  "file_path": self.document_file_path,
107
  "file_url": self.url,
108
  "image_file_paths": os.path.join(image_save_dir, "*.png"),
 
109
  }
110
 
111
  def load_data(
medrag_multi_modal/document_loader/image_loader/pdf2image_img_loader.py CHANGED
@@ -3,7 +3,9 @@ from typing import Any, Dict
3
 
4
  from pdf2image.pdf2image import convert_from_path
5
 
6
- from .base_img_loader import BaseImageLoader
 
 
7
 
8
 
9
  class PDF2ImageLoader(BaseImageLoader):
@@ -19,27 +21,16 @@ class PDF2ImageLoader(BaseImageLoader):
19
  ```python
20
  import asyncio
21
 
22
- import weave
23
-
24
- import wandb
25
  from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
26
 
27
- weave.init(project_name="ml-colabs/medrag-multi-modal")
28
- wandb.init(project="medrag-multi-modal", entity="ml-colabs")
29
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
  loader = PDF2ImageLoader(
31
- url=url,
32
  document_name="Gray's Anatomy",
33
  document_file_path="grays_anatomy.pdf",
34
  )
35
- asyncio.run(
36
- loader.load_data(
37
- start_page=31,
38
- end_page=36,
39
- wandb_artifact_name="grays-anatomy-images-pdf2image",
40
- cleanup=False,
41
- )
42
- )
43
  ```
44
 
45
  Args:
 
3
 
4
  from pdf2image.pdf2image import convert_from_path
5
 
6
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
7
+ BaseImageLoader,
8
+ )
9
 
10
 
11
  class PDF2ImageLoader(BaseImageLoader):
 
21
  ```python
22
  import asyncio
23
 
 
 
 
24
  from medrag_multi_modal.document_loader.image_loader import PDF2ImageLoader
25
 
26
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
27
+
 
28
  loader = PDF2ImageLoader(
29
+ url=URL,
30
  document_name="Gray's Anatomy",
31
  document_file_path="grays_anatomy.pdf",
32
  )
33
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
 
 
 
 
 
 
 
34
  ```
35
 
36
  Args:
medrag_multi_modal/document_loader/image_loader/pdfplumber_img_loader.py CHANGED
@@ -2,8 +2,11 @@ import os
2
  from typing import Any, Dict
3
 
4
  import pdfplumber
 
5
 
6
- from .base_img_loader import BaseImageLoader
 
 
7
 
8
 
9
  class PDFPlumberImageLoader(BaseImageLoader):
@@ -18,27 +21,16 @@ class PDFPlumberImageLoader(BaseImageLoader):
18
  ```python
19
  import asyncio
20
 
21
- import weave
22
-
23
- import wandb
24
  from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
25
 
26
- weave.init(project_name="ml-colabs/medrag-multi-modal")
27
- wandb.init(project="medrag-multi-modal", entity="ml-colabs")
28
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
29
  loader = PDFPlumberImageLoader(
30
- url=url,
31
  document_name="Gray's Anatomy",
32
  document_file_path="grays_anatomy.pdf",
33
  )
34
- asyncio.run(
35
- loader.load_data(
36
- start_page=32,
37
- end_page=37,
38
- wandb_artifact_name="grays-anatomy-images-pdfplumber",
39
- cleanup=False,
40
- )
41
- )
42
  ```
43
 
44
  Args:
@@ -92,6 +84,14 @@ class PDFPlumberImageLoader(BaseImageLoader):
92
  extracted_image.save(image_file_path, "png")
93
  image_file_paths.append(image_file_path)
94
 
 
 
 
 
 
 
 
 
95
  return {
96
  "page_idx": page_idx,
97
  "document_name": self.document_name,
 
2
  from typing import Any, Dict
3
 
4
  import pdfplumber
5
+ from pdf2image.pdf2image import convert_from_path
6
 
7
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
8
+ BaseImageLoader,
9
+ )
10
 
11
 
12
  class PDFPlumberImageLoader(BaseImageLoader):
 
21
  ```python
22
  import asyncio
23
 
 
 
 
24
  from medrag_multi_modal.document_loader.image_loader import PDFPlumberImageLoader
25
 
26
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
27
+
 
28
  loader = PDFPlumberImageLoader(
29
+ url=URL,
30
  document_name="Gray's Anatomy",
31
  document_file_path="grays_anatomy.pdf",
32
  )
33
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
 
 
 
 
 
 
 
34
  ```
35
 
36
  Args:
 
84
  extracted_image.save(image_file_path, "png")
85
  image_file_paths.append(image_file_path)
86
 
87
+ page_image = convert_from_path(
88
+ self.document_file_path,
89
+ first_page=page_idx + 1,
90
+ last_page=page_idx + 1,
91
+ **kwargs,
92
+ )[0]
93
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
94
+
95
  return {
96
  "page_idx": page_idx,
97
  "document_name": self.document_name,
medrag_multi_modal/document_loader/image_loader/pymupdf_img_loader.py CHANGED
@@ -3,9 +3,12 @@ import os
3
  from typing import Any, Dict
4
 
5
  import fitz
 
6
  from PIL import Image
7
 
8
- from .base_img_loader import BaseImageLoader
 
 
9
 
10
 
11
  class PyMuPDFImageLoader(BaseImageLoader):
@@ -20,27 +23,16 @@ class PyMuPDFImageLoader(BaseImageLoader):
20
  ```python
21
  import asyncio
22
 
23
- import weave
24
-
25
- import wandb
26
  from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
27
 
28
- weave.init(project_name="ml-colabs/medrag-multi-modal")
29
- wandb.init(project="medrag-multi-modal", entity="ml-colabs")
30
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
31
  loader = PyMuPDFImageLoader(
32
- url=url,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
- asyncio.run(
37
- loader.load_data(
38
- start_page=32,
39
- end_page=37,
40
- wandb_artifact_name="grays-anatomy-images-pymupdf",
41
- cleanup=False,
42
- )
43
- )
44
  ```
45
 
46
  Args:
@@ -115,6 +107,14 @@ class PyMuPDFImageLoader(BaseImageLoader):
115
 
116
  pdf_document.close()
117
 
 
 
 
 
 
 
 
 
118
  return {
119
  "page_idx": page_idx,
120
  "document_name": self.document_name,
 
3
  from typing import Any, Dict
4
 
5
  import fitz
6
+ from pdf2image.pdf2image import convert_from_path
7
  from PIL import Image
8
 
9
+ from medrag_multi_modal.document_loader.image_loader.base_img_loader import (
10
+ BaseImageLoader,
11
+ )
12
 
13
 
14
  class PyMuPDFImageLoader(BaseImageLoader):
 
23
  ```python
24
  import asyncio
25
 
 
 
 
26
  from medrag_multi_modal.document_loader.image_loader import PyMuPDFImageLoader
27
 
28
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
29
+
 
30
  loader = PyMuPDFImageLoader(
31
+ url=URL,
32
  document_name="Gray's Anatomy",
33
  document_file_path="grays_anatomy.pdf",
34
  )
35
+ dataset = asyncio.run(loader.load_data(start_page=32, end_page=37))
 
 
 
 
 
 
 
36
  ```
37
 
38
  Args:
 
107
 
108
  pdf_document.close()
109
 
110
+ page_image = convert_from_path(
111
+ self.document_file_path,
112
+ first_page=page_idx + 1,
113
+ last_page=page_idx + 1,
114
+ **kwargs,
115
+ )[0]
116
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
117
+
118
  return {
119
  "page_idx": page_idx,
120
  "document_name": self.document_name,
medrag_multi_modal/document_loader/text_loader/base_text_loader.py CHANGED
@@ -1,12 +1,13 @@
1
  import asyncio
2
  import os
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, List, Optional
5
 
 
6
  import PyPDF2
7
- import rich
8
- import weave
9
  from firerequests import FireRequests
 
10
 
11
 
12
  class BaseTextLoader(ABC):
@@ -22,14 +23,22 @@ class BaseTextLoader(ABC):
22
  url (str): The URL of the PDF file to download if not present locally.
23
  document_name (str): The name of the document for metadata purposes.
24
  document_file_path (str): The local file path where the PDF is stored or will be downloaded.
 
25
  """
26
 
27
- def __init__(self, url: str, document_name: str, document_file_path: str):
 
 
 
 
 
 
28
  self.url = url
29
  self.document_name = document_name
30
  self.document_file_path = document_file_path
 
31
  if not os.path.exists(self.document_file_path):
32
- FireRequests().download(url, filename=self.document_file_path)
33
  with open(self.document_file_path, "rb") as file:
34
  pdf_reader = PyPDF2.PdfReader(file)
35
  self.page_count = len(pdf_reader.pages)
@@ -85,9 +94,11 @@ class BaseTextLoader(ABC):
85
  self,
86
  start_page: Optional[int] = None,
87
  end_page: Optional[int] = None,
88
- weave_dataset_name: Optional[str] = None,
 
 
89
  **kwargs,
90
- ) -> List[Dict[str, str]]:
91
  """
92
  Asynchronously loads text from a PDF file specified by a URL or local file path.
93
  The overrided processing abstract method then processes the text into markdown format,
@@ -102,23 +113,26 @@ class BaseTextLoader(ABC):
102
  each page, extract the text from the PDF, and convert it to markdown.
103
  It processes pages concurrently using `asyncio` for efficiency.
104
 
105
- If a weave_dataset_name is provided, the processed pages are published to a Weave dataset.
106
 
107
  Args:
108
  start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
109
  end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
110
- weave_dataset_name (Optional[str]): The name of the Weave dataset to publish the pages to, if provided.
 
 
111
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
112
 
113
  Returns:
114
- List[Dict[str, str]]: A list of dictionaries, each containing the text and metadata for a processed page.
115
- Each dictionary will have the following keys and values:
116
 
117
  - "text": (str) the processed page data in markdown format.
118
  - "page_idx": (int) the index of the page.
119
  - "document_name": (str) the name of the document.
120
  - "file_path": (str) the local file path where the PDF is stored.
121
  - "file_url": (str) the URL of the PDF file.
 
122
 
123
  Raises:
124
  ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
@@ -127,21 +141,45 @@ class BaseTextLoader(ABC):
127
  pages = []
128
  processed_pages_counter: int = 1
129
  total_pages = end_page - start_page
 
130
 
131
  async def process_page(page_idx):
132
  nonlocal processed_pages_counter
133
  page_data = await self.extract_page_data(page_idx, **kwargs)
134
  page_data["loader_name"] = self.__class__.__name__
 
 
 
135
  pages.append(page_data)
136
- rich.print(
137
- f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
 
 
138
  )
139
  processed_pages_counter += 1
140
 
141
- tasks = [process_page(page_idx) for page_idx in range(start_page, end_page)]
142
- for task in asyncio.as_completed(tasks):
143
- await task
144
-
145
- if weave_dataset_name:
146
- weave.publish(weave.Dataset(name=weave_dataset_name, rows=pages))
147
- return pages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import os
3
  from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, Optional
5
 
6
+ import huggingface_hub
7
  import PyPDF2
8
+ from datasets import Dataset, concatenate_datasets, load_dataset
 
9
  from firerequests import FireRequests
10
+ from rich.progress import Progress
11
 
12
 
13
  class BaseTextLoader(ABC):
 
23
  url (str): The URL of the PDF file to download if not present locally.
24
  document_name (str): The name of the document for metadata purposes.
25
  document_file_path (str): The local file path where the PDF is stored or will be downloaded.
26
+ metadata (Optional[dict[str, any]]): Additional metadata to be added to each row of the dataset.
27
  """
28
 
29
+ def __init__(
30
+ self,
31
+ url: str,
32
+ document_name: str,
33
+ document_file_path: str,
34
+ metadata: Optional[dict[str, Any]] = None,
35
+ ):
36
  self.url = url
37
  self.document_name = document_name
38
  self.document_file_path = document_file_path
39
+ self.metadata = metadata or {}
40
  if not os.path.exists(self.document_file_path):
41
+ FireRequests().download(url, filenames=self.document_file_path)
42
  with open(self.document_file_path, "rb") as file:
43
  pdf_reader = PyPDF2.PdfReader(file)
44
  self.page_count = len(pdf_reader.pages)
 
94
  self,
95
  start_page: Optional[int] = None,
96
  end_page: Optional[int] = None,
97
+ exclude_pages: Optional[list[int]] = None,
98
+ dataset_repo_id: Optional[str] = None,
99
+ overwrite_dataset: bool = False,
100
  **kwargs,
101
+ ) -> Dataset:
102
  """
103
  Asynchronously loads text from a PDF file specified by a URL or local file path.
104
  The overrided processing abstract method then processes the text into markdown format,
 
113
  each page, extract the text from the PDF, and convert it to markdown.
114
  It processes pages concurrently using `asyncio` for efficiency.
115
 
116
+ If a `dataset_repo_id` is provided, the processed pages are published to a HuggingFace dataset.
117
 
118
  Args:
119
  start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
120
  end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
121
+ exclude_pages (Optional[list[int]]): The list of page indices to exclude from processing.
122
+ dataset_repo_id (Optional[str]): The repository ID of the HuggingFace dataset to publish the pages to, if provided.
123
+ overwrite_dataset (bool): Whether to overwrite the existing dataset if it exists. Defaults to False.
124
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
125
 
126
  Returns:
127
+ Dataset: A HuggingFace Dataset object containing the text and metadata for processed pages.
128
+ Each entry in the dataset will have the following keys and values:
129
 
130
  - "text": (str) the processed page data in markdown format.
131
  - "page_idx": (int) the index of the page.
132
  - "document_name": (str) the name of the document.
133
  - "file_path": (str) the local file path where the PDF is stored.
134
  - "file_url": (str) the URL of the PDF file.
135
+ - "loader_name": (str) the name of the loader class used to process the page.
136
 
137
  Raises:
138
  ValueError: If the specified start_page or end_page is out of bounds of the document's page count.
 
141
  pages = []
142
  processed_pages_counter: int = 1
143
  total_pages = end_page - start_page
144
+ exclude_pages = exclude_pages or []
145
 
146
  async def process_page(page_idx):
147
  nonlocal processed_pages_counter
148
  page_data = await self.extract_page_data(page_idx, **kwargs)
149
  page_data["loader_name"] = self.__class__.__name__
150
+ for key, value in self.metadata.items():
151
+ if key not in page_data:
152
+ page_data[key] = value
153
  pages.append(page_data)
154
+ progress.update(
155
+ task_id,
156
+ advance=1,
157
+ description=f"Loading page {page_idx} using {self.__class__.__name__}",
158
  )
159
  processed_pages_counter += 1
160
 
161
+ progress = Progress()
162
+ with progress:
163
+ task_id = progress.add_task("Starting...", total=total_pages)
164
+ tasks = [
165
+ process_page(page_idx)
166
+ for page_idx in range(start_page, end_page + 1)
167
+ if page_idx not in exclude_pages
168
+ ]
169
+ for task in asyncio.as_completed(tasks):
170
+ await task
171
+
172
+ pages.sort(key=lambda x: x["page_idx"])
173
+
174
+ dataset = Dataset.from_list(pages)
175
+ if dataset_repo_id:
176
+ if huggingface_hub.repo_exists(dataset_repo_id, repo_type="dataset"):
177
+ print("Dataset already exists")
178
+ if not overwrite_dataset:
179
+ print("Not overwriting dataset")
180
+ dataset = concatenate_datasets(
181
+ [dataset, load_dataset(dataset_repo_id, split="corpus")]
182
+ )
183
+ dataset.push_to_hub(repo_id=dataset_repo_id, split="corpus", private=False)
184
+
185
+ return dataset
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py CHANGED
@@ -4,7 +4,9 @@ from typing import Dict
4
  from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
6
 
7
- from .base_text_loader import BaseTextLoader
 
 
8
 
9
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10
 
@@ -26,24 +28,16 @@ class MarkerTextLoader(BaseTextLoader):
26
  ```python
27
  import asyncio
28
 
29
- import weave
30
 
31
- from medrag_multi_modal.document_loader.text_loader import MarkerTextLoader
32
 
33
- weave.init(project_name="ml-colabs/medrag-multi-modal")
34
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
35
  loader = MarkerTextLoader(
36
- url=url,
37
  document_name="Gray's Anatomy",
38
  document_file_path="grays_anatomy.pdf",
39
  )
40
- asyncio.run(
41
- loader.load_data(
42
- start_page=31,
43
- end_page=36,
44
- weave_dataset_name="grays-anatomy-text",
45
- )
46
- )
47
  ```
48
 
49
  Args:
@@ -76,7 +70,7 @@ class MarkerTextLoader(BaseTextLoader):
76
  """
77
  model_lst = load_all_models()
78
 
79
- text, _, out_meta = convert_single_pdf(
80
  self.document_file_path,
81
  model_lst,
82
  max_pages=1,
@@ -92,5 +86,4 @@ class MarkerTextLoader(BaseTextLoader):
92
  "document_name": self.document_name,
93
  "file_path": self.document_file_path,
94
  "file_url": self.url,
95
- "meta": out_meta,
96
  }
 
4
  from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
6
 
7
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
8
+ BaseTextLoader,
9
+ )
10
 
11
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
12
 
 
28
  ```python
29
  import asyncio
30
 
31
+ from medrag_multi_modal.document_loader import MarkerTextLoader
32
 
33
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
34
 
 
 
35
  loader = MarkerTextLoader(
36
+ url=URL,
37
  document_name="Gray's Anatomy",
38
  document_file_path="grays_anatomy.pdf",
39
  )
40
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
 
 
 
 
 
 
41
  ```
42
 
43
  Args:
 
70
  """
71
  model_lst = load_all_models()
72
 
73
+ text, _, _ = convert_single_pdf(
74
  self.document_file_path,
75
  model_lst,
76
  max_pages=1,
 
86
  "document_name": self.document_name,
87
  "file_path": self.document_file_path,
88
  "file_url": self.url,
 
89
  }
medrag_multi_modal/document_loader/text_loader/pdfplumber_text_loader.py CHANGED
@@ -2,7 +2,9 @@ from typing import Dict
2
 
3
  import pdfplumber
4
 
5
- from .base_text_loader import BaseTextLoader
 
 
6
 
7
 
8
  class PDFPlumberTextLoader(BaseTextLoader):
@@ -22,24 +24,16 @@ class PDFPlumberTextLoader(BaseTextLoader):
22
  ```python
23
  import asyncio
24
 
25
- import weave
26
 
27
- from medrag_multi_modal.document_loader.text_loader import PDFPlumberTextLoader
28
 
29
- weave.init(project_name="ml-colabs/medrag-multi-modal")
30
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
31
  loader = PDFPlumberTextLoader(
32
- url=url,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
- asyncio.run(
37
- loader.load_data(
38
- start_page=31,
39
- end_page=36,
40
- weave_dataset_name="grays-anatomy-text",
41
- )
42
- )
43
  ```
44
 
45
  Args:
 
2
 
3
  import pdfplumber
4
 
5
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
6
+ BaseTextLoader,
7
+ )
8
 
9
 
10
  class PDFPlumberTextLoader(BaseTextLoader):
 
24
  ```python
25
  import asyncio
26
 
27
+ from medrag_multi_modal.document_loader import PDFPlumberTextLoader
28
 
29
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
 
 
 
31
  loader = PDFPlumberTextLoader(
32
+ url=URL,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
 
 
 
 
 
 
37
  ```
38
 
39
  Args:
medrag_multi_modal/document_loader/text_loader/pymupdf4llm_text_loader.py CHANGED
@@ -2,7 +2,9 @@ from typing import Dict
2
 
3
  import pymupdf4llm
4
 
5
- from .base_text_loader import BaseTextLoader
 
 
6
 
7
 
8
  class PyMuPDF4LLMTextLoader(BaseTextLoader):
@@ -20,26 +22,16 @@ class PyMuPDF4LLMTextLoader(BaseTextLoader):
20
  ```python
21
  import asyncio
22
 
23
- import weave
24
 
25
- from medrag_multi_modal.document_loader.text_loader import (
26
- PyMuPDF4LLMTextLoader
27
- )
28
 
29
- weave.init(project_name="ml-colabs/medrag-multi-modal")
30
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
31
  loader = PyMuPDF4LLMTextLoader(
32
- url=url,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
- asyncio.run(
37
- loader.load_data(
38
- start_page=31,
39
- end_page=36,
40
- weave_dataset_name="grays-anatomy-text",
41
- )
42
- )
43
  ```
44
 
45
  Args:
 
2
 
3
  import pymupdf4llm
4
 
5
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
6
+ BaseTextLoader,
7
+ )
8
 
9
 
10
  class PyMuPDF4LLMTextLoader(BaseTextLoader):
 
22
  ```python
23
  import asyncio
24
 
25
+ from medrag_multi_modal.document_loader import PyMuPDF4LLMTextLoader
26
 
27
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
 
 
28
 
 
 
29
  loader = PyMuPDF4LLMTextLoader(
30
+ url=URL,
31
  document_name="Gray's Anatomy",
32
  document_file_path="grays_anatomy.pdf",
33
  )
34
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
 
 
 
 
 
 
35
  ```
36
 
37
  Args:
medrag_multi_modal/document_loader/text_loader/pypdf2_text_loader.py CHANGED
@@ -2,7 +2,9 @@ from typing import Dict
2
 
3
  import PyPDF2
4
 
5
- from .base_text_loader import BaseTextLoader
 
 
6
 
7
 
8
  class PyPDF2TextLoader(BaseTextLoader):
@@ -22,24 +24,16 @@ class PyPDF2TextLoader(BaseTextLoader):
22
  ```python
23
  import asyncio
24
 
25
- import weave
26
 
27
- from medrag_multi_modal.document_loader.text_loader import PyPDF2TextLoader
28
 
29
- weave.init(project_name="ml-colabs/medrag-multi-modal")
30
- url = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
31
  loader = PyPDF2TextLoader(
32
- url=url,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
- asyncio.run(
37
- loader.load_data(
38
- start_page=31,
39
- end_page=36,
40
- weave_dataset_name="grays-anatomy-text",
41
- )
42
- )
43
  ```
44
 
45
  Args:
 
2
 
3
  import PyPDF2
4
 
5
+ from medrag_multi_modal.document_loader.text_loader.base_text_loader import (
6
+ BaseTextLoader,
7
+ )
8
 
9
 
10
  class PyPDF2TextLoader(BaseTextLoader):
 
24
  ```python
25
  import asyncio
26
 
27
+ from medrag_multi_modal.document_loader import PyPDF2TextLoader
28
 
29
+ URL = "https://archive.org/download/GraysAnatomy41E2015PDF/Grays%20Anatomy-41%20E%20%282015%29%20%5BPDF%5D.pdf"
30
 
 
 
31
  loader = PyPDF2TextLoader(
32
+ url=URL,
33
  document_name="Gray's Anatomy",
34
  document_file_path="grays_anatomy.pdf",
35
  )
36
+ dataset = asyncio.run(loader.load_data(start_page=31, end_page=36))
 
 
 
 
 
 
37
  ```
38
 
39
  Args:
medrag_multi_modal/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .mmlu import MMLUOptionAccuracy
2
+
3
+ __all__ = ["MMLUOptionAccuracy"]
medrag_multi_modal/metrics/base.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import weave
5
+
6
+
7
+ class BaseAccuracyMetric(weave.Scorer):
8
+ """
9
+ BaseAccuracyMetric is a class that extends the
10
+ [`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers)
11
+ to provide a comprehensive evaluation of accuracy metrics for a given set of score rows.
12
+
13
+ This class is designed to process a list of score rows, each containing a
14
+ 'correct' key that indicates whether a particular prediction was correct.
15
+ The `summarize` method calculates various statistical measures and metrics
16
+ based on this data, including:
17
+
18
+ - True and false counts: The number of true and false predictions.
19
+ - True and false fractions: The proportion of true and false predictions.
20
+ - Standard error: The standard error of the mean for the true predictions.
21
+ - Precision: The ratio of true positive predictions to the total number of
22
+ positive predictions.
23
+ - Recall: The ratio of true positive predictions to the total number of
24
+ actual positives.
25
+ - F1 Score: The harmonic mean of precision and recall, providing a balance
26
+ between the two metrics.
27
+
28
+ The `summarize` method returns a dictionary containing these metrics,
29
+ allowing for a detailed analysis of the model's performance.
30
+
31
+ Methods:
32
+ summarize(score_rows: list) -> Optional[dict]:
33
+ Processes the input score rows to compute and return a dictionary
34
+ of accuracy metrics.
35
+ """
36
+ @weave.op()
37
+ def summarize(self, score_rows: list) -> Optional[dict]:
38
+ """
39
+ Summarizes the accuracy metrics from a list of score rows.
40
+
41
+ This method processes a list of score rows, each containing a 'correct' key
42
+ that indicates whether a particular prediction was correct. It calculates
43
+ various statistical measures and metrics based on this data, including:
44
+
45
+ - True and false counts: The number of true and false predictions.
46
+ - True and false fractions: The proportion of true and false predictions.
47
+ - Standard error: The standard error of the mean for the true predictions.
48
+ - Precision: The ratio of true positive predictions to the total number of
49
+ positive predictions.
50
+ - Recall: The ratio of true positive predictions to the total number of
51
+ actual positives.
52
+ - F1 Score: The harmonic mean of precision and recall, providing a balance
53
+ between the two metrics.
54
+
55
+ The method returns a dictionary containing these metrics, allowing for a
56
+ detailed analysis of the model's performance.
57
+
58
+ Args:
59
+ score_rows (list): A list of dictionaries, each containing a 'correct'
60
+ key with a boolean value indicating the correctness of a prediction.
61
+
62
+ Returns:
63
+ Optional[dict]: A dictionary containing the calculated accuracy metrics,
64
+ or None if the input list is empty.
65
+ """
66
+ valid_data = [
67
+ x.get("correct") for x in score_rows if x.get("correct") is not None
68
+ ]
69
+ count_true = list(valid_data).count(True)
70
+ int_data = [int(x) for x in valid_data]
71
+
72
+ sample_mean = np.mean(int_data) if int_data else 0
73
+ sample_variance = np.var(int_data) if int_data else 0
74
+ sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
75
+
76
+ # Calculate precision, recall, and F1 score
77
+ true_positives = count_true
78
+ false_positives = len(valid_data) - count_true
79
+ false_negatives = len(score_rows) - len(valid_data)
80
+
81
+ precision = (
82
+ true_positives / (true_positives + false_positives)
83
+ if (true_positives + false_positives) > 0
84
+ else 0
85
+ )
86
+ recall = (
87
+ true_positives / (true_positives + false_negatives)
88
+ if (true_positives + false_negatives) > 0
89
+ else 0
90
+ )
91
+ f1_score = (
92
+ (2 * precision * recall) / (precision + recall)
93
+ if (precision + recall) > 0
94
+ else 0
95
+ )
96
+
97
+ return {
98
+ "correct": {
99
+ "true_count": count_true,
100
+ "false_count": len(score_rows) - count_true,
101
+ "true_fraction": float(sample_mean),
102
+ "false_fraction": 1.0 - float(sample_mean),
103
+ "stderr": float(sample_error),
104
+ "precision": precision,
105
+ "recall": recall,
106
+ "f1_score": f1_score,
107
+ }
108
+ }
medrag_multi_modal/metrics/mmlu.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weave
2
+
3
+ from medrag_multi_modal.assistant.schema import MedQAResponse
4
+ from medrag_multi_modal.metrics.base import BaseAccuracyMetric
5
+
6
+
7
+ class MMLUOptionAccuracy(BaseAccuracyMetric):
8
+ """
9
+ MMLUOptionAccuracy is a metric class that inherits from `BaseAccuracyMetric`.
10
+
11
+ This class is designed to evaluate the accuracy of a multiple-choice question
12
+ response by comparing the provided answer with the correct answer from the
13
+ given options. It uses the MedQAResponse schema to extract the response
14
+ and checks if it matches the correct answer.
15
+
16
+ Methods:
17
+ --------
18
+ score(output: MedQAResponse, options: list[str], answer: str) -> dict:
19
+ Compares the provided answer with the correct answer and returns a
20
+ dictionary indicating whether the answer is correct.
21
+ """
22
+ @weave.op()
23
+ def score(self, output: MedQAResponse, options: list[str], answer: str):
24
+ return {"correct": options[answer] == output.response.answer}
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -1,15 +1,3 @@
1
- from .bm25s_retrieval import BM25sRetriever
2
  from .colpali_retrieval import CalPaliRetriever
3
- from .common import SimilarityMetric
4
- from .contriever_retrieval import ContrieverRetriever
5
- from .medcpt_retrieval import MedCPTRetriever
6
- from .nv_embed_2 import NVEmbed2Retriever
7
 
8
- __all__ = [
9
- "CalPaliRetriever",
10
- "BM25sRetriever",
11
- "ContrieverRetriever",
12
- "SimilarityMetric",
13
- "MedCPTRetriever",
14
- "NVEmbed2Retriever",
15
- ]
 
 
1
  from .colpali_retrieval import CalPaliRetriever
 
 
 
 
2
 
3
+ __all__ = ["CalPaliRetriever"]
 
 
 
 
 
 
 
medrag_multi_modal/retrieval/colpali_retrieval.py CHANGED
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
9
  import wandb
10
  from PIL import Image
11
 
12
- from ..utils import get_wandb_artifact
13
 
14
 
15
  class CalPaliRetriever(weave.Model):
 
9
  import wandb
10
  from PIL import Image
11
 
12
+ from medrag_multi_modal.utils import get_wandb_artifact
13
 
14
 
15
  class CalPaliRetriever(weave.Model):
medrag_multi_modal/retrieval/common.py CHANGED
@@ -1,10 +1,5 @@
1
  from enum import Enum
2
 
3
- import safetensors
4
- import safetensors.torch
5
- import torch
6
- import wandb
7
-
8
 
9
  class SimilarityMetric(Enum):
10
  COSINE = "cosine"
@@ -24,21 +19,3 @@ def argsort_scores(scores: list[float], descending: bool = False):
24
  list(enumerate(scores)), key=lambda x: x[1], reverse=descending
25
  )
26
  ]
27
-
28
-
29
- def save_vector_index(
30
- vector_index: torch.Tensor,
31
- type: str,
32
- index_name: str,
33
- metadata: dict,
34
- filename: str = "vector_index.safetensors",
35
- ):
36
- safetensors.torch.save_file({"vector_index": vector_index.cpu()}, filename)
37
- if wandb.run:
38
- artifact = wandb.Artifact(
39
- name=index_name,
40
- type=type,
41
- metadata=metadata,
42
- )
43
- artifact.add_file(filename)
44
- artifact.save()
 
1
  from enum import Enum
2
 
 
 
 
 
 
3
 
4
  class SimilarityMetric(Enum):
5
  COSINE = "cosine"
 
19
  list(enumerate(scores)), key=lambda x: x[1], reverse=descending
20
  )
21
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
medrag_multi_modal/retrieval/text_retrieval/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bm25s_retrieval import BM25sRetriever
2
+ from .contriever_retrieval import ContrieverRetriever
3
+ from .medcpt_retrieval import MedCPTRetriever
4
+ from .nv_embed_2 import NVEmbed2Retriever
5
+
6
+ __all__ = [
7
+ "BM25sRetriever",
8
+ "ContrieverRetriever",
9
+ "MedCPTRetriever",
10
+ "NVEmbed2Retriever",
11
+ ]
medrag_multi_modal/retrieval/{bm25s_retrieval.py → text_retrieval/bm25s_retrieval.py} RENAMED
@@ -1,12 +1,17 @@
 
1
  import os
2
- from glob import glob
3
- from typing import Optional
4
 
5
  import bm25s
6
- import wandb
7
  import weave
 
 
8
  from Stemmer import Stemmer
9
 
 
 
10
  LANGUAGE_DICT = {
11
  "english": "en",
12
  "french": "fr",
@@ -26,49 +31,60 @@ class BM25sRetriever(weave.Model):
26
  a new instance is created.
27
  """
28
 
29
- language: str
30
- use_stemmer: bool
31
- _retriever: Optional[bm25s.BM25]
32
 
33
  def __init__(
34
  self,
35
  language: str = "english",
36
  use_stemmer: bool = True,
37
- retriever: Optional[bm25s.BM25] = None,
38
  ):
39
  super().__init__(language=language, use_stemmer=use_stemmer)
40
- self._retriever = retriever or bm25s.BM25()
41
 
42
- def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
 
 
 
 
 
43
  """
44
  Indexes a dataset of text chunks using the BM25 algorithm.
45
 
46
- This function takes a dataset of text chunks identified by `chunk_dataset_name`,
47
- tokenizes the text using the BM25 tokenizer with optional stemming, and indexes
48
- the tokenized text using the BM25 retriever. If an `index_name` is provided, the
49
- index is saved to disk and logged as a Weights & Biases artifact.
50
 
51
  !!! example "Example Usage"
52
  ```python
53
  import weave
54
  from dotenv import load_dotenv
55
 
56
- import wandb
57
- from medrag_multi_modal.retrieval import BM25sRetriever
58
 
59
  load_dotenv()
60
  weave.init(project_name="ml-colabs/medrag-multi-modal")
61
- wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index")
62
  retriever = BM25sRetriever()
63
- retriever.index(chunk_dataset_name="grays-anatomy-text:v13", index_name="grays-anatomy-bm25s")
 
 
 
64
  ```
65
 
66
  Args:
67
- chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
68
- index_name (Optional[str]): The name to save the index under. If provided, the index
69
- is saved to disk and logged as a Weights & Biases artifact.
 
70
  """
71
- chunk_dataset = weave.ref(chunk_dataset_name).get().rows
 
 
 
 
72
  corpus = [row["text"] for row in chunk_dataset]
73
  corpus_tokens = bm25s.tokenize(
74
  corpus,
@@ -76,28 +92,40 @@ class BM25sRetriever(weave.Model):
76
  stemmer=Stemmer(self.language) if self.use_stemmer else None,
77
  )
78
  self._retriever.index(corpus_tokens)
79
- if index_name:
 
 
80
  self._retriever.save(
81
- index_name, corpus=[dict(row) for row in chunk_dataset]
 
 
 
 
 
82
  )
83
- if wandb.run:
84
- artifact = wandb.Artifact(
85
- name=index_name,
86
- type="bm25s-index",
87
- metadata={
88
  "language": self.language,
89
  "use_stemmer": self.use_stemmer,
90
  },
 
 
91
  )
92
- artifact.add_dir(index_name, name=index_name)
93
- artifact.save()
 
 
 
 
 
94
 
95
  @classmethod
96
- def from_wandb_artifact(cls, index_artifact_address: str):
97
  """
98
- Creates an instance of the class from a Weights & Biases artifact.
99
 
100
- This class method retrieves a BM25 index artifact from Weights & Biases,
101
  downloads the artifact, and loads the BM25 retriever with the index and its
102
  associated corpus. The method also extracts metadata from the artifact to
103
  initialize the class instance with the appropriate language and stemming
@@ -108,41 +136,26 @@ class BM25sRetriever(weave.Model):
108
  import weave
109
  from dotenv import load_dotenv
110
 
111
- from medrag_multi_modal.retrieval import BM25sRetriever
112
 
113
  load_dotenv()
114
  weave.init(project_name="ml-colabs/medrag-multi-modal")
115
- retriever = BM25sRetriever.from_wandb_artifact(
116
- index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:latest"
117
- )
118
  ```
119
 
120
  Args:
121
- index_artifact_address (str): The address of the Weights & Biases artifact
122
- containing the BM25 index.
123
 
124
  Returns:
125
  An instance of the class initialized with the BM25 retriever and metadata
126
  from the artifact.
127
  """
128
- if wandb.run:
129
- artifact = wandb.run.use_artifact(
130
- index_artifact_address, type="bm25s-index"
131
- )
132
- artifact_dir = artifact.download()
133
- else:
134
- api = wandb.Api()
135
- artifact = api.artifact(index_artifact_address)
136
- artifact_dir = artifact.download()
137
- retriever = bm25s.BM25.load(
138
- glob(os.path.join(artifact_dir, "*"))[0], load_corpus=True
139
- )
140
- metadata = artifact.metadata
141
- return cls(
142
- language=metadata["language"],
143
- use_stemmer=metadata["use_stemmer"],
144
- retriever=retriever,
145
- )
146
 
147
  @weave.op()
148
  def retrieve(self, query: str, top_k: int = 2):
@@ -155,6 +168,20 @@ class BM25sRetriever(weave.Model):
155
  The results are returned as a list of dictionaries, each containing a chunk and
156
  its corresponding relevance score.
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  Args:
159
  query (str): The input query string to search for relevant chunks.
160
  top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
@@ -192,13 +219,12 @@ class BM25sRetriever(weave.Model):
192
  import weave
193
  from dotenv import load_dotenv
194
 
195
- from medrag_multi_modal.retrieval import BM25sRetriever
196
 
197
  load_dotenv()
198
  weave.init(project_name="ml-colabs/medrag-multi-modal")
199
- retriever = BM25sRetriever.from_wandb_artifact(
200
- index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:latest"
201
- )
202
  retrieved_chunks = retriever.predict(query="What are Ribosomes?")
203
  ```
204
 
 
1
+ import json
2
  import os
3
+ import shutil
4
+ from typing import Optional, Union
5
 
6
  import bm25s
7
+ import huggingface_hub
8
  import weave
9
+ from bm25s import BM25
10
+ from datasets import Dataset, load_dataset
11
  from Stemmer import Stemmer
12
 
13
+ from medrag_multi_modal.utils import fetch_from_huggingface, save_to_huggingface
14
+
15
  LANGUAGE_DICT = {
16
  "english": "en",
17
  "french": "fr",
 
31
  a new instance is created.
32
  """
33
 
34
+ language: Optional[str]
35
+ use_stemmer: bool = True
36
+ _retriever: Optional[BM25]
37
 
38
  def __init__(
39
  self,
40
  language: str = "english",
41
  use_stemmer: bool = True,
42
+ retriever: Optional[BM25] = None,
43
  ):
44
  super().__init__(language=language, use_stemmer=use_stemmer)
45
+ self._retriever = retriever or BM25()
46
 
47
+ def index(
48
+ self,
49
+ chunk_dataset: Union[Dataset, str],
50
+ index_repo_id: Optional[str] = None,
51
+ cleanup: bool = True,
52
+ ):
53
  """
54
  Indexes a dataset of text chunks using the BM25 algorithm.
55
 
56
+ This method retrieves a dataset of text chunks from a specified source, tokenizes
57
+ the text using the BM25 tokenizer with optional stemming, and indexes the tokenized
58
+ text using the BM25 retriever. If an `index_repo_id` is provided, the index is saved
59
+ to disk and optionally logged as a Huggingface artifact.
60
 
61
  !!! example "Example Usage"
62
  ```python
63
  import weave
64
  from dotenv import load_dotenv
65
 
66
+ from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
 
67
 
68
  load_dotenv()
69
  weave.init(project_name="ml-colabs/medrag-multi-modal")
 
70
  retriever = BM25sRetriever()
71
+ retriever.index(
72
+ chunk_dataset="geekyrakshit/grays-anatomy-chunks-test",
73
+ index_repo_id="geekyrakshit/grays-anatomy-index",
74
+ )
75
  ```
76
 
77
  Args:
78
+ chunk_dataset (str): The Huggingface dataset containing the text chunks to be indexed. Either a
79
+ dataset repository name or a dataset object can be provided.
80
+ index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
81
+ cleanup (bool, optional): Whether to delete the local index directory after saving the vector index.
82
  """
83
+ chunk_dataset = (
84
+ load_dataset(chunk_dataset, split="chunks")
85
+ if isinstance(chunk_dataset, str)
86
+ else chunk_dataset
87
+ )
88
  corpus = [row["text"] for row in chunk_dataset]
89
  corpus_tokens = bm25s.tokenize(
90
  corpus,
 
92
  stemmer=Stemmer(self.language) if self.use_stemmer else None,
93
  )
94
  self._retriever.index(corpus_tokens)
95
+ if index_repo_id:
96
+ os.makedirs(".huggingface", exist_ok=True)
97
+ index_save_dir = os.path.join(".huggingface", index_repo_id.split("/")[-1])
98
  self._retriever.save(
99
+ index_save_dir, corpus=[dict(row) for row in chunk_dataset]
100
+ )
101
+ commit_type = (
102
+ "update"
103
+ if huggingface_hub.repo_exists(index_repo_id, repo_type="model")
104
+ else "add"
105
  )
106
+ with open(os.path.join(index_save_dir, "config.json"), "w") as config_file:
107
+ json.dump(
108
+ {
 
 
109
  "language": self.language,
110
  "use_stemmer": self.use_stemmer,
111
  },
112
+ config_file,
113
+ indent=4,
114
  )
115
+ save_to_huggingface(
116
+ index_repo_id,
117
+ index_save_dir,
118
+ commit_message=f"{commit_type}: BM25s index",
119
+ )
120
+ if cleanup:
121
+ shutil.rmtree(index_save_dir)
122
 
123
  @classmethod
124
+ def from_index(cls, index_repo_id: str):
125
  """
126
+ Creates an instance of the class from a Huggingface repository.
127
 
128
+ This class method retrieves a BM25 index artifact from a Huggingface repository,
129
  downloads the artifact, and loads the BM25 retriever with the index and its
130
  associated corpus. The method also extracts metadata from the artifact to
131
  initialize the class instance with the appropriate language and stemming
 
136
  import weave
137
  from dotenv import load_dotenv
138
 
139
+ from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
140
 
141
  load_dotenv()
142
  weave.init(project_name="ml-colabs/medrag-multi-modal")
143
+ retriever = BM25sRetriever()
144
+ retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
 
145
  ```
146
 
147
  Args:
148
+ index_repo_id (Optional[str]): The Huggingface repository of the index artifact to be saved.
 
149
 
150
  Returns:
151
  An instance of the class initialized with the BM25 retriever and metadata
152
  from the artifact.
153
  """
154
+ index_dir = fetch_from_huggingface(index_repo_id, ".huggingface")
155
+ retriever = bm25s.BM25.load(index_dir, load_corpus=True)
156
+ with open(os.path.join(index_dir, "config.json"), "r") as config_file:
157
+ config = json.load(config_file)
158
+ return cls(retriever=retriever, **config)
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  @weave.op()
161
  def retrieve(self, query: str, top_k: int = 2):
 
168
  The results are returned as a list of dictionaries, each containing a chunk and
169
  its corresponding relevance score.
170
 
171
+ !!! example "Example Usage"
172
+ ```python
173
+ import weave
174
+ from dotenv import load_dotenv
175
+
176
+ from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
177
+
178
+ load_dotenv()
179
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
180
+ retriever = BM25sRetriever()
181
+ retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
182
+ retrieved_chunks = retriever.retrieve(query="What are Ribosomes?")
183
+ ```
184
+
185
  Args:
186
  query (str): The input query string to search for relevant chunks.
187
  top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
 
219
  import weave
220
  from dotenv import load_dotenv
221
 
222
+ from medrag_multi_modal.retrieval.text_retrieval import BM25sRetriever
223
 
224
  load_dotenv()
225
  weave.init(project_name="ml-colabs/medrag-multi-modal")
226
+ retriever = BM25sRetriever()
227
+ retriever = BM25sRetriever().from_index(index_repo_id="geekyrakshit/grays-anatomy-index")
 
228
  retrieved_chunks = retriever.predict(query="What are Ribosomes?")
229
  ```
230