Spaces:
Sleeping
Sleeping
File size: 6,124 Bytes
e628215 c002f09 e628215 cae172a e628215 cae172a e628215 cae172a e628215 cae172a e628215 cae172a e628215 cae172a e628215 cae172a e628215 0c11451 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from flask import Flask, request, jsonify, Response
from functools import wraps
import uuid
import json
from typing import List, Optional
from pydantic import BaseModel, ValidationError
from API_provider import API_Inference
from core_logic import (
check_api_key_validity,
update_request_count,
get_rate_limit_status,
get_subscription_status,
get_available_models,
get_model_info,
)
app = Flask(__name__)
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
max_tokens: Optional[int] = 4000
temperature: Optional[float] = 0.5
top_p: Optional[float] = 0.95
def get_api_key():
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return None
return auth_header.split(' ')[1]
def requires_api_key(func):
@wraps(func)
def decorated(*args, **kwargs):
api_key = get_api_key()
if not api_key:
return jsonify({'detail': 'Not authenticated'}), 401
kwargs['api_key'] = api_key
return func(*args, **kwargs)
return decorated
@app.route('/')
def index():
return 'Hello, World!'
@app.route('/chat/completions', methods=['POST', 'GET'])
@requires_api_key
def chat_completions(api_key):
try:
# Parse and validate request data
try:
data = request.get_json()
chat_request = ChatCompletionRequest(**data)
except ValidationError as e:
return jsonify({'detail': e.errors()}), 400
# Check API key validity and rate limit
is_valid, error_message = check_api_key_validity(api_key)
if not is_valid:
return jsonify({'detail': error_message}), 401
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
# Get model info
model_info = get_model_info(chat_request.model)
if not model_info:
return jsonify({'detail': 'Invalid model specified'}), 400
# Model mapping
model_mapping = {
"meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"claude-3.5-sonnet": "claude-3-sonnet-20240229",
}
model_name = model_mapping.get(chat_request.model, chat_request.model)
credits_reduction = {
"gpt-4o": 1,
"claude-3-sonnet-20240229": 1,
"gemini-1.5-pro": 1,
"gemini-1-5-flash": 1,
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1,
"o1-mini": 2,
"o1-preview": 3,
}.get(model_name, 0)
if chat_request.stream:
def generate():
try:
for chunk in API_Inference(messages, model=model_name, stream=True,
max_tokens=chat_request.max_tokens,
temperature=chat_request.temperature,
top_p=chat_request.top_p):
data = json.dumps({'choices': [{'delta': {'content': chunk}}]})
yield f"data: {data}\n\n"
yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n"
update_request_count(api_key, credits_reduction)
except Exception as e:
yield f"data: [ERROR] {str(e)}\n\n"
return Response(generate(), mimetype='text/event-stream')
else:
response = API_Inference(messages, model=model_name, stream=False,
max_tokens=chat_request.max_tokens,
temperature=chat_request.temperature,
top_p=chat_request.top_p)
update_request_count(api_key, credits_reduction)
prompt_tokens = sum(len(msg['content'].split()) for msg in messages)
completion_tokens = len(response.split())
total_tokens = prompt_tokens + completion_tokens
return jsonify({
"id": f"chatcmpl-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(uuid.uuid1().time // 1e7),
"model": model_name,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
},
"credits_used": credits_reduction
})
except Exception as e:
return jsonify({'detail': str(e)}), 500
@app.route('/rate_limit/status', methods=['GET'])
@requires_api_key
def get_rate_limit_status_endpoint(api_key):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
return jsonify({'detail': error_message}), 401
return jsonify(get_rate_limit_status(api_key))
@app.route('/subscription/status', methods=['GET'])
@requires_api_key
def get_subscription_status_endpoint(api_key):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
return jsonify({'detail': error_message}), 401
return jsonify(get_subscription_status(api_key))
@app.route('/models', methods=['GET'])
@requires_api_key
def get_available_models_endpoint(api_key):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
return jsonify({'detail': error_message}), 401
return jsonify({"data": [{"id": model} for model in get_available_models().values()]})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000) |