Upload 2 files
Browse files
app.py
CHANGED
@@ -4,12 +4,10 @@ import google.generativeai as genai
|
|
4 |
import json
|
5 |
from datetime import datetime
|
6 |
import os
|
7 |
-
import base64
|
8 |
from termcolor import colored
|
9 |
import logging
|
10 |
-
|
11 |
-
|
12 |
-
import re
|
13 |
os.environ['TZ'] = 'Asia/Shanghai'
|
14 |
app = Flask(__name__)
|
15 |
if 'API_KEYS' not in os.environ:
|
@@ -63,78 +61,31 @@ current_api_key = key_manager.get_available_key()
|
|
63 |
logger.info(f"Current API key: {current_api_key}")
|
64 |
# 模型列表
|
65 |
GEMINI_MODELS = [
|
66 |
-
{"id": "gemini-pro"
|
67 |
-
{"id": "gemini-pro-vision"
|
68 |
-
{"id": "gemini-1.0-pro"
|
69 |
-
{"id": "gemini-1.0-pro-vision"
|
70 |
-
{"id": "gemini-1.5-pro-002"
|
71 |
-
{"id": "gemini-exp-1114"
|
72 |
-
{"id": "gemini-exp-1121"
|
73 |
-
{"id": "gemini-exp-1206"
|
74 |
-
{"id": "gemini-2.0-flash-exp"
|
75 |
-
{"id": "gemini-2.0-exp"
|
76 |
-
{"id": "gemini-2.0-pro-exp"
|
77 |
]
|
78 |
-
def authenticate_request(request):
|
79 |
-
auth_header = request.headers.get('Authorization')
|
80 |
-
hf_api_key = os.environ.get('HF_API_KEY').split(',')
|
81 |
-
|
82 |
-
if not auth_header:
|
83 |
-
return False, jsonify({'error': 'Authorization header is missing'}), 401
|
84 |
-
|
85 |
-
try:
|
86 |
-
auth_type, api_key = auth_header.split(' ', 1)
|
87 |
-
except ValueError:
|
88 |
-
return False, jsonify({'error': 'Invalid Authorization header format'}), 401
|
89 |
-
|
90 |
-
if auth_type.lower() != 'bearer':
|
91 |
-
return False, jsonify({'error': 'Authorization type must be Bearer'}), 401
|
92 |
-
|
93 |
-
if api_key not in hf_api_key:
|
94 |
-
return False, jsonify({'error': 'Unauthorized'}), 401
|
95 |
-
|
96 |
-
return True, None, None
|
97 |
-
|
98 |
-
def sanitize_request_data(request_data):
|
99 |
-
"""
|
100 |
-
从请求数据中删除base64编码的数据。
|
101 |
-
|
102 |
-
Args:
|
103 |
-
request_data: 包含可能存在base64数据的字典。
|
104 |
-
|
105 |
-
Returns:
|
106 |
-
清理后的字典,其中base64数据被替换为"[Base64 Data Omitted]"。
|
107 |
-
"""
|
108 |
-
|
109 |
-
|
110 |
-
def replace_base64(match):
|
111 |
-
# 替换base64数据为提示信息
|
112 |
-
return '"[Base64 Data Omitted]"'
|
113 |
-
|
114 |
-
|
115 |
-
request_data_str = json.dumps(request_data)
|
116 |
-
|
117 |
-
# 使用正则表达式匹配base64数据,并替换为提示信息
|
118 |
-
sanitized_request_data_str = re.sub(
|
119 |
-
r'"(data:[^;]+;base64,)[^"]+"',
|
120 |
-
replace_base64,
|
121 |
-
request_data_str
|
122 |
-
)
|
123 |
-
|
124 |
-
return json.loads(sanitized_request_data_str)
|
125 |
|
126 |
@app.route('/hf/v1/chat/completions', methods=['POST'])
|
127 |
def chat_completions():
|
128 |
global current_api_key
|
129 |
-
is_authenticated, auth_error, status_code = authenticate_request(request)
|
130 |
if not is_authenticated:
|
131 |
return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
|
132 |
try:
|
133 |
request_data = request.get_json()
|
134 |
-
r_data = sanitize_request_data(request_data)
|
135 |
logger.info(json.dumps(r_data, indent=4, ensure_ascii=False), extra={"color": "green"})
|
136 |
messages = request_data.get('messages', [])
|
137 |
-
model = request_data.get('model', 'gemini-exp-1206')
|
138 |
temperature = request_data.get('temperature', 1)
|
139 |
max_tokens = request_data.get('max_tokens', 8192)
|
140 |
stream = request_data.get('stream', False)
|
@@ -142,67 +93,12 @@ def chat_completions():
|
|
142 |
logger.info(colored(f"\n{model} [r] -> {current_api_key[:11]}...", 'yellow'))
|
143 |
|
144 |
# 将 OpenAI 格式的消息转换为 Gemini 格式
|
145 |
-
gemini_history =
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
if role == 'system':
|
152 |
-
gemini_history.append({"role": "user", "parts": [content]})
|
153 |
-
elif role == 'user':
|
154 |
-
gemini_history.append({"role": "user", "parts": [content]})
|
155 |
-
elif role == 'assistant':
|
156 |
-
gemini_history.append({"role": "model", "parts": [content]})
|
157 |
-
elif isinstance(content, list): # 图文
|
158 |
-
parts = []
|
159 |
-
for item in content:
|
160 |
-
if item.get('type') == 'text':
|
161 |
-
parts.append(item.get('text'))
|
162 |
-
elif item.get('type') == 'image_url':
|
163 |
-
image_data = item.get('image_url', {}).get('url', '')
|
164 |
-
if image_data.startswith('data:image/'): # 修改判断条件
|
165 |
-
try:
|
166 |
-
# 提取 base64 编码和图片类型
|
167 |
-
image_type = image_data.split(';')[0].split('/')[1].upper() # 提取图片类型并转为大写
|
168 |
-
base64_image = image_data.split(';base64,')[1]
|
169 |
-
|
170 |
-
image = Image.open(BytesIO(base64.b64decode(base64_image)))
|
171 |
-
|
172 |
-
# 将图片转换为 RGB 模式
|
173 |
-
if image.mode != 'RGB':
|
174 |
-
image = image.convert('RGB')
|
175 |
-
|
176 |
-
# 压缩图像
|
177 |
-
if image.width > 2048 or image.height > 2048:
|
178 |
-
image.thumbnail((2048, 2048))
|
179 |
-
|
180 |
-
output_buffer = BytesIO()
|
181 |
-
image.save(output_buffer, format=image_type) # 使用原始图片类型保存
|
182 |
-
output_buffer.seek(0)
|
183 |
-
parts.append(image) # 直接添加 image 对象
|
184 |
-
except Exception as e:
|
185 |
-
logger.error(f"Error processing image: {e}")
|
186 |
-
return jsonify({'error': 'Invalid image data'}), 400
|
187 |
-
else:
|
188 |
-
return jsonify({'error': 'Invalid image URL format'}), 400
|
189 |
-
|
190 |
-
|
191 |
-
# 根据 role 添加到 gemini_history
|
192 |
-
if role in ['user', 'system']:
|
193 |
-
gemini_history.append({"role": "user", "parts": parts}) # 保持 parts 的原有顺序
|
194 |
-
elif role == 'assistant':
|
195 |
-
gemini_history.append({"role": "model", "parts": parts})
|
196 |
-
else:
|
197 |
-
return jsonify({'error': f'Invalid role: {role}'}), 400
|
198 |
-
|
199 |
-
# 用户最后一条消息
|
200 |
-
if gemini_history:
|
201 |
-
user_message = gemini_history[-1]
|
202 |
-
gemini_history = gemini_history[:-1] # 历史记录不包含最后一条消息
|
203 |
-
else:
|
204 |
-
user_message = {"role": "user", "parts": [""]}
|
205 |
-
|
206 |
genai.configure(api_key=current_api_key)
|
207 |
|
208 |
generation_config = {
|
@@ -305,10 +201,10 @@ def chat_completions():
|
|
305 |
}
|
306 |
logger.info(colored(f"Generation Success", 'green'))
|
307 |
return jsonify(response_data)
|
308 |
-
|
309 |
except Exception as e:
|
310 |
logger.error(f"Error in chat completions: {str(e)}")
|
311 |
-
|
312 |
return jsonify({
|
313 |
'error': {
|
314 |
'message': str(e),
|
@@ -319,10 +215,11 @@ def chat_completions():
|
|
319 |
current_api_key = key_manager.get_available_key()
|
320 |
logger.info(colored(f"API KEY Switched -> {current_api_key[:11]}...", 'aqua'))
|
321 |
|
322 |
-
|
323 |
-
|
324 |
@app.route('/hf/v1/models', methods=['GET'])
|
325 |
def list_models():
|
|
|
|
|
|
|
326 |
response = {"object": "list", "data": GEMINI_MODELS}
|
327 |
return jsonify(response)
|
328 |
|
|
|
4 |
import json
|
5 |
from datetime import datetime
|
6 |
import os
|
|
|
7 |
from termcolor import colored
|
8 |
import logging
|
9 |
+
import func
|
10 |
+
|
|
|
11 |
os.environ['TZ'] = 'Asia/Shanghai'
|
12 |
app = Flask(__name__)
|
13 |
if 'API_KEYS' not in os.environ:
|
|
|
61 |
logger.info(f"Current API key: {current_api_key}")
|
62 |
# 模型列表
|
63 |
GEMINI_MODELS = [
|
64 |
+
{"id": "gemini-pro"},
|
65 |
+
{"id": "gemini-pro-vision"},
|
66 |
+
{"id": "gemini-1.0-pro"},
|
67 |
+
{"id": "gemini-1.0-pro-vision"},
|
68 |
+
{"id": "gemini-1.5-pro-002"},
|
69 |
+
{"id": "gemini-exp-1114"},
|
70 |
+
{"id": "gemini-exp-1121"},
|
71 |
+
{"id": "gemini-exp-1206"},
|
72 |
+
{"id": "gemini-2.0-flash-exp"},
|
73 |
+
{"id": "gemini-2.0-exp"},
|
74 |
+
{"id": "gemini-2.0-pro-exp"},
|
75 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
@app.route('/hf/v1/chat/completions', methods=['POST'])
|
78 |
def chat_completions():
|
79 |
global current_api_key
|
80 |
+
is_authenticated, auth_error, status_code = func.authenticate_request(request)
|
81 |
if not is_authenticated:
|
82 |
return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
|
83 |
try:
|
84 |
request_data = request.get_json()
|
85 |
+
r_data = func.sanitize_request_data(request_data)
|
86 |
logger.info(json.dumps(r_data, indent=4, ensure_ascii=False), extra={"color": "green"})
|
87 |
messages = request_data.get('messages', [])
|
88 |
+
model = request_data.get('model', 'gemini-exp-1206')
|
89 |
temperature = request_data.get('temperature', 1)
|
90 |
max_tokens = request_data.get('max_tokens', 8192)
|
91 |
stream = request_data.get('stream', False)
|
|
|
93 |
logger.info(colored(f"\n{model} [r] -> {current_api_key[:11]}...", 'yellow'))
|
94 |
|
95 |
# 将 OpenAI 格式的消息转换为 Gemini 格式
|
96 |
+
gemini_history, user_message, error_response = func.process_messages_for_gemini(messages)
|
97 |
+
|
98 |
+
if error_response:
|
99 |
+
# 处理错误
|
100 |
+
print(error_response)
|
101 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
genai.configure(api_key=current_api_key)
|
103 |
|
104 |
generation_config = {
|
|
|
201 |
}
|
202 |
logger.info(colored(f"Generation Success", 'green'))
|
203 |
return jsonify(response_data)
|
204 |
+
|
205 |
except Exception as e:
|
206 |
logger.error(f"Error in chat completions: {str(e)}")
|
207 |
+
|
208 |
return jsonify({
|
209 |
'error': {
|
210 |
'message': str(e),
|
|
|
215 |
current_api_key = key_manager.get_available_key()
|
216 |
logger.info(colored(f"API KEY Switched -> {current_api_key[:11]}...", 'aqua'))
|
217 |
|
|
|
|
|
218 |
@app.route('/hf/v1/models', methods=['GET'])
|
219 |
def list_models():
|
220 |
+
is_authenticated, auth_error, status_code = func.authenticate_request(request)
|
221 |
+
if not is_authenticated:
|
222 |
+
return auth_error if auth_error else jsonify({'error': 'Unauthorized'}), status_code if status_code else 401
|
223 |
response = {"object": "list", "data": GEMINI_MODELS}
|
224 |
return jsonify(response)
|
225 |
|
func.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import base64
|
3 |
+
from PIL import Image
|
4 |
+
from flask import jsonify
|
5 |
+
import logging
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
import os
|
9 |
+
GEMINI_MODELS = [
|
10 |
+
{"id": "gemini-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
|
11 |
+
{"id": "gemini-pro-vision", "object": "model", "created": 1700000000, "owned_by": "google"},
|
12 |
+
{"id": "gemini-1.0-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
|
13 |
+
{"id": "gemini-1.0-pro-vision", "object": "model", "created": 1700000000, "owned_by": "google"},
|
14 |
+
{"id": "gemini-1.5-pro-002", "object": "model", "created": 1700000000, "owned_by": "google"},
|
15 |
+
{"id": "gemini-exp-1114", "object": "model", "created": 1700000000, "owned_by": "google"},
|
16 |
+
{"id": "gemini-exp-1121", "object": "model", "created": 1700000000, "owned_by": "google"},
|
17 |
+
{"id": "gemini-exp-1206", "object": "model", "created": 1700000000, "owned_by": "google"},
|
18 |
+
{"id": "gemini-2.0-flash-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
|
19 |
+
{"id": "gemini-2.0-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
|
20 |
+
{"id": "gemini-2.0-pro-exp", "object": "model", "created": 1700000000, "owned_by": "google"},
|
21 |
+
]
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
def authenticate_request(request):
|
26 |
+
auth_header = request.headers.get('Authorization')
|
27 |
+
hf_api_key = os.environ.get('HF_API_KEY').split(',')
|
28 |
+
|
29 |
+
if not auth_header:
|
30 |
+
return False, jsonify({'error': 'Authorization header is missing'}), 401
|
31 |
+
|
32 |
+
try:
|
33 |
+
auth_type, api_key = auth_header.split(' ', 1)
|
34 |
+
except ValueError:
|
35 |
+
return False, jsonify({'error': 'Invalid Authorization header format'}), 401
|
36 |
+
|
37 |
+
if auth_type.lower() != 'bearer':
|
38 |
+
return False, jsonify({'error': 'Authorization type must be Bearer'}), 401
|
39 |
+
|
40 |
+
if api_key not in hf_api_key:
|
41 |
+
return False, jsonify({'error': 'Unauthorized'}), 401
|
42 |
+
|
43 |
+
return True, None, None
|
44 |
+
|
45 |
+
def sanitize_request_data(request_data):
|
46 |
+
"""
|
47 |
+
从请求数据中删除base64编码的数据。
|
48 |
+
|
49 |
+
Args:
|
50 |
+
request_data: 包含可能存在base64数据的字典。
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
清理后的字典,其中base64数据被替换为"[Base64 Data Omitted]"。
|
54 |
+
"""
|
55 |
+
|
56 |
+
|
57 |
+
def replace_base64(match):
|
58 |
+
# 替换base64数据为提示信息
|
59 |
+
return '"[Base64 Data Omitted]"'
|
60 |
+
|
61 |
+
|
62 |
+
request_data_str = json.dumps(request_data)
|
63 |
+
|
64 |
+
# 使用正则表达式匹配base64数据,并替换为提示信息
|
65 |
+
sanitized_request_data_str = re.sub(
|
66 |
+
r'"(data:[^;]+;base64,)[^"]+"',
|
67 |
+
replace_base64,
|
68 |
+
request_data_str
|
69 |
+
)
|
70 |
+
|
71 |
+
return json.loads(sanitized_request_data_str)
|
72 |
+
|
73 |
+
def process_messages_for_gemini(messages):
|
74 |
+
"""
|
75 |
+
Processes a list of messages to construct the history format expected by Gemini,
|
76 |
+
handling both text and image content.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
messages: A list of message dictionaries. Each dictionary should have a 'role' key
|
80 |
+
('system', 'user', or 'assistant') and a 'content' key.
|
81 |
+
The 'content' can be either a string (for text messages) or a list
|
82 |
+
(for multi-modal messages).
|
83 |
+
For multi-modal messages, the list should contain dictionaries with 'type'
|
84 |
+
('text' or 'image_url') and corresponding 'text' or 'image_url' keys.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
A tuple containing:
|
88 |
+
- gemini_history: A list of dictionaries formatted for the Gemini API, excluding the last user message.
|
89 |
+
- user_message: The last user message from the input, or a default empty user message if the input is empty.
|
90 |
+
- error_response: A Flask error response (jsonify, status_code) if an error occurred, otherwise None.
|
91 |
+
"""
|
92 |
+
gemini_history = []
|
93 |
+
for message in messages:
|
94 |
+
role = message.get('role')
|
95 |
+
content = message.get('content')
|
96 |
+
|
97 |
+
if isinstance(content, str): # 纯文本
|
98 |
+
if role == 'system':
|
99 |
+
gemini_history.append({"role": "user", "parts": [content]})
|
100 |
+
elif role == 'user':
|
101 |
+
gemini_history.append({"role": "user", "parts": [content]})
|
102 |
+
elif role == 'assistant':
|
103 |
+
gemini_history.append({"role": "model", "parts": [content]})
|
104 |
+
elif isinstance(content, list): # 图文
|
105 |
+
parts = []
|
106 |
+
for item in content:
|
107 |
+
if item.get('type') == 'text':
|
108 |
+
parts.append(item.get('text'))
|
109 |
+
elif item.get('type') == 'image_url':
|
110 |
+
image_data = item.get('image_url', {}).get('url', '')
|
111 |
+
if image_data.startswith('data:image/'):
|
112 |
+
try:
|
113 |
+
# 提取 base64 编码和图片类型
|
114 |
+
image_type = image_data.split(';')[0].split('/')[1].upper() # 提取图片类型并转为大写
|
115 |
+
base64_image = image_data.split(';base64,')[1]
|
116 |
+
|
117 |
+
image = Image.open(BytesIO(base64.b64decode(base64_image)))
|
118 |
+
|
119 |
+
# 将图片转换为 RGB 模式
|
120 |
+
if image.mode != 'RGB':
|
121 |
+
image = image.convert('RGB')
|
122 |
+
|
123 |
+
# 压缩图像
|
124 |
+
if image.width > 2048 or image.height > 2048:
|
125 |
+
image.thumbnail((2048, 2048))
|
126 |
+
|
127 |
+
output_buffer = BytesIO()
|
128 |
+
image.save(output_buffer, format=image_type) # 使用原始图片类型保存
|
129 |
+
output_buffer.seek(0)
|
130 |
+
parts.append(image)
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Error processing image: {e}")
|
133 |
+
return [], None, (jsonify({'error': 'Invalid image data'}), 400)
|
134 |
+
else:
|
135 |
+
return [], None, (jsonify({'error': 'Invalid image URL format'}), 400)
|
136 |
+
|
137 |
+
# 根据 role 添加到 gemini_history
|
138 |
+
if role in ['user', 'system']:
|
139 |
+
gemini_history.append({"role": "user", "parts": parts})
|
140 |
+
elif role == 'assistant':
|
141 |
+
gemini_history.append({"role": "model", "parts": parts})
|
142 |
+
else:
|
143 |
+
return [], None, (jsonify({'error': f'Invalid role: {role}'}), 400)
|
144 |
+
|
145 |
+
# 用户最后一条消息
|
146 |
+
if gemini_history:
|
147 |
+
user_message = gemini_history[-1]
|
148 |
+
gemini_history = gemini_history[:-1] # 历史记录不包含最后一条消息
|
149 |
+
else:
|
150 |
+
user_message = {"role": "user", "parts": [""]}
|
151 |
+
|
152 |
+
return gemini_history, user_message, None
|