Heck2API / app.py
aithink's picture
Upload app.py
165a580 verified
from flask import Flask, request, Response, json
import requests
from uuid import uuid4
import time
import os
from flask_cors import CORS
app = Flask(__name__)
CORS(app) # 启用CORS支持
# 从环境变量获取 API Key
API_KEY = os.environ.get('API_KEY')
if not API_KEY:
raise ValueError("API_KEY environment variable is required")
MODEL_MAPPING = {
"deepseek": "deepseek/deepseek-chat",
"gpt-4o-mini": "openai/gpt-4o-mini",
"gemini-flash-1.5": "google/gemini-flash-1.5",
"deepseek-reasoner": "deepseek-reasoner",
"minimax-01": "minimax/minimax-01"
}
def verify_api_key():
auth_header = request.headers.get('Authorization')
if not auth_header:
return False
try:
# 支持 Bearer token 格式
if auth_header.startswith('Bearer '):
token = auth_header.split(' ')[1]
else:
token = auth_header
return token == API_KEY
except:
return False
def make_heck_request(question, session_id, messages, actual_model):
previous_question = previous_answer = None
if len(messages) >= 2:
for i in range(len(messages)-2, -1, -1):
if messages[i]["role"] == "user":
previous_question = messages[i]["content"]
if i+1 < len(messages) and messages[i+1]["role"] == "assistant":
previous_answer = messages[i+1]["content"]
break
payload = {
"model": actual_model,
"question": question,
"language": "Chinese",
"sessionId": session_id,
"previousQuestion": previous_question,
"previousAnswer": previous_answer
}
headers = {
"Content-Type": "application/json",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
}
return requests.post(
"https://gateway.aiapilab.com/api/ha/v1/chat",
json=payload,
headers=headers,
stream=True
)
def stream_response(question, session_id, messages, request_model, actual_model):
resp = make_heck_request(question, session_id, messages, actual_model)
is_answering = False
for line in resp.iter_lines():
if line:
line = line.decode('utf-8')
if not line.startswith('data: '):
continue
content = line[6:].strip()
if content == "[ANSWER_START]":
is_answering = True
chunk = {
"id": session_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_model,
"choices": [{
"index": 0,
"delta": {"role": "assistant"},
}]
}
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
continue
if content == "[ANSWER_DONE]":
chunk = {
"id": session_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
break
if is_answering and content and not content.startswith("[RELATE_Q"):
chunk = {
"id": session_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_model,
"choices": [{
"index": 0,
"delta": {"content": content},
}]
}
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
def normal_response(question, session_id, messages, request_model, actual_model):
resp = make_heck_request(question, session_id, messages, actual_model)
full_content = []
is_answering = False
for line in resp.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
content = line[6:].strip()
if content == "[ANSWER_START]":
is_answering = True
elif content == "[ANSWER_DONE]":
break
elif is_answering:
full_content.append(content)
response = {
"id": session_id,
"object": "chat.completion",
"created": int(time.time()),
"model": request_model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "".join(full_content)
},
"finish_reason": "stop"
}]
}
return response
@app.route("/hf/v1/models", methods=["GET"])
def list_models():
models = []
for model_id, _ in MODEL_MAPPING.items():
models.append({
"id": model_id,
"object": "model",
"created": int(time.time()),
"owned_by": "heck",
})
return {
"object": "list",
"data": models
}
@app.route("/hf/v1/chat/completions", methods=["POST"])
def chat_completions():
# API Key 验证
if not verify_api_key():
return {"error": "Invalid API Key"}, 401
data = request.json
if not data or "model" not in data:
return {"error": "Invalid request - missing model"}, 400
if not data.get("messages"):
return {"error": "Invalid request - missing messages"}, 400
# 验证消息格式
for msg in data["messages"]:
if not isinstance(msg, dict):
return {"error": "Invalid message format"}, 400
if "role" not in msg or "content" not in msg:
return {"error": "Invalid message format"}, 400
# 检查content的类型
if isinstance(msg["content"], list):
# 如果content是列表,确保每个元素都有text字段
for item in msg["content"]:
if not isinstance(item, dict) or "text" not in item:
return {"error": "Invalid content format"}, 400
# 提取所有text字段并合并
msg["content"] = " ".join(item["text"] for item in msg["content"])
elif not isinstance(msg["content"], str):
return {"error": "Invalid content type"}, 400
model = MODEL_MAPPING.get(data["model"])
if not model:
return {"error": "Unsupported Model"}, 400
try:
question = next((msg["content"] for msg in reversed(data["messages"])
if msg["role"] == "user"), None)
except Exception as e:
return {"error": "Failed to extract question"}, 400
if not question:
return {"error": "No user message found"}, 400
session_id = str(uuid4())
try:
if data.get("stream"):
return Response(
stream_response(question, session_id, data["messages"],
data["model"], model),
mimetype="text/event-stream"
)
else:
return normal_response(question, session_id, data["messages"],
data["model"], model)
except Exception as e:
return {"error": f"Internal server error: {str(e)}"}, 500
@app.route("/", methods=["GET"])
def root():
return {
"message": "App running"
}
if __name__ == "__main__":
# 使用环境变量获取端口,默认为7860(HF Spaces 默认端口)
port = int(os.environ.get("PORT", 7860))
app.run(host='0.0.0.0', port=port)