Spaces:
Running
Running
优化应用初始化逻辑,使用异步上下文管理器处理生命周期;改进模型加载机制,添加线程锁以确保线程安全;更新 Gunicorn 配置以提高性能和稳定性
Browse files- app.py +27 -19
- blkeras.py +38 -23
- gunicorn.conf.py +17 -1
app.py
CHANGED
@@ -1,16 +1,37 @@
|
|
1 |
import os
|
2 |
from fastapi import FastAPI
|
3 |
from pydantic import BaseModel
|
4 |
-
from fastapi.middleware.wsgi import WSGIMiddleware
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
7 |
-
|
8 |
-
from
|
9 |
|
10 |
from RequestModel import PredictRequest
|
11 |
-
from us_stock import fetch_symbols
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# 添加 CORS 中间件和限流配置
|
16 |
app.add_middleware(
|
@@ -55,25 +76,12 @@ async def api_bbb(request: TextRequest):
|
|
55 |
result = request.text + 'bbb'
|
56 |
return {"result": result}
|
57 |
|
58 |
-
|
59 |
-
@app.on_event("startup")
|
60 |
-
async def initialize_symbols():
|
61 |
-
# 在 FastAPI 启动时初始化变量
|
62 |
-
await fetch_symbols()
|
63 |
-
|
64 |
# 优化预测路由
|
65 |
@app.post("/api/predict")
|
66 |
async def predict(request: PredictRequest):
|
67 |
from blkeras import predict
|
68 |
-
|
69 |
try:
|
70 |
-
|
71 |
-
import asyncio
|
72 |
-
result = await asyncio.to_thread(
|
73 |
-
predict,
|
74 |
-
request.text,
|
75 |
-
request.stock_codes
|
76 |
-
)
|
77 |
return result
|
78 |
except Exception as e:
|
79 |
return []
|
|
|
1 |
import os
|
2 |
from fastapi import FastAPI
|
3 |
from pydantic import BaseModel
|
|
|
4 |
from fastapi.middleware.cors import CORSMiddleware
|
5 |
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
6 |
+
import asyncio
|
7 |
+
from contextlib import asynccontextmanager
|
8 |
|
9 |
from RequestModel import PredictRequest
|
|
|
10 |
|
11 |
+
# 全局变量,用于跟踪初始化状态
|
12 |
+
is_initialized = False
|
13 |
+
initialization_lock = asyncio.Lock()
|
14 |
+
|
15 |
+
@asynccontextmanager
|
16 |
+
async def lifespan(app: FastAPI):
|
17 |
+
# 启动时运行
|
18 |
+
global is_initialized
|
19 |
+
async with initialization_lock:
|
20 |
+
if not is_initialized:
|
21 |
+
await initialize_application()
|
22 |
+
is_initialized = True
|
23 |
+
yield
|
24 |
+
# 关闭时运行
|
25 |
+
# cleanup_code_here()
|
26 |
+
|
27 |
+
async def initialize_application():
|
28 |
+
# 在这里进行所有需要的初始化
|
29 |
+
from us_stock import fetch_symbols
|
30 |
+
|
31 |
+
await fetch_symbols()
|
32 |
+
# 其他初始化代码...
|
33 |
+
|
34 |
+
app = FastAPI(lifespan=lifespan)
|
35 |
|
36 |
# 添加 CORS 中间件和限流配置
|
37 |
app.add_middleware(
|
|
|
76 |
result = request.text + 'bbb'
|
77 |
return {"result": result}
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# 优化预测路由
|
80 |
@app.post("/api/predict")
|
81 |
async def predict(request: PredictRequest):
|
82 |
from blkeras import predict
|
|
|
83 |
try:
|
84 |
+
result = await asyncio.to_thread(predict, request.text, request.stock_codes)
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
return result
|
86 |
except Exception as e:
|
87 |
return []
|
blkeras.py
CHANGED
@@ -27,35 +27,48 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
27 |
# 设置环境变量,指定 Hugging Face 缓存路径
|
28 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# 加载模型
|
31 |
model = None
|
32 |
-
if model is None:
|
33 |
-
# 从环境变量中获取 Hugging Face token
|
34 |
-
hf_token = os.environ.get("HF_Token")
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
if
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
|
|
|
|
59 |
|
60 |
|
61 |
|
@@ -106,6 +119,7 @@ def predict(text: str, stock_codes: list):
|
|
106 |
|
107 |
print(f"Input Text Length: {len(text)}, Start with: {text[:200] if len(text) > 200 else text}")
|
108 |
print("Input stock codes:", stock_codes)
|
|
|
109 |
|
110 |
start_time = datetime.now()
|
111 |
input_text = text
|
@@ -230,6 +244,7 @@ def predict(text: str, stock_codes: list):
|
|
230 |
# print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
|
231 |
|
232 |
# 使用模型进行预测
|
|
|
233 |
predictions = model.predict(features)
|
234 |
|
235 |
# 生成伪精准度值
|
|
|
27 |
# 设置环境变量,指定 Hugging Face 缓存路径
|
28 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
29 |
|
30 |
+
import threading
|
31 |
+
|
32 |
+
# 添加线程锁
|
33 |
+
model_lock = threading.Lock()
|
34 |
+
model_initialized = False
|
35 |
+
|
36 |
# 加载模型
|
37 |
model = None
|
|
|
|
|
|
|
38 |
|
39 |
+
def get_model():
|
40 |
+
global model, model_initialized
|
41 |
+
if not model_initialized:
|
42 |
+
with model_lock:
|
43 |
+
if not model_initialized: # 双重检查锁定
|
44 |
+
# 从环境变量中获取 Hugging Face token
|
45 |
+
hf_token = os.environ.get("HF_Token")
|
46 |
+
|
47 |
+
|
48 |
+
# 使用 Hugging Face API token 登录 (确保只读权限)
|
49 |
+
if hf_token:
|
50 |
+
login(token=hf_token)
|
51 |
+
else:
|
52 |
+
raise ValueError("Hugging Face token not found in environment variables.")
|
53 |
|
54 |
+
# 下载模型到本地
|
55 |
+
model_path = hf_hub_download(repo_id="parkerjj/BuckLake-Stock-Model",
|
56 |
+
filename="stock_prediction_model_1118_final.keras",
|
57 |
+
use_auth_token=hf_token)
|
58 |
|
59 |
+
# 使用 Keras 加载模型
|
60 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
61 |
+
print(f"Loading saved model from {model_path}...")
|
62 |
+
from model_build import TransformerEncoder, ExpandDimension, ConcatenateTimesteps
|
63 |
+
model = keras.saving.load_model(model_path, custom_objects={
|
64 |
+
"TransformerEncoder": TransformerEncoder,
|
65 |
+
"ExpandDimension": ExpandDimension,
|
66 |
+
"ConcatenateTimesteps": ConcatenateTimesteps
|
67 |
+
})
|
68 |
|
69 |
+
model.summary()
|
70 |
+
model_initialized = True
|
71 |
+
return model
|
72 |
|
73 |
|
74 |
|
|
|
119 |
|
120 |
print(f"Input Text Length: {len(text)}, Start with: {text[:200] if len(text) > 200 else text}")
|
121 |
print("Input stock codes:", stock_codes)
|
122 |
+
print("Current Time:", datetime.now())
|
123 |
|
124 |
start_time = datetime.now()
|
125 |
input_text = text
|
|
|
244 |
# print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
|
245 |
|
246 |
# 使用模型进行预测
|
247 |
+
model = get_model()
|
248 |
predictions = model.predict(features)
|
249 |
|
250 |
# 生成伪精准度值
|
gunicorn.conf.py
CHANGED
@@ -11,6 +11,9 @@ workers = multiprocessing.cpu_count() + 1
|
|
11 |
# 设置为2,增加并发处理能力
|
12 |
threads = 2
|
13 |
|
|
|
|
|
|
|
14 |
# 工作方式
|
15 |
worker_class = "uvicorn.workers.UvicornWorker"
|
16 |
|
@@ -27,7 +30,20 @@ worker_connections = 2000
|
|
27 |
|
28 |
# 工作模式
|
29 |
worker_tmp_dir = "/dev/shm" # 使用内存文件系统提高性能
|
30 |
-
preload_app =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# 进程名称前缀
|
33 |
proc_name = 'gunicorn_fastapi'
|
|
|
11 |
# 设置为2,增加并发处理能力
|
12 |
threads = 2
|
13 |
|
14 |
+
# 请求超时时间
|
15 |
+
timeout = 600
|
16 |
+
|
17 |
# 工作方式
|
18 |
worker_class = "uvicorn.workers.UvicornWorker"
|
19 |
|
|
|
30 |
|
31 |
# 工作模式
|
32 |
worker_tmp_dir = "/dev/shm" # 使用内存文件系统提高性能
|
33 |
+
preload_app = False # 修改为 False,避免重复加载
|
34 |
+
|
35 |
+
# 添加新的配置
|
36 |
+
reload = False # 禁用自动重载
|
37 |
+
daemon = False # 非守护进程模式运行
|
38 |
+
|
39 |
+
# 添加应用初始化钩子
|
40 |
+
def when_ready(server):
|
41 |
+
# 当 Gunicorn 准备好时执行
|
42 |
+
server.log.info("Server is ready. Doing nothing.")
|
43 |
+
|
44 |
+
def post_fork(server, worker):
|
45 |
+
# 当 worker 进程被 fork 后执行
|
46 |
+
server.log.info(f"Worker spawned (pid: {worker.pid})")
|
47 |
|
48 |
# 进程名称前缀
|
49 |
proc_name = 'gunicorn_fastapi'
|