ai_school_hw5 / backend /query_llm.py
complynx's picture
change clients
cb68ee9
raw
history blame
4.53 kB
import openai
import gradio as gr
import os
from typing import Any, Dict, Generator, List
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
HF_TOKEN = os.getenv("HF_TOKEN")
hf_models = {
"mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
"llama 3": "meta-llama/Meta-Llama-3-70B-Instruct",
}
openai_models = {"gpt-4o","gpt-3.5-turbo-0125"}
tokenizers = {k: AutoTokenizer.from_pretrained(m) for k,m in hf_models.items()}
clients = {k: InferenceClient(
m, token=HF_TOKEN
) for k,m in hf_models.items()}
HF_GENERATE_KWARGS = {
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
'top_p': float(os.getenv("TOP_P", 0.6)),
'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)),
'do_sample': bool(os.getenv("DO_SAMPLE", True))
}
OAI_GENERATE_KWARGS = {
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
'top_p': float(os.getenv("TOP_P", 0.6)),
'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2))
}
def format_prompt(message: str, model: str):
"""
Formats the given message using a chat template.
Args:
message (str): The user message to be formatted.
api_kind (str): LLM API provider.
Returns:
str: Formatted message after applying the chat template.
"""
# Create a list of message dictionaries with role and content
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
if model in openai_models:
return messages
elif model in hf_models:
return tokenizers[model].apply_chat_template(messages, tokenize=False)
else:
raise ValueError(f"Model {model} is not supported")
def generate_hf(model: str, prompt: str, history: str, _: str) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The prompt for the text generation.
history (str): Context or history for the text generation.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
formatted_prompt = format_prompt(prompt, model)
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
try:
stream = clients[model].text_generation(
formatted_prompt,
**HF_GENERATE_KWARGS,
stream=True,
details=True,
return_full_text=False
)
output = ""
for response in stream:
output += response.token.text
yield output
except Exception as e:
if "Too Many Requests" in str(e):
raise gr.Error(f"Too many requests: {str(e)}")
elif "Authorization header is invalid" in str(e):
raise gr.Error("Authentication error: HF token was either not provided or incorrect")
else:
raise gr.Error(f"Unhandled Exception: {str(e)}")
def generate_openai(model: str, prompt: str, history: str, api_key: str) -> Generator[str, None, str]:
"""
Generate a sequence of tokens based on a given prompt and history using Mistral client.
Args:
prompt (str): The initial prompt for the text generation.
history (str): Context or history for the text generation.
Returns:
Generator[str, None, str]: A generator yielding chunks of generated text.
Returns a final string if an error occurs.
"""
formatted_prompt = format_prompt(prompt, model)
client = openai.Client(api_key=api_key)
try:
stream = client.chat.completions.create(
model=model,
messages=formatted_prompt,
**OAI_GENERATE_KWARGS,
stream=True
)
output = ""
for chunk in stream:
if chunk.choices[0].delta.content:
output += chunk.choices[0].delta.content
yield output
except Exception as e:
if "Too Many Requests" in str(e):
raise gr.Error("ERROR: Too many requests on OpenAI client")
elif "You didn't provide an API key" in str(e):
raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect")
else:
raise gr.Error(f"Unhandled Exception: {str(e)}")