import datetime import json from time import time, sleep from os import environ as env from typing import Any, Dict, Union from data import log_to_jsonl import requests from huggingface_hub import hf_hub_download # There are 4 ways to use a LLM model currently used: # 1. Use the HTTP server (USE_HTTP_SERVER=True), this is good for development # when you want to change the logic of the translator without restarting the server. # 2. Load the model into memory # When using the HTTP server, it must be ran separately. See the README for instructions. # The llama_cpp Python HTTP server communicates with the AI model, similar # to the OpenAI API but adds a unique "grammar" parameter. # The real OpenAI API has other ways to set the output format. # It's possible to switch to another LLM API by changing the llm_streaming function. # 3. Use the RunPod API, which is a paid service with severless GPU functions. # See serverless.md for more information. # 4. Use the Mistral API, which is a paid services. URL = "http://localhost:5834/v1/chat/completions" in_memory_llm = None worker_options = ["runpod", "http", "in_memory", "mistral"] LLM_WORKER = env.get("LLM_WORKER", "mistral") if LLM_WORKER not in worker_options: raise ValueError(f"Invalid worker: {LLM_WORKER}") N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available CONTEXT_SIZE = int(env.get("CONTEXT_SIZE", 2048)) LLM_MODEL_PATH = env.get("LLM_MODEL_PATH", None) MAX_TOKENS = int(env.get("MAX_TOKENS", 1000)) TEMPERATURE = float(env.get("TEMPERATURE", 0.3)) performing_local_inference = (LLM_WORKER == "in_memory" or LLM_WORKER == "http") if LLM_MODEL_PATH and len(LLM_MODEL_PATH) > 0: print(f"Using local model from {LLM_MODEL_PATH}") if performing_local_inference and not LLM_MODEL_PATH: print("No local LLM_MODEL_PATH environment variable set. We need a model, downloading model from HuggingFace Hub") LLM_MODEL_PATH =hf_hub_download( repo_id=env.get("REPO_ID", "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF"), filename=env.get("MODEL_FILE", "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"), ) print(f"Model downloaded to {LLM_MODEL_PATH}") if LLM_WORKER == "http" or LLM_WORKER == "in_memory": from llama_cpp import Llama, LlamaGrammar, json_schema_to_gbnf if in_memory_llm is None and LLM_WORKER == "in_memory": print("Loading model into memory. If you didn't want this, set the USE_HTTP_SERVER environment variable to 'true'.") in_memory_llm = Llama(model_path=LLM_MODEL_PATH, n_ctx=CONTEXT_SIZE, n_gpu_layers=N_GPU_LAYERS, verbose=True) def llm_streaming( prompt: str, pydantic_model_class, return_pydantic_object=False ) -> Union[str, Dict[str, Any]]: schema = pydantic_model_class.model_json_schema() # Optional example field from schema, is not needed for the grammar generation if "example" in schema: del schema["example"] json_schema = json.dumps(schema) grammar = json_schema_to_gbnf(json_schema) payload = { "stream": True, "max_tokens": MAX_TOKENS, "grammar": grammar, "temperature": TEMPERATURE, "messages": [{"role": "user", "content": prompt}], } headers = { "Content-Type": "application/json", } response = requests.post( URL, headers=headers, json=payload, stream=True, ) output_text = "" for chunk in response.iter_lines(): if chunk: chunk = chunk.decode("utf-8") if chunk.startswith("data: "): chunk = chunk.split("data: ")[1] if chunk.strip() == "[DONE]": break chunk = json.loads(chunk) new_token = chunk.get("choices")[0].get("delta").get("content") if new_token: output_text = output_text + new_token print(new_token, sep="", end="", flush=True) print('\n') if return_pydantic_object: model_object = pydantic_model_class.model_validate_json(output_text) return model_object else: json_output = json.loads(output_text) return json_output def replace_text(template: str, replacements: dict) -> str: for key, value in replacements.items(): template = template.replace(f"{{{key}}}", value) return template def calculate_overall_score(faithfulness, spiciness): baseline_weight = 0.8 overall = faithfulness + (1 - baseline_weight) * spiciness * faithfulness return overall def llm_stream_sans_network( prompt: str, pydantic_model_class, return_pydantic_object=False ) -> Union[str, Dict[str, Any]]: schema = pydantic_model_class.model_json_schema() # Optional example field from schema, is not needed for the grammar generation if "example" in schema: del schema["example"] json_schema = json.dumps(schema) grammar = LlamaGrammar.from_json_schema(json_schema) stream = in_memory_llm( prompt, max_tokens=MAX_TOKENS, temperature=TEMPERATURE, grammar=grammar, stream=True ) output_text = "" for chunk in stream: result = chunk["choices"][0] print(result["text"], end='', flush=True) output_text = output_text + result["text"] print('\n') if return_pydantic_object: model_object = pydantic_model_class.model_validate_json(output_text) return model_object else: json_output = json.loads(output_text) return json_output def llm_stream_serverless(prompt,model): RUNPOD_ENDPOINT_ID = env.get("RUNPOD_ENDPOINT_ID") RUNPOD_API_KEY = env.get("RUNPOD_API_KEY") assert RUNPOD_ENDPOINT_ID, "RUNPOD_ENDPOINT_ID environment variable not set" assert RUNPOD_API_KEY, "RUNPOD_API_KEY environment variable not set" url = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/runsync" headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {RUNPOD_API_KEY}' } schema = model.schema() data = { 'input': { 'schema': json.dumps(schema), 'prompt': prompt } } response = requests.post(url, json=data, headers=headers) assert response.status_code == 200, f"Unexpected RunPod API status code: {response.status_code} with body: {response.text}" result = response.json() print(result) # TODO: After a 30 second timeout, a job ID is returned in the response instead, # and the client must poll the job status endpoint to get the result. output = result['output'].replace("model:mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf\n", "") # TODO: remove replacement once new version of runpod is deployed return json.loads(output) # Global variables to enforce rate limiting LAST_REQUEST_TIME = None REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]: global LAST_REQUEST_TIME current_time = time() if LAST_REQUEST_TIME is not None: elapsed_time = current_time - LAST_REQUEST_TIME if elapsed_time < REQUEST_INTERVAL: sleep_time = REQUEST_INTERVAL - elapsed_time sleep(sleep_time) print(f"Slept for {sleep_time} seconds to enforce rate limit") LAST_REQUEST_TIME = time() MISTRAL_API_URL = env.get("MISTRAL_API_URL", "https://api.mistral.ai/v1/chat/completions") MISTRAL_API_KEY = env.get("MISTRAL_API_KEY", None) if not MISTRAL_API_KEY: raise ValueError("MISTRAL_API_KEY environment variable not set") headers = { 'Content-Type': 'application/json', 'Accept': 'application/json', 'Authorization': f'Bearer {MISTRAL_API_KEY}' } data = { 'model': 'mistral-small-latest', 'messages': [ { 'role': 'user', 'response_format': {'type': 'json_object'}, 'content': prompt } ] } response = requests.post(MISTRAL_API_URL, headers=headers, json=data) if response.status_code != 200: raise ValueError(f"Unexpected Mistral API status code: {response.status_code} with body: {response.text}") result = response.json() print(result) output = result['choices'][0]['message']['content'] if pydantic_model_class: parsed_result = pydantic_model_class.model_validate_json(output) print(parsed_result) # This will raise an exception if the model is invalid, # TODO: handle exception with retry logic else: print("No pydantic model class provided, returning without class validation") return json.loads(output) def query_ai_prompt(prompt, replacements, model_class): prompt = replace_text(prompt, replacements) if LLM_WORKER == "mistral": result = llm_stream_mistral_api(prompt, model_class) if LLM_WORKER == "mistral": result = llm_stream_mistral_api(prompt, model_class) if LLM_WORKER == "runpod": result = llm_stream_serverless(prompt, model_class) if LLM_WORKER == "http": result = llm_streaming(prompt, model_class) if LLM_WORKER == "in_memory": result = llm_stream_sans_network(prompt, model_class) log_entry = { "uuid": str(uuid.uuid4()), "timestamp": datetime.datetime.utcnow().isoformat(), "worker": LLM_WORKER, "prompt_input": prompt, "prompt_output": result } log_to_jsonl('prompt_inputs_and_outputs.jsonl', log_entry) return result