|
import gradio as gr |
|
import os |
|
import warnings |
|
import asyncio |
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings |
|
from llama_index.llms.cerebras import Cerebras |
|
from llama_index.embeddings.mixedbreadai import MixedbreadAIEmbedding, EncodingFormat |
|
from groq import Groq |
|
import io |
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*") |
|
|
|
|
|
index = None |
|
query_engine = None |
|
|
|
|
|
api_key = os.getenv("CEREBRAS_API_KEY") |
|
if not api_key: |
|
raise ValueError("CEREBRAS_API_KEY is not set in Hugging Face Secrets.") |
|
else: |
|
print("Cerebras API key loaded successfully.") |
|
|
|
|
|
os.environ["CEREBRAS_API_KEY"] = api_key |
|
llm = Cerebras(model="llama-3.3-70b", api_key=os.environ["CEREBRAS_API_KEY"]) |
|
Settings.llm = llm |
|
|
|
|
|
mixedbread_api_key = os.getenv("MXBAI_API_KEY") |
|
embed_model = MixedbreadAIEmbedding(api_key=mixedbread_api_key, model_name="mixedbread-ai/mxbai-embed-large-v1") |
|
|
|
|
|
groq_api_key = os.getenv("GROQ_API_KEY") |
|
if not groq_api_key: |
|
raise ValueError("GROQ_API_KEY is not set.") |
|
else: |
|
print("Groq API key loaded successfully.") |
|
client = Groq(api_key=groq_api_key) |
|
|
|
|
|
def transcribe_or_translate_audio(audio_file, translate=False): |
|
""" |
|
Transcribes or translates audio using Whisper Large V3 via Groq API. |
|
""" |
|
try: |
|
with open(audio_file, "rb") as file: |
|
if translate: |
|
result = client.audio.translations.create( |
|
file=(audio_file, file.read()), |
|
model="whisper-large-v3", |
|
response_format="json", |
|
temperature=0.0 |
|
) |
|
return result.text |
|
else: |
|
result = client.audio.transcriptions.create( |
|
file=(audio_file, file.read()), |
|
model="whisper-large-v3", |
|
response_format="json", |
|
temperature=0.0 |
|
) |
|
return result.text |
|
except Exception as e: |
|
return f"Error processing audio: {str(e)}" |
|
|
|
|
|
def load_documents(file_objs): |
|
global index, query_engine |
|
try: |
|
if not file_objs: |
|
return "Error: No files selected." |
|
|
|
documents = [] |
|
document_names = [] |
|
for file_obj in file_objs: |
|
file_name = os.path.basename(file_obj.name) |
|
document_names.append(file_name) |
|
loaded_docs = SimpleDirectoryReader(input_files=[file_obj.name]).load_data() |
|
for doc in loaded_docs: |
|
doc.metadata["source"] = file_name |
|
documents.append(doc) |
|
|
|
if not documents: |
|
return "No documents found in the selected files." |
|
|
|
index = VectorStoreIndex.from_documents(documents, llm=llm, embed_model=embed_model) |
|
query_engine = index.as_query_engine() |
|
|
|
return f"Successfully loaded {len(documents)} documents from the files: {', '.join(document_names)}" |
|
except Exception as e: |
|
return f"Error loading documents: {str(e)}" |
|
|
|
async def perform_rag(query, history, audio_file=None, translate_audio=False): |
|
global query_engine |
|
if query_engine is None: |
|
return history + [("Please load documents first.", None)] |
|
|
|
try: |
|
|
|
if audio_file: |
|
transcription = transcribe_or_translate_audio(audio_file, translate=translate_audio) |
|
query = f"{query} {transcription}".strip() |
|
|
|
response = await asyncio.to_thread(query_engine.query, query) |
|
answer = str(response) |
|
|
|
|
|
if hasattr(response, "get_documents"): |
|
relevant_docs = response.get_documents() |
|
if relevant_docs: |
|
sources = "\n\n".join([f"{doc.metadata.get('source', 'No source available')}" for doc in relevant_docs]) |
|
else: |
|
sources = "" |
|
else: |
|
sources = "" |
|
|
|
|
|
final_result = f"{answer}\n\n{sources}".strip() |
|
|
|
|
|
return history + [(query, final_result)] |
|
except Exception as e: |
|
return history + [(query, f"Error processing query: {str(e)}")] |
|
|
|
|
|
def clear_all(): |
|
global index, query_engine |
|
index = None |
|
query_engine = None |
|
return None, "", [], "" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal", secondary_hue="teal", neutral_hue="slate")) as demo: |
|
gr.Markdown("# RAG Multi-file Chat Application with Speech-to-Text") |
|
|
|
chatbot = gr.Chatbot() |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="Select files to load", file_count="multiple") |
|
load_btn = gr.Button("Load Documents") |
|
load_output = gr.Textbox(label="Load Status") |
|
|
|
with gr.Row(): |
|
msg = gr.Textbox(label="Enter your question") |
|
audio_input = gr.Audio(type="filepath", label="Upload Audio") |
|
translate_checkbox = gr.Checkbox(label="Translate Audio to English Text", value=False) |
|
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
load_btn.click(load_documents, inputs=[file_input], outputs=[load_output]) |
|
|
|
|
|
msg.submit(perform_rag, inputs=[msg, chatbot], outputs=[chatbot]) |
|
|
|
|
|
audio_input.change(perform_rag, inputs=[msg, chatbot, audio_input, translate_checkbox], outputs=[chatbot]) |
|
|
|
clear.click(clear_all, outputs=[file_input, load_output, chatbot, msg], queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.launch() |