geekyrakshit commited on
Commit
49cde8e
·
1 Parent(s): 01ed12d

add: MedQAAssistant

Browse files
medrag_multi_modal/assistant/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .llm_client import LLMClient
 
2
 
3
- __all__ = ["LLMClient"]
 
1
  from .llm_client import LLMClient
2
+ from .medqa_assistant import MedQAAssistant
3
 
4
+ __all__ = ["LLMClient", "MedQAAssistant"]
medrag_multi_modal/assistant/llm_client.py CHANGED
@@ -29,7 +29,7 @@ class LLMClient(weave.Model):
29
  schema: Optional[Any] = None,
30
  ) -> Union[str, Any]:
31
  import google.generativeai as genai
32
-
33
  system_prompt = (
34
  [system_prompt] if isinstance(system_prompt, str) else system_prompt
35
  )
 
29
  schema: Optional[Any] = None,
30
  ) -> Union[str, Any]:
31
  import google.generativeai as genai
32
+
33
  system_prompt = (
34
  [system_prompt] if isinstance(system_prompt, str) else system_prompt
35
  )
medrag_multi_modal/assistant/medqa_assistant.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import weave
4
+ from PIL import Image
5
+
6
+ from ..retrieval import SimilarityMetric
7
+ from .llm_client import LLMClient
8
+
9
+
10
+ class MedQAAssistant(weave.Model):
11
+ llm_client: LLMClient
12
+ retriever: weave.Model
13
+ top_k_chunks: int = 2
14
+ retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
15
+
16
+ @weave.op()
17
+ def predict(self, query: str, image: Optional[Image.Image] = None) -> str:
18
+ _image = image
19
+ retrieved_chunks = self.retriever.predict(
20
+ query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
21
+ )
22
+ retrieved_chunks = [chunk["text"] for chunk in retrieved_chunks]
23
+ system_prompt = """
24
+ You are a medical expert. You are given a query and a list of chunks from a medical document.
25
+ """
26
+ return self.llm_client.predict(
27
+ system_prompt=system_prompt, user_prompt=retrieved_chunks
28
+ )