Spaces:
Runtime error
Runtime error
import openai | |
import gradio as gr | |
import os | |
from typing import Any, Dict, Generator, List | |
from huggingface_hub import InferenceClient | |
from transformers import AutoTokenizer | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
hf_models = { | |
"mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2", | |
"llama 3": "meta-llama/Meta-Llama-3-70B-Instruct", | |
} | |
openai_models = {"gpt-4o","gpt-3.5-turbo-0125"} | |
tokenizers = {k: AutoTokenizer.from_pretrained(m) for k,m in hf_models.items()} | |
clients = {k: InferenceClient( | |
m, token=HF_TOKEN | |
) for k,m in hf_models.items()} | |
HF_GENERATE_KWARGS = { | |
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2), | |
'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)), | |
'top_p': float(os.getenv("TOP_P", 0.6)), | |
'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)), | |
'do_sample': bool(os.getenv("DO_SAMPLE", True)) | |
} | |
OAI_GENERATE_KWARGS = { | |
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2), | |
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)), | |
'top_p': float(os.getenv("TOP_P", 0.6)), | |
'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2)) | |
} | |
def format_prompt(message: str, model: str): | |
""" | |
Formats the given message using a chat template. | |
Args: | |
message (str): The user message to be formatted. | |
api_kind (str): LLM API provider. | |
Returns: | |
str: Formatted message after applying the chat template. | |
""" | |
# Create a list of message dictionaries with role and content | |
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}] | |
if model in openai_models: | |
return messages | |
elif model in hf_models: | |
return tokenizers[model].apply_chat_template(messages, tokenize=False) | |
else: | |
raise ValueError(f"Model {model} is not supported") | |
def generate_hf(model: str, prompt: str, history: str, _: str) -> Generator[str, None, str]: | |
""" | |
Generate a sequence of tokens based on a given prompt and history using Mistral client. | |
Args: | |
prompt (str): The prompt for the text generation. | |
history (str): Context or history for the text generation. | |
Returns: | |
Generator[str, None, str]: A generator yielding chunks of generated text. | |
Returns a final string if an error occurs. | |
""" | |
formatted_prompt = format_prompt(prompt, model) | |
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8") | |
try: | |
stream = clients[model].text_generation( | |
formatted_prompt, | |
**HF_GENERATE_KWARGS, | |
stream=True, | |
details=True, | |
return_full_text=False | |
) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
except Exception as e: | |
if "Too Many Requests" in str(e): | |
raise gr.Error(f"Too many requests: {str(e)}") | |
elif "Authorization header is invalid" in str(e): | |
raise gr.Error("Authentication error: HF token was either not provided or incorrect") | |
else: | |
raise gr.Error(f"Unhandled Exception: {str(e)}") | |
def generate_openai(model: str, prompt: str, history: str, api_key: str) -> Generator[str, None, str]: | |
""" | |
Generate a sequence of tokens based on a given prompt and history using Mistral client. | |
Args: | |
prompt (str): The initial prompt for the text generation. | |
history (str): Context or history for the text generation. | |
Returns: | |
Generator[str, None, str]: A generator yielding chunks of generated text. | |
Returns a final string if an error occurs. | |
""" | |
formatted_prompt = format_prompt(prompt, model) | |
client = openai.Client(api_key=api_key) | |
try: | |
stream = client.chat.completions.create( | |
model=model, | |
messages=formatted_prompt, | |
**OAI_GENERATE_KWARGS, | |
stream=True | |
) | |
output = "" | |
for chunk in stream: | |
if chunk.choices[0].delta.content: | |
output += chunk.choices[0].delta.content | |
yield output | |
except Exception as e: | |
if "Too Many Requests" in str(e): | |
raise gr.Error("ERROR: Too many requests on OpenAI client") | |
elif "You didn't provide an API key" in str(e): | |
raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect") | |
else: | |
raise gr.Error(f"Unhandled Exception: {str(e)}") | |