File size: 6,246 Bytes
656e743 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import os
import logging
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
# LangChain & LangGraph
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain.tools import tool
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from duckduckgo_search import DDGS
from llama_cpp import Llama
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------
# πΉ Load GGUF Model with llama-cpp-python
# ------------------------------
MODEL_PATH = "Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q4_0.gguf"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file {MODEL_PATH} not found. Upload it to the same directory.")
llm = Llama(
model_path=MODEL_PATH,
n_ctx=8192,
n_gpu_layers=0, # Set to 0 for CPU inference
logits_all=True,
n_batch=512
)
logger.info("Llama GGUF Model Loaded Successfully.")
# ------------------------------
# πΉ Multi-Specialty Prompt
# ------------------------------
UNIFIED_MEDICAL_PROMPT = """
You are an advanced Medical AI Assistant capable of providing thorough,
comprehensive answers for a wide range of medical specialties:
General Practice, Radiology, Cardiology, Neurology, Psychiatry, Pediatrics,
Endocrinology, Oncology, and more.
You can:
1) Analyze images if provided (Radiology).
2) Search the web for up-to-date medical info (Web Search).
3) Retrieve relevant documents from a knowledge base (Vector Store).
4) Provide scientific, evidence-based explanations and references when possible.
Always strive to provide a detailed, helpful, and empathetic response.
"""
# ------------------------------
# πΉ FAISS Vector Store for RAG
# ------------------------------
_vector_store_cache = None
def load_vectorstore(pdf_path="medical_docs.pdf"):
try:
loader = PyPDFLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
docs = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_documents(docs, embeddings)
logger.info(f"Vector store loaded with {len(docs)} documents.")
return vector_store
except Exception as e:
logger.error(f"Error loading vector store: {str(e)}")
return None
if os.path.exists("medical_docs.pdf"):
_vector_store_cache = load_vectorstore("medical_docs.pdf")
vector_store = _vector_store_cache
# ------------------------------
# πΉ Define AI Tools
# ------------------------------
@tool
def analyze_medical_image(image_path: str):
"""Analyzes a medical image and returns a diagnostic explanation."""
try:
image = Image.open(image_path)
except Exception as e:
logger.error(f"Error opening image: {str(e)}")
return "Error processing image."
output = llm(f"Analyze this medical image for radiological findings:\n{image}", max_tokens=512)
return output["choices"][0]["text"]
@tool
def retrieve_medical_knowledge(query: str):
"""Retrieves medical knowledge from FAISS vector store."""
if vector_store is None:
return "No external medical knowledge available."
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
docs = retriever.get_relevant_documents(query)
citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)]
content = "\n".join([doc.page_content for doc in docs])
citations_text = "\n".join(citations)
return content + f"\n\n**Citations:**\n{citations_text}"
@tool
def web_search(query: str):
"""Performs a real-time web search using DuckDuckGo."""
try:
results = DDGS().text(query, max_results=5)
summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found."
return summary
except Exception as e:
logger.error(f"Web search error: {str(e)}")
return "Error retrieving web search results."
# ------------------------------
# πΉ Multi-Context Chat Function
# ------------------------------
def chat_with_agent(user_query, image_file, pdf_file):
image_analysis = analyze_medical_image(image_file) if image_file else ""
rag_text = retrieve_medical_knowledge(user_query)
web_text = web_search(user_query)
combined_context = f"""
{UNIFIED_MEDICAL_PROMPT}
Additional Context:
- Radiology Analysis (if any): {image_analysis}
- Retrieved from Vector Store (RAG): {rag_text}
- Web Search Results: {web_text}
Now, respond to the user's query with detailed, medically accurate information.
Q: {user_query}
A:
"""
response_accumulator = ""
for token in llm(
prompt=combined_context,
max_tokens=1024,
temperature=0.7,
top_p=0.9,
stream=True
):
partial_text = token["choices"][0]["text"]
response_accumulator += partial_text
yield response_accumulator
# ------------------------------
# πΉ Gradio Interface
# ------------------------------
with gr.Blocks(title="π₯ Llama3-Med AI Assistant") as demo:
gr.Markdown("# π₯ Llama3-Med AI Assistant\n_Your intelligent medical assistant powered by advanced AI._")
with gr.Row():
user_input = gr.Textbox(label="π¬ Ask a medical question", placeholder="Type your question here...")
image_file = gr.Image(label="π· Upload Medical Image", type="filepath")
pdf_file = gr.File(label="π Upload PDF (Optional)", file_types=[".pdf"])
submit_btn = gr.Button("π Submit", variant="primary")
output_text = gr.Textbox(label="π Assistant's Response", interactive=False, lines=25)
submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|