File size: 6,246 Bytes
656e743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import logging
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image

# LangChain & LangGraph
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain.tools import tool
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from duckduckgo_search import DDGS
from llama_cpp import Llama

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ------------------------------
# πŸ”Ή Load GGUF Model with llama-cpp-python
# ------------------------------
MODEL_PATH = "Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q4_0.gguf"

if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model file {MODEL_PATH} not found. Upload it to the same directory.")

llm = Llama(
    model_path=MODEL_PATH,
    n_ctx=8192,
    n_gpu_layers=0,  # Set to 0 for CPU inference
    logits_all=True,
    n_batch=512
)

logger.info("Llama GGUF Model Loaded Successfully.")

# ------------------------------
# πŸ”Ή Multi-Specialty Prompt
# ------------------------------
UNIFIED_MEDICAL_PROMPT = """
You are an advanced Medical AI Assistant capable of providing thorough,
comprehensive answers for a wide range of medical specialties:
General Practice, Radiology, Cardiology, Neurology, Psychiatry, Pediatrics,
Endocrinology, Oncology, and more.

You can:
1) Analyze images if provided (Radiology).
2) Search the web for up-to-date medical info (Web Search).
3) Retrieve relevant documents from a knowledge base (Vector Store).
4) Provide scientific, evidence-based explanations and references when possible.

Always strive to provide a detailed, helpful, and empathetic response.
"""

# ------------------------------
# πŸ”Ή FAISS Vector Store for RAG
# ------------------------------
_vector_store_cache = None

def load_vectorstore(pdf_path="medical_docs.pdf"):
    try:
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
        docs = text_splitter.split_documents(documents)
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vector_store = FAISS.from_documents(docs, embeddings)
        logger.info(f"Vector store loaded with {len(docs)} documents.")
        return vector_store
    except Exception as e:
        logger.error(f"Error loading vector store: {str(e)}")
        return None

if os.path.exists("medical_docs.pdf"):
    _vector_store_cache = load_vectorstore("medical_docs.pdf")

vector_store = _vector_store_cache

# ------------------------------
# πŸ”Ή Define AI Tools
# ------------------------------
@tool
def analyze_medical_image(image_path: str):
    """Analyzes a medical image and returns a diagnostic explanation."""
    try:
        image = Image.open(image_path)
    except Exception as e:
        logger.error(f"Error opening image: {str(e)}")
        return "Error processing image."
    output = llm(f"Analyze this medical image for radiological findings:\n{image}", max_tokens=512)
    return output["choices"][0]["text"]

@tool
def retrieve_medical_knowledge(query: str):
    """Retrieves medical knowledge from FAISS vector store."""
    if vector_store is None:
        return "No external medical knowledge available."
    retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
    docs = retriever.get_relevant_documents(query)
    citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)]
    content = "\n".join([doc.page_content for doc in docs])
    citations_text = "\n".join(citations)
    return content + f"\n\n**Citations:**\n{citations_text}"

@tool
def web_search(query: str):
    """Performs a real-time web search using DuckDuckGo."""
    try:
        results = DDGS().text(query, max_results=5)
        summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found."
        return summary
    except Exception as e:
        logger.error(f"Web search error: {str(e)}")
        return "Error retrieving web search results."

# ------------------------------
# πŸ”Ή Multi-Context Chat Function
# ------------------------------
def chat_with_agent(user_query, image_file, pdf_file):
    image_analysis = analyze_medical_image(image_file) if image_file else ""
    rag_text = retrieve_medical_knowledge(user_query)
    web_text = web_search(user_query)

    combined_context = f"""
    {UNIFIED_MEDICAL_PROMPT}

    Additional Context:
    - Radiology Analysis (if any): {image_analysis}
    - Retrieved from Vector Store (RAG): {rag_text}
    - Web Search Results: {web_text}

    Now, respond to the user's query with detailed, medically accurate information.
    Q: {user_query}
    A:
    """

    response_accumulator = ""
    for token in llm(
        prompt=combined_context,
        max_tokens=1024,
        temperature=0.7,
        top_p=0.9,
        stream=True
    ):
        partial_text = token["choices"][0]["text"]
        response_accumulator += partial_text
        yield response_accumulator

# ------------------------------
# πŸ”Ή Gradio Interface
# ------------------------------
with gr.Blocks(title="πŸ₯ Llama3-Med AI Assistant") as demo:
    gr.Markdown("# πŸ₯ Llama3-Med AI Assistant\n_Your intelligent medical assistant powered by advanced AI._")

    with gr.Row():
        user_input = gr.Textbox(label="πŸ’¬ Ask a medical question", placeholder="Type your question here...")
        image_file = gr.Image(label="πŸ“· Upload Medical Image", type="filepath")
        pdf_file = gr.File(label="πŸ“„ Upload PDF (Optional)", file_types=[".pdf"])

    submit_btn = gr.Button("πŸš€ Submit", variant="primary")
    output_text = gr.Textbox(label="πŸ“ Assistant's Response", interactive=False, lines=25)

    submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text)

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)