muhammadsalmanalfaridzi's picture
Update app.py
d5275bb verified
import gradio as gr
import os
import warnings
import asyncio
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
from llama_index.llms.cerebras import Cerebras
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat
from groq import Groq
import io
# Suppress warnings
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")
# Global variables
index = None
query_engine = None
# Load Cerebras API key from Hugging Face secrets
api_key = os.getenv("CEREBRAS_API_KEY")
if not api_key:
raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.")
else:
print("Cerebras API key loaded successfully.")
# Initialize Cerebras LLM and embedding model
os.environ["CEREBRAS_API_KEY"] = api_key
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"]) # Change model to Llama3.1-70b from Cerebras
Settings.llm = llm # Ensure Cerebras is the LLM being used
# Initialize Mixedbread Embedding model
mixedbread_api_key = os.getenv("MXBAI_API_KEY")
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1")
# Initialize Groq client for Whisper Large V3
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY is not set.")
else:
print("Groq API key loaded successfully.")
client = Groq(api_key=groq_api_key) # Groq client initialization
# Function for audio transcription and translation (Whisper Large V3 from Groq)
def transcribe_or_translate_audio(audio_file, translate=False):
"""
Transcribes or translates audio using Whisper Large V3 via Groq API.
"""
try:
with open(audio_file, "rb") as file:
if translate:
result = client.audio.translations.create(
file=(audio_file, file.read()),
model="whisper-large-v3", # Use Groq Whisper Large V3
response_format="json",
temperature=0.0
)
return result.text
else:
result = client.audio.transcriptions.create(
file=(audio_file, file.read()),
model="whisper-large-v3", # Use Groq Whisper Large V3
response_format="json",
temperature=0.0
)
return result.text
except Exception as e:
return f"Error processing audio: {str(e)}"
# Function to load documents and create index
def load_documents(file_objs):
global index, query_engine
try:
if not file_objs:
return "Error: No files selected."
documents = []
document_names = []
for file_obj in file_objs:
file_name = os.path.basename(file_obj.name)
document_names.append(file_name)
loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data()
for doc in loaded_docs:
doc.metadata["source"] = file_name
documents.append(doc)
if not documents:
return "No documents found in the selected files."
index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model)
query_engine = index.as_query_engine()
return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}"
except Exception as e:
return f"Error loading documents: {str(e)}"
async def perform_rag(query, history, audio_file=None, translate_audio=False):
global query_engine
if query_engine is None:
return history + [("Please load documents first.", None)]
try:
# Handle audio input if provided
if audio_file:
transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio)
query = f"{query} {transcription}".strip()
response = await asyncio.to_thread(query_engine.query, query)
answer = str(response) # Directly get the answer from the response
# If relevant documents are available, add sources without the "Sources" label
if hasattr(response, "get_documents"):
relevant_docs = response.get_documents()
if relevant_docs:
sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs])
else:
sources = ""
else:
sources = ""
# Combine answer with sources (if any) without additional labels
final_result = f"{answer}\n\n{sources}".strip()
# Return updated history with the final result
return history + [(query, final_result)]
except Exception as e:
return history + [(query, f"Error processing query: {str(e)}")]
# Function to clear the session and reset variables
def clear_all():
global index, query_engine
index = None
query_engine = None
return None, "", [], "" # Reset file input, load output, chatbot, and message input to default states
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo:
gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text")
chatbot = gr.Chatbot()
with gr.Row():
file_input = gr.File(label="Select files to load", file_count="multiple")
load_btn = gr.Button("Load Documents")
load_output = gr.Textbox(label="Load Status")
with gr.Row():
msg = gr.Textbox(label="Enter your question")
audio_input = gr.Audio(type="filepath", label="Upload Audio")
translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False)
clear = gr.Button("Clear")
# Set up event handlers
load_btn.click(load_documents, inputs=[file_input], outputs=[load_output])
# Event handler for text input (only process text)
msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot])
# Event handler for audio input (only process audio)
audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot])
clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False)
# Run the app
if __name__ == "__main__":
demo.queue()
demo.launch()