Nick088's picture
Update app.py
8ecdb24 verified
raw
history blame
18.7 kB
import os
import subprocess
import random
import numpy as np
import json
from datetime import timedelta
import tempfile
import re
import gradio as gr
import groq
from groq import Groq
# setup groq
client = Groq(api_key=os.environ.get("Groq_Api_Key"))
def handle_groq_error(e, model_name):
error_data = e.args[0]
if isinstance(error_data, str):
# Use regex to extract the JSON part of the string
json_match = re.search(r'(\{.*\})', error_data)
if json_match:
json_str = json_match.group(1)
# Ensure the JSON string is well-formed
json_str = json_str.replace("'", '"') # Replace single quotes with double quotes
error_data = json.loads(json_str)
if isinstance(e, groq.RateLimitError):
if isinstance(error_data, dict) and 'error' in error_data and 'message' in error_data['error']:
error_message = error_data['error']['message']
raise gr.Error(error_message)
else:
raise gr.Error(f"Error during Groq API call: {e}")
# llms
MAX_SEED = np.iinfo(np.int32).max
def update_max_tokens(model):
if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]:
return gr.update(maximum=8192)
elif model == "mixtral-8x7b-32768":
return gr.update(maximum=32768)
def create_history_messages(history):
history_messages = [{"role": "user", "content": m[0]} for m in history]
history_messages.extend([{"role": "assistant", "content": m[1]} for m in history])
return history_messages
def generate_response(prompt, history, model, temperature, max_tokens, top_p, seed):
messages = create_history_messages(history)
messages.append({"role": "user", "content": prompt})
print(messages)
if seed == 0:
seed = random.randint(1, MAX_SEED)
try:
stream = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
seed=seed,
stop=None,
stream=True,
)
response = ""
for chunk in stream:
delta_content = chunk.choices[0].delta.content
if delta_content is not None:
response += delta_content
yield response
return response
except Groq.GroqApiException as e:
handle_groq_error(e, model)
# speech to text
ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
MAX_FILE_SIZE_MB = 25
CHUNK_SIZE_MB = 25
LANGUAGE_CODES = {
"English": "en",
"Chinese": "zh",
"German": "de",
"Spanish": "es",
"Russian": "ru",
"Korean": "ko",
"French": "fr",
"Japanese": "ja",
"Portuguese": "pt",
"Turkish": "tr",
"Polish": "pl",
"Catalan": "ca",
"Dutch": "nl",
"Arabic": "ar",
"Swedish": "sv",
"Italian": "it",
"Indonesian": "id",
"Hindi": "hi",
"Finnish": "fi",
"Vietnamese": "vi",
"Hebrew": "he",
"Ukrainian": "uk",
"Greek": "el",
"Malay": "ms",
"Czech": "cs",
"Romanian": "ro",
"Danish": "da",
"Hungarian": "hu",
"Tamil": "ta",
"Norwegian": "no",
"Thai": "th",
"Urdu": "ur",
"Croatian": "hr",
"Bulgarian": "bg",
"Lithuanian": "lt",
"Latin": "la",
"Māori": "mi",
"Malayalam": "ml",
"Welsh": "cy",
"Slovak": "sk",
"Telugu": "te",
"Persian": "fa",
"Latvian": "lv",
"Bengali": "bn",
"Serbian": "sr",
"Azerbaijani": "az",
"Slovenian": "sl",
"Kannada": "kn",
"Estonian": "et",
"Macedonian": "mk",
"Breton": "br",
"Basque": "eu",
"Icelandic": "is",
"Armenian": "hy",
"Nepali": "ne",
"Mongolian": "mn",
"Bosnian": "bs",
"Kazakh": "kk",
"Albanian": "sq",
"Swahili": "sw",
"Galician": "gl",
"Marathi": "mr",
"Panjabi": "pa",
"Sinhala": "si",
"Khmer": "km",
"Shona": "sn",
"Yoruba": "yo",
"Somali": "so",
"Afrikaans": "af",
"Occitan": "oc",
"Georgian": "ka",
"Belarusian": "be",
"Tajik": "tg",
"Sindhi": "sd",
"Gujarati": "gu",
"Amharic": "am",
"Yiddish": "yi",
"Lao": "lo",
"Uzbek": "uz",
"Faroese": "fo",
"Haitian": "ht",
"Pashto": "ps",
"Turkmen": "tk",
"Norwegian Nynorsk": "nn",
"Maltese": "mt",
"Sanskrit": "sa",
"Luxembourgish": "lb",
"Burmese": "my",
"Tibetan": "bo",
"Tagalog": "tl",
"Malagasy": "mg",
"Assamese": "as",
"Tatar": "tt",
"Hawaiian": "haw",
"Lingala": "ln",
"Hausa": "ha",
"Bashkir": "ba",
"jw": "jw",
"Sundanese": "su",
}
def split_audio(audio_file_path, chunk_size_mb):
chunk_size = chunk_size_mb * 1024 * 1024 # Convert MB to bytes
file_number = 1
chunks = []
with open(audio_file_path, 'rb') as f:
chunk = f.read(chunk_size)
while chunk:
chunk_name = f"{os.path.splitext(audio_file_path)[0]}_part{file_number:03}.mp3" # Pad file number for correct ordering
with open(chunk_name, 'wb') as chunk_file:
chunk_file.write(chunk)
chunks.append(chunk_name)
file_number += 1
chunk = f.read(chunk_size)
return chunks
def merge_audio(chunks, output_file_path):
with open("temp_list.txt", "w") as f:
for file in chunks:
f.write(f"file '{file}'\n")
try:
subprocess.run(
[
"ffmpeg",
"-f",
"concat",
"-safe", "0",
"-i",
"temp_list.txt",
"-c",
"copy",
"-y",
output_file_path
],
check=True
)
os.remove("temp_list.txt")
for chunk in chunks:
os.remove(chunk)
except subprocess.CalledProcessError as e:
raise gr.Error(f"Error during audio merging: {e}")
# Checks file extension, size, and downsamples or splits if needed.
def check_file(audio_file_path):
if not audio_file_path:
raise gr.Error("Please upload an audio file.")
file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024)
file_extension = audio_file_path.split(".")[-1].lower()
if file_extension not in ALLOWED_FILE_EXTENSIONS:
raise gr.Error(f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}")
if file_size_mb > MAX_FILE_SIZE_MB:
gr.Warning(
f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz MP3 128kbps. Maximum size allowed: {MAX_FILE_SIZE_MB} MB"
)
output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.mp3"
try:
subprocess.run(
[
"ffmpeg",
"-i",
audio_file_path,
"-ar",
"16000",
"-ab",
"128k",
"-ac",
"1",
"-f",
"mp3",
"-y",
output_file_path,
],
check=True
)
# Check size after downsampling
downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024)
if downsampled_size_mb > MAX_FILE_SIZE_MB:
gr.Warning(f"File still too large after downsampling ({downsampled_size_mb:.2f} MB). Splitting into {CHUNK_SIZE_MB} MB chunks.")
return split_audio(output_file_path, CHUNK_SIZE_MB), "split"
return output_file_path, None
except subprocess.CalledProcessError as e:
raise gr.Error(f"Error during downsampling: {e}")
return audio_file_path, None
def transcribe_audio(audio_file_path, model, prompt, language, auto_detect_language):
processed_path, split_status = check_file(audio_file_path)
full_transcription = ""
if split_status == "split":
processed_chunks = []
for i, chunk_path in enumerate(processed_path):
try:
with open(chunk_path, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(os.path.basename(chunk_path), file.read()),
model=model,
prompt=prompt,
response_format="text",
language=None if auto_detect_language else language,
temperature=0.0,
)
full_transcription += transcription
processed_chunks.append(chunk_path)
except groq.RateLimitError as e: # Handle rate limit error
handle_groq_error(e, model)
gr.Warning(f"API limit reached during chunk {i+1}. Returning processed chunks only.")
if processed_chunks:
merge_audio(processed_chunks, 'merged_output.mp3')
return full_transcription, 'merged_output.mp3'
else:
return "Transcription failed due to API limits.", None
merge_audio(processed_path, 'merged_output.mp3')
return full_transcription, 'merged_output.mp3'
else:
try:
with open(processed_path, "rb") as file:
transcription = client.audio.transcriptions.create(
file=(os.path.basename(processed_path), file.read()),
model=model,
prompt=prompt,
response_format="text",
language=None if auto_detect_language else language,
temperature=0.0,
)
return transcription.text, None
except groq.RateLimitError as e: # Handle rate limit error
handle_groq_error(e, model)
def translate_audio(audio_file_path, model, prompt):
processed_path, split_status = check_file(audio_file_path)
full_translation = ""
if split_status == "split":
for chunk_path in processed_path:
try:
with open(chunk_path, "rb") as file:
translation = client.audio.translations.create(
file=(os.path.basename(chunk_path), file.read()),
model=model,
prompt=prompt,
response_format="text",
temperature=0.0,
)
full_translation += translation
except Groq.GroqApiException as e:
handle_groq_error(e, model)
return f"API limit reached. Partial translation: {full_translation}"
return full_translation
else:
try:
with open(processed_path, "rb") as file:
translation = client.audio.translations.create(
file=(os.path.basename(processed_path), file.read()),
model=model,
prompt=prompt,
response_format="text",
temperature=0.0,
)
return translation
except Groq.GroqApiException as e:
handle_groq_error(e, model)
with gr.Blocks() as interface:
gr.Markdown(
"""
# Groq API UI
Inference by Groq API
If you are having API Rate Limit issues, you can retry later based on the [rate limits](https://console.groq.com/docs/rate-limits) or <a href="https://huggingface.co/spaces/Nick088/Fast-Subtitle-Maker?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a> with <a href=https://console.groq.com/keys>your own API Key</a> </p>
Hugging Face Space by [Nick088](https://linktr.ee/Nick088)
<br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a>
"""
)
with gr.Tabs():
with gr.TabItem("LLMs"):
with gr.Row():
with gr.Column(scale=1, min_width=250):
model = gr.Dropdown(
choices=[
"llama3-70b-8192",
"llama3-8b-8192",
"mixtral-8x7b-32768",
"gemma-7b-it",
"gemma2-9b-it",
],
value="llama3-70b-8192",
label="Model",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
label="Temperature",
info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.",
)
max_tokens = gr.Slider(
minimum=1,
maximum=8192,
step=1,
value=4096,
label="Max Tokens",
info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5,
label="Top P",
info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.",
)
seed = gr.Number(
precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random"
)
model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
with gr.Column(scale=1, min_width=400):
chatbot = gr.ChatInterface(
fn=generate_response,
chatbot=None,
additional_inputs=[
model,
temperature,
max_tokens,
top_p,
seed,
],
)
model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
with gr.TabItem("Speech To Text"):
with gr.Tabs():
with gr.TabItem("Transcription"):
gr.Markdown("Transcript audio from files to text!")
with gr.Row():
audio_input = gr.File(
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
)
model_choice_transcribe = gr.Dropdown(
choices=["whisper-large-v3"],
value="whisper-large-v3",
label="Model",
)
with gr.Row():
transcribe_prompt = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
with gr.Column():
language = gr.Dropdown(
choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()],
value="en",
label="Language",
)
auto_detect_language = gr.Checkbox(label="Auto Detect Language")
transcribe_button = gr.Button("Transcribe")
transcription_output = gr.Textbox(label="Transcription")
merged_audio_output = gr.File(label="Merged Audio (if chunked)")
transcribe_button.click(
transcribe_audio,
inputs=[audio_input, model_choice_transcribe, transcribe_prompt, language, auto_detect_language],
outputs=[transcription_output, merged_audio_output],
)
with gr.TabItem("Translation"):
gr.Markdown("Transcript audio from files and translate them to English text!")
with gr.Row():
audio_input_translate = gr.File(
type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
)
model_choice_translate = gr.Dropdown(
choices=["whisper-large-v3"],
value="whisper-large-v3",
label="Audio Speech Recognition (ASR) Model",
)
with gr.Row():
translate_prompt = gr.Textbox(
label="Prompt (Optional)",
info="Specify any context or spelling corrections.",
)
translate_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translation")
translate_button.click(
translate_audio,
inputs=[audio_input_translate, model_choice_translate, translate_prompt],
outputs=translation_output,
)
interface.launch(share=True)