Spaces:
Sleeping
Sleeping
File size: 9,634 Bytes
abbebf8 327982a cf4506f 8093276 74d6e52 327982a 44bab49 469f650 44bab49 139217d 56e785c 8093276 74d6e52 a96b492 6c32632 56e785c 469f650 8093276 327982a a96b492 ddb0d91 8093276 74d6e52 8093276 56e785c 74d6e52 56e785c e01e28e e327a9e 74d6e52 469f650 74d6e52 469f650 74d6e52 469f650 74d6e52 56e785c 74d6e52 e01e28e 327982a e4b918c 327982a e4b918c f84c1a6 327982a e4b918c 327982a e01e28e 327982a e01e28e e4b918c 327982a e4b918c a96b492 e4b918c 327982a e4b918c 327982a e4b918c 3ebb6e1 327982a e4b918c 327982a e4b918c 139217d ddb0d91 a0f49a0 ddb0d91 e01e28e ddb0d91 a0f49a0 ddb0d91 a0f49a0 ddb0d91 9475016 56e785c 469f650 00af17e 56e785c 9475016 56e785c 00af17e 56e785c 469f650 3c6c618 c013599 56e785c 8093276 abbebf8 124003a abbebf8 124003a abbebf8 124003a abbebf8 124003a abbebf8 124003a abbebf8 124003a abbebf8 8093276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import datetime
import json
import uuid
from time import time, sleep
from os import environ as env
from typing import Any, Dict, Union
from data import log_to_jsonl
import requests
from huggingface_hub import hf_hub_download
# There are 4 ways to use a LLM model currently used:
# 1. Use the HTTP server (USE_HTTP_SERVER=True), this is good for development
# when you want to change the logic of the translator without restarting the server.
# 2. Load the model into memory
# When using the HTTP server, it must be ran separately. See the README for instructions.
# The llama_cpp Python HTTP server communicates with the AI model, similar
# to the OpenAI API but adds a unique "grammar" parameter.
# The real OpenAI API has other ways to set the output format.
# It's possible to switch to another LLM API by changing the llm_streaming function.
# 3. Use the RunPod API, which is a paid service with severless GPU functions.
# See serverless.md for more information.
# 4. Use the Mistral API, which is a paid services.
URL = "http://localhost:5834/v1/chat/completions"
in_memory_llm = None
worker_options = ["runpod", "http", "in_memory", "mistral"]
LLM_WORKER = env.get("LLM_WORKER", "mistral")
if LLM_WORKER not in worker_options:
raise ValueError(f"Invalid worker: {LLM_WORKER}")
N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
CONTEXT_SIZE = int(env.get("CONTEXT_SIZE", 2048))
LLM_MODEL_PATH = env.get("LLM_MODEL_PATH", None)
MAX_TOKENS = int(env.get("MAX_TOKENS", 1000))
TEMPERATURE = float(env.get("TEMPERATURE", 0.3))
performing_local_inference = (LLM_WORKER == "in_memory" or LLM_WORKER == "http")
if LLM_MODEL_PATH and len(LLM_MODEL_PATH) > 0:
print(f"Using local model from {LLM_MODEL_PATH}")
if performing_local_inference and not LLM_MODEL_PATH:
print("No local LLM_MODEL_PATH environment variable set. We need a model, downloading model from HuggingFace Hub")
LLM_MODEL_PATH =hf_hub_download(
repo_id=env.get("REPO_ID", "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"),
filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"),
)
print(f"Model downloaded to {LLM_MODEL_PATH}")
if LLM_WORKER == "http" or LLM_WORKER == "in_memory":
from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf
if in_memory_llm is None and LLM_WORKER == "in_memory":
print("Loading model into memory. If you didn't want this, set the USE_HTTP_SERVER environment variable to 'true'.")
in_memory_llm = Llama(model_path=LLM_MODEL_PATH, n_ctx=CONTEXT_SIZE, n_gpu_layers=N_GPU_LAYERS, verbose=True)
def llm_streaming(
prompt: str, pydantic_model_class, return_pydantic_object=False
) -> Union[str, Dict[str, Any]]:
schema = pydantic_model_class.model_json_schema()
# Optional example field from schema, is not needed for the grammar generation
if "example" in schema:
del schema["example"]
json_schema = json.dumps(schema)
grammar = json_schema_to_gbnf(json_schema)
payload = {
"stream": True,
"max_tokens": MAX_TOKENS,
"grammar": grammar,
"temperature": TEMPERATURE,
"messages": [{"role": "user", "content": prompt}],
}
headers = {
"Content-Type": "application/json",
}
response = requests.post(
URL,
headers=headers,
json=payload,
stream=True,
)
output_text = ""
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode("utf-8")
if chunk.startswith("data: "):
chunk = chunk.split("data: ")[1]
if chunk.strip() == "[DONE]":
break
chunk = json.loads(chunk)
new_token = chunk.get("choices")[0].get("delta").get("content")
if new_token:
output_text = output_text + new_token
print(new_token, sep="", end="", flush=True)
print('\n')
if return_pydantic_object:
model_object = pydantic_model_class.model_validate_json(output_text)
return model_object
else:
json_output = json.loads(output_text)
return json_output
def replace_text(template: str, replacements: dict) -> str:
for key, value in replacements.items():
template = template.replace(f"{{{key}}}", value)
return template
def calculate_overall_score(faithfulness, spiciness):
baseline_weight = 0.8
overall = faithfulness + (1 - baseline_weight) * spiciness * faithfulness
return overall
def llm_stream_sans_network(
prompt: str, pydantic_model_class, return_pydantic_object=False
) -> Union[str, Dict[str, Any]]:
schema = pydantic_model_class.model_json_schema()
# Optional example field from schema, is not needed for the grammar generation
if "example" in schema:
del schema["example"]
json_schema = json.dumps(schema)
grammar = LlamaGrammar.from_json_schema(json_schema)
stream = in_memory_llm(
prompt,
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
grammar=grammar,
stream=True
)
output_text = ""
for chunk in stream:
result = chunk["choices"][0]
print(result["text"], end='', flush=True)
output_text = output_text + result["text"]
print('\n')
if return_pydantic_object:
model_object = pydantic_model_class.model_validate_json(output_text)
return model_object
else:
json_output = json.loads(output_text)
return json_output
def llm_stream_serverless(prompt,model):
RUNPOD_ENDPOINT_ID = env.get("RUNPOD_ENDPOINT_ID")
RUNPOD_API_KEY = env.get("RUNPOD_API_KEY")
assert RUNPOD_ENDPOINT_ID, "RUNPOD_ENDPOINT_ID environment variable not set"
assert RUNPOD_API_KEY, "RUNPOD_API_KEY environment variable not set"
url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync"
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {RUNPOD_API_KEY}'
}
schema = model.schema()
data = {
'input': {
'schema': json.dumps(schema),
'prompt': prompt
}
}
response = requests.post(url, json=data, headers=headers)
assert response.status_code == 200, f"Unexpected RunPod API status code: {response.status_code} with body: {response.text}"
result = response.json()
print(result)
# TODO: After a 30 second timeout, a job ID is returned in the response instead,
# and the client must poll the job status endpoint to get the result.
output = result['output'].replace("model:mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf\n", "")
# TODO: remove replacement once new version of runpod is deployed
return json.loads(output)
# Global variables to enforce rate limiting
LAST_REQUEST_TIME = None
REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]:
global LAST_REQUEST_TIME
current_time = time()
if LAST_REQUEST_TIME is not None:
elapsed_time = current_time - LAST_REQUEST_TIME
if elapsed_time < REQUEST_INTERVAL:
sleep_time = REQUEST_INTERVAL - elapsed_time
sleep(sleep_time)
print(f"Slept for {sleep_time} seconds to enforce rate limit")
LAST_REQUEST_TIME = time()
MISTRAL_API_URL = env.get("MISTRAL_API_URL", "https://api.mistral.ai/v1/chat/completions")
MISTRAL_API_KEY = env.get("MISTRAL_API_KEY", None)
if not MISTRAL_API_KEY:
raise ValueError("MISTRAL_API_KEY environment variable not set")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {MISTRAL_API_KEY}'
}
data = {
'model': 'mistral-small-latest',
'messages': [
{
'role': 'user',
'response_format': {'type': 'json_object'},
'content': prompt
}
]
}
response = requests.post(MISTRAL_API_URL, headers=headers, json=data)
if response.status_code != 200:
raise ValueError(f"Unexpected Mistral API status code: {response.status_code} with body: {response.text}")
result = response.json()
print(result)
output = result['choices'][0]['message']['content']
if pydantic_model_class:
parsed_result = pydantic_model_class.model_validate_json(output)
print(parsed_result)
# This will raise an exception if the model is invalid,
# TODO: handle exception with retry logic
else:
print("No pydantic model class provided, returning without class validation")
return json.loads(output)
def query_ai_prompt(prompt, replacements, model_class):
prompt = replace_text(prompt, replacements)
if LLM_WORKER == "mistral":
result = llm_stream_mistral_api(prompt, model_class)
if LLM_WORKER == "mistral":
result = llm_stream_mistral_api(prompt, model_class)
if LLM_WORKER == "runpod":
result = llm_stream_serverless(prompt, model_class)
if LLM_WORKER == "http":
result = llm_streaming(prompt, model_class)
if LLM_WORKER == "in_memory":
result = llm_stream_sans_network(prompt, model_class)
log_entry = {
"uuid": str(uuid.uuid4()),
"timestamp": datetime.datetime.utcnow().isoformat(),
"worker": LLM_WORKER,
"prompt_input": prompt,
"prompt_output": result
}
log_to_jsonl('prompt_inputs_and_outputs.jsonl', log_entry)
return result
|