parkerjj commited on
Commit
068fdbc
·
1 Parent(s): 558076d

优化应用初始化逻辑,使用异步上下文管理器处理生命周期;改进模型加载机制,添加线程锁以确保线程安全;更新 Gunicorn 配置以提高性能和稳定性

Browse files
Files changed (3) hide show
  1. app.py +27 -19
  2. blkeras.py +38 -23
  3. gunicorn.conf.py +17 -1
app.py CHANGED
@@ -1,16 +1,37 @@
1
  import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
- from fastapi.middleware.wsgi import WSGIMiddleware
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.middleware.trustedhost import TrustedHostMiddleware
7
-
8
- from transformers import pipeline
9
 
10
  from RequestModel import PredictRequest
11
- from us_stock import fetch_symbols
12
 
13
- app = FastAPI() # 创建 FastAPI 应用
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # 添加 CORS 中间件和限流配置
16
  app.add_middleware(
@@ -55,25 +76,12 @@ async def api_bbb(request: TextRequest):
55
  result = request.text + 'bbb'
56
  return {"result": result}
57
 
58
-
59
- @app.on_event("startup")
60
- async def initialize_symbols():
61
- # 在 FastAPI 启动时初始化变量
62
- await fetch_symbols()
63
-
64
  # 优化预测路由
65
  @app.post("/api/predict")
66
  async def predict(request: PredictRequest):
67
  from blkeras import predict
68
-
69
  try:
70
- # 使用 asyncio.to_thread 将同步操作转换为异步
71
- import asyncio
72
- result = await asyncio.to_thread(
73
- predict,
74
- request.text,
75
- request.stock_codes
76
- )
77
  return result
78
  except Exception as e:
79
  return []
 
1
  import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.middleware.trustedhost import TrustedHostMiddleware
6
+ import asyncio
7
+ from contextlib import asynccontextmanager
8
 
9
  from RequestModel import PredictRequest
 
10
 
11
+ # 全局变量,用于跟踪初始化状态
12
+ is_initialized = False
13
+ initialization_lock = asyncio.Lock()
14
+
15
+ @asynccontextmanager
16
+ async def lifespan(app: FastAPI):
17
+ # 启动时运行
18
+ global is_initialized
19
+ async with initialization_lock:
20
+ if not is_initialized:
21
+ await initialize_application()
22
+ is_initialized = True
23
+ yield
24
+ # 关闭时运行
25
+ # cleanup_code_here()
26
+
27
+ async def initialize_application():
28
+ # 在这里进行所有需要的初始化
29
+ from us_stock import fetch_symbols
30
+
31
+ await fetch_symbols()
32
+ # 其他初始化代码...
33
+
34
+ app = FastAPI(lifespan=lifespan)
35
 
36
  # 添加 CORS 中间件和限流配置
37
  app.add_middleware(
 
76
  result = request.text + 'bbb'
77
  return {"result": result}
78
 
 
 
 
 
 
 
79
  # 优化预测路由
80
  @app.post("/api/predict")
81
  async def predict(request: PredictRequest):
82
  from blkeras import predict
 
83
  try:
84
+ result = await asyncio.to_thread(predict, request.text, request.stock_codes)
 
 
 
 
 
 
85
  return result
86
  except Exception as e:
87
  return []
blkeras.py CHANGED
@@ -27,35 +27,48 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
  # 设置环境变量,指定 Hugging Face 缓存路径
28
  os.environ["HF_HOME"] = "/tmp/huggingface"
29
 
 
 
 
 
 
 
30
  # 加载模型
31
  model = None
32
- if model is None:
33
- # 从环境变量中获取 Hugging Face token
34
- hf_token = os.environ.get("HF_Token")
35
 
36
-
37
- # 使用 Hugging Face API token 登录 (确保只读权限)
38
- if hf_token:
39
- login(token=hf_token)
40
- else:
41
- raise ValueError("Hugging Face token not found in environment variables.")
 
 
 
 
 
 
 
 
42
 
43
- # 下载模型到本地
44
- model_path = hf_hub_download(repo_id="parkerjj/BuckLake-Stock-Model",
45
- filename="stock_prediction_model_1118_final.keras",
46
- use_auth_token=hf_token)
47
 
48
- # 使用 Keras 加载模型
49
- os.environ["KERAS_BACKEND"] = "jax"
50
- print(f"Loading saved model from {model_path}...")
51
- from model_build import TransformerEncoder, ExpandDimension, ConcatenateTimesteps
52
- model = keras.saving.load_model(model_path, custom_objects={
53
- "TransformerEncoder": TransformerEncoder,
54
- "ExpandDimension": ExpandDimension,
55
- "ConcatenateTimesteps": ConcatenateTimesteps
56
- })
57
 
58
- model.summary()
 
 
59
 
60
 
61
 
@@ -106,6 +119,7 @@ def predict(text: str, stock_codes: list):
106
 
107
  print(f"Input Text Length: {len(text)}, Start with: {text[:200] if len(text) > 200 else text}")
108
  print("Input stock codes:", stock_codes)
 
109
 
110
  start_time = datetime.now()
111
  input_text = text
@@ -230,6 +244,7 @@ def predict(text: str, stock_codes: list):
230
  # print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
231
 
232
  # 使用模型进行预测
 
233
  predictions = model.predict(features)
234
 
235
  # 生成伪精准度值
 
27
  # 设置环境变量,指定 Hugging Face 缓存路径
28
  os.environ["HF_HOME"] = "/tmp/huggingface"
29
 
30
+ import threading
31
+
32
+ # 添加线程锁
33
+ model_lock = threading.Lock()
34
+ model_initialized = False
35
+
36
  # 加载模型
37
  model = None
 
 
 
38
 
39
+ def get_model():
40
+ global model, model_initialized
41
+ if not model_initialized:
42
+ with model_lock:
43
+ if not model_initialized: # 双重检查锁定
44
+ # 从环境变量中获取 Hugging Face token
45
+ hf_token = os.environ.get("HF_Token")
46
+
47
+
48
+ # 使用 Hugging Face API token 登录 (确保只读权限)
49
+ if hf_token:
50
+ login(token=hf_token)
51
+ else:
52
+ raise ValueError("Hugging Face token not found in environment variables.")
53
 
54
+ # 下载模型到本地
55
+ model_path = hf_hub_download(repo_id="parkerjj/BuckLake-Stock-Model",
56
+ filename="stock_prediction_model_1118_final.keras",
57
+ use_auth_token=hf_token)
58
 
59
+ # 使用 Keras 加载模型
60
+ os.environ["KERAS_BACKEND"] = "jax"
61
+ print(f"Loading saved model from {model_path}...")
62
+ from model_build import TransformerEncoder, ExpandDimension, ConcatenateTimesteps
63
+ model = keras.saving.load_model(model_path, custom_objects={
64
+ "TransformerEncoder": TransformerEncoder,
65
+ "ExpandDimension": ExpandDimension,
66
+ "ConcatenateTimesteps": ConcatenateTimesteps
67
+ })
68
 
69
+ model.summary()
70
+ model_initialized = True
71
+ return model
72
 
73
 
74
 
 
119
 
120
  print(f"Input Text Length: {len(text)}, Start with: {text[:200] if len(text) > 200 else text}")
121
  print("Input stock codes:", stock_codes)
122
+ print("Current Time:", datetime.now())
123
 
124
  start_time = datetime.now()
125
  input_text = text
 
244
  # print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
245
 
246
  # 使用模型进行预测
247
+ model = get_model()
248
  predictions = model.predict(features)
249
 
250
  # 生成伪精准度值
gunicorn.conf.py CHANGED
@@ -11,6 +11,9 @@ workers = multiprocessing.cpu_count() + 1
11
  # 设置为2,增加并发处理能力
12
  threads = 2
13
 
 
 
 
14
  # 工作方式
15
  worker_class = "uvicorn.workers.UvicornWorker"
16
 
@@ -27,7 +30,20 @@ worker_connections = 2000
27
 
28
  # 工作模式
29
  worker_tmp_dir = "/dev/shm" # 使用内存文件系统提高性能
30
- preload_app = True # 预加载应用,减少启动时间
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # 进程名称前缀
33
  proc_name = 'gunicorn_fastapi'
 
11
  # 设置为2,增加并发处理能力
12
  threads = 2
13
 
14
+ # 请求超时时间
15
+ timeout = 600
16
+
17
  # 工作方式
18
  worker_class = "uvicorn.workers.UvicornWorker"
19
 
 
30
 
31
  # 工作模式
32
  worker_tmp_dir = "/dev/shm" # 使用内存文件系统提高性能
33
+ preload_app = False # 修改为 False,避免重复加载
34
+
35
+ # 添加新的配置
36
+ reload = False # 禁用自动重载
37
+ daemon = False # 非守护进程模式运行
38
+
39
+ # 添加应用初始化钩子
40
+ def when_ready(server):
41
+ # 当 Gunicorn 准备好时执行
42
+ server.log.info("Server is ready. Doing nothing.")
43
+
44
+ def post_fork(server, worker):
45
+ # 当 worker 进程被 fork 后执行
46
+ server.log.info(f"Worker spawned (pid: {worker.pid})")
47
 
48
  # 进程名称前缀
49
  proc_name = 'gunicorn_fastapi'