from flask import Flask, request, Response, json import requests from uuid import uuid4 import time import os from flask_cors import CORS app = Flask(__name__) CORS(app) # 启用CORS支持 # 从环境变量获取 API Key API_KEY = os.environ.get('API_KEY') if not API_KEY: raise ValueError("API_KEY environment variable is required") MODEL_MAPPING = { "deepseek": "deepseek/deepseek-chat", "gpt-4o-mini": "openai/gpt-4o-mini", "gemini-flash-1.5": "google/gemini-flash-1.5", "deepseek-reasoner": "deepseek-reasoner", "minimax-01": "minimax/minimax-01" } def verify_api_key(): auth_header = request.headers.get('Authorization') if not auth_header: return False try: # 支持 Bearer token 格式 if auth_header.startswith('Bearer '): token = auth_header.split(' ')[1] else: token = auth_header return token == API_KEY except: return False def make_heck_request(question, session_id, messages, actual_model): previous_question = previous_answer = None if len(messages) >= 2: for i in range(len(messages)-2, -1, -1): if messages[i]["role"] == "user": previous_question = messages[i]["content"] if i+1 < len(messages) and messages[i+1]["role"] == "assistant": previous_answer = messages[i+1]["content"] break payload = { "model": actual_model, "question": question, "language": "Chinese", "sessionId": session_id, "previousQuestion": previous_question, "previousAnswer": previous_answer } headers = { "Content-Type": "application/json", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" } return requests.post( "https://gateway.aiapilab.com/api/ha/v1/chat", json=payload, headers=headers, stream=True ) def stream_response(question, session_id, messages, request_model, actual_model): resp = make_heck_request(question, session_id, messages, actual_model) is_answering = False for line in resp.iter_lines(): if line: line = line.decode('utf-8') if not line.startswith('data: '): continue content = line[6:].strip() if content == "[ANSWER_START]": is_answering = True chunk = { "id": session_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": request_model, "choices": [{ "index": 0, "delta": {"role": "assistant"}, }] } yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" continue if content == "[ANSWER_DONE]": chunk = { "id": session_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": request_model, "choices": [{ "index": 0, "delta": {}, "finish_reason": "stop" }] } yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" break if is_answering and content and not content.startswith("[RELATE_Q"): chunk = { "id": session_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": request_model, "choices": [{ "index": 0, "delta": {"content": content}, }] } yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" def normal_response(question, session_id, messages, request_model, actual_model): resp = make_heck_request(question, session_id, messages, actual_model) full_content = [] is_answering = False for line in resp.iter_lines(): if line: line = line.decode('utf-8') if line.startswith('data: '): content = line[6:].strip() if content == "[ANSWER_START]": is_answering = True elif content == "[ANSWER_DONE]": break elif is_answering: full_content.append(content) response = { "id": session_id, "object": "chat.completion", "created": int(time.time()), "model": request_model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": "".join(full_content) }, "finish_reason": "stop" }] } return response @app.route("/hf/v1/models", methods=["GET"]) def list_models(): models = [] for model_id, _ in MODEL_MAPPING.items(): models.append({ "id": model_id, "object": "model", "created": int(time.time()), "owned_by": "heck", }) return { "object": "list", "data": models } @app.route("/hf/v1/chat/completions", methods=["POST"]) def chat_completions(): # API Key 验证 if not verify_api_key(): return {"error": "Invalid API Key"}, 401 data = request.json if not data or "model" not in data: return {"error": "Invalid request - missing model"}, 400 if not data.get("messages"): return {"error": "Invalid request - missing messages"}, 400 # 验证消息格式 for msg in data["messages"]: if not isinstance(msg, dict): return {"error": "Invalid message format"}, 400 if "role" not in msg or "content" not in msg: return {"error": "Invalid message format"}, 400 # 检查content的类型 if isinstance(msg["content"], list): # 如果content是列表,确保每个元素都有text字段 for item in msg["content"]: if not isinstance(item, dict) or "text" not in item: return {"error": "Invalid content format"}, 400 # 提取所有text字段并合并 msg["content"] = " ".join(item["text"] for item in msg["content"]) elif not isinstance(msg["content"], str): return {"error": "Invalid content type"}, 400 model = MODEL_MAPPING.get(data["model"]) if not model: return {"error": "Unsupported Model"}, 400 try: question = next((msg["content"] for msg in reversed(data["messages"]) if msg["role"] == "user"), None) except Exception as e: return {"error": "Failed to extract question"}, 400 if not question: return {"error": "No user message found"}, 400 session_id = str(uuid4()) try: if data.get("stream"): return Response( stream_response(question, session_id, data["messages"], data["model"], model), mimetype="text/event-stream" ) else: return normal_response(question, session_id, data["messages"], data["model"], model) except Exception as e: return {"error": f"Internal server error: {str(e)}"}, 500 @app.route("/", methods=["GET"]) def root(): return { "message": "App running" } if __name__ == "__main__": # 使用环境变量获取端口,默认为7860(HF Spaces 默认端口) port = int(os.environ.get("PORT", 7860)) app.run(host='0.0.0.0', port=port)