Spaces:
Running
Running
from flask import Flask, request, jsonify, Response, stream_with_context | |
import google.generativeai as genai | |
import json | |
from datetime import datetime | |
import os | |
import logging | |
import func | |
os.environ['TZ'] = 'Asia/Shanghai' | |
app = Flask(__name__) | |
app.secret_key = os.urandom(24) | |
PASSWORD = os.environ['password'] | |
formatter = logging.Formatter('%(message)s') | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE" | |
}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE" | |
}, | |
] | |
class APIKeyManager: | |
def __init__(self): | |
self.api_keys = os.environ.get('KeyArray').split(',') | |
self.current_index = 0 | |
def get_available_key(self): | |
if self.current_index >= len(self.api_keys): | |
self.current_index = 0 | |
current_key = self.api_keys[self.current_index] | |
self.current_index += 1 | |
return current_key | |
key_manager = APIKeyManager() | |
current_api_key = key_manager.get_available_key() | |
logger.info(f"Current API key: {current_api_key}") | |
GEMINI_MODELS = [ | |
{"id": "gemini-pro"}, | |
{"id": "gemini-pro-vision"}, | |
{"id": "gemini-1.0-pro"}, | |
{"id": "gemini-1.0-pro-vision"}, | |
{"id": "gemini-1.5-pro-002"}, | |
{"id": "gemini-exp-1114"}, | |
{"id": "gemini-exp-1121"}, | |
{"id": "gemini-exp-1206"}, | |
{"id": "gemini-2.0-flash-exp"}, | |
{"id": "gemini-2.0-exp"}, | |
{"id": "gemini-2.0-pro-exp"}, | |
] | |
def chat_completions(): | |
global current_api_key | |
is_authenticated, auth_error, status_code = func.authenticate_request(PASSWORD, request) | |
if not is_authenticated: | |
return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401 | |
try: | |
request_data = request.get_json() | |
messages = request_data.get('messages', []) | |
model = request_data.get('model', 'gemini-exp-1206') | |
temperature = request_data.get('temperature', 1) | |
max_tokens = request_data.get('max_tokens', 8192) | |
stream = request_data.get('stream', False) | |
logger.info(f"\n{model} [r] -> {current_api_key[:11]}...") | |
gemini_history, user_message, error_response = func.process_messages_for_gemini(messages) | |
if error_response: | |
print(error_response) | |
genai.configure(api_key=current_api_key) | |
generation_config = { | |
"temperature": temperature, | |
"max_output_tokens": max_tokens | |
} | |
gen_model = genai.GenerativeModel( | |
model_name=model, | |
generation_config=generation_config, | |
safety_settings=safety_settings | |
) | |
if stream: | |
if gemini_history: | |
chat_session = gen_model.start_chat(history=gemini_history) | |
response = chat_session.send_message(user_message, stream=True) | |
else: | |
response = gen_model.generate_content(user_message, stream=True) | |
def generate(): | |
try: | |
for chunk in response: | |
if chunk.text: | |
data = { | |
'choices': [ | |
{ | |
'delta': { | |
'content': chunk.text | |
}, | |
'finish_reason': None, | |
'index': 0 | |
} | |
], | |
'object': 'chat.completion.chunk' | |
} | |
yield f"data: {json.dumps(data)}\n\n" | |
data = { | |
'choices': [ | |
{ | |
'delta': {}, | |
'finish_reason': 'stop', | |
'index': 0 | |
} | |
], | |
'object': 'chat.completion.chunk' | |
} | |
yield f"data: {json.dumps(data)}\n\n" | |
except Exception as e: | |
logger.error(f"Error during streaming: {str(e)}") | |
data = { | |
'error': { | |
'message': str(e), | |
'type': 'internal_server_error' | |
} | |
} | |
yield f"data: {json.dumps(data)}\n\n" | |
return Response(stream_with_context(generate()), mimetype='text/event-stream') | |
else: | |
if gemini_history: | |
chat_session = gen_model.start_chat(history=gemini_history) | |
response = chat_session.send_message(user_message) | |
else: | |
response = gen_model.generate_content(user_message) | |
try: | |
text_content = response.candidates[0].content.parts[0].text | |
except (AttributeError, IndexError, TypeError) as e: | |
logger.error(f"Error getting text content: {str(e)}") | |
text_content = "Error: Unable to get text content." | |
response_data = { | |
'id': 'chatcmpl-xxxxxxxxxxxx', | |
'object': 'chat.completion', | |
'created': int(datetime.now().timestamp()), | |
'model': model, | |
'choices': [{ | |
'index': 0, | |
'message': { | |
'role': 'assistant', | |
'content': text_content | |
}, | |
'finish_reason': 'stop' | |
}], | |
'usage':{ | |
'prompt_tokens': 0, | |
'completion_tokens': 0, | |
'total_tokens': 0 | |
} | |
} | |
logger.info(f"Generation Success") | |
return jsonify(response_data) | |
except Exception as e: | |
logger.error(f"Error in chat completions: {str(e)}") | |
return jsonify({ | |
'error': { | |
'message': str(e), | |
'type': 'invalid_request_error' | |
} | |
}), 500 | |
finally: | |
current_api_key = key_manager.get_available_key() | |
logger.info(f"API KEY Switched -> {current_api_key[:11]}...") | |
def list_models(): | |
is_authenticated, auth_error, status_code = func.authenticate_request(PASSWORD, request) | |
if not is_authenticated: | |
return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401 | |
response = {"object": "list", "data": GEMINI_MODELS} | |
return jsonify(response) | |
if __name__ == '__main__': | |
app.run(debug=True, host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) |