Test-Running / app.py
DevsDoCode's picture
Update app.py
84e772f verified
from flask import Flask, request, jsonify, Response
from functools import wraps
import uuid
import json
from typing import List, Optional
from pydantic import BaseModel, ValidationError
from API_provider import API_Inference
from core_logic import (
check_api_key_validity,
update_request_count,
get_rate_limit_status,
get_subscription_status,
get_available_models,
get_model_info,
)
app = Flask(__name__)
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
stream: Optional[bool] = False
max_tokens: Optional[int] = 4000
temperature: Optional[float] = 0.5
top_p: Optional[float] = 0.95
def get_api_key():
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return None
return auth_header.split(' ')[1]
def requires_api_key(func):
@wraps(func)
def decorated(*args, **kwargs):
api_key = get_api_key()
if not api_key:
return jsonify({'detail': 'Not authenticated'}), 401
kwargs['api_key'] = api_key
return func(*args, **kwargs)
return decorated
@app.route('/')
def index():
return 'Hello, World!'
@app.route('/chat/completions', methods=['POST', 'GET'])
@requires_api_key
def chat_completions(api_key):
try:
# Parse and validate request data
try:
data = request.get_json()
chat_request = ChatCompletionRequest(**data)
except ValidationError as e:
return jsonify({'detail': e.errors()}), 400
# Check API key validity and rate limit
is_valid, error_message = check_api_key_validity(api_key)
if not is_valid:
return jsonify({'detail': error_message}), 401
messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages]
# Get model info
model_info = get_model_info(chat_request.model)
if not model_info:
return jsonify({'detail': 'Invalid model specified'}), 400
# Model mapping
model_mapping = {
"meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"claude-3.5-sonnet": "claude-3-sonnet-20240229",
}
model_name = model_mapping.get(chat_request.model, chat_request.model)
credits_reduction = {
"gpt-4o": 1,
"claude-3-sonnet-20240229": 1,
"gemini-1.5-pro": 1,
"gemini-1-5-flash": 1,
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1,
"o1-mini": 2,
"o1-preview": 3,
}.get(model_name, 0)
if chat_request.stream:
def generate():
try:
for chunk in API_Inference(messages, model=model_name, stream=True,
max_tokens=chat_request.max_tokens,
temperature=chat_request.temperature,
top_p=chat_request.top_p):
data = json.dumps({'choices': [{'delta': {'content': chunk}}]})
yield f"data: {data}\n\n"
yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n"
update_request_count(api_key, credits_reduction)
except Exception as e:
yield f"data: [ERROR] {str(e)}\n\n"
return Response(generate(), mimetype='text/event-stream')
else:
response = API_Inference(messages, model=model_name, stream=False,
max_tokens=chat_request.max_tokens,
temperature=chat_request.temperature,
top_p=chat_request.top_p)
update_request_count(api_key, credits_reduction)
prompt_tokens = sum(len(msg['content'].split()) for msg in messages)
completion_tokens = len(response.split())
total_tokens = prompt_tokens + completion_tokens
return jsonify({
"id": f"chatcmpl-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(uuid.uuid1().time // 1e7),
"model": model_name,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens
},
"credits_used": credits_reduction
})
except Exception as e:
return jsonify({'detail': str(e)}), 500
@app.route('/rate_limit/status', methods=['GET'])
@requires_api_key
def get_rate_limit_status_endpoint(api_key):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
return jsonify({'detail': error_message}), 401
return jsonify(get_rate_limit_status(api_key))
@app.route('/subscription/status', methods=['GET'])
@requires_api_key
def get_subscription_status_endpoint(api_key):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
return jsonify({'detail': error_message}), 401
return jsonify(get_subscription_status(api_key))
@app.route('/models', methods=['GET'])
@requires_api_key
def get_available_models_endpoint(api_key):
is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False)
if not is_valid:
return jsonify({'detail': error_message}), 401
return jsonify({"data": [{"id": model} for model in get_available_models().values()]})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000)