Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from llama_cpp import Llama | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import PromptTemplate | |
class RAGInterface: | |
def __init__(self): | |
# Initialize embedding model | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
# Load vector store | |
persist_directory = os.path.join(os.path.dirname(__file__), 'mydb') | |
self.vectorstore = Chroma( | |
persist_directory=persist_directory, | |
embedding_function=self.embeddings | |
) | |
# Model configurations | |
self.model_configs = { | |
"Llama 3.2 3B (Fast, Less Accurate)": { | |
"repo_id": "bartowski/Llama-3.2-3B-Instruct-GGUF", | |
"filename": "Llama-3.2-3B-Instruct-Q6_K.gguf", | |
}, | |
"Llama 3.1 8B (Slower, More Accurate)": { | |
"repo_id": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", | |
"filename": "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", | |
} | |
} | |
# Initialize with default model | |
self.current_model = "Llama 3.1 8B (Slower, More Accurate)" | |
self.load_model(self.current_model) | |
# Define RAG prompt template | |
self.template = """Answer the question based only on the following context: | |
{context} | |
Question: {question} | |
Answer the question in a clear way. If you cannot find the answer in the context, | |
just say "I don't have enough information to answer this question." | |
Make sure to: | |
1. Only use information from the provided context | |
2. If you're unsure, acknowledge it | |
""" | |
self.prompt = PromptTemplate.from_template(self.template) | |
def load_model(self, model_name): | |
config = self.model_configs[model_name] | |
self.llm = Llama.from_pretrained( | |
repo_id=config["repo_id"], | |
filename=config["filename"], | |
n_ctx=2048 | |
) | |
self.current_model = model_name | |
def respond(self, message, history, system_message, model_choice, temperature, max_tokens=2048): | |
# Load new model if different from current | |
if model_choice != self.current_model: | |
self.load_model(model_choice) | |
# Build messages list | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Search vector store | |
retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5}) | |
docs = retriever.get_relevant_documents(message) | |
context = "\n\n".join([doc.page_content for doc in docs]) | |
# Format prompt and add to messages | |
final_prompt = self.prompt.format(context=context, question=message) | |
messages.append({"role": "user", "content": final_prompt}) | |
# Generate response | |
response = self.llm.create_chat_completion( | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
) | |
return response['choices'][0]['message']['content'] | |
def create_interface(self): | |
# Custom CSS for better styling | |
custom_css = """ | |
<style> | |
/* Global Styles */ | |
body, #root { | |
font-family: Helvetica, Arial, sans-serif; | |
background-color: #1a1a1a; | |
color: #fafafa; | |
} | |
/* Header Styles */ | |
.app-header { | |
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); | |
padding: 24px; | |
border-radius: 8px; | |
margin-bottom: 24px; | |
text-align: center; | |
} | |
.app-title { | |
font-size: 36px; | |
margin: 0; | |
color: #fafafa; | |
} | |
.app-subtitle { | |
font-size: 18px; | |
margin: 8px 0; | |
color: #fafafa; | |
opacity: 0.8; | |
} | |
/* Chat Container */ | |
.chat-container { | |
background-color: #2a2a2a; | |
border-radius: 8px; | |
padding: 20px; | |
margin-bottom: 20px; | |
} | |
/* Control Panel */ | |
.control-panel { | |
background-color: #333; | |
padding: 16px; | |
border-radius: 8px; | |
margin-top: 16px; | |
} | |
/* Gradio Component Overrides */ | |
.gr-button { | |
background-color: #4a4a4a; | |
color: #fff; | |
border: none; | |
border-radius: 4px; | |
padding: 8px 16px; | |
transition: background-color 0.3s; | |
} | |
.gr-button:hover { | |
background-color: #5a5a5a; | |
} | |
.gr-input, .gr-dropdown { | |
background-color: #3a3a3a; | |
color: #fff; | |
border: 1px solid #4a4a4a; | |
border-radius: 4px; | |
padding: 8px; | |
} | |
</style> | |
""" | |
# Header HTML | |
header_html = f""" | |
<div class="app-header"> | |
<h1 class="app-title">Document-Based Question Answering</h1> | |
<h2 class="app-subtitle">Powered by Llama and RAG</h2> | |
</div> | |
{custom_css} | |
""" | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
fn=self.respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are a friendly chatbot.", | |
label="System Message", | |
elem_classes="control-panel" | |
), | |
gr.Dropdown( | |
choices=list(self.model_configs.keys()), | |
value=self.current_model, | |
label="Select Model", | |
elem_classes="control-panel" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
elem_classes="control-panel" | |
), | |
], | |
title="", # Title is handled in custom HTML | |
description="Ask questions about Computers and get AI-powered answers.", | |
theme=gr.themes.Default(), | |
) | |
# Wrap the interface with custom CSS | |
with gr.Blocks(css=custom_css) as wrapper: | |
gr.HTML(header_html) | |
demo.render() | |
return wrapper | |
def main(): | |
interface = RAGInterface() | |
demo = interface.create_interface() | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
main() |