Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, Response | |
from functools import wraps | |
import uuid | |
import json | |
from typing import List, Optional | |
from pydantic import BaseModel, ValidationError | |
from API_provider import API_Inference | |
from core_logic import ( | |
check_api_key_validity, | |
update_request_count, | |
get_rate_limit_status, | |
get_subscription_status, | |
get_available_models, | |
get_model_info, | |
) | |
app = Flask(__name__) | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
stream: Optional[bool] = False | |
max_tokens: Optional[int] = 4000 | |
temperature: Optional[float] = 0.5 | |
top_p: Optional[float] = 0.95 | |
def get_api_key(): | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return None | |
return auth_header.split(' ')[1] | |
def requires_api_key(func): | |
def decorated(*args, **kwargs): | |
api_key = get_api_key() | |
if not api_key: | |
return jsonify({'detail': 'Not authenticated'}), 401 | |
kwargs['api_key'] = api_key | |
return func(*args, **kwargs) | |
return decorated | |
def index(): | |
return 'Hello, World!' | |
def chat_completions(api_key): | |
try: | |
# Parse and validate request data | |
try: | |
data = request.get_json() | |
chat_request = ChatCompletionRequest(**data) | |
except ValidationError as e: | |
return jsonify({'detail': e.errors()}), 400 | |
# Check API key validity and rate limit | |
is_valid, error_message = check_api_key_validity(api_key) | |
if not is_valid: | |
return jsonify({'detail': error_message}), 401 | |
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages] | |
# Get model info | |
model_info = get_model_info(chat_request.model) | |
if not model_info: | |
return jsonify({'detail': 'Invalid model specified'}), 400 | |
# Model mapping | |
model_mapping = { | |
"meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", | |
"claude-3.5-sonnet": "claude-3-sonnet-20240229", | |
} | |
model_name = model_mapping.get(chat_request.model, chat_request.model) | |
credits_reduction = { | |
"gpt-4o": 1, | |
"claude-3-sonnet-20240229": 1, | |
"gemini-1.5-pro": 1, | |
"gemini-1-5-flash": 1, | |
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1, | |
"o1-mini": 2, | |
"o1-preview": 3, | |
}.get(model_name, 0) | |
if chat_request.stream: | |
def generate(): | |
try: | |
for chunk in API_Inference(messages, model=model_name, stream=True, | |
max_tokens=chat_request.max_tokens, | |
temperature=chat_request.temperature, | |
top_p=chat_request.top_p): | |
data = json.dumps({'choices': [{'delta': {'content': chunk}}]}) | |
yield f"data: {data}\n\n" | |
yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n" | |
update_request_count(api_key, credits_reduction) | |
except Exception as e: | |
yield f"data: [ERROR] {str(e)}\n\n" | |
return Response(generate(), mimetype='text/event-stream') | |
else: | |
response = API_Inference(messages, model=model_name, stream=False, | |
max_tokens=chat_request.max_tokens, | |
temperature=chat_request.temperature, | |
top_p=chat_request.top_p) | |
update_request_count(api_key, credits_reduction) | |
prompt_tokens = sum(len(msg['content'].split()) for msg in messages) | |
completion_tokens = len(response.split()) | |
total_tokens = prompt_tokens + completion_tokens | |
return jsonify({ | |
"id": f"chatcmpl-{str(uuid.uuid4())}", | |
"object": "chat.completion", | |
"created": int(uuid.uuid1().time // 1e7), | |
"model": model_name, | |
"choices": [ | |
{ | |
"index": 0, | |
"message": { | |
"role": "assistant", | |
"content": response | |
}, | |
"finish_reason": "stop" | |
} | |
], | |
"usage": { | |
"prompt_tokens": prompt_tokens, | |
"completion_tokens": completion_tokens, | |
"total_tokens": total_tokens | |
}, | |
"credits_used": credits_reduction | |
}) | |
except Exception as e: | |
return jsonify({'detail': str(e)}), 500 | |
def get_rate_limit_status_endpoint(api_key): | |
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) | |
if not is_valid: | |
return jsonify({'detail': error_message}), 401 | |
return jsonify(get_rate_limit_status(api_key)) | |
def get_subscription_status_endpoint(api_key): | |
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) | |
if not is_valid: | |
return jsonify({'detail': error_message}), 401 | |
return jsonify(get_subscription_status(api_key)) | |
def get_available_models_endpoint(api_key): | |
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) | |
if not is_valid: | |
return jsonify({'detail': error_message}), 401 | |
return jsonify({"data": [{"id": model} for model in get_available_models().values()]}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8000) |