Spaces:
Sleeping
Sleeping
Commit
·
49cde8e
1
Parent(s):
01ed12d
add: MedQAAssistant
Browse files
medrag_multi_modal/assistant/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .llm_client import LLMClient
|
|
|
2 |
|
3 |
-
__all__ = ["LLMClient"]
|
|
|
1 |
from .llm_client import LLMClient
|
2 |
+
from .medqa_assistant import MedQAAssistant
|
3 |
|
4 |
+
__all__ = ["LLMClient", "MedQAAssistant"]
|
medrag_multi_modal/assistant/llm_client.py
CHANGED
@@ -29,7 +29,7 @@ class LLMClient(weave.Model):
|
|
29 |
schema: Optional[Any] = None,
|
30 |
) -> Union[str, Any]:
|
31 |
import google.generativeai as genai
|
32 |
-
|
33 |
system_prompt = (
|
34 |
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
35 |
)
|
|
|
29 |
schema: Optional[Any] = None,
|
30 |
) -> Union[str, Any]:
|
31 |
import google.generativeai as genai
|
32 |
+
|
33 |
system_prompt = (
|
34 |
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
35 |
)
|
medrag_multi_modal/assistant/medqa_assistant.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import weave
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from ..retrieval import SimilarityMetric
|
7 |
+
from .llm_client import LLMClient
|
8 |
+
|
9 |
+
|
10 |
+
class MedQAAssistant(weave.Model):
|
11 |
+
llm_client: LLMClient
|
12 |
+
retriever: weave.Model
|
13 |
+
top_k_chunks: int = 2
|
14 |
+
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
15 |
+
|
16 |
+
@weave.op()
|
17 |
+
def predict(self, query: str, image: Optional[Image.Image] = None) -> str:
|
18 |
+
_image = image
|
19 |
+
retrieved_chunks = self.retriever.predict(
|
20 |
+
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
21 |
+
)
|
22 |
+
retrieved_chunks = [chunk["text"] for chunk in retrieved_chunks]
|
23 |
+
system_prompt = """
|
24 |
+
You are a medical expert. You are given a query and a list of chunks from a medical document.
|
25 |
+
"""
|
26 |
+
return self.llm_client.predict(
|
27 |
+
system_prompt=system_prompt, user_prompt=retrieved_chunks
|
28 |
+
)
|