|
import os |
|
import logging |
|
import torch |
|
import gradio as gr |
|
from tqdm import tqdm |
|
from PIL import Image |
|
|
|
|
|
from langgraph.graph import StateGraph |
|
from langgraph.checkpoint.memory import MemorySaver |
|
from langchain.tools import tool |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from duckduckgo_search import DDGS |
|
from llama_cpp import Llama |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
MODEL_PATH = "Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q4_0.gguf" |
|
|
|
if not os.path.exists(MODEL_PATH): |
|
raise FileNotFoundError(f"Model file {MODEL_PATH} not found. Upload it to the same directory.") |
|
|
|
llm = Llama( |
|
model_path=MODEL_PATH, |
|
n_ctx=8192, |
|
n_gpu_layers=0, |
|
logits_all=True, |
|
n_batch=512 |
|
) |
|
|
|
logger.info("Llama GGUF Model Loaded Successfully.") |
|
|
|
|
|
|
|
|
|
UNIFIED_MEDICAL_PROMPT = """ |
|
You are an advanced Medical AI Assistant capable of providing thorough, |
|
comprehensive answers for a wide range of medical specialties: |
|
General Practice, Radiology, Cardiology, Neurology, Psychiatry, Pediatrics, |
|
Endocrinology, Oncology, and more. |
|
|
|
You can: |
|
1) Analyze images if provided (Radiology). |
|
2) Search the web for up-to-date medical info (Web Search). |
|
3) Retrieve relevant documents from a knowledge base (Vector Store). |
|
4) Provide scientific, evidence-based explanations and references when possible. |
|
|
|
Always strive to provide a detailed, helpful, and empathetic response. |
|
""" |
|
|
|
|
|
|
|
|
|
_vector_store_cache = None |
|
|
|
def load_vectorstore(pdf_path="medical_docs.pdf"): |
|
try: |
|
loader = PyPDFLoader(pdf_path) |
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100) |
|
docs = text_splitter.split_documents(documents) |
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
vector_store = FAISS.from_documents(docs, embeddings) |
|
logger.info(f"Vector store loaded with {len(docs)} documents.") |
|
return vector_store |
|
except Exception as e: |
|
logger.error(f"Error loading vector store: {str(e)}") |
|
return None |
|
|
|
if os.path.exists("medical_docs.pdf"): |
|
_vector_store_cache = load_vectorstore("medical_docs.pdf") |
|
|
|
vector_store = _vector_store_cache |
|
|
|
|
|
|
|
|
|
@tool |
|
def analyze_medical_image(image_path: str): |
|
"""Analyzes a medical image and returns a diagnostic explanation.""" |
|
try: |
|
image = Image.open(image_path) |
|
except Exception as e: |
|
logger.error(f"Error opening image: {str(e)}") |
|
return "Error processing image." |
|
output = llm(f"Analyze this medical image for radiological findings:\n{image}", max_tokens=512) |
|
return output["choices"][0]["text"] |
|
|
|
@tool |
|
def retrieve_medical_knowledge(query: str): |
|
"""Retrieves medical knowledge from FAISS vector store.""" |
|
if vector_store is None: |
|
return "No external medical knowledge available." |
|
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5}) |
|
docs = retriever.get_relevant_documents(query) |
|
citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)] |
|
content = "\n".join([doc.page_content for doc in docs]) |
|
citations_text = "\n".join(citations) |
|
return content + f"\n\n**Citations:**\n{citations_text}" |
|
|
|
@tool |
|
def web_search(query: str): |
|
"""Performs a real-time web search using DuckDuckGo.""" |
|
try: |
|
results = DDGS().text(query, max_results=5) |
|
summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found." |
|
return summary |
|
except Exception as e: |
|
logger.error(f"Web search error: {str(e)}") |
|
return "Error retrieving web search results." |
|
|
|
|
|
|
|
|
|
def chat_with_agent(user_query, image_file, pdf_file): |
|
image_analysis = analyze_medical_image(image_file) if image_file else "" |
|
rag_text = retrieve_medical_knowledge(user_query) |
|
web_text = web_search(user_query) |
|
|
|
combined_context = f""" |
|
{UNIFIED_MEDICAL_PROMPT} |
|
|
|
Additional Context: |
|
- Radiology Analysis (if any): {image_analysis} |
|
- Retrieved from Vector Store (RAG): {rag_text} |
|
- Web Search Results: {web_text} |
|
|
|
Now, respond to the user's query with detailed, medically accurate information. |
|
Q: {user_query} |
|
A: |
|
""" |
|
|
|
response_accumulator = "" |
|
for token in llm( |
|
prompt=combined_context, |
|
max_tokens=1024, |
|
temperature=0.7, |
|
top_p=0.9, |
|
stream=True |
|
): |
|
partial_text = token["choices"][0]["text"] |
|
response_accumulator += partial_text |
|
yield response_accumulator |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="π₯ Llama3-Med AI Assistant") as demo: |
|
gr.Markdown("# π₯ Llama3-Med AI Assistant\n_Your intelligent medical assistant powered by advanced AI._") |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(label="π¬ Ask a medical question", placeholder="Type your question here...") |
|
image_file = gr.Image(label="π· Upload Medical Image", type="filepath") |
|
pdf_file = gr.File(label="π Upload PDF (Optional)", file_types=[".pdf"]) |
|
|
|
submit_btn = gr.Button("π Submit", variant="primary") |
|
output_text = gr.Textbox(label="π Assistant's Response", interactive=False, lines=25) |
|
|
|
submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
|