FAYO commited on
Commit
1ef9436
·
1 Parent(s): 77b0e0f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. server/__init__.py +0 -0
  2. server/__pycache__/__init__.cpython-310.pyc +0 -0
  3. server/__pycache__/web_configs.cpython-310.pyc +0 -0
  4. server/asr/asr_server.py +58 -0
  5. server/asr/asr_worker.py +54 -0
  6. server/base/__init__.py +0 -0
  7. server/base/base_server.py +206 -0
  8. server/base/database/__init__.py +0 -0
  9. server/base/database/init_db.py +44 -0
  10. server/base/database/llm_db.py +22 -0
  11. server/base/database/product_db.py +186 -0
  12. server/base/database/streamer_info_db.py +152 -0
  13. server/base/database/streamer_room_db.py +415 -0
  14. server/base/database/user_db.py +48 -0
  15. server/base/models/__init__.py +0 -0
  16. server/base/models/llm_model.py +17 -0
  17. server/base/models/product_model.py +59 -0
  18. server/base/models/streamer_info_model.py +40 -0
  19. server/base/models/streamer_room_model.py +127 -0
  20. server/base/models/user_model.py +43 -0
  21. server/base/modules/__init__.py +0 -0
  22. server/base/modules/agent/__init__.py +0 -0
  23. server/base/modules/agent/agent_worker.py +200 -0
  24. server/base/modules/agent/delivery_time_query.py +300 -0
  25. server/base/modules/rag/__init__.py +0 -0
  26. server/base/modules/rag/feature_store.py +545 -0
  27. server/base/modules/rag/file_operation.py +228 -0
  28. server/base/modules/rag/rag_worker.py +122 -0
  29. server/base/modules/rag/retriever.py +244 -0
  30. server/base/modules/rag/test_queries.json +4 -0
  31. server/base/queue_thread.py +73 -0
  32. server/base/routers/__init__.py +0 -0
  33. server/base/routers/digital_human.py +85 -0
  34. server/base/routers/llm.py +187 -0
  35. server/base/routers/products.py +119 -0
  36. server/base/routers/streamer_info.py +156 -0
  37. server/base/routers/streaming_room.py +335 -0
  38. server/base/routers/users.py +157 -0
  39. server/base/server_info.py +134 -0
  40. server/base/utils.py +485 -0
  41. server/digital_human/digital_human_server.py +68 -0
  42. server/digital_human/modules/__init__.py +6 -0
  43. server/digital_human/modules/digital_human_worker.py +33 -0
  44. server/digital_human/modules/musetalk/models/unet.py +43 -0
  45. server/digital_human/modules/musetalk/models/vae.py +149 -0
  46. server/digital_human/modules/musetalk/utils/__init__.py +5 -0
  47. server/digital_human/modules/musetalk/utils/blending.py +110 -0
  48. server/digital_human/modules/musetalk/utils/dwpose/default_runtime.py +54 -0
  49. server/digital_human/modules/musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py +257 -0
  50. server/digital_human/modules/musetalk/utils/face_detection/README.md +1 -0
server/__init__.py ADDED
File without changes
server/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (139 Bytes). View file
 
server/__pycache__/web_configs.cpython-310.pyc ADDED
Binary file (4.19 kB). View file
 
server/asr/asr_server.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.exceptions import RequestValidationError
3
+ from fastapi.responses import PlainTextResponse
4
+ from loguru import logger
5
+ from pydantic import BaseModel
6
+
7
+ from ..web_configs import WEB_CONFIGS
8
+ from .asr_worker import load_asr_model, process_asr
9
+
10
+ app = FastAPI()
11
+
12
+ if WEB_CONFIGS.ENABLE_ASR:
13
+ ASR_HANDLER = load_asr_model()
14
+ else:
15
+ ASR_HANDLER = None
16
+
17
+
18
+ class ASRItem(BaseModel):
19
+ user_id: int # User 识别号,用于区分不用的用户调用
20
+ request_id: str # 请求 ID,用于生成 TTS & 数字人
21
+ wav_path: str # wav 文件路径
22
+
23
+
24
+ @app.post("/asr")
25
+ async def get_asr(asr_item: ASRItem):
26
+ # 语音转文字
27
+ result = ""
28
+ status = "success"
29
+ if ASR_HANDLER is None:
30
+ result = "ASR not enable in sever"
31
+ status = "fail"
32
+ logger.error(f"ASR not enable...")
33
+ else:
34
+ result = process_asr(ASR_HANDLER, asr_item.wav_path)
35
+ logger.info(f"ASR res for id {asr_item.request_id}, res = {result}")
36
+
37
+ return {"user_id": asr_item.user_id, "request_id": asr_item.request_id, "status": status, "result": result}
38
+
39
+
40
+ @app.get("/asr/check")
41
+ async def check_server():
42
+ return {"message": "server enabled"}
43
+
44
+
45
+ @app.exception_handler(RequestValidationError)
46
+ async def validation_exception_handler(request, exc):
47
+ """调 API 入参错误的回调接口
48
+
49
+ Args:
50
+ request (_type_): _description_
51
+ exc (_type_): _description_
52
+
53
+ Returns:
54
+ _type_: _description_
55
+ """
56
+ logger.info(request)
57
+ logger.info(exc)
58
+ return PlainTextResponse(str(exc), status_code=400)
server/asr/asr_worker.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ from funasr import AutoModel
4
+ from funasr.download.name_maps_from_hub import name_maps_ms as NAME_MAPS_MS
5
+ from modelscope import snapshot_download
6
+ from modelscope.utils.constant import Invoke, ThirdParty
7
+
8
+ from ..web_configs import WEB_CONFIGS
9
+
10
+
11
+ def load_asr_model():
12
+
13
+ # 模型下载
14
+ model_path_info = dict()
15
+ for model_name in ["paraformer-zh", "fsmn-vad", "ct-punc"]:
16
+ print(f"downloading asr model : {NAME_MAPS_MS[model_name]}")
17
+ mode_dir = snapshot_download(
18
+ NAME_MAPS_MS[model_name],
19
+ revision="master",
20
+ user_agent={Invoke.KEY: Invoke.PIPELINE, ThirdParty.KEY: "funasr"},
21
+ cache_dir=WEB_CONFIGS.ASR_MODEL_DIR,
22
+ )
23
+ model_path_info[model_name] = mode_dir
24
+ NAME_MAPS_MS[model_name] = mode_dir # 更新权重路径环境变量
25
+
26
+ print(f"ASR model path info = {model_path_info}")
27
+ # paraformer-zh is a multi-functional asr model
28
+ # use vad, punc, spk or not as you need
29
+ model = AutoModel(
30
+ model="paraformer-zh", # 语音识别,带时间戳输出,非实时
31
+ vad_model="fsmn-vad", # 语音端点检测,实时
32
+ punc_model="ct-punc", # 标点恢复
33
+ # spk_model="cam++" # 说话人确认/分割
34
+ model_path=model_path_info["paraformer-zh"],
35
+ vad_kwargs={"model_path": model_path_info["fsmn-vad"]},
36
+ punc_kwargs={"model_path": model_path_info["ct-punc"]},
37
+ )
38
+ return model
39
+
40
+
41
+ def process_asr(model: AutoModel, wav_path):
42
+ # https://github.com/modelscope/FunASR/blob/main/README_zh.md#%E5%AE%9E%E6%97%B6%E8%AF%AD%E9%9F%B3%E8%AF%86%E5%88%AB
43
+ f_start_time = datetime.datetime.now()
44
+ res = model.generate(input=wav_path, batch_size_s=50, hotword="魔搭")
45
+ delta_time = datetime.datetime.now() - f_start_time
46
+
47
+ try:
48
+ print(f"ASR using time {delta_time}s, text: ", res[0]["text"])
49
+ res_str = res[0]["text"]
50
+ except Exception as e:
51
+ print("ASR 解析失败,无法获取到文字")
52
+ return ""
53
+
54
+ return res_str
server/base/__init__.py ADDED
File without changes
server/base/base_server.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : base_server.py
5
+ @Time : 2024/09/02
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 中台服务入口文件
10
+ """
11
+
12
+ import time
13
+ import uuid
14
+ from contextlib import asynccontextmanager
15
+ from pathlib import Path
16
+
17
+ from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile
18
+ from fastapi.exceptions import RequestValidationError
19
+ from fastapi.responses import PlainTextResponse
20
+ from fastapi.staticfiles import StaticFiles
21
+ from loguru import logger
22
+
23
+ from ..web_configs import API_CONFIG, WEB_CONFIGS
24
+ from .database.init_db import create_db_and_tables
25
+ from .routers import digital_human, llm, products, streamer_info, streaming_room, users
26
+ from .server_info import SERVER_PLUGINS_INFO
27
+ from .utils import ChatItem, ResultCode, gen_default_data, make_return_data, streamer_sales_process
28
+
29
+ swagger_description = """
30
+
31
+ ## 项目地址
32
+
33
+ [销冠 —— 卖货主播大模型 && 后台管理系统](https://github.com/PeterH0323/Streamer-Sales)
34
+
35
+ ## 功能点
36
+
37
+ 1. 📜 **主播文案一键生成**
38
+ 2. 🚀 KV cache + Turbomind **推理加速**
39
+ 3. 📚 RAG **检索增强生成**
40
+ 4. 🔊 TTS **文字转语音**
41
+ 5. 🦸 **数字人生成**
42
+ 6. 🌐 **Agent 网络查询**
43
+ 7. 🎙️ **ASR 语音转文字**
44
+ 8. 🍍 **Vue + pinia + element-plus **搭建的前端,可自由扩展快速开发
45
+ 9. 🗝️ 后端采用 FastAPI + Uvicorn + PostgreSQL,**高性能,高效编码,生产可用,同时具有 JWT 身份验证**
46
+ 10. 🐋 采用 Docker-compose 部署,**一键实现分布式部署**
47
+
48
+ """
49
+
50
+
51
+ @asynccontextmanager
52
+ async def lifespan(app: FastAPI):
53
+ """服务生命周期函数"""
54
+ # 启动
55
+ create_db_and_tables() # 创建数据库和数据表
56
+
57
+ # 新服务,生成默认数据,可以自行注释 or 修改
58
+ gen_default_data()
59
+
60
+ if WEB_CONFIGS.ENABLE_RAG:
61
+ from .modules.rag.rag_worker import load_rag_model
62
+
63
+ # 生成 rag 数据库
64
+ await load_rag_model(user_id=1)
65
+
66
+ yield
67
+
68
+ # 结束
69
+ logger.info("Base server stopped.")
70
+
71
+
72
+ app = FastAPI(
73
+ title="销冠 —— 卖货主播大模型 && 后台管理系统",
74
+ description=swagger_description,
75
+ summary="一个能够根据给定的商品特点从激发用户购买意愿角度出发进行商品解说的卖货主播大模型。",
76
+ version="1.0.0",
77
+ license_info={
78
+ "name": "AGPL-3.0 license",
79
+ "url": "https://github.com/PeterH0323/Streamer-Sales/blob/main/LICENSE",
80
+ },
81
+ root_path=API_CONFIG.API_V1_STR,
82
+ lifespan=lifespan,
83
+ )
84
+
85
+ # 注册路由
86
+ app.include_router(users.router)
87
+ app.include_router(products.router)
88
+ app.include_router(llm.router)
89
+ app.include_router(streamer_info.router)
90
+ app.include_router(streaming_room.router)
91
+ app.include_router(digital_human.router)
92
+
93
+
94
+ # 挂载静态文件目录,以便访问上传的文件
95
+ WEB_CONFIGS.SERVER_FILE_ROOT = str(Path(WEB_CONFIGS.SERVER_FILE_ROOT).absolute())
96
+ Path(WEB_CONFIGS.SERVER_FILE_ROOT).mkdir(parents=True, exist_ok=True)
97
+ logger.info(f"上传文件挂载路径: {WEB_CONFIGS.SERVER_FILE_ROOT}")
98
+ logger.info(f"上传文件访问 URL: {API_CONFIG.REQUEST_FILES_URL}")
99
+ app.mount(
100
+ f"/{API_CONFIG.REQUEST_FILES_URL.split('/')[-1]}",
101
+ StaticFiles(directory=WEB_CONFIGS.SERVER_FILE_ROOT),
102
+ name=API_CONFIG.REQUEST_FILES_URL.split("/")[-1],
103
+ )
104
+
105
+
106
+ @app.get("/")
107
+ async def hello():
108
+ return {"message": "Hello Streamer-Sales"}
109
+
110
+
111
+ @app.exception_handler(RequestValidationError)
112
+ async def validation_exception_handler(request, exc):
113
+ """调 API 入参错误的回调接口
114
+
115
+ Args:
116
+ request (_type_): _description_
117
+ exc (_type_): _description_
118
+
119
+ Returns:
120
+ _type_: _description_
121
+ """
122
+ logger.info(request.headers)
123
+ logger.info(exc)
124
+ return PlainTextResponse(str(exc), status_code=400)
125
+
126
+
127
+ @app.get("/dashboard", tags=["base"], summary="获取主页信息接口")
128
+ async def get_dashboard_info():
129
+ """首页展示数据"""
130
+ fake_dashboard_data = {
131
+ "registeredBrandNum": 98431, # 入驻品牌方
132
+ "productNum": 49132, # 商品数
133
+ "dailyActivity": 68431, # 日活
134
+ "todayOrder": 8461321, # 订单量
135
+ "totalSales": 245578131857, # 销售额
136
+ "conversionRate": 90.0, # 转化率
137
+ # 折线图
138
+ "orderNumList": [46813, 68461, 99561, 138131, 233812, 84613, 846122], # 订单量
139
+ "totalSalesList": [46813, 68461, 99561, 138131, 23383, 84613, 841213], # 销售额
140
+ "newUserList": [3215, 65131, 6513, 6815, 2338, 84614, 84213], # 新增用户
141
+ "activityUserList": [132, 684, 59431, 4618, 31354, 68431, 88431], # 活跃用户
142
+ # 柱状图
143
+ "knowledgeBasesNum": 12, # 知识库数量
144
+ "digitalHumanNum": 3, # 数字人数量
145
+ "LiveRoomNum": 5, # 直播间数量
146
+ }
147
+
148
+ return make_return_data(True, ResultCode.SUCCESS, "成功", fake_dashboard_data)
149
+
150
+
151
+ @app.get("/plugins_info", tags=["base"], summary="获取组件信息接口")
152
+ async def get_plugins_info():
153
+
154
+ plugins_info = SERVER_PLUGINS_INFO.get_status()
155
+ return make_return_data(True, ResultCode.SUCCESS, "成功", plugins_info)
156
+
157
+
158
+ @app.post("/upload/file", tags=["base"], summary="上传文件接口")
159
+ async def upload_product_api(file: UploadFile = File(...), user_id: int = Depends(users.get_current_user_info)):
160
+
161
+ file_type = file.filename.split(".")[-1] # eg. png
162
+ logger.info(f"upload file type = {file_type}")
163
+
164
+ sub_dir_name_map = {
165
+ "md": WEB_CONFIGS.INSTRUCTIONS_DIR,
166
+ "png": WEB_CONFIGS.IMAGES_DIR,
167
+ "jpg": WEB_CONFIGS.IMAGES_DIR,
168
+ "mp4": WEB_CONFIGS.STREAMER_INFO_FILES_DIR,
169
+ "wav": WEB_CONFIGS.STREAMER_INFO_FILES_DIR,
170
+ "webm": WEB_CONFIGS.ASR_FILE_DIR,
171
+ }
172
+ if file_type in ["wav", "mp4"]:
173
+ save_root = WEB_CONFIGS.STREAMER_FILE_DIR
174
+ elif file_type in ["webm"]:
175
+ save_root = ""
176
+ else:
177
+ save_root = WEB_CONFIGS.PRODUCT_FILE_DIR
178
+
179
+ upload_time = str(int(time.time())) + "__" + str(uuid.uuid4().hex)
180
+
181
+ sub_dir_name = sub_dir_name_map[file_type]
182
+ save_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(save_root, sub_dir_name, upload_time + "." + file_type)
183
+ save_path.parent.mkdir(exist_ok=True, parents=True)
184
+ logger.info(f"save path = {save_path}")
185
+
186
+ # 使用流式处理接收文件
187
+ with open(save_path, "wb") as buffer:
188
+ while chunk := await file.read(1024 * 1024 * 5): # 每次读取 5MB 的数据块
189
+ buffer.write(chunk)
190
+
191
+ split_dir_name = Path(WEB_CONFIGS.SERVER_FILE_ROOT).name # 保存文件夹根目录名字
192
+ file_url = f"{API_CONFIG.REQUEST_FILES_URL}{str(save_path).split(split_dir_name)[-1]}"
193
+
194
+ # TODO 文件归属记录表
195
+
196
+ return make_return_data(True, ResultCode.SUCCESS, "成功", file_url)
197
+
198
+
199
+ @app.post("/streamer-sales/chat", tags=["base"], summary="对话接口", deprecated=True)
200
+ async def streamer_sales_chat(chat_item: ChatItem, response: Response):
201
+ from sse_starlette import EventSourceResponse
202
+
203
+ # 对话总接口
204
+ response.headers["Content-Type"] = "text/event-stream"
205
+ response.headers["Cache-Control"] = "no-cache"
206
+ return EventSourceResponse(streamer_sales_process(chat_item))
server/base/database/__init__.py ADDED
File without changes
server/base/database/init_db.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : init_db.py
5
+ @Time : 2024/09/06
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 数据库初始化
10
+ """
11
+
12
+ from loguru import logger
13
+ from pydantic import PostgresDsn
14
+ from pydantic_core import MultiHostUrl
15
+ from sqlmodel import SQLModel, create_engine
16
+
17
+ from ...web_configs import WEB_CONFIGS
18
+
19
+ ECHO_DB_MESG = True # 数据库执行中是否回显,for debug
20
+
21
+
22
+ def sqlalchemy_db_url() -> PostgresDsn:
23
+ """生成数据库 URL
24
+
25
+ Returns:
26
+ PostgresDsn: 数据库地址
27
+ """
28
+ return MultiHostUrl.build(
29
+ scheme="postgresql+psycopg",
30
+ username=WEB_CONFIGS.POSTGRES_USER,
31
+ password=WEB_CONFIGS.POSTGRES_PASSWORD,
32
+ host=WEB_CONFIGS.POSTGRES_SERVER,
33
+ port=WEB_CONFIGS.POSTGRES_PORT,
34
+ path=WEB_CONFIGS.POSTGRES_DB,
35
+ )
36
+
37
+
38
+ logger.info(f"connecting to db: {str(sqlalchemy_db_url())}")
39
+ DB_ENGINE = create_engine(str(sqlalchemy_db_url()), echo=ECHO_DB_MESG)
40
+
41
+
42
+ def create_db_and_tables():
43
+ """创建所有数据库和对应的表,有则跳过"""
44
+ SQLModel.metadata.create_all(DB_ENGINE)
server/base/database/llm_db.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : llm_db.py
5
+ @Time : 2024/09/01
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 大模型对话数据库交互
10
+ """
11
+
12
+ import yaml
13
+
14
+ from ...web_configs import WEB_CONFIGS
15
+
16
+
17
+ async def get_llm_product_prompt_base_info():
18
+ # 加载对话配置文件
19
+ with open(WEB_CONFIGS.CONVERSATION_CFG_YAML_PATH, "r", encoding="utf-8") as f:
20
+ dataset_yaml = yaml.safe_load(f)
21
+
22
+ return dataset_yaml
server/base/database/product_db.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : product_db.py
5
+ @Time : 2024/08/30
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 商品数据表文件读写
10
+ """
11
+
12
+ from typing import List, Tuple
13
+
14
+ from loguru import logger
15
+ from sqlalchemy import func
16
+ from sqlmodel import Session, and_, not_, select
17
+
18
+ from ...web_configs import API_CONFIG
19
+ from ..models.product_model import ProductInfo
20
+ from .init_db import DB_ENGINE
21
+
22
+
23
+ async def get_db_product_info(
24
+ user_id: int,
25
+ current_page: int = -1,
26
+ page_size: int = 10,
27
+ product_name: str | None = None,
28
+ product_id: int | None = None,
29
+ exclude_list: List[int] | None = None,
30
+ ) -> Tuple[List[ProductInfo], int]:
31
+ """查询数据库中的商品信息
32
+
33
+ Args:
34
+ user_id (int): 用户 ID
35
+ current_page (int, optional): 页数. Defaults to -1.
36
+ page_size (int, optional): 每页的大小. Defaults to 10.
37
+ product_name (str | None, optional): 商品名称,模糊搜索. Defaults to None.
38
+ product_id (int | None, optional): 商品 ID,用户获取特定商品信息. Defaults to None.
39
+
40
+ Returns:
41
+ List[ProductInfo]: 商品信息
42
+ int : 该用户持有的总商品数,已剔除删除的
43
+ """
44
+
45
+ assert current_page != 0
46
+ assert page_size != 0
47
+
48
+ # 查询条件
49
+ query_condiction = and_(ProductInfo.user_id == user_id, ProductInfo.delete == False)
50
+
51
+ # 获取总数
52
+ with Session(DB_ENGINE) as session:
53
+ # 获得该用户所有商品的总数
54
+ total_product_num = session.scalar(select(func.count(ProductInfo.product_id)).where(query_condiction))
55
+
56
+ if product_name is not None:
57
+ # 查询条件更改为商品名称模糊搜索
58
+ query_condiction = and_(
59
+ ProductInfo.user_id == user_id, ProductInfo.delete == False, ProductInfo.product_name.ilike(f"%{product_name}%")
60
+ )
61
+
62
+ elif product_id is not None:
63
+ # 查询条件更改为查找特定 ID
64
+ query_condiction = and_(
65
+ ProductInfo.user_id == user_id, ProductInfo.delete == False, ProductInfo.product_id == product_id
66
+ )
67
+
68
+ elif exclude_list is not None:
69
+ # 排除查询
70
+ query_condiction = and_(
71
+ ProductInfo.user_id == user_id, ProductInfo.delete == False, not_(ProductInfo.product_id.in_(exclude_list))
72
+ )
73
+
74
+ # 查询获取商品
75
+ if current_page < 0:
76
+ # 全部查询
77
+ product_list = session.exec(select(ProductInfo).where(query_condiction).order_by(ProductInfo.product_id)).all()
78
+ else:
79
+ # 分页查询
80
+ offset_idx = (current_page - 1) * page_size
81
+ product_list = session.exec(
82
+ select(ProductInfo).where(query_condiction).offset(offset_idx).limit(page_size).order_by(ProductInfo.product_id)
83
+ ).all()
84
+
85
+ if product_list is None:
86
+ logger.warning("nothing to find in db...")
87
+ product_list = []
88
+
89
+ # 将路径换成服务器路径
90
+ for product in product_list:
91
+ product.image_path = API_CONFIG.REQUEST_FILES_URL + product.image_path
92
+ product.instruction = API_CONFIG.REQUEST_FILES_URL + product.instruction
93
+
94
+ logger.info(product_list)
95
+ logger.info(f"len {len(product_list)}")
96
+
97
+ return product_list, total_product_num
98
+
99
+
100
+ async def delete_product_id(product_id: int, user_id: int) -> bool:
101
+ """删除特定的商品 ID
102
+
103
+ Args:
104
+ product_id (int): 商品 ID
105
+ user_id (int): 用户 ID,用于防止其他用户恶意删除
106
+
107
+ Returns:
108
+ bool: 是否删除成功
109
+ """
110
+
111
+ delete_success = True
112
+
113
+ try:
114
+ # 获取总数
115
+ with Session(DB_ENGINE) as session:
116
+ # 查找特定 ID
117
+ product_info = session.exec(
118
+ select(ProductInfo).where(and_(ProductInfo.product_id == product_id, ProductInfo.user_id == user_id))
119
+ ).one()
120
+
121
+ if product_info is None:
122
+ logger.error("Delete by other ID !!!")
123
+ return False
124
+
125
+ product_info.delete = True # 设置为删除
126
+ session.add(product_info)
127
+ session.commit() # 提交
128
+ except Exception:
129
+ delete_success = False
130
+
131
+ return delete_success
132
+
133
+
134
+ def create_or_update_db_product_by_id(product_id: int, new_info: ProductInfo, user_id: int) -> bool:
135
+ """新增 or 编辑商品信息
136
+
137
+ Args:
138
+ product_id (int): 商品 ID
139
+ new_info (ProductInfo): 新的信息
140
+ user_id (int): 用户 ID,用于防止其他用户恶意修改
141
+
142
+ Returns:
143
+ bool: 说明书是否变化
144
+ """
145
+
146
+ instruction_updated = False
147
+
148
+ # 去掉服务器地址
149
+ new_info.image_path = new_info.image_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
150
+ new_info.instruction = new_info.instruction.replace(API_CONFIG.REQUEST_FILES_URL, "")
151
+
152
+ with Session(DB_ENGINE) as session:
153
+
154
+ if product_id > 0:
155
+ # 更新特定 ID
156
+ product_info = session.exec(
157
+ select(ProductInfo).where(and_(ProductInfo.product_id == product_id, ProductInfo.user_id == user_id))
158
+ ).one()
159
+
160
+ if product_info is None:
161
+ logger.error("Edit by other ID !!!")
162
+ return False
163
+
164
+ if product_info.instruction != new_info.instruction:
165
+ # 判断说明书是否变化了
166
+ instruction_updated = True
167
+
168
+ # 更新对应的值
169
+ product_info.product_name = new_info.product_name
170
+ product_info.product_class = new_info.product_class
171
+ product_info.heighlights = new_info.heighlights
172
+ product_info.image_path = new_info.image_path
173
+ product_info.instruction = new_info.instruction
174
+ product_info.departure_place = new_info.departure_place
175
+ product_info.delivery_company = new_info.delivery_company
176
+ product_info.selling_price = new_info.selling_price
177
+ product_info.amount = new_info.amount
178
+
179
+ session.add(product_info)
180
+ else:
181
+ # 新增,直接添加即可
182
+ session.add(new_info)
183
+ instruction_updated = True
184
+
185
+ session.commit() # 提交
186
+ return instruction_updated
server/base/database/streamer_info_db.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : streamer_info_db.py
5
+ @Time : 2024/08/30
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 主播信息数据库操作
10
+ """
11
+
12
+
13
+ from typing import List
14
+
15
+ from loguru import logger
16
+ from sqlmodel import Session, and_, select
17
+
18
+ from ...web_configs import API_CONFIG
19
+ from ..models.streamer_info_model import StreamerInfo
20
+ from .init_db import DB_ENGINE
21
+
22
+
23
+ async def get_db_streamer_info(user_id: int, streamer_id: int | None = None) -> List[StreamerInfo] | None:
24
+ """查询数据库中的主播信息
25
+
26
+ Args:
27
+ user_id (int): 用户 ID
28
+ streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None.
29
+
30
+ Returns:
31
+ List[StreamerInfo] | StreamerInfo | None: 主播信息,如果获取全部则返回 list,如果获取单个则返回单个,如果查不到返回 None
32
+ """
33
+
34
+ # 查询条件
35
+ query_condiction = and_(StreamerInfo.user_id == user_id, StreamerInfo.delete == False)
36
+
37
+ # 获取总数
38
+ with Session(DB_ENGINE) as session:
39
+ # 获得该用户所有主播的总数
40
+ # total_product_num = session.scalar(select(func.count(StreamerInfo.product_id)).where(query_condiction))
41
+
42
+ if streamer_id is not None:
43
+ # 查询条件更改为查找特定 ID
44
+ query_condiction = and_(
45
+ StreamerInfo.user_id == user_id, StreamerInfo.delete == False, StreamerInfo.streamer_id == streamer_id
46
+ )
47
+
48
+ # 查询主播商品,并根据 ID 进行排序
49
+ try:
50
+ streamer_list = session.exec(select(StreamerInfo).where(query_condiction).order_by(StreamerInfo.streamer_id)).all()
51
+ except Exception as e:
52
+ streamer_list = None
53
+
54
+ if streamer_list is None:
55
+ logger.warning("nothing to find in db...")
56
+ streamer_list = []
57
+
58
+ # 将路径换成服务器路径
59
+ for streamer in streamer_list:
60
+ streamer.avatar = API_CONFIG.REQUEST_FILES_URL + streamer.avatar
61
+ streamer.tts_reference_audio = API_CONFIG.REQUEST_FILES_URL + streamer.tts_reference_audio
62
+ streamer.poster_image = API_CONFIG.REQUEST_FILES_URL + streamer.poster_image
63
+ streamer.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + streamer.base_mp4_path
64
+
65
+ logger.info(streamer_list)
66
+ logger.info(f"len {len(streamer_list)}")
67
+
68
+ return streamer_list
69
+
70
+
71
+ async def delete_streamer_id(streamer_id: int, user_id: int) -> bool:
72
+ """删除特定的主播 ID
73
+
74
+ Args:
75
+ streamer_id (int): 主播 ID
76
+ user_id (int): 用户 ID,用于防止其他用户恶意删除
77
+
78
+ Returns:
79
+ bool: 是否删除成功
80
+ """
81
+
82
+ delete_success = True
83
+
84
+ try:
85
+ # 获取总数
86
+ with Session(DB_ENGINE) as session:
87
+ # 查找特定 ID
88
+ streamer_info = session.exec(
89
+ select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
90
+ ).one()
91
+
92
+ if streamer_info is None:
93
+ logger.error("Delete by other ID !!!")
94
+ return False
95
+
96
+ streamer_info.delete = True # 设置为删除
97
+ session.add(streamer_info)
98
+ session.commit() # 提交
99
+ except Exception:
100
+ delete_success = False
101
+
102
+ return delete_success
103
+
104
+
105
+ def create_or_update_db_streamer_by_id(streamer_id: int, new_info: StreamerInfo, user_id: int) -> int:
106
+ """新增 or 编辑主播信息
107
+
108
+ Args:
109
+ product_id (int): 商品 ID
110
+ new_info (ProductInfo): 新的信息
111
+ user_id (int): 用户 ID,用于防止其他用户恶意修改
112
+
113
+ Returns:
114
+ int: 主播 ID
115
+ """
116
+
117
+ # 去掉服务器地址
118
+ new_info.avatar = new_info.avatar.replace(API_CONFIG.REQUEST_FILES_URL, "")
119
+ new_info.tts_reference_audio = new_info.tts_reference_audio.replace(API_CONFIG.REQUEST_FILES_URL, "")
120
+ new_info.poster_image = new_info.poster_image.replace(API_CONFIG.REQUEST_FILES_URL, "")
121
+ new_info.base_mp4_path = new_info.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
122
+
123
+ with Session(DB_ENGINE) as session:
124
+
125
+ if streamer_id > 0:
126
+ # 更新特定 ID
127
+ streamer_info = session.exec(
128
+ select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
129
+ ).one()
130
+
131
+ if streamer_info is None:
132
+ logger.error("Edit by other ID !!!")
133
+ return -1
134
+ else:
135
+ # 新增,直接添加即可
136
+ streamer_info = StreamerInfo(user_id=user_id)
137
+
138
+ # 更新对应的值
139
+ streamer_info.name = new_info.name
140
+ streamer_info.character = new_info.character
141
+ streamer_info.avatar = new_info.avatar
142
+ streamer_info.tts_weight_tag = new_info.tts_weight_tag
143
+ streamer_info.tts_reference_sentence = new_info.tts_reference_sentence
144
+ streamer_info.tts_reference_audio = new_info.tts_reference_audio
145
+ streamer_info.poster_image = new_info.poster_image
146
+ streamer_info.base_mp4_path = new_info.base_mp4_path
147
+
148
+ session.add(streamer_info)
149
+ session.commit() # 提交
150
+ session.refresh(streamer_info)
151
+
152
+ return int(streamer_info.streamer_id)
server/base/database/streamer_room_db.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : streamer_room_db.py
5
+ @Time : 2024/08/31
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 直播间信息数据库操作
10
+ """
11
+
12
+
13
+ from datetime import datetime
14
+ from typing import List
15
+
16
+ from loguru import logger
17
+ from sqlmodel import Session, and_, not_, select
18
+
19
+ from ...web_configs import API_CONFIG
20
+ from ..models.streamer_room_model import ChatMessageInfo, OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo
21
+ from .init_db import DB_ENGINE
22
+
23
+
24
+ async def get_db_streaming_room_info(user_id: int, room_id: int | None = None) -> List[StreamRoomInfo] | None:
25
+ """查询数据库中的商品信息
26
+
27
+ Args:
28
+ user_id (int): 用户 ID
29
+ streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None.
30
+
31
+ Returns:
32
+ List[StreamRoomInfo] | None: 直播间信息
33
+ """
34
+
35
+ # 查询条件
36
+ query_condiction = and_(StreamRoomInfo.user_id == user_id, StreamRoomInfo.delete == False)
37
+
38
+ # 获取总数
39
+ with Session(DB_ENGINE) as session:
40
+ if room_id is not None:
41
+ # 查询条件更改为查找特定 ID
42
+ query_condiction = and_(
43
+ StreamRoomInfo.user_id == user_id, StreamRoomInfo.delete == False, StreamRoomInfo.room_id == room_id
44
+ )
45
+
46
+ # 查询获取直播间信息
47
+ stream_room_list = session.exec(select(StreamRoomInfo).where(query_condiction).order_by(StreamRoomInfo.room_id)).all()
48
+
49
+ if stream_room_list is None:
50
+ logger.warning("nothing to find in db...")
51
+ stream_room_list = []
52
+
53
+ # 将路径换成服务器路径
54
+ for stream_room in stream_room_list:
55
+ # 主播信息
56
+ stream_room.streamer_info.avatar = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.avatar
57
+ stream_room.streamer_info.tts_reference_audio = (
58
+ API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.tts_reference_audio
59
+ )
60
+ stream_room.streamer_info.poster_image = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.poster_image
61
+ stream_room.streamer_info.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.base_mp4_path
62
+
63
+ # 商品信息
64
+ for idx, product in enumerate(stream_room.product_list):
65
+ stream_room.product_list[idx].product_info.image_path = API_CONFIG.REQUEST_FILES_URL + product.product_info.image_path
66
+ stream_room.product_list[idx].product_info.instruction = (
67
+ API_CONFIG.REQUEST_FILES_URL + product.product_info.instruction
68
+ )
69
+
70
+ logger.info(stream_room_list)
71
+ logger.info(f"len {len(stream_room_list)}")
72
+
73
+ return stream_room_list
74
+
75
+
76
+ async def delete_room_id(room_id: int, user_id: int) -> bool:
77
+ """删除特定的主播间 ID
78
+
79
+ Args:
80
+ room_id (int): 直播间 ID
81
+ user_id (int): 用户 ID,用于防止其他用户恶意删除
82
+
83
+ Returns:
84
+ bool: 是否删除成功
85
+ """
86
+
87
+ delete_success = True
88
+
89
+ try:
90
+ # 获取总数
91
+ with Session(DB_ENGINE) as session:
92
+ # 查找特定 ID
93
+ room_info = session.exec(
94
+ select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id))
95
+ ).one()
96
+
97
+ if room_info is None:
98
+ logger.error("Delete by other ID !!!")
99
+ return False
100
+
101
+ room_info.delete = True # 设置为删除
102
+ session.add(room_info)
103
+ session.commit() # 提交
104
+ except Exception:
105
+ delete_success = False
106
+
107
+ return delete_success
108
+
109
+
110
+ def create_or_update_db_room_by_id(room_id: int, new_info: StreamRoomInfo, user_id: int):
111
+ """新增 or 编辑直播间信息
112
+
113
+ Args:
114
+ room_id (int): 直播间 ID
115
+ new_info (StreamRoomInfo): 新的信息
116
+ user_id (int): 用户 ID,用于防止其他用户恶意修改
117
+ """
118
+
119
+ with Session(DB_ENGINE) as session:
120
+
121
+ # 更新 status 内容
122
+ if new_info.status_id is not None:
123
+ status_info = session.exec(
124
+ select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == new_info.status_id)
125
+ ).one()
126
+ else:
127
+ status_info = OnAirRoomStatusItem()
128
+
129
+ status_info.streaming_video_path = new_info.status.streaming_video_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
130
+ status_info.live_status = new_info.status.live_status
131
+ session.add(status_info)
132
+ session.commit()
133
+ session.refresh(status_info)
134
+
135
+ if room_id > 0:
136
+
137
+ # 更新主播间其他信息
138
+ room_info = session.exec(
139
+ select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id))
140
+ ).one()
141
+
142
+ if room_info is None:
143
+ logger.error("Edit by other ID !!!")
144
+ return
145
+
146
+ else:
147
+ room_info = StreamRoomInfo(status_id=status_info.status_id, user_id=user_id)
148
+
149
+ # 更新直播间基础信息
150
+ room_info.name = new_info.name
151
+ room_info.prohibited_words_id = new_info.prohibited_words_id
152
+ room_info.room_poster = new_info.room_poster.replace(API_CONFIG.REQUEST_FILES_URL, "")
153
+ room_info.background_image = new_info.background_image.replace(API_CONFIG.REQUEST_FILES_URL, "")
154
+ room_info.streamer_id = new_info.streamer_id
155
+
156
+ session.add(room_info)
157
+ session.commit() # 提交
158
+ session.refresh(room_info)
159
+
160
+ # 更新商品信息
161
+ if len(new_info.product_list) > 0:
162
+ selected_id_list = [product.product_id for product in new_info.product_list]
163
+ for product in new_info.product_list:
164
+ if product.sales_info_id is not None:
165
+ # 更新
166
+ sales_info = session.exec(
167
+ select(SalesDocAndVideoInfo).where(
168
+ and_(
169
+ SalesDocAndVideoInfo.room_id == room_info.room_id,
170
+ SalesDocAndVideoInfo.product_id == product.product_id,
171
+ SalesDocAndVideoInfo.sales_info_id == product.sales_info_id,
172
+ )
173
+ )
174
+ ).one()
175
+ else:
176
+ # 新建
177
+ sales_info = SalesDocAndVideoInfo()
178
+
179
+ sales_info.product_id = product.product_id
180
+ sales_info.sales_doc = product.sales_doc
181
+ sales_info.start_time = product.start_time
182
+ sales_info.start_video = product.start_video.replace(API_CONFIG.REQUEST_FILES_URL, "")
183
+ sales_info.selected = True
184
+ sales_info.room_id = room_info.room_id
185
+ session.add(sales_info)
186
+ session.commit()
187
+
188
+ # 删除没选上的
189
+ if len(selected_id_list) > 0:
190
+ cancel_select_sales_info = session.exec(
191
+ select(SalesDocAndVideoInfo).where(
192
+ and_(
193
+ SalesDocAndVideoInfo.room_id == room_info.room_id,
194
+ not_(SalesDocAndVideoInfo.product_id.in_(selected_id_list)),
195
+ )
196
+ )
197
+ ).all()
198
+
199
+ if cancel_select_sales_info is not None:
200
+ for cancel_select in cancel_select_sales_info:
201
+ session.delete(cancel_select)
202
+ session.commit()
203
+
204
+ return room_info.room_id
205
+
206
+
207
+ def init_conversation(db_session, sales_info_id: int, streamer_id: int, sales_doc: str):
208
+ """新建直播间对话,一般触发于点击 开始直播 or 下一个商品
209
+
210
+ Args:
211
+ db_session (it): 数据库句柄
212
+ sales_info_id (int): 销售 ID
213
+ streamer_id (int): 主播 ID
214
+ sales_doc (str): 主播文案
215
+ """
216
+ message_info = ChatMessageInfo(
217
+ sales_info_id=sales_info_id, streamer_id=streamer_id, role="streamer", message=sales_doc, send_time=datetime.now()
218
+ )
219
+ db_session.add(message_info)
220
+
221
+
222
+ def update_message_info(sales_info_id: int, role_id: int, role: str, message: str):
223
+ """新增对话记录
224
+
225
+ Args:
226
+ sales_info_id (int): 销售 ID
227
+ role_id (int): 角色 ID
228
+ role (str): 角色类型:"streamer", "user"
229
+ message (str): 插入的消息
230
+ """
231
+
232
+ assert role in ["streamer", "user"]
233
+
234
+ with Session(DB_ENGINE) as session:
235
+
236
+ role_key = "streamer_id" if role == "streamer" else "user_id"
237
+ role_id_info = {role_key: role_id}
238
+
239
+ message_info = ChatMessageInfo(
240
+ **role_id_info, sales_info_id=sales_info_id, role=role, message=message, send_time=datetime.now()
241
+ )
242
+ session.add(message_info)
243
+ session.commit()
244
+
245
+
246
+ def update_db_room_status(room_id: int, user_id: int, process_type: str):
247
+ """编辑直播间状态信息
248
+
249
+ Args:
250
+ room_id (int): 直播间 ID
251
+ new_status_info (OnAirRoomStatusItem): 新的信息
252
+ user_id (int): 用户 ID,用于防止其他用户恶意修改
253
+ """
254
+
255
+ with Session(DB_ENGINE) as session:
256
+
257
+ # 更新主播间其他信息
258
+ room_info = session.exec(
259
+ select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id))
260
+ ).one()
261
+
262
+ if room_info is None:
263
+ logger.error("Edit by other ID !!!")
264
+ return
265
+
266
+ # 更新 status 内容
267
+ if room_info.status_id is not None:
268
+ status_info = session.exec(
269
+ select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == room_info.status_id)
270
+ ).one()
271
+
272
+ if status_info is None:
273
+ logger.error("status_info is None !!!")
274
+ return
275
+
276
+ if process_type in ["online", "next-product"]:
277
+
278
+ if process_type == "online":
279
+ status_info.live_status = 1
280
+ status_info.start_time = datetime.now()
281
+ status_info.end_time = None
282
+ status_info.current_product_index = 0
283
+
284
+ elif process_type == "next-product":
285
+ status_info.current_product_index += 1
286
+
287
+ current_idx = status_info.current_product_index
288
+
289
+ status_info.streaming_video_path = room_info.product_list[current_idx].start_video
290
+ status_info.sales_info_id = room_info.product_list[current_idx].sales_info_id
291
+
292
+ sales_info = session.exec(
293
+ select(SalesDocAndVideoInfo).where(
294
+ SalesDocAndVideoInfo.sales_info_id == room_info.product_list[current_idx].sales_info_id
295
+ )
296
+ ).one()
297
+
298
+ sales_info.start_time = datetime.now()
299
+ session.add(sales_info)
300
+
301
+ # 新建对话
302
+ init_conversation(
303
+ session, status_info.sales_info_id, room_info.streamer_id, room_info.product_list[current_idx].sales_doc
304
+ )
305
+
306
+ elif process_type == "offline":
307
+ status_info.streaming_video_path = ""
308
+ status_info.live_status = 2
309
+ status_info.end_time = datetime.now()
310
+
311
+ else:
312
+ raise NotImplemented("process type error !!")
313
+
314
+ session.add(status_info)
315
+ session.commit()
316
+
317
+
318
+ def get_message_list(sales_info_id: int) -> List[ChatMessageInfo]:
319
+ """根据销售 ID 获取全部对话
320
+
321
+ Args:
322
+ sales_info_id (int): 销售 ID
323
+
324
+ Returns:
325
+ List[ChatMessageInfo]: 对话列表
326
+ """
327
+ with Session(DB_ENGINE) as session:
328
+
329
+ message_info = session.exec(
330
+ select(ChatMessageInfo)
331
+ .where(and_(ChatMessageInfo.sales_info_id == sales_info_id))
332
+ .order_by(ChatMessageInfo.message_id)
333
+ ).all()
334
+
335
+ if message_info is None:
336
+ return []
337
+
338
+ formate_message_list = []
339
+ for message_ in message_info:
340
+ chat_item = {
341
+ "role": message_.role,
342
+ "avatar": message_.user_info.avatar if message_.role == "user" else message_.streamer_info.avatar,
343
+ "userName": message_.user_info.username if message_.role == "user" else message_.streamer_info.name,
344
+ "message": message_.message,
345
+ "datetime": message_.send_time,
346
+ }
347
+
348
+ chat_item["avatar"] = API_CONFIG.REQUEST_FILES_URL + chat_item["avatar"]
349
+ formate_message_list.append(chat_item)
350
+
351
+ return formate_message_list
352
+
353
+
354
+ def update_room_video_path(status_id: int, news_video_server_path: str):
355
+ """数据库更新 status 主播视频
356
+
357
+ Args:
358
+ status_id (int): 主播间 status ID
359
+ news_video_server_path (str): 需要更新的主播视频 服务器地址
360
+
361
+ """
362
+ with Session(DB_ENGINE) as session:
363
+ # 更新 status 内容
364
+ status_info = session.exec(select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == status_id)).one()
365
+
366
+ status_info.streaming_video_path = news_video_server_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
367
+ session.add(status_info)
368
+ session.commit()
369
+
370
+
371
+ async def get_live_room_info(user_id: int, room_id: int):
372
+ """获取直播间的开播实时信息
373
+
374
+ Args:
375
+ user_id (int): 用户 ID
376
+ room_id (int): 直播间 ID
377
+
378
+ Returns:
379
+ dict: 直播间实时信息
380
+ """
381
+
382
+ # 根据直播间 ID 获取信息
383
+ streaming_room_info = await get_db_streaming_room_info(user_id, room_id)
384
+ streaming_room_info = streaming_room_info[0]
385
+
386
+ # 主播信息
387
+ streamer_info = streaming_room_info.streamer_info
388
+
389
+ # 商品索引
390
+ prodcut_index = streaming_room_info.status.current_product_index
391
+
392
+ # 是否为最后的商品
393
+ final_procut = True if len(streaming_room_info.product_list) - 1 == prodcut_index else False
394
+
395
+ # 对话信息
396
+ conversation_list = get_message_list(streaming_room_info.status.sales_info_id)
397
+
398
+ # 视频转换为服务器地址
399
+ video_path = API_CONFIG.REQUEST_FILES_URL + streaming_room_info.status.streaming_video_path
400
+
401
+ # 返回报文
402
+ res_data = {
403
+ "streamerInfo": streamer_info,
404
+ "conversation": conversation_list,
405
+ "currentProductInfo": streaming_room_info.product_list[prodcut_index].product_info,
406
+ "currentStreamerVideo": video_path,
407
+ "currentProductIndex": streaming_room_info.status.current_product_index,
408
+ "startTime": streaming_room_info.status.start_time,
409
+ "currentPoductStartTime": streaming_room_info.product_list[prodcut_index].start_time,
410
+ "finalProduct": final_procut,
411
+ }
412
+
413
+ logger.info(res_data)
414
+
415
+ return res_data
server/base/database/user_db.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : user_db.py
5
+ @Time : 2024/08/31
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 用户信息数据库操作
10
+ """
11
+
12
+ from sqlmodel import Session, select
13
+
14
+ from ...web_configs import API_CONFIG
15
+ from ..models.user_model import UserBaseInfo, UserInfo
16
+ from .init_db import DB_ENGINE
17
+
18
+
19
+ def get_db_user_info(id: int = -1, username: str = "", all_info: bool = False) -> UserBaseInfo | UserInfo | None:
20
+ """查询数据库获取用户信息
21
+
22
+ Args:
23
+ id (int): 用户 ID
24
+ username (str): 用户名
25
+ all_info (bool): 是否返回含有密码串的敏感信息
26
+
27
+ Returns:
28
+ UserInfo | None: 用户信息,没有查到返回 None
29
+ """
30
+
31
+ if username == "":
32
+ # 使用 ID 的方式进行查询
33
+ query = select(UserInfo).where(UserInfo.user_id == id)
34
+ else:
35
+ query = select(UserInfo).where(UserInfo.username == username)
36
+
37
+ # 查询数据库
38
+ with Session(DB_ENGINE) as session:
39
+ results = session.exec(query).first()
40
+
41
+ # 返回服务器地址
42
+ results.avatar = API_CONFIG.REQUEST_FILES_URL + results.avatar
43
+
44
+ if results is not None and all_info is False:
45
+ # 返回不含用户敏感信息的基本信息
46
+ results = UserBaseInfo(**results.model_dump())
47
+
48
+ return results
server/base/models/__init__.py ADDED
File without changes
server/base/models/llm_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : llm_model.py
5
+ @Time : 2024/09/01
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 大模型对话数据结构
10
+ """
11
+
12
+ from pydantic import BaseModel
13
+
14
+
15
+ class GenProductItem(BaseModel):
16
+ gen_type: str
17
+ instruction: str
server/base/models/product_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : product_model.py
5
+ @Time : 2024/08/30
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 商品数据类型定义
10
+ """
11
+
12
+ from datetime import datetime
13
+ from typing import List
14
+ from pydantic import BaseModel
15
+ from sqlmodel import Field, Relationship, SQLModel
16
+
17
+
18
+ # =======================================================
19
+ # 数据库模型
20
+ # =======================================================
21
+
22
+
23
+ class ProductInfo(SQLModel, table=True):
24
+ """商品信息"""
25
+
26
+ __tablename__ = "product_info"
27
+
28
+ product_id: int | None = Field(default=None, primary_key=True, unique=True)
29
+ product_name: str = Field(index=True, unique=True)
30
+ product_class: str
31
+ heighlights: str
32
+ image_path: str
33
+ instruction: str
34
+ departure_place: str
35
+ delivery_company: str
36
+ selling_price: float
37
+ amount: int
38
+ upload_date: datetime = datetime.now()
39
+ delete: bool = False
40
+
41
+ user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
42
+
43
+ sales_info: list["SalesDocAndVideoInfo"] = Relationship(back_populates="product_info")
44
+
45
+
46
+ # =======================================================
47
+ # 基本模型
48
+ # =======================================================
49
+
50
+
51
+ class ProductPageItem(BaseModel):
52
+ product_list: List[ProductInfo] = []
53
+ currentPage: int = 0 # 当前页数
54
+ pageSize: int = 0 # 页面的组件数量
55
+ totalSize: int = 0 # 总大小
56
+
57
+
58
+ class ProductQueryItem(BaseModel):
59
+ instructionPath: str = "" # 商品说明书路径,用于获取说明书内容
server/base/models/streamer_info_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : streamer_info_model.py
5
+ @Time : 2024/08/30
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 主播信息数据结构
10
+ """
11
+
12
+ from typing import Optional
13
+ from sqlmodel import Field, Relationship, SQLModel
14
+
15
+
16
+ # =======================================================
17
+ # 数据库模型
18
+ # =======================================================
19
+ class StreamerInfo(SQLModel, table=True):
20
+ __tablename__ = "streamer_info"
21
+
22
+ streamer_id: int | None = Field(default=None, primary_key=True, unique=True)
23
+ name: str = Field(index=True, unique=True)
24
+ character: str = ""
25
+ avatar: str = "" # 头像
26
+
27
+ tts_weight_tag: str = "" # 艾丝妲
28
+ tts_reference_sentence: str = ""
29
+ tts_reference_audio: str = ""
30
+
31
+ poster_image: str = ""
32
+ base_mp4_path: str = ""
33
+
34
+ delete: bool = False
35
+
36
+ user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
37
+
38
+ room_info: Optional["StreamRoomInfo"] | None = Relationship(
39
+ back_populates="streamer_info", sa_relationship_kwargs={"lazy": "selectin"}
40
+ )
server/base/models/streamer_room_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : streamer_room_model.py
5
+ @Time : 2024/08/31
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 直播间信息数据结构定义
10
+ """
11
+
12
+ from datetime import datetime
13
+ from typing import Optional
14
+
15
+ from pydantic import BaseModel
16
+ from sqlmodel import Field, Relationship, SQLModel
17
+
18
+ from ..models.user_model import UserInfo
19
+ from ..models.product_model import ProductInfo
20
+ from ..models.streamer_info_model import StreamerInfo
21
+
22
+
23
+ class RoomChatItem(BaseModel):
24
+ roomId: int
25
+ message: str = ""
26
+ asrFileUrl: str = ""
27
+
28
+
29
+ # =======================================================
30
+ # 直播间数据库模型
31
+ # =======================================================
32
+
33
+
34
+ class SalesDocAndVideoInfo(SQLModel, table=True):
35
+ """直播间 文案 和 数字人介绍视频数据结构"""
36
+
37
+ __tablename__ = "sales_doc_and_video_info"
38
+
39
+ sales_info_id: int | None = Field(default=None, primary_key=True, unique=True)
40
+
41
+ sales_doc: str = "" # 讲解文案
42
+ start_video: str = "" # 开播时候第一个讲解视频
43
+ start_time: datetime | None = None # 当前商品开始时间
44
+ selected: bool = True
45
+
46
+ product_id: int | None = Field(default=None, foreign_key="product_info.product_id")
47
+ product_info: ProductInfo | None = Relationship(back_populates="sales_info", sa_relationship_kwargs={"lazy": "selectin"})
48
+
49
+ room_id: int | None = Field(default=None, foreign_key="stream_room_info.room_id")
50
+ stream_room: Optional["StreamRoomInfo"] | None = Relationship(back_populates="product_list")
51
+
52
+
53
+ class OnAirRoomStatusItem(SQLModel, table=True):
54
+ """直播间状态信息"""
55
+
56
+ __tablename__ = "on_air_room_status_item"
57
+
58
+ status_id: int | None = Field(default=None, primary_key=True, unique=True) # 直播间 ID
59
+
60
+ sales_info_id: int | None = Field(default=None, foreign_key="sales_doc_and_video_info.sales_info_id")
61
+
62
+ current_product_index: int = 0 # 目前讲解的商品列表索引
63
+ streaming_video_path: str = "" # 目前介绍使用的视频
64
+
65
+ live_status: int = 0 # 直播间状态 0 未开播,1 正在直播,2 下播了
66
+ start_time: datetime | None = None # 直播开始时间
67
+ end_time: datetime | None = None # 直播下播时间
68
+
69
+ room_info: Optional["StreamRoomInfo"] | None = Relationship(
70
+ back_populates="status", sa_relationship_kwargs={"lazy": "selectin"}
71
+ )
72
+
73
+ """直播间信息,数据库保存时的数据结构"""
74
+
75
+
76
+ class StreamRoomInfo(SQLModel, table=True):
77
+
78
+ __tablename__ = "stream_room_info"
79
+
80
+ room_id: int | None = Field(default=None, primary_key=True, unique=True) # 直播间 ID
81
+
82
+ name: str = "" # 直播间名字
83
+
84
+ product_list: list[SalesDocAndVideoInfo] = Relationship(
85
+ back_populates="stream_room",
86
+ sa_relationship_kwargs={"lazy": "selectin", "order_by": "asc(SalesDocAndVideoInfo.product_id)"},
87
+ ) # 商品列表,查找的时候加上 order_by 自动排序,desc -> 降序; asc -> 升序
88
+
89
+ prohibited_words_id: int = 0 # 违禁词表 ID
90
+ room_poster: str = "" # 海报图
91
+ background_image: str = "" # 主播背景图
92
+
93
+ delete: bool = False # 是否删除
94
+
95
+ status_id: int | None = Field(default=None, foreign_key="on_air_room_status_item.status_id")
96
+ status: OnAirRoomStatusItem | None = Relationship(back_populates="room_info", sa_relationship_kwargs={"lazy": "selectin"})
97
+
98
+ streamer_id: int | None = Field(default=None, foreign_key="streamer_info.streamer_id") # 主播 ID
99
+ streamer_info: StreamerInfo | None = Relationship(back_populates="room_info", sa_relationship_kwargs={"lazy": "selectin"})
100
+
101
+ user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
102
+
103
+
104
+ # =======================================================
105
+ # 直播对话数据库模型
106
+ # =======================================================
107
+
108
+
109
+ class ChatMessageInfo(SQLModel, table=True):
110
+ """直播页面对话数据结构"""
111
+
112
+ __tablename__ = "chat_message_info"
113
+
114
+ message_id: int | None = Field(default=None, primary_key=True, unique=True) # 消息 ID
115
+
116
+ sales_info_id: int | None = Field(default=None, foreign_key="sales_doc_and_video_info.sales_info_id")
117
+ sales_info: SalesDocAndVideoInfo | None = Relationship(sa_relationship_kwargs={"lazy": "selectin"})
118
+
119
+ user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
120
+ user_info: UserInfo | None = Relationship(sa_relationship_kwargs={"lazy": "selectin"})
121
+
122
+ streamer_id: int | None = Field(default=None, foreign_key="streamer_info.streamer_id")
123
+ streamer_info: StreamerInfo | None = Relationship(sa_relationship_kwargs={"lazy": "selectin"})
124
+
125
+ role: str
126
+ message: str
127
+ send_time: datetime | None = None
server/base/models/user_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : user_model.py
5
+ @Time : 2024/08/31
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 用户信息数据结构
10
+ """
11
+
12
+ from datetime import datetime
13
+ from ipaddress import IPv4Address
14
+ from pydantic import BaseModel
15
+ from sqlmodel import Field, SQLModel
16
+
17
+
18
+ # =======================================================
19
+ # 基本模型
20
+ # =======================================================
21
+ class TokenItem(BaseModel):
22
+ access_token: str
23
+ token_type: str
24
+
25
+
26
+ class UserBaseInfo(BaseModel):
27
+ user_id: int | None = Field(default=None, primary_key=True, unique=True)
28
+ username: str = Field(index=True, unique=True)
29
+ email: str | None = None
30
+ avatar: str | None = None
31
+ create_time: datetime = datetime.now()
32
+
33
+
34
+ # =======================================================
35
+ # 数据库模型
36
+ # =======================================================
37
+ class UserInfo(UserBaseInfo, SQLModel, table=True):
38
+
39
+ __tablename__ = "user_info"
40
+
41
+ hashed_password: str
42
+ ip_address: IPv4Address | None = None
43
+ delete: bool = False
server/base/modules/__init__.py ADDED
File without changes
server/base/modules/agent/__init__.py ADDED
File without changes
server/base/modules/agent/agent_worker.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
+ from lagent.actions import ActionExecutor
5
+ from lagent.agents.internlm2_agent import Internlm2Protocol
6
+ from lagent.schema import ActionReturn, AgentReturn
7
+ from loguru import logger
8
+
9
+ from .delivery_time_query import DeliveryTimeQueryAction
10
+
11
+
12
+ def init_handlers(departure_place, delivery_company_name):
13
+ META_CN = "当开启工具以及代码时,根据需求选择合适的工具进行调用"
14
+
15
+ INTERPRETER_CN = (
16
+ "你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。"
17
+ "当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。"
18
+ "这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),"
19
+ "复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),"
20
+ "文本处理和分析(比如文本解析和自然语言处理),"
21
+ "机器学习和数据科学(用于展示模型训练和数据可视化),"
22
+ "以及文件操作和数据导入(处理CSV、JSON等格式的文件)。"
23
+ )
24
+
25
+ PLUGIN_CN = (
26
+ "你可以使用如下工具:"
27
+ "\n{prompt}\n"
28
+ "如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! "
29
+ "同时注意你可以使用的工具,不要随意捏造!"
30
+ )
31
+
32
+ protocol_handler = Internlm2Protocol(
33
+ meta_prompt=META_CN,
34
+ interpreter_prompt=INTERPRETER_CN,
35
+ plugin_prompt=PLUGIN_CN,
36
+ tool=dict(
37
+ begin="{start_token}{name}\n",
38
+ start_token="<|action_start|>",
39
+ name_map=dict(plugin="<|plugin|>", interpreter="<|interpreter|>"),
40
+ belong="assistant",
41
+ end="<|action_end|>\n",
42
+ ),
43
+ )
44
+ action_list = [
45
+ DeliveryTimeQueryAction(
46
+ departure_place=departure_place,
47
+ delivery_company_name=delivery_company_name,
48
+ ),
49
+ ]
50
+ plugin_map = {action.name: action for action in action_list}
51
+ plugin_name = [action.name for action in action_list]
52
+ plugin_action = [plugin_map[name] for name in plugin_name]
53
+ action_executor = ActionExecutor(actions=plugin_action)
54
+
55
+ return action_executor, protocol_handler
56
+
57
+
58
+ def get_agent_result(llm_model_handler, prompt_input, departure_place, delivery_company_name):
59
+
60
+ action_executor, protocol_handler = init_handlers(departure_place, delivery_company_name)
61
+
62
+ # 第一次将 prompt 生成 agent 形式的 prompt
63
+ # [{'role': 'system', 'content': '当开启工具以及代码时,根据需求选择合适的工具进行调用'},
64
+ # {'role': 'system', 'content': '你可以使用如下工具:\n[\n {\n "name": "ArxivSearch.get_arxiv_article_information",\n
65
+ # "description": "This is the subfunction for tool \'ArxivSearch\', you can use this tool. The description of this function is: \\nRun Arxiv search and get the article meta information.",\n
66
+ # "parameters": [\n {\n "name": "query",\n "type": "STRING",\n "description": "the content of search query"\n }\n ],\n "required": [\n "query"\n ],\n "return_data": [\n {\n "name": "content",\n "description": "a list of 3 arxiv search papers",\n "type": "STRING"\n }\n ],\n "parameter_description": "If you call this tool, you must pass arguments in the JSON format {key: value}, where the key is the parameter name."\n }\n]\n
67
+ # 如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! 同时注意你可以使用的工具,不要随意捏造!',
68
+ # 'name': 'plugin'},
69
+ # {'role': 'user', 'content': '帮我搜索 InternLM2 Technical Report'}]
70
+
71
+ # 推理得出:'<|action_start|><|plugin|>\n{"name": "ArxivSearch.get_arxiv_article_information", "parameters": {"query": "InternLM2 Technical Report"}}<|action_end|>\n'
72
+ # 放入 assient 中
73
+
74
+ # 使用 ArxivSearch.get_arxiv_article_information 方法得出结果,放到 envrinment 里面,结果是:
75
+ # [{'role': 'system', 'content': '当开启工具以及代码时,根据需求选择合适的工具进行调用'},
76
+ # {'role': 'system', 'content': '你可以使用如下工具:\n[\n {\n "name": "ArxivSearch.get_arxiv_article_information",\n "description": "This is the subfunction for tool \'ArxivSearch\', you can use this tool. The description of this function is: \\nRun Arxiv search and get the article meta information.",\n "parameters": [\n {\n "name": "query",\n "type": "STRING",\n "description": "the content of search query"\n }\n ],\n "required": [\n "query"\n ],\n "return_data": [\n {\n "name": "content",\n "description": "a list of 3 arxiv search papers",\n "type": "STRING"\n }\n ],\n "parameter_description": "If you call this tool, you must pass arguments in the JSON format {key: value}, where the key is the parameter name."\n }\n]\n如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! 同时注意你可以使用的工具,不要随意捏造!', 'name': 'plugin'},
77
+ # {'role': 'user', 'content': '帮我搜索 InternLM2 Technical Report'},
78
+ # {'role': 'assistant', 'content': '<|action_start|><|plugin|>\n{"name": "ArxivSearch.get_arxiv_article_information", "parameters": {"query": "InternLM2 Technical Report"}}<|action_end|>\n'},
79
+ # {'role': 'environment', 'content': '{"content": "Published: 2024-03-26\\nTitle: InternLM2 Technical Report\\nAuthors: Zheng Cai, Maosong Cao, Haojiong Chen, Kai Chen, Keyu Chen, Xin Chen, Xun Chen, Zehui Chen, Zhi Chen, Pei Chu, Xiaoyi Dong, Haodong Duan, Qi Fan, Zhaoye Fei, Yang Gao, Jiaye Ge, Chenya Gu, Yuzhe Gu, Tao Gui, Aijia Guo, Qipeng Guo, Conghui He, Yingfan Hu, Ting Huang, Tao Jiang, Penglong Jiao, Zhenjiang Jin, Zhikai Lei, Jiaxing Li, Jingwen Li, Linyang Li, Shuaibin Li, Wei Li, Yining Li, Hongwei Liu, Jiangning Liu, Jiawei Hong, Kaiwen Liu, Kuikun Liu, Xiaoran Liu, Chengqi Lv, Haijun Lv, Kai Lv, Li Ma, Runyuan Ma, Zerun Ma, Wenchang Ning, Linke Ouyang, Jiantao Qiu, Yuan Qu, Fukai Shang, Yunfan Shao, Demin Song, Zifan Song, Zhihao Sui, Peng Sun, Yu Sun, Huanze Tang, Bin Wang, Guoteng Wang, Jiaqi Wang, Jiayu Wang, Rui Wang, Yudong Wang, Ziyi Wang, Xingjian Wei, Qizhen Weng, Fan Wu, Yingtong Xiong, Chao Xu, Ruiliang Xu, Hang Yan, Yirong Yan, Xiaogui Yang, Haochen Ye, Huaiyuan Ying, Jia Yu, Jing Yu, Yuhang Zang, Chuyu Zhang, Li Zhang, Pan Zhang, Peng Zhang, Ruijie Zhang, Shuo Zhang, Songyang Zhang, Wenjian Zhang, Wenwei Zhang, Xingcheng Zhang, Xinyue Zhang, Hui Zhao, Qian Zhao, Xiaomeng Zhao, Fengzhe Zhou, Zaida Zhou, Jingming Zhuo, Yicheng Zou, Xipeng Qiu, Yu Qiao, Dahua Lin\\nSummary: The evolution of Large Language Models (LLMs) like ChatGPT and GPT-4 has\\nsparked discussions on the advent of Artificial General Intelligence (AGI).\\nHowever, replicating such advancements in open-source models has been\\nchallenging. This paper introduces InternLM2, an open-source LLM that\\noutperforms its predecessors in comprehensive evaluations across 6 dimensions\\nand 30 benchmarks, long-context modeling, and open-ended subjective evaluations\\nthrough innovative pre-training and optimization techniques. The pre-training\\nprocess of InternLM2 is meticulously detailed, highlighting the preparation of\\ndiverse data types including text, code, and long-context data. InternLM2\\nefficiently captures long-term dependencies, initially trained on 4k tokens\\nbefore advancing to 32k tokens in pre-training and fine-tuning stages,\\nexhibiting remarkable performance on the 200k ``Needle-in-a-Haystack\\" test.\\nInternLM2 is further aligned using Supervised Fine-Tuning (SFT) and a novel\\nConditional Online Reinforcement Learning from Human Feedback (COOL RLHF)\\nstrategy that addresses conflicting human preferences and reward hacking. By\\nreleasing InternLM2 models in different training stages and model sizes, we\\nprovide the community with insights into the model\'s evolution.\\n\\nPublished: 2017-07-27\\nTitle: Cumulative Reports of the SoNDe Project July 2017\\nAuthors: Sebastian Jaksch, Ralf Engels, Günter Kemmerling, Codin Gheorghe, Philip Pahlsson, Sylvain Désert, Frederic Ott\\nSummary: This are the cumulated reports of the SoNDe detector Project as of July 2017.\\nThe contained reports are: - Report on the 1x1 module technical demonstrator -\\nReport on used materials - Report on radiation hardness of components - Report\\non potential additional applications - Report on the 2x2 module technical\\ndemonstrator - Report on test results of the 2x2 technical demonstrator\\n\\nPublished: 2023-03-12\\nTitle: Banach Couples. I. Elementary Theory\\nAuthors: Jaak Peetre, Per Nilsson\\nSummary: This note is an (exact) copy of the report of Jaak Peetre, \\"Banach Couples.\\nI. Elementary Theory\\". Published as Technical Report, Lund (1971). Some more\\nrecent general references have been added and some references updated though"}', 'name': 'plugin'}]
80
+
81
+ # 然后调用大模型推理总结,stream 输出
82
+
83
+ # 判断 name is None ,跳出循环
84
+ inner_history = [{"role": "user", "content": prompt_input}]
85
+ interpreter_executor = None
86
+ max_turn = 7
87
+ for _ in range(max_turn):
88
+
89
+ prompt = protocol_handler.format( # 生成 agent prompt
90
+ inner_step=inner_history,
91
+ plugin_executor=action_executor,
92
+ interpreter_executor=interpreter_executor,
93
+ )
94
+ cur_response = ""
95
+
96
+ agent_return = AgentReturn()
97
+
98
+ # 根据 tokenizer_config.json 中查找到特殊的 token :
99
+ # token_map = {
100
+ # 92538: "<|plugin|>",
101
+ # 92539: "<|interpreter|>",
102
+ # 92540: "<|action_end|>",
103
+ # 92541: "<|action_start|>",
104
+ # }
105
+
106
+ # 将 prompt 给模型
107
+ # [{'role': 'system', 'content': '当开启工具以及代码时,根据需求选择合适的工具进行调用'},
108
+ # {'role': 'system', 'content': '你可以使用如下工具:\n[\n {\n "name": "ArxivSearch.get_arxiv_article_information",\n
109
+ # "description": "This is the subfunction for tool \'ArxivSearch\', you can use this tool. The description of this function is: \\nRun Arxiv search and get the article meta information.",\n
110
+ # "parameters": [\n {\n "name": "query",\n "type": "STRING",\n "description": "the content of search query"\n }\n ],\n "required": [\n "query"\n ],\n "return_data": [\n {\n "name": "content",\n "description": "a list of 3 arxiv search papers",\n "type": "STRING"\n }\n ],\n "parameter_description": "If you call this tool, you must pass arguments in the JSON format {key: value}, where the key is the parameter name."\n }\n]\n
111
+ # 如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! 同时注意你可以使用的工具,不要随意捏造!',
112
+ # 'name': 'plugin'},
113
+ # {'role': 'user', 'content': '帮我搜索 InternLM2 Technical Report'}]
114
+
115
+ # skip_special_tokens = False 输出 <|action_start|> <|plugin|> 等特殊字符
116
+ # for item in model_pipe.stream_infer(prompt, gen_config=prepare_generation_config(skip_special_tokens=False)):
117
+
118
+ logger.info(f"agent input for llm: {prompt}")
119
+
120
+ model_name = llm_model_handler.available_models[0]
121
+ for item in llm_model_handler.chat_completions_v1(
122
+ model=model_name, messages=prompt, stream=True, skip_special_tokens=False
123
+ ):
124
+ # 从 prompt 推理结果例子:
125
+ # '<|action_start|><|plugin|>\n{"name": "ArxivSearch.get_arxiv_article_information", "parameters": {"query": "InternLM2 Technical Report"}}<|action_end|>\n'
126
+
127
+ logger.info(f"agent return = {item}")
128
+ if "content" not in item["choices"][0]["delta"]:
129
+ continue
130
+ current_res = item["choices"][0]["delta"]["content"]
131
+
132
+ if "~" in current_res:
133
+ current_res = item.text.replace("~", "。").replace("。。", "。")
134
+
135
+ cur_response += current_res
136
+
137
+ logger.info(f"agent return = {item}")
138
+
139
+ name, language, action = protocol_handler.parse(
140
+ message=cur_response,
141
+ plugin_executor=action_executor,
142
+ interpreter_executor=interpreter_executor,
143
+ )
144
+ if name: # "plugin"
145
+ if name == "plugin":
146
+ if action_executor:
147
+ executor = action_executor
148
+ else:
149
+ logging.info(msg="No plugin is instantiated!")
150
+ continue
151
+ try:
152
+ action = json.loads(action)
153
+ except Exception as e:
154
+ logging.info(msg=f"Invaild action {e}")
155
+ continue
156
+ elif name == "interpreter":
157
+ if interpreter_executor:
158
+ executor = interpreter_executor
159
+ else:
160
+ logging.info(msg="No interpreter is instantiated!")
161
+ continue
162
+ # agent_return.state = agent_state
163
+ agent_return.response = action
164
+
165
+ print(f"Agent response: {cur_response}")
166
+
167
+ if name:
168
+ print(f"Agent action: {action}")
169
+ action_return: ActionReturn = executor(action["name"], action["parameters"])
170
+ # action_return.thought = language
171
+ # agent_return.actions.append(action_return)
172
+ try:
173
+ return_str = action_return.result[0]["content"]
174
+ return return_str
175
+ except Exception as e:
176
+ return ""
177
+
178
+ # agent_return_list.append(dict(role='assistant', name=name, content=action))
179
+ # agent_return_list.append(protocol_handler.format_response(action_return, name=name))
180
+
181
+ # inner_history.append(dict(role="language", content=language))
182
+
183
+ if not name:
184
+ agent_return.response = language
185
+ break
186
+ # elif action_return.type == executor.finish_action.name:
187
+ # try:
188
+ # response = action_return.args["text"]["response"]
189
+ # except Exception:
190
+ # logging.info(msg="Unable to parse FinishAction.")
191
+ # response = ""
192
+ # agent_return.response = response
193
+ # break
194
+ # else:
195
+ # inner_history.append(dict(role="tool", content=action, name=name))
196
+ # inner_history.append(protocol_handler.format_response(action_return, name=name))
197
+ # # agent_state += 1
198
+ # # agent_return.state = agent_state
199
+ # # yield agent_return
200
+ return ""
server/base/modules/agent/delivery_time_query.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from datetime import datetime
3
+ import hashlib
4
+ import json
5
+ from typing import Optional, Type
6
+
7
+ import jionlp as jio
8
+ import requests
9
+ from lagent.actions.base_action import BaseAction, tool_api
10
+ from lagent.actions.parser import BaseParser, JsonParser
11
+ from lagent.schema import ActionReturn, ActionStatusCode
12
+
13
+ from ....web_configs import WEB_CONFIGS
14
+
15
+
16
+ class DeliveryTimeQueryAction(BaseAction):
17
+ """快递时效查询插件,用于根据用户提出的收货地址查询到达期限"""
18
+
19
+ def __init__(
20
+ self,
21
+ departure_place: str,
22
+ delivery_company_name: str,
23
+ description: Optional[dict] = None,
24
+ parser: Type[BaseParser] = JsonParser,
25
+ enable: bool = True,
26
+ ) -> None:
27
+ super().__init__(description, parser, enable)
28
+ self.departure_place = departure_place # 发货地
29
+
30
+ # 天气查询
31
+ self.weather_query_handler = WeatherQuery(departure_place, WEB_CONFIGS.AGENT_WEATHER_API_KEY)
32
+ self.delivery_time_handler = DeliveryTimeQuery(delivery_company_name, WEB_CONFIGS.AGENT_DELIVERY_TIME_API_KEY)
33
+
34
+ @tool_api
35
+ def run(self, query: str) -> ActionReturn:
36
+ """一个到货时间查询API。可以根据城市名查询到货时间信息。
37
+
38
+ Args:
39
+ query (:class:`str`): 需要查询的城市名。
40
+ """
41
+
42
+ # 获取文本中收货地,发货地后台设置
43
+ # 防止 LLM 将城市识别错误,进行兜底
44
+ city_info = jio.parse_location(query, town_village=True)
45
+ city_name = city_info["city"]
46
+
47
+ # 获取收货地代号 -> 天气
48
+ destination_weather = self.weather_query_handler(city_name)
49
+
50
+ # 获取发货地代号 -> 天气
51
+ departure_weather = self.weather_query_handler(self.departure_place)
52
+
53
+ # 获取到达时间
54
+ delivery_time = self.delivery_time_handler(self.departure_place, city_name)
55
+
56
+ final_str = (
57
+ f"今天日期:{datetime.now().strftime('%m月%d日')}\n"
58
+ f"收货地天气:{destination_weather.result[0]['content']}\n"
59
+ f"发货地天气:{departure_weather.result[0]['content']}\n"
60
+ f"物流信息:{delivery_time.result[0]['content']}\n"
61
+ "回答突出“预计送达时间”和“收货地天气”,如果收货地或者发货地遇到暴雨暴雪等极端天气,须告知用户快递到达时间会有所增加。"
62
+ )
63
+
64
+ tool_return = ActionReturn(type=self.name)
65
+ tool_return.result = [dict(type="text", content=final_str)]
66
+ return tool_return
67
+
68
+
69
+ class WeatherQuery:
70
+ """快递时效查询插件,用于根据用户提出的收货地址查询到达期限"""
71
+
72
+ def __init__(
73
+ self,
74
+ departure_place: str,
75
+ api_key: Optional[str] = None,
76
+ ) -> None:
77
+ self.departure_place = departure_place # 发货地
78
+
79
+ # 天气查询
80
+ # api_key = os.environ.get("WEATHER_API_KEY", key)
81
+ if api_key is None:
82
+ raise ValueError("Please set Weather API key either in the environment as WEATHER_API_KEY")
83
+ self.api_key = api_key
84
+ self.location_query_url = "https://geoapi.qweather.com/v2/city/lookup"
85
+ self.weather_query_url = "https://devapi.qweather.com/v7/weather/now"
86
+
87
+ def parse_results(self, city_name: str, results: dict) -> str:
88
+ """解析 API 返回的信息
89
+
90
+ Args:
91
+ results (dict): JSON 格式的 API 报文。
92
+
93
+ Returns:
94
+ str: 解析后的结果。
95
+ """
96
+ now = results["now"]
97
+ data = (
98
+ # f'数据观测时间: {now["obsTime"]};'
99
+ f"城市名: {city_name};"
100
+ f'温度: {now["temp"]}°C;'
101
+ f'体感温度: {now["feelsLike"]}°C;'
102
+ f'天气: {now["text"]};'
103
+ # f'风向: {now["windDir"]},角度为 {now["wind360"]}°;'
104
+ f'风力等级: {now["windScale"]},风速为 {now["windSpeed"]} km/h;'
105
+ f'相对湿度: {now["humidity"]};'
106
+ f'当前小时累计降水量: {now["precip"]} mm;'
107
+ # f'大气压强: {now["pressure"]} 百帕;'
108
+ f'能见度: {now["vis"]} km。'
109
+ )
110
+ return data
111
+
112
+ def __call__(self, query):
113
+ tool_return = ActionReturn()
114
+ status_code, response = self.search_weather_with_city(query)
115
+ if status_code == -1:
116
+ tool_return.errmsg = response
117
+ tool_return.state = ActionStatusCode.HTTP_ERROR
118
+ elif status_code == 200:
119
+ parsed_res = self.parse_results(query, response)
120
+ tool_return.result = [dict(type="text", content=str(parsed_res))]
121
+ tool_return.state = ActionStatusCode.SUCCESS
122
+ else:
123
+ tool_return.errmsg = str(status_code)
124
+ tool_return.state = ActionStatusCode.API_ERROR
125
+ return tool_return
126
+
127
+ def search_weather_with_city(self, query: str):
128
+ """根据城市名获取城市代号,然后进行天气���询
129
+
130
+ Args:
131
+ query (str): 城市名
132
+
133
+ Returns:
134
+ int: 天气接口调用状态码
135
+ dict: 天气接口返回信息
136
+ """
137
+
138
+ # 获取城市代号
139
+ try:
140
+ city_code_response = requests.get(self.location_query_url, params={"key": self.api_key, "location": query})
141
+ except Exception as e:
142
+ return -1, str(e)
143
+
144
+ if city_code_response.status_code != 200:
145
+ return city_code_response.status_code, city_code_response.json()
146
+ city_code_response = city_code_response.json()
147
+ if len(city_code_response["location"]) == 0:
148
+ return -1, "未查询到城市"
149
+ city_code = city_code_response["location"][0]["id"]
150
+
151
+ # 获取天气
152
+ try:
153
+ weather_response = requests.get(self.weather_query_url, params={"key": self.api_key, "location": city_code})
154
+ except Exception as e:
155
+ return -1, str(e)
156
+ return weather_response.status_code, weather_response.json()
157
+
158
+
159
+ class DeliveryTimeQuery:
160
+ def __init__(
161
+ self,
162
+ delivery_company_name: Optional[str] = "中通",
163
+ api_key: Optional[str] = None,
164
+ ) -> None:
165
+
166
+ # 快递时效查询
167
+ # api_key = os.environ.get("DELIVERY_TIME_API_KEY", key)
168
+ if api_key is None or "," not in api_key:
169
+ raise ValueError(
170
+ 'Please set Delivery time API key either in the environment as DELIVERY_TIME_API_KEY="${e_business_id},${api_key}"'
171
+ )
172
+ self.e_business_id = api_key.split(",")[0]
173
+ self.api_key = api_key.split(",")[1]
174
+ self.api_url = "http://api.kdniao.com/api/dist" # 快递鸟
175
+ self.china_location = jio.china_location_loader()
176
+ # 快递鸟对应的
177
+ DELIVERY_COMPANY_MAP = {
178
+ "德邦": "DBL",
179
+ "邮政": "EMS",
180
+ "京东": "JD",
181
+ "极兔速递": "JTSD",
182
+ "顺丰": "SF",
183
+ "申通": "STO",
184
+ "韵达": "YD",
185
+ "圆通": "YTO",
186
+ "中通": "ZTO",
187
+ }
188
+ self.delivery_company_name = delivery_company_name
189
+ self.delivery_company_id = DELIVERY_COMPANY_MAP[delivery_company_name]
190
+
191
+ @staticmethod
192
+ def data_md5(n):
193
+ # md5加密
194
+ md5 = hashlib.md5()
195
+ md5.update(str(n).encode("utf-8"))
196
+ return md5.hexdigest()
197
+
198
+ def get_data_sign(self, n):
199
+ # 签名
200
+ md5Data = self.data_md5(json.dumps(n) + self.api_key)
201
+ res = str(base64.b64encode(md5Data.encode("utf-8")), "utf-8")
202
+ return res
203
+
204
+ def get_city_detail(self, name):
205
+ # 如果是城市名,使用第一个区名
206
+ city_info = jio.parse_location(name, town_village=True)
207
+ # china_location = jio.china_location_loader()
208
+
209
+ county_name = ""
210
+ for i in self.china_location[city_info["province"]][city_info["city"]].keys():
211
+ if "区" == i[-1]:
212
+ county_name = i
213
+ break
214
+
215
+ return {
216
+ "province": city_info["province"],
217
+ "city": city_info["city"],
218
+ "county": county_name,
219
+ }
220
+
221
+ def get_params(self, send_city, receive_city):
222
+
223
+ # 根据市查出省份和区名称
224
+ send_city_info = self.get_city_detail(send_city)
225
+ receive_city_info = self.get_city_detail(receive_city)
226
+
227
+ # 预计送达时间接口文档;https://www.yuque.com/kdnjishuzhichi/dfcrg1/ynkmts0e5owsnpvu
228
+ # 请求接口指令
229
+ RequestType = "6004"
230
+ # 组装应用级参数
231
+ RequestData = {
232
+ "ShipperCode": self.delivery_company_id,
233
+ "ReceiveArea": receive_city_info["county"],
234
+ "ReceiveCity": receive_city_info["city"],
235
+ "ReceiveProvince": receive_city_info["province"],
236
+ "SendArea": send_city_info["county"],
237
+ "SendCity": send_city_info["city"],
238
+ "SendProvince": send_city_info["province"],
239
+ }
240
+ # 组装系统级参数
241
+ data = {
242
+ "RequestData": json.dumps(RequestData),
243
+ "RequestType": RequestType,
244
+ "EBusinessID": self.e_business_id,
245
+ "DataSign": self.get_data_sign(RequestData),
246
+ "DataType": 2,
247
+ }
248
+ return data
249
+
250
+ def parse_results(self, response):
251
+
252
+ # 返回例子:
253
+ # {
254
+ # "EBusinessID" : "1000000",
255
+ # "Data" : {
256
+ # "DeliveryTime" : "06月15日下午可达",
257
+ # "SendAddress" : null,
258
+ # "ReceiveArea" : "芙蓉区",
259
+ # "SendProvince" : "广东省",
260
+ # "ReceiveProvince" : "湖南省",
261
+ # "ShipperCode" : "DBL",
262
+ # "Hour" : "52h",
263
+ # "SendArea" : "白云区",
264
+ # "ReceiveAddress" : null,
265
+ # "SendCity" : "广州市",
266
+ # "ReceiveCity" : "长沙市"
267
+ # },
268
+ # "ResultCode" : "100",
269
+ # "Success" : true
270
+ # }
271
+
272
+ response = response["Data"]
273
+ data = (
274
+ f'发货地点: {response["SendProvince"]} {response["SendCity"]};'
275
+ f'收货地点: {response["ReceiveProvince"]} {response["ReceiveCity"]};'
276
+ f'预计送达时间: {response["DeliveryTime"]};'
277
+ f"快递公司: {self.delivery_company_name};"
278
+ f'预计时效: {response["Hour"]}。'
279
+ )
280
+ return data
281
+
282
+ def __call__(self, send_city, receive_city):
283
+ tool_return = ActionReturn()
284
+ try:
285
+ res = requests.post(self.api_url, self.get_params(send_city, receive_city))
286
+ status_code = res.status_code
287
+ response = res.json()
288
+ except Exception as e:
289
+ tool_return.errmsg = str(e)
290
+ tool_return.state = ActionStatusCode.API_ERROR
291
+ return tool_return
292
+
293
+ if status_code == 200:
294
+ parsed_res = self.parse_results(response)
295
+ tool_return.result = [dict(type="text", content=str(parsed_res))]
296
+ tool_return.state = ActionStatusCode.SUCCESS
297
+ else:
298
+ tool_return.errmsg = str(status_code)
299
+ tool_return.state = ActionStatusCode.API_ERROR
300
+ return tool_return
server/base/modules/rag/__init__.py ADDED
File without changes
server/base/modules/rag/feature_store.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """extract feature and search with user query."""
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import re
7
+ import shutil
8
+ from multiprocessing import Pool
9
+ from pathlib import Path
10
+ from typing import Any, List, Optional
11
+
12
+ import yaml
13
+
14
+ # 解决 Warning:huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks…
15
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ from BCEmbedding.tools.langchain import BCERerank
18
+ from langchain.embeddings import HuggingFaceEmbeddings
19
+ from langchain.text_splitter import (MarkdownHeaderTextSplitter,
20
+ MarkdownTextSplitter,
21
+ RecursiveCharacterTextSplitter)
22
+ from langchain.vectorstores.faiss import FAISS as Vectorstore
23
+ from langchain_core.documents import Document
24
+ from loguru import logger
25
+ from torch.cuda import empty_cache
26
+
27
+ from .file_operation import FileName, FileOperation
28
+ from .retriever import CacheRetriever, Retriever
29
+
30
+
31
+ def read_and_save(file: FileName):
32
+ if os.path.exists(file.copypath):
33
+ # already exists, return
34
+ logger.info("already exist, skip load")
35
+ return
36
+ file_opr = FileOperation()
37
+ logger.info("reading {}, would save to {}".format(file.origin, file.copypath))
38
+ content, error = file_opr.read(file.origin)
39
+ if error is not None:
40
+ logger.error("{} load error: {}".format(file.origin, str(error)))
41
+ return
42
+
43
+ if content is None or len(content) < 1:
44
+ logger.warning("{} empty, skip save".format(file.origin))
45
+ return
46
+
47
+ with open(file.copypath, "w") as f:
48
+ f.write(content)
49
+
50
+
51
+ def _split_text_with_regex_from_end(text: str, separator: str, keep_separator: bool) -> List[str]:
52
+ # Now that we have the separator, split the text
53
+ if separator:
54
+ if keep_separator:
55
+ # The parentheses in the pattern keep the delimiters in the result.
56
+ _splits = re.split(f"({separator})", text)
57
+ splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
58
+ if len(_splits) % 2 == 1:
59
+ splits += _splits[-1:]
60
+ # splits = [_splits[0]] + splits
61
+ else:
62
+ splits = re.split(separator, text)
63
+ else:
64
+ splits = list(text)
65
+ return [s for s in splits if s != ""]
66
+
67
+
68
+ # copy from https://github.com/chatchat-space/Langchain-Chatchat/blob/master/text_splitter/chinese_recursive_text_splitter.py
69
+ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
70
+
71
+ def __init__(
72
+ self,
73
+ separators: Optional[List[str]] = None,
74
+ keep_separator: bool = True,
75
+ is_separator_regex: bool = True,
76
+ **kwargs: Any,
77
+ ) -> None:
78
+ """Create a new TextSplitter."""
79
+ super().__init__(keep_separator=keep_separator, **kwargs)
80
+ self._separators = separators or ["\n\n", "\n", "。|!|?", "\.\s|\!\s|\?\s", ";|;\s", ",|,\s"]
81
+ self._is_separator_regex = is_separator_regex
82
+
83
+ def _split_text(self, text: str, separators: List[str]) -> List[str]:
84
+ """Split incoming text and return chunks."""
85
+ final_chunks = []
86
+ # Get appropriate separator to use
87
+ separator = separators[-1]
88
+ new_separators = []
89
+ for i, _s in enumerate(separators):
90
+ _separator = _s if self._is_separator_regex else re.escape(_s)
91
+ if _s == "":
92
+ separator = _s
93
+ break
94
+ if re.search(_separator, text):
95
+ separator = _s
96
+ new_separators = separators[i + 1 :]
97
+ break
98
+
99
+ _separator = separator if self._is_separator_regex else re.escape(separator)
100
+ splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
101
+
102
+ # Now go merging things, recursively splitting longer texts.
103
+ _good_splits = []
104
+ _separator = "" if self._keep_separator else separator
105
+ for s in splits:
106
+ if self._length_function(s) < self._chunk_size:
107
+ _good_splits.append(s)
108
+ else:
109
+ if _good_splits:
110
+ merged_text = self._merge_splits(_good_splits, _separator)
111
+ final_chunks.extend(merged_text)
112
+ _good_splits = []
113
+ if not new_separators:
114
+ final_chunks.append(s)
115
+ else:
116
+ other_info = self._split_text(s, new_separators)
117
+ final_chunks.extend(other_info)
118
+ if _good_splits:
119
+ merged_text = self._merge_splits(_good_splits, _separator)
120
+ final_chunks.extend(merged_text)
121
+ return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip() != ""]
122
+
123
+
124
+ class FeatureStore:
125
+ """Tokenize and extract features from the project's documents, for use in
126
+ the reject pipeline and response pipeline."""
127
+
128
+ def __init__(
129
+ self, embeddings: HuggingFaceEmbeddings, reranker: BCERerank, config_path: str = "config.ini", language: str = "zh"
130
+ ) -> None:
131
+ """Init with model device type and config."""
132
+ self.config_path = config_path
133
+ self.reject_throttle = -1
134
+ self.language = language
135
+ with open(config_path, "r", encoding="utf-8") as f:
136
+ config = yaml.safe_load(f)["feature_store"]
137
+ self.reject_throttle = config["reject_throttle"]
138
+
139
+ logger.warning(
140
+ "!!! If your feature generated by `text2vec-large-chinese` before 20240208, please rerun `python3 -m huixiangdou.service.feature_store`" # noqa E501
141
+ )
142
+
143
+ logger.debug("loading text2vec model..")
144
+ self.embeddings = embeddings
145
+ self.reranker = reranker
146
+ self.compression_retriever = None
147
+ self.rejecter = None
148
+ self.retriever = None
149
+ self.md_splitter = MarkdownTextSplitter(chunk_size=768, chunk_overlap=32)
150
+
151
+ if language == "zh":
152
+ self.text_splitter = ChineseRecursiveTextSplitter(
153
+ keep_separator=True, is_separator_regex=True, chunk_size=768, chunk_overlap=32
154
+ )
155
+ else:
156
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=768, chunk_overlap=32)
157
+
158
+ self.head_splitter = MarkdownHeaderTextSplitter(
159
+ headers_to_split_on=[
160
+ ("#", "Header 1"),
161
+ ("##", "Header 2"),
162
+ ("###", "Header 3"),
163
+ ]
164
+ )
165
+
166
+ def split_md(self, text: str, source: None):
167
+ """Split the markdown document in a nested way, first extracting the
168
+ header.
169
+
170
+ If the extraction result exceeds 1024, split it again according to
171
+ length.
172
+ """
173
+ docs = self.head_splitter.split_text(text)
174
+
175
+ final = []
176
+ for doc in docs:
177
+ header = ""
178
+ if len(doc.metadata) > 0:
179
+ if "Header 1" in doc.metadata:
180
+ header += doc.metadata["Header 1"]
181
+ if "Header 2" in doc.metadata:
182
+ header += " "
183
+ header += doc.metadata["Header 2"]
184
+ if "Header 3" in doc.metadata:
185
+ header += " "
186
+ header += doc.metadata["Header 3"]
187
+
188
+ if len(doc.page_content) >= 1024:
189
+ subdocs = self.md_splitter.create_documents([doc.page_content])
190
+ for subdoc in subdocs:
191
+ if len(subdoc.page_content) >= 10:
192
+ final.append("{} {}".format(header, subdoc.page_content.lower()))
193
+ elif len(doc.page_content) >= 10:
194
+ final.append("{} {}".format(header, doc.page_content.lower())) # noqa E501
195
+
196
+ for item in final:
197
+ if len(item) >= 1024:
198
+ logger.debug("source {} split length {}".format(source, len(item)))
199
+ return final
200
+
201
+ def clean_md(self, text: str):
202
+ """Remove parts of the markdown document that do not contain the key
203
+ question words, such as code blocks, URL links, etc."""
204
+ # remove ref
205
+ pattern_ref = r"\[(.*?)\]\(.*?\)"
206
+ new_text = re.sub(pattern_ref, r"\1", text)
207
+
208
+ # remove code block
209
+ pattern_code = r"```.*?```"
210
+ new_text = re.sub(pattern_code, "", new_text, flags=re.DOTALL)
211
+
212
+ # remove underline
213
+ new_text = re.sub("_{5,}", "", new_text)
214
+
215
+ # remove table
216
+ # new_text = re.sub('\|.*?\|\n\| *\:.*\: *\|.*\n(\|.*\|.*\n)*', '', new_text, flags=re.DOTALL) # noqa E501
217
+
218
+ # use lower
219
+ new_text = new_text.lower()
220
+ return new_text
221
+
222
+ def get_md_documents(self, file: FileName):
223
+ documents = []
224
+ length = 0
225
+ text = ""
226
+ with open(file.copypath, encoding="utf8") as f:
227
+ text = f.read()
228
+ text = file.prefix + "\n" + self.clean_md(text)
229
+ if len(text) <= 1:
230
+ return [], length
231
+
232
+ chunks = self.split_md(text=text, source=os.path.abspath(file.copypath))
233
+ for chunk in chunks:
234
+ new_doc = Document(page_content=chunk, metadata={"source": file.basename, "read": file.copypath})
235
+ length += len(chunk)
236
+ documents.append(new_doc)
237
+ return documents, length
238
+
239
+ def get_text_documents(self, text: str, file: FileName):
240
+ if len(text) <= 1:
241
+ return []
242
+ chunks = self.text_splitter.create_documents([text])
243
+ documents = []
244
+ for chunk in chunks:
245
+ # `source` is for return references
246
+ # `read` is for LLM response
247
+ chunk.metadata = {"source": file.basename, "read": file.copypath}
248
+ documents.append(chunk)
249
+ return documents
250
+
251
+ def ingress_response(self, files: list, work_dir: str):
252
+ """Extract the features required for the response pipeline based on the
253
+ document."""
254
+ feature_dir = os.path.join(work_dir, "db_response")
255
+ if not os.path.exists(feature_dir):
256
+ os.makedirs(feature_dir)
257
+
258
+ # logger.info('glob {} in dir {}'.format(files, file_dir))
259
+ file_opr = FileOperation()
260
+ documents = []
261
+
262
+ for i, file in enumerate(files):
263
+ logger.debug("{}/{}.. {}".format(i + 1, len(files), file.basename))
264
+ if not file.state:
265
+ continue
266
+
267
+ if file._type == "md":
268
+ md_documents, md_length = self.get_md_documents(file)
269
+ documents += md_documents
270
+ logger.info("{} content length {}".format(file._type, md_length))
271
+ file.reason = str(md_length)
272
+
273
+ else:
274
+ # now read pdf/word/excel/ppt text
275
+ text, error = file_opr.read(file.copypath)
276
+ if error is not None:
277
+ file.state = False
278
+ file.reason = str(error)
279
+ continue
280
+ file.reason = str(len(text))
281
+ logger.info("{} content length {}".format(file._type, len(text)))
282
+ text = file.prefix + text
283
+ documents += self.get_text_documents(text, file)
284
+
285
+ if len(documents) < 1:
286
+ return
287
+ vs = Vectorstore.from_documents(documents, self.embeddings)
288
+ vs.save_local(feature_dir)
289
+
290
+ def ingress_reject(self, files: list, work_dir: str):
291
+ """Extract the features required for the reject pipeline based on
292
+ documents."""
293
+ feature_dir = os.path.join(work_dir, "db_reject")
294
+ if not os.path.exists(feature_dir):
295
+ os.makedirs(feature_dir)
296
+
297
+ documents = []
298
+ file_opr = FileOperation()
299
+
300
+ logger.debug("ingress reject..")
301
+ for i, file in enumerate(files):
302
+ if not file.state:
303
+ continue
304
+
305
+ if file._type == "md":
306
+ # reject base not clean md
307
+ text = file.basename + "\n"
308
+ with open(file.copypath, encoding="utf8") as f:
309
+ text += f.read()
310
+ if len(text) <= 1:
311
+ continue
312
+
313
+ chunks = self.split_md(text=text, source=os.path.abspath(file.copypath))
314
+ for chunk in chunks:
315
+ new_doc = Document(page_content=chunk, metadata={"source": file.basename, "read": file.copypath})
316
+ documents.append(new_doc)
317
+
318
+ else:
319
+ text, error = file_opr.read(file.copypath)
320
+ if error is not None:
321
+ continue
322
+ text = file.basename + text
323
+ documents += self.get_text_documents(text, file)
324
+
325
+ if len(documents) < 1:
326
+ return
327
+ vs = Vectorstore.from_documents(documents, self.embeddings)
328
+ vs.save_local(feature_dir)
329
+
330
+ def preprocess(self, files: list, work_dir: str):
331
+ """Preprocesses files in a given directory. Copies each file to
332
+ 'preprocess' with new name formed by joining all subdirectories with
333
+ '_'.
334
+
335
+ Args:
336
+ files (list): original file list.
337
+ work_dir (str): Working directory where preprocessed files will be stored. # noqa E501
338
+
339
+ Returns:
340
+ str: Path to the directory where preprocessed markdown files are saved.
341
+
342
+ Raises:
343
+ Exception: Raise an exception if no markdown files are found in the provided repository directory. # noqa E501
344
+ """
345
+ preproc_dir = os.path.join(work_dir, "preprocess")
346
+ if not os.path.exists(preproc_dir):
347
+ os.makedirs(preproc_dir)
348
+
349
+ pool = Pool(processes=16)
350
+ file_opr = FileOperation()
351
+ for idx, file in enumerate(files):
352
+ if not os.path.exists(file.origin):
353
+ file.state = False
354
+ file.reason = "skip not exist"
355
+ continue
356
+
357
+ if file._type == "image":
358
+ file.state = False
359
+ file.reason = "skip image"
360
+
361
+ elif file._type in ["pdf", "word", "excel", "ppt", "html"]:
362
+ # read pdf/word/excel file and save to text format
363
+ md5 = file_opr.md5(file.origin)
364
+ file.copypath = os.path.join(preproc_dir, "{}.text".format(md5))
365
+ pool.apply_async(read_and_save, (file,))
366
+
367
+ elif file._type in ["md", "text"]:
368
+ # rename text files to new dir
369
+ md5 = file_opr.md5(file.origin)
370
+ file.copypath = os.path.join(preproc_dir, file.origin.replace("/", "_")[-84:])
371
+ try:
372
+ shutil.copy(file.origin, file.copypath)
373
+ file.state = True
374
+ file.reason = "preprocessed"
375
+ except Exception as e:
376
+ file.state = False
377
+ file.reason = str(e)
378
+
379
+ else:
380
+ file.state = False
381
+ file.reason = "skip unknown format"
382
+ pool.close()
383
+ logger.debug("waiting for preprocess read finish..")
384
+ pool.join()
385
+
386
+ # check process result
387
+ for file in files:
388
+ if file._type in ["pdf", "word", "excel"]:
389
+ if os.path.exists(file.copypath):
390
+ file.state = True
391
+ file.reason = "preprocessed"
392
+ else:
393
+ file.state = False
394
+ file.reason = "read error"
395
+
396
+ def initialize(self, files: list, work_dir: str):
397
+ """Initializes response and reject feature store.
398
+
399
+ Only needs to be called once. Also calculates the optimal threshold
400
+ based on provided good and bad question examples, and saves it in the
401
+ configuration file.
402
+ """
403
+ logger.info("initialize response and reject feature store, you only need call this once.") # noqa E501
404
+ self.preprocess(files=files, work_dir=work_dir)
405
+ self.ingress_response(files=files, work_dir=work_dir)
406
+ self.ingress_reject(files=files, work_dir=work_dir)
407
+
408
+
409
+ def parse_args():
410
+ """Parse command-line arguments."""
411
+ parser = argparse.ArgumentParser(description="Feature store for processing directories.")
412
+ parser.add_argument("--work_dir", type=str, default="work_dir", help="Working directory.")
413
+ parser.add_argument("--repo_dir", type=str, default="repodir", help="Root directory where the repositories are located.")
414
+ parser.add_argument(
415
+ "--config_path", default="config.ini", help="Feature store configuration path. Default value is config.ini"
416
+ )
417
+ parser.add_argument(
418
+ "--good_questions",
419
+ default="resource/good_questions.json",
420
+ help="Positive examples in the dataset. Default value is resource/good_questions.json", # noqa E251 # noqa E501
421
+ )
422
+ parser.add_argument(
423
+ "--bad_questions",
424
+ default="resource/bad_questions.json",
425
+ help="Negative examples json path. Default value is resource/bad_questions.json", # noqa E251 # noqa E501
426
+ )
427
+ parser.add_argument("--sample", help="Input an json file, save reject and search output.")
428
+ args = parser.parse_args()
429
+ return args
430
+
431
+
432
+ def test_reject(retriever: Retriever, sample: str = None):
433
+ """Simple test reject pipeline."""
434
+ if sample is None:
435
+ real_questions = [
436
+ "SAM 10个T 的训练集,怎么比比较公平呢~?速度上还有缺陷吧?",
437
+ "想问下,如果只是推理的话,amp的fp16是不会省显存么,我看parameter仍然是float32,开和不开推理的显存占用都是一样的。能不能直接用把数据和model都 .half() 代替呢,相比之下amp好在哪里", # noqa E501
438
+ "mmdeploy支持ncnn vulkan部署么,我只找到了ncnn cpu 版本",
439
+ "大佬们,如果我想在高空检测安全帽,我应该用 mmdetection 还是 mmrotate",
440
+ "请问 ncnn 全称是什么",
441
+ "有啥中文的 text to speech 模型吗?",
442
+ "今天中午吃什么?",
443
+ "huixiangdou 是什么?",
444
+ "mmpose 如何安装?",
445
+ "使用科研仪器需要注意什么?",
446
+ ]
447
+ else:
448
+ with open(sample) as f:
449
+ real_questions = json.load(f)
450
+
451
+ for example in real_questions:
452
+ reject, _ = retriever.is_reject(example)
453
+
454
+ if reject:
455
+ logger.error(f"reject query: {example}")
456
+ else:
457
+ logger.warning(f"process query: {example}")
458
+
459
+ if sample is not None:
460
+ if reject:
461
+ with open("workdir/negative.txt", "a+") as f:
462
+ f.write(example)
463
+ f.write("\n")
464
+ else:
465
+ with open("workdir/positive.txt", "a+") as f:
466
+ f.write(example)
467
+ f.write("\n")
468
+
469
+ empty_cache()
470
+
471
+
472
+ def test_query(retriever: Retriever, sample: str = None):
473
+ """Simple test response pipeline."""
474
+ if sample is not None:
475
+ with open(sample) as f:
476
+ real_questions = json.load(f)
477
+ logger.add("logs/feature_store_query.log", rotation="4MB")
478
+ else:
479
+ real_questions = ["mmpose installation", "how to use std::vector ?"]
480
+
481
+ for example in real_questions:
482
+ example = example[0:400]
483
+ print(retriever.query(example))
484
+ empty_cache()
485
+
486
+ empty_cache()
487
+
488
+
489
+ def fix_system_error():
490
+ """
491
+ Fix `No module named 'faiss.swigfaiss_avx2`
492
+ """
493
+ import os
494
+ from pathlib import Path
495
+
496
+ import faiss
497
+
498
+ if Path(faiss.__file__).parent.joinpath("swigfaiss_avx2.py").exists():
499
+ return
500
+
501
+ print("Fixing faiss error...")
502
+ os.system(f"cd {Path(faiss.__file__).parent} && ln -s swigfaiss.py swigfaiss_avx2.py")
503
+
504
+
505
+ def gen_vector_db(config_path, source_dir, work_dir, test_mode=False, update_reject=False):
506
+
507
+ # 解决 faiss 导入问题
508
+ fix_system_error()
509
+
510
+ # 必须是绝对路径,否则加载会有问题
511
+ work_dir = str(Path(work_dir).absolute())
512
+
513
+ cache = CacheRetriever(config_path=config_path)
514
+
515
+ # 生成向量数据库
516
+ fs_init = FeatureStore(embeddings=cache.embeddings, reranker=cache.reranker, config_path=config_path)
517
+
518
+ # walk all files in repo dir
519
+ file_opr = FileOperation()
520
+ files = file_opr.scan_dir(repo_dir=source_dir)
521
+ fs_init.initialize(files=files, work_dir=work_dir)
522
+ file_opr.summarize(files)
523
+ del fs_init
524
+
525
+ # update reject throttle
526
+ if update_reject:
527
+ retriever = cache.get(config_path=config_path, work_dir=work_dir)
528
+ with open(os.path.join("resource", "good_questions.json")) as f:
529
+ good_questions = json.load(f)
530
+ with open(os.path.join("resource", "bad_questions.json")) as f:
531
+ bad_questions = json.load(f)
532
+ retriever.update_throttle(config_path=config_path, good_questions=good_questions, bad_questions=bad_questions)
533
+
534
+ cache.pop("default")
535
+
536
+ if test_mode:
537
+ # test
538
+ retriever = cache.get(config_path=config_path, work_dir=work_dir)
539
+ # test_reject(retriever, args.sample)
540
+ test_query(retriever, args.sample)
541
+
542
+
543
+ if __name__ == "__main__":
544
+ args = parse_args()
545
+ gen_vector_db(args.config_path, args.repo_dir, args.work_dir, test_mode=True)
server/base/modules/rag/file_operation.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import pandas as pd
5
+ from bs4 import BeautifulSoup
6
+ from loguru import logger
7
+
8
+
9
+ class FileName:
10
+ """Record file original name, state and copied filepath with text
11
+ format."""
12
+
13
+ def __init__(self, root: str, filename: str, _type: str):
14
+ self.root = root
15
+ self.prefix = filename.replace("/", "_")
16
+ self.basename = os.path.basename(filename)
17
+ self.origin = os.path.join(root, filename)
18
+ self.copypath = ""
19
+ self._type = _type
20
+ self.state = True
21
+ self.reason = ""
22
+
23
+ def __str__(self):
24
+ return "{},{},{},{}\n".format(self.basename, self.copypath, self.state, self.reason)
25
+
26
+
27
+ class FileOperation:
28
+ """Encapsulate all file reading operations."""
29
+
30
+ def __init__(self):
31
+ self.image_suffix = [".jpg", ".jpeg", ".png", ".bmp"]
32
+ self.md_suffix = ".md"
33
+ self.text_suffix = [".txt", ".text"]
34
+ self.excel_suffix = [".xlsx", ".xls", ".csv"]
35
+ self.pdf_suffix = ".pdf"
36
+ self.ppt_suffix = ".pptx"
37
+ self.html_suffix = [".html", ".htm", ".shtml", ".xhtml"]
38
+ self.word_suffix = [".docx", ".doc"]
39
+ self.normal_suffix = (
40
+ [self.md_suffix]
41
+ + self.text_suffix
42
+ + self.excel_suffix
43
+ + [self.pdf_suffix]
44
+ + self.word_suffix
45
+ + [self.ppt_suffix]
46
+ + self.html_suffix
47
+ )
48
+
49
+ def get_type(self, filepath: str):
50
+ filepath = filepath.lower()
51
+ if filepath.endswith(self.pdf_suffix):
52
+ return "pdf"
53
+
54
+ if filepath.endswith(self.md_suffix):
55
+ return "md"
56
+
57
+ if filepath.endswith(self.ppt_suffix):
58
+ return "ppt"
59
+
60
+ for suffix in self.image_suffix:
61
+ if filepath.endswith(suffix):
62
+ return "image"
63
+
64
+ for suffix in self.text_suffix:
65
+ if filepath.endswith(suffix):
66
+ return "text"
67
+
68
+ for suffix in self.word_suffix:
69
+ if filepath.endswith(suffix):
70
+ return "word"
71
+
72
+ for suffix in self.excel_suffix:
73
+ if filepath.endswith(suffix):
74
+ return "excel"
75
+
76
+ for suffix in self.html_suffix:
77
+ if filepath.endswith(suffix):
78
+ return "html"
79
+ return None
80
+
81
+ def md5(self, filepath: str):
82
+ hash_object = hashlib.sha256()
83
+ with open(filepath, "rb") as file:
84
+ chunk_size = 8192
85
+ while chunk := file.read(chunk_size):
86
+ hash_object.update(chunk)
87
+
88
+ return hash_object.hexdigest()[0:8]
89
+
90
+ def summarize(self, files: list):
91
+ success = 0
92
+ skip = 0
93
+ failed = 0
94
+
95
+ for file in files:
96
+ if file.state:
97
+ success += 1
98
+ elif file.reason == "skip":
99
+ skip += 1
100
+ else:
101
+ logger.info("{} {}".format(file.origin, file.reason))
102
+ failed += 1
103
+
104
+ logger.info("{} {}".format(file.reason, file.copypath))
105
+ logger.info("累计{}文件,成功{}个,跳过{}个,异常{}个".format(len(files), success, skip, failed))
106
+
107
+ def scan_dir(self, repo_dir: str):
108
+ files = []
109
+ for root, _, filenames in os.walk(repo_dir):
110
+ for filename in filenames:
111
+ _type = self.get_type(filename)
112
+ if _type is not None:
113
+ files.append(FileName(root=root, filename=filename, _type=_type))
114
+ return files
115
+
116
+ def read_pdf(self, filepath: str):
117
+ # load pdf and serialize table
118
+
119
+ # TODO fitz 安装有些不兼容,后续按需完善
120
+ import fitz
121
+
122
+ text = ""
123
+ with fitz.open(filepath) as pages:
124
+ for page in pages:
125
+ text += page.get_text()
126
+ tables = page.find_tables()
127
+ for table in tables:
128
+ tablename = "_".join(filter(lambda x: x is not None and "Col" not in x, table.header.names))
129
+ pan = table.to_pandas()
130
+ json_text = pan.dropna(axis=1).to_json(force_ascii=False)
131
+ text += tablename
132
+ text += "\n"
133
+ text += json_text
134
+ text += "\n"
135
+ return text
136
+
137
+ def read_excel(self, filepath: str):
138
+ table = None
139
+ if filepath.endswith(".csv"):
140
+ table = pd.read_csv(filepath)
141
+ else:
142
+ table = pd.read_excel(filepath)
143
+ if table is None:
144
+ return ""
145
+ json_text = table.dropna(axis=1).to_json(force_ascii=False)
146
+ return json_text
147
+
148
+ def read(self, filepath: str):
149
+ file_type = self.get_type(filepath)
150
+
151
+ text = ""
152
+
153
+ if not os.path.exists(filepath):
154
+ return text, None
155
+
156
+ try:
157
+
158
+ if file_type == "md" or file_type == "text":
159
+ with open(filepath) as f:
160
+ text = f.read()
161
+
162
+ elif file_type == "pdf":
163
+ text += self.read_pdf(filepath)
164
+
165
+ elif file_type == "excel":
166
+ text += self.read_excel(filepath)
167
+
168
+ elif file_type == "word" or file_type == "ppt":
169
+ # https://stackoverflow.com/questions/36001482/read-doc-file-with-python
170
+ # https://textract.readthedocs.io/en/latest/installation.html
171
+
172
+ # TODO textract 在 pip 高于 24.1 后安装不了,因为其库自身原因,后续按需进行完善
173
+ # 可自行安装 pip install textract==1.6.5
174
+ import textract # for word and ppt
175
+
176
+ text = textract.process(filepath).decode("utf8")
177
+ if file_type == "ppt":
178
+ text = text.replace("\n", " ")
179
+
180
+ elif file_type == "html":
181
+ with open(filepath) as f:
182
+ soup = BeautifulSoup(f.read(), "html.parser")
183
+ text += soup.text
184
+
185
+ except Exception as e:
186
+ logger.error((filepath, str(e)))
187
+ return "", e
188
+ text = text.replace("\n\n", "\n")
189
+ text = text.replace("\n\n", "\n")
190
+ text = text.replace("\n\n", "\n")
191
+ text = text.replace(" ", " ")
192
+ text = text.replace(" ", " ")
193
+ text = text.replace(" ", " ")
194
+ return text, None
195
+
196
+
197
+ if __name__ == "__main__":
198
+
199
+ def get_pdf_files(directory):
200
+ pdf_files = []
201
+ # 遍历目录
202
+ for root, dirs, files in os.walk(directory):
203
+ for file in files:
204
+ # 检查文件扩展名是否为.pdf
205
+ if file.lower().endswith(".pdf"):
206
+ # 将完整路径添加到列表中
207
+ pdf_files.append(os.path.abspath(os.path.join(root, file)))
208
+ return pdf_files
209
+
210
+ # 将你想要搜索的目录替换为下面的路径
211
+ pdf_list = get_pdf_files("/home/khj/huixiangdou-web-online-data/hxd-bad-file")
212
+
213
+ # 打印所有找到的PDF文件的绝对路径
214
+
215
+ opr = FileOperation()
216
+ for pdf_path in pdf_list:
217
+ text, error = opr.read(pdf_path)
218
+ print("processing {}".format(pdf_path))
219
+ if error is not None:
220
+ # pdb.set_trace()
221
+ print("")
222
+
223
+ else:
224
+ if text is not None:
225
+ print(len(text))
226
+ else:
227
+ # pdb.set_trace()
228
+ print("")
server/base/modules/rag/rag_worker.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from loguru import logger
6
+
7
+ from ....web_configs import WEB_CONFIGS
8
+ from ...database.product_db import get_db_product_info
9
+ from .feature_store import gen_vector_db
10
+ from .retriever import CacheRetriever
11
+
12
+ # 基础配置
13
+ CONTEXT_MAX_LENGTH = 3000 # 上下文最大长度
14
+ GENERATE_TEMPLATE = "这是说明书:“{}”\n 客户的问题:“{}” \n 请阅读说明并运用你的性格进行解答。" # RAG prompt 模板
15
+
16
+ # RAG 实例句柄
17
+ RAG_RETRIEVER = None
18
+
19
+
20
+ def build_rag_prompt(rag_retriever: CacheRetriever, product_name, prompt):
21
+
22
+ real_retriever = rag_retriever.get(fs_id="default")
23
+
24
+ if isinstance(real_retriever, tuple):
25
+ logger.info(f" @@@ GOT real_retriever == tuple : {real_retriever}")
26
+ return ""
27
+
28
+ chunk, db_context, references = real_retriever.query(
29
+ f"商品名:{product_name}。{prompt}", context_max_length=CONTEXT_MAX_LENGTH - 2 * len(GENERATE_TEMPLATE)
30
+ )
31
+ logger.info(f"db_context = {db_context}")
32
+
33
+ if db_context is not None and len(db_context) > 1:
34
+ prompt_rag = GENERATE_TEMPLATE.format(db_context, prompt)
35
+ else:
36
+ logger.info("db_context get error")
37
+ prompt_rag = prompt
38
+
39
+ logger.info(f"RAG reference = {references}")
40
+ logger.info("=" * 20)
41
+
42
+ return prompt_rag
43
+
44
+
45
+ def init_rag_retriever(rag_config: str, db_path: str):
46
+ torch.cuda.empty_cache()
47
+
48
+ retriever = CacheRetriever(config_path=rag_config)
49
+
50
+ # 初始化
51
+ retriever.get(fs_id="default", config_path=rag_config, work_dir=db_path)
52
+
53
+ return retriever
54
+
55
+
56
+ async def gen_rag_db(user_id, force_gen=False):
57
+ """
58
+ 生成向量数据库。
59
+
60
+ 参数:
61
+ force_gen - 布尔值,当设置为 True 时,即使数据库已存在也会重新生成数据库。
62
+ """
63
+
64
+ # 检查数据库目录是否存在,如果存在且force_gen为False,则不执行生成操作
65
+ if Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists() and not force_gen:
66
+ return
67
+
68
+ if force_gen and Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists():
69
+ shutil.rmtree(WEB_CONFIGS.RAG_VECTOR_DB_DIR)
70
+
71
+ # 仅仅遍历 instructions 字段里面的文件
72
+ if Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).exists():
73
+ shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP)
74
+ Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).mkdir(exist_ok=True, parents=True)
75
+
76
+ # 读取 yaml 文件,获取所有说明书路径,并移动到 tmp 目录
77
+ product_list, _ = await get_db_product_info(user_id)
78
+
79
+ for info in product_list:
80
+
81
+ shutil.copyfile(
82
+ Path(
83
+ WEB_CONFIGS.SERVER_FILE_ROOT,
84
+ WEB_CONFIGS.PRODUCT_FILE_DIR,
85
+ WEB_CONFIGS.INSTRUCTIONS_DIR,
86
+ Path(info.instruction).name,
87
+ ),
88
+ Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).joinpath(Path(info.instruction).name),
89
+ )
90
+
91
+ logger.info("Generating rag database, pls wait ...")
92
+ # 调用函数生成向量数据库
93
+ gen_vector_db(
94
+ WEB_CONFIGS.RAG_CONFIG_PATH,
95
+ str(Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).absolute()),
96
+ WEB_CONFIGS.RAG_VECTOR_DB_DIR,
97
+ )
98
+
99
+ # 删除过程文件
100
+ shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP)
101
+
102
+
103
+ async def load_rag_model(user_id):
104
+
105
+ global RAG_RETRIEVER
106
+
107
+ # 重新生成 RAG 向量数据库
108
+ await gen_rag_db(user_id)
109
+
110
+ # 加载 rag 模型
111
+ RAG_RETRIEVER = init_rag_retriever(rag_config=WEB_CONFIGS.RAG_CONFIG_PATH, db_path=WEB_CONFIGS.RAG_VECTOR_DB_DIR)
112
+ logger.info("load rag model done !...")
113
+
114
+
115
+ async def rebuild_rag_db(user_id, db_name="default"):
116
+
117
+ # 重新生成 RAG 向量数据库
118
+ await gen_rag_db(user_id, force_gen=True)
119
+
120
+ # 重新加载 retriever
121
+ RAG_RETRIEVER.pop(db_name)
122
+ RAG_RETRIEVER.get(fs_id=db_name, config_path=WEB_CONFIGS.RAG_CONFIG_PATH, work_dir=WEB_CONFIGS.RAG_VECTOR_DB_DIR)
server/base/modules/rag/retriever.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """extract feature and search with user query."""
2
+
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import yaml
8
+ from BCEmbedding.tools.langchain import BCERerank
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
+ from langchain.retrievers import ContextualCompressionRetriever
11
+ from langchain.vectorstores.faiss import FAISS as Vectorstore
12
+ from langchain_community.vectorstores.utils import DistanceStrategy
13
+ from loguru import logger
14
+ from modelscope import snapshot_download
15
+ from sklearn.metrics import precision_recall_curve
16
+
17
+ from ....web_configs import WEB_CONFIGS
18
+ from .file_operation import FileOperation
19
+
20
+
21
+ class Retriever:
22
+ """Tokenize and extract features from the project's documents, for use in
23
+ the reject pipeline and response pipeline."""
24
+
25
+ def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None:
26
+ """Init with model device type and config."""
27
+ self.reject_throttle = reject_throttle
28
+ self.rejecter = Vectorstore.load_local(
29
+ os.path.join(work_dir, "db_reject"), embeddings=embeddings, allow_dangerous_deserialization=True
30
+ )
31
+ self.retriever = Vectorstore.load_local(
32
+ os.path.join(work_dir, "db_response"),
33
+ embeddings=embeddings,
34
+ allow_dangerous_deserialization=True,
35
+ distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
36
+ ).as_retriever(search_type="similarity", search_kwargs={"score_threshold": 0.15, "k": 30})
37
+ self.compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=self.retriever)
38
+
39
+ def is_reject(self, question, k=30, disable_throttle=False):
40
+ """If no search results below the threshold can be found from the
41
+ database, reject this query."""
42
+ if disable_throttle:
43
+ # for searching throttle during update sample
44
+ docs_with_score = self.rejecter.similarity_search_with_relevance_scores(question, k=1)
45
+ if len(docs_with_score) < 1:
46
+ return True, docs_with_score
47
+ return False, docs_with_score
48
+ else:
49
+ # for retrieve result
50
+ # if no chunk passed the throttle, give the max
51
+ docs_with_score = self.rejecter.similarity_search_with_relevance_scores(question, k=k)
52
+ ret = []
53
+ max_score = -1
54
+ top1 = None
55
+ for doc, score in docs_with_score:
56
+ if score >= self.reject_throttle:
57
+ ret.append(doc)
58
+ if score > max_score:
59
+ max_score = score
60
+ top1 = (doc, score)
61
+ reject = False if len(ret) > 0 else True
62
+ return reject, [top1]
63
+
64
+ def update_throttle(self, config_path: str = "config.yaml", good_questions=[], bad_questions=[]):
65
+ """Update reject throttle based on positive and negative examples."""
66
+
67
+ if len(good_questions) == 0 or len(bad_questions) == 0:
68
+ raise Exception("good and bad question examples cat not be empty.")
69
+ questions = good_questions + bad_questions
70
+ predictions = []
71
+ for question in questions:
72
+ self.reject_throttle = -1
73
+ _, docs = self.is_reject(question=question, disable_throttle=True)
74
+ score = docs[0][1]
75
+ predictions.append(max(0, score))
76
+
77
+ labels = [1 for _ in range(len(good_questions))] + [0 for _ in range(len(bad_questions))]
78
+ precision, recall, thresholds = precision_recall_curve(labels, predictions)
79
+
80
+ # get the best index for sum(precision, recall)
81
+ sum_precision_recall = precision[:-1] + recall[:-1]
82
+ index_max = np.argmax(sum_precision_recall)
83
+ optimal_threshold = max(thresholds[index_max], 0.0)
84
+
85
+ with open(config_path, "r", encoding="utf-8") as f:
86
+ config = yaml.safe_load(f)
87
+ config["feature_store"]["reject_throttle"] = float(optimal_threshold)
88
+ with open(config_path, "w", encoding="utf8") as f:
89
+ yaml.dump(config, f)
90
+ logger.info(f"The optimal threshold is: {optimal_threshold}, saved it to {config_path}") # noqa E501
91
+
92
+ def query(self, question: str, context_max_length: int = 16000): # , tracker: QueryTracker = None):
93
+ """Processes a query and returns the best match from the vector store
94
+ database. If the question is rejected, returns None.
95
+
96
+ Args:
97
+ question (str): The question asked by the user.
98
+
99
+ Returns:
100
+ str: The best matching chunk, or None.
101
+ str: The best matching text, or None
102
+ """
103
+ print(f"DEBUG -1: enter query")
104
+
105
+ if question is None or len(question) < 1:
106
+ print(f"DEBUG 0: len error")
107
+
108
+ return None, None, []
109
+
110
+ if len(question) > 512:
111
+ logger.warning("input too long, truncate to 512")
112
+ question = question[0:512]
113
+
114
+ # reject, docs = self.is_reject(question=question)
115
+ # assert (len(docs) > 0)
116
+ # if reject:
117
+ # return None, None, [docs[0][0].metadata['source']]
118
+
119
+ docs = self.compression_retriever.get_relevant_documents(question)
120
+
121
+ print(f"DEBUG 1: {docs}")
122
+
123
+ # if tracker is not None:
124
+ # tracker.log('retrieve', [doc.metadata['source'] for doc in docs])
125
+ chunks = []
126
+ context = ""
127
+ references = []
128
+
129
+ # add file text to context, until exceed `context_max_length`
130
+
131
+ file_opr = FileOperation()
132
+ for idx, doc in enumerate(docs):
133
+ chunk = doc.page_content
134
+ chunks.append(chunk)
135
+
136
+ if "read" not in doc.metadata:
137
+ logger.error(
138
+ "If you are using the version before 20240319, please rerun `python3 -m huixiangdou.service.feature_store`"
139
+ )
140
+ raise Exception("huixiangdou version mismatch")
141
+ file_text, error = file_opr.read(doc.metadata["read"])
142
+ if error is not None:
143
+ # read file failed, skip
144
+ print(f"DEBUG 2: error")
145
+
146
+ continue
147
+
148
+ source = doc.metadata["source"]
149
+ logger.info("target {} file length {}".format(source, len(file_text)))
150
+
151
+ print(f"DEBUG 3: target {source}, file length {len(file_text)}")
152
+
153
+ if len(file_text) + len(context) > context_max_length:
154
+ if source in references:
155
+ continue
156
+ references.append(source)
157
+ # add and break
158
+ add_len = context_max_length - len(context)
159
+ if add_len <= 0:
160
+ break
161
+ chunk_index = file_text.find(chunk)
162
+ if chunk_index == -1:
163
+ # chunk not in file_text
164
+ context += chunk
165
+ context += "\n"
166
+ context += file_text[0 : add_len - len(chunk) - 1]
167
+ else:
168
+ start_index = max(0, chunk_index - (add_len - len(chunk)))
169
+ context += file_text[start_index : start_index + add_len]
170
+ break
171
+
172
+ if source not in references:
173
+ context += file_text
174
+ context += "\n"
175
+ references.append(source)
176
+
177
+ context = context[0:context_max_length]
178
+ logger.debug("query:{} top1 file:{}".format(question, references[0]))
179
+ return "\n".join(chunks), context, [os.path.basename(r) for r in references]
180
+
181
+
182
+ class CacheRetriever:
183
+
184
+ def __init__(self, config_path: str, max_len: int = 4):
185
+ self.cache = dict()
186
+ self.max_len = max_len
187
+ with open(config_path, "r", encoding="utf-8") as f:
188
+ config = yaml.safe_load(f)["feature_store"]
189
+ embedding_model_path = config["embedding_model_path"]
190
+ reranker_model_path = config["reranker_model_path"]
191
+
192
+ embedding_model_path = snapshot_download(embedding_model_path, cache_dir=WEB_CONFIGS.RAG_MODEL_DIR)
193
+ reranker_model_path = snapshot_download(reranker_model_path, cache_dir=WEB_CONFIGS.RAG_MODEL_DIR)
194
+
195
+ # load text2vec and rerank model
196
+ logger.info("loading test2vec and rerank models")
197
+ self.embeddings = HuggingFaceEmbeddings(
198
+ model_name=embedding_model_path,
199
+ model_kwargs={"device": "cuda"},
200
+ encode_kwargs={"batch_size": 1, "normalize_embeddings": True},
201
+ )
202
+ self.embeddings.client = self.embeddings.client.half()
203
+ reranker_args = {"model": reranker_model_path, "top_n": 7, "device": "cuda", "use_fp16": True}
204
+ self.reranker = BCERerank(**reranker_args)
205
+
206
+ def get(self, fs_id: str = "default", config_path="config.yaml", work_dir="workdir"):
207
+ if fs_id in self.cache:
208
+ self.cache[fs_id]["time"] = time.time()
209
+ return self.cache[fs_id]["retriever"]
210
+
211
+ if not os.path.exists(work_dir) or not os.path.exists(config_path):
212
+ return None, "workdir or config.yaml not exist"
213
+
214
+ with open(config_path, "r", encoding="utf-8") as f:
215
+ reject_throttle = yaml.safe_load(f)["feature_store"]["reject_throttle"]
216
+
217
+ if len(self.cache) >= self.max_len:
218
+ # drop the oldest one
219
+ del_key = None
220
+ min_time = time.time()
221
+ for key, value in self.cache.items():
222
+ cur_time = value["time"]
223
+ if cur_time < min_time:
224
+ min_time = cur_time
225
+ del_key = key
226
+
227
+ if del_key is not None:
228
+ del_value = self.cache[del_key]
229
+ self.cache.pop(del_key)
230
+ del del_value["retriever"]
231
+
232
+ retriever = Retriever(
233
+ embeddings=self.embeddings, reranker=self.reranker, work_dir=work_dir, reject_throttle=reject_throttle
234
+ )
235
+ self.cache[fs_id] = {"retriever": retriever, "time": time.time()}
236
+ return retriever
237
+
238
+ def pop(self, fs_id: str):
239
+ if fs_id not in self.cache:
240
+ return
241
+ del_value = self.cache[fs_id]
242
+ self.cache.pop(fs_id)
243
+ # manually free memory
244
+ del del_value
server/base/modules/rag/test_queries.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [
2
+ "我的商品是牛肉。饲养天数",
3
+ "我的商品是唇膏。净含量是多少"
4
+ ]
server/base/queue_thread.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : queue_thread.py
5
+ @Time : 2024/09/02
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 队列调取相关逻辑(半废弃状态)
10
+ """
11
+
12
+
13
+ from loguru import logger
14
+ import requests
15
+ import multiprocessing
16
+
17
+ from ..web_configs import API_CONFIG
18
+ from .server_info import SERVER_PLUGINS_INFO
19
+
20
+
21
+ def process_tts(tts_text_queue):
22
+
23
+ while True:
24
+ try:
25
+ text_chunk = tts_text_queue.get(block=True, timeout=1)
26
+ except Exception as e:
27
+ # logger.info(f"### {e}")
28
+ continue
29
+ logger.info(f"Get tts quene: {type(text_chunk)} , {text_chunk}")
30
+ res = requests.post(API_CONFIG.TTS_URL, json=text_chunk)
31
+
32
+ # # tts 推理成功,放入数字人队列进行推理
33
+ # res_json = res.json()
34
+ # tts_request_dict = {
35
+ # "user_id": "123",
36
+ # "request_id": text_chunk["request_id"],
37
+ # "chunk_id": text_chunk["chunk_id"],
38
+ # "tts_path": res_json["wav_path"],
39
+ # }
40
+
41
+ # DIGITAL_HUMAN_QUENE.put(tts_request_dict)
42
+
43
+ logger.info(f"tts res = {res}")
44
+
45
+
46
+ def process_digital_human(digital_human_queue):
47
+
48
+ while True:
49
+ try:
50
+ text_chunk = digital_human_queue.get(block=True, timeout=1)
51
+ except Exception as e:
52
+ # logger.info(f"### {e}")
53
+ continue
54
+ logger.info(f"Get digital human quene: {type(text_chunk)} , {text_chunk}")
55
+ res = requests.post(API_CONFIG.DIGITAL_HUMAN_URL, json=text_chunk)
56
+ logger.info(f"digital human res = {res}")
57
+
58
+
59
+ if SERVER_PLUGINS_INFO.tts_server_enabled:
60
+ TTS_TEXT_QUENE = multiprocessing.Queue(maxsize=100)
61
+ tts_thread = multiprocessing.Process(target=process_tts, args=(TTS_TEXT_QUENE,), name="tts_processer")
62
+ tts_thread.start()
63
+ else:
64
+ TTS_TEXT_QUENE = None
65
+
66
+ if SERVER_PLUGINS_INFO.digital_human_server_enabled:
67
+ DIGITAL_HUMAN_QUENE = multiprocessing.Queue(maxsize=100)
68
+ digital_human_thread = multiprocessing.Process(
69
+ target=process_digital_human, args=(DIGITAL_HUMAN_QUENE,), name="digital_human_processer"
70
+ )
71
+ digital_human_thread.start()
72
+ else:
73
+ DIGITAL_HUMAN_QUENE = None
server/base/routers/__init__.py ADDED
File without changes
server/base/routers/digital_human.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : digital_human.py
5
+ @Time : 2024/09/02
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 数字人接口
10
+ """
11
+
12
+
13
+ from pathlib import Path
14
+ import uuid
15
+ import requests
16
+ from fastapi import APIRouter
17
+ from loguru import logger
18
+ from pydantic import BaseModel
19
+
20
+ from ...web_configs import API_CONFIG, WEB_CONFIGS
21
+ from ..utils import ResultCode, make_return_data
22
+
23
+ router = APIRouter(
24
+ prefix="/digital-human",
25
+ tags=["digital-human"],
26
+ responses={404: {"description": "Not found"}},
27
+ )
28
+
29
+
30
+ class GenDigitalHumanVideoItem(BaseModel):
31
+ streamerId: int
32
+ salesDoc: str
33
+
34
+
35
+ async def gen_tts_and_digital_human_video_app(streamer_id: int, sales_doc: str):
36
+ logger.info(sales_doc)
37
+
38
+ request_id = str(uuid.uuid1())
39
+ sentence_id = 1 # 直接推理,所以设置成 1
40
+ user_id = "123"
41
+
42
+ # 生成 TTS wav
43
+ tts_json = {
44
+ "user_id": user_id,
45
+ "request_id": request_id,
46
+ "sentence": sales_doc,
47
+ "chunk_id": sentence_id,
48
+ # "wav_save_name": chat_item.request_id + f"{str(sentence_id).zfill(8)}.wav",
49
+ }
50
+ tts_save_path = Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, request_id + f"-{str(1).zfill(8)}.wav")
51
+ logger.info(f"waiting for wav generating done: {tts_save_path}")
52
+ _ = requests.post(API_CONFIG.TTS_URL, json=tts_json)
53
+
54
+ # 生成数字人视频
55
+ digital_human_gen_info = {
56
+ "user_id": user_id,
57
+ "request_id": request_id,
58
+ "chunk_id": 0,
59
+ "tts_path": str(tts_save_path),
60
+ "streamer_id": str(streamer_id),
61
+ }
62
+ video_path = Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH).joinpath(request_id + ".mp4")
63
+ logger.info(f"Generating digital human: {video_path}")
64
+ _ = requests.post(API_CONFIG.DIGITAL_HUMAN_URL, json=digital_human_gen_info)
65
+
66
+ # 删除过程文件
67
+ tts_save_path.unlink()
68
+
69
+ server_video_path = f"{API_CONFIG.REQUEST_FILES_URL}/{WEB_CONFIGS.STREAMER_FILE_DIR}/vid_output/{request_id}.mp4"
70
+ logger.info(server_video_path)
71
+
72
+ return server_video_path
73
+
74
+
75
+ @router.post("/gen")
76
+ async def get_digital_human_according_doc_api(gen_item: GenDigitalHumanVideoItem):
77
+ """根据口播文案生成数字人介绍视频
78
+
79
+ Args:
80
+ gen_item (GenDigitalHumanVideoItem): _description_
81
+
82
+ """
83
+ server_video_path = await gen_tts_and_digital_human_video_app(gen_item.streamerId, gen_item.salesDoc)
84
+
85
+ return make_return_data(True, ResultCode.SUCCESS, "成功", server_video_path)
server/base/routers/llm.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : llm.py
5
+ @Time : 2024/09/02
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 大模型接口
10
+ """
11
+
12
+
13
+ from typing import Dict, List
14
+
15
+ from fastapi import APIRouter, Depends
16
+ from loguru import logger
17
+
18
+ from ..database.llm_db import get_llm_product_prompt_base_info
19
+ from ..database.product_db import get_db_product_info
20
+ from ..database.streamer_info_db import get_db_streamer_info
21
+ from ..models.product_model import ProductInfo
22
+ from ..models.streamer_info_model import StreamerInfo
23
+ from ..modules.agent.agent_worker import get_agent_result
24
+ from ..server_info import SERVER_PLUGINS_INFO
25
+ from ..utils import LLM_MODEL_HANDLER, ResultCode, make_return_data
26
+ from .users import get_current_user_info
27
+
28
+ router = APIRouter(
29
+ prefix="/llm",
30
+ tags=["llm"],
31
+ responses={404: {"description": "Not found"}},
32
+ )
33
+
34
+
35
+ def combine_history(prompt: list, history_msg: list):
36
+ """生成对话历史 prompt
37
+
38
+ Args:
39
+ prompt (_type_): _description_
40
+ history_msg (_type_): _description_. Defaults to None.
41
+
42
+ Returns:
43
+ _type_: _description_
44
+ """
45
+ # 角色映射表
46
+ role_map = {"streamer": "assistant", "user": "user"}
47
+
48
+ # 生成历史对话信息
49
+ for message in history_msg:
50
+ prompt.append({"role": role_map[message["role"]], "content": message["message"]})
51
+
52
+ return prompt
53
+
54
+
55
+ async def gen_poduct_base_prompt(
56
+ user_id: int,
57
+ streamer_id: int = -1,
58
+ product_id: int = -1,
59
+ streamer_info: StreamerInfo | None = None,
60
+ product_info: ProductInfo | None = None,
61
+ ) -> List[Dict[str, str]]:
62
+ """生成商品介绍的 prompt
63
+
64
+ Args:
65
+ user_id (int): 用户 ID
66
+ streamer_id (int): 主播 ID
67
+ product_id (int): 商品 ID
68
+ streamer_info (StreamerInfo, optional): 主播信息,如果为空则根据 streamer_id 查表
69
+ product_info (ProductInfo, optional): 商品信息,如果为空则根据 product_id 查表
70
+
71
+ Returns:
72
+ List[Dict[str,str]]: 生成的 promot
73
+ """
74
+
75
+ assert (streamer_id == -1 and streamer_info is not None) or (streamer_id != -1 and streamer_info is None)
76
+ assert (product_id == -1 and product_info is not None) or (product_id != -1 and product_info is None)
77
+
78
+ # 加载对话配置文件
79
+ dataset_yaml = await get_llm_product_prompt_base_info()
80
+
81
+ # 从配置中提取对话设置相关的信息
82
+ # system_str: 系统词,针对销售角色定制
83
+ # first_input_template: 对话开始时的第一个输入模板
84
+ # product_info_struct_template: 产品信息结构模板
85
+ system = dataset_yaml["conversation_setting"]["system"]
86
+ first_input_template = dataset_yaml["conversation_setting"]["first_input"]
87
+ product_info_struct_template = dataset_yaml["product_info_struct"]
88
+
89
+ # 根据 ID 获取主播信息
90
+ if streamer_info is None:
91
+ streamer_info = await get_db_streamer_info(user_id, streamer_id)
92
+ streamer_info = streamer_info[0]
93
+
94
+ # 将销售角色名和角色信息插入到 system prompt
95
+ character_str = streamer_info.character.replace(";", "、")
96
+ system_str = system.replace("{role_type}", streamer_info.name).replace("{character}", character_str)
97
+
98
+ # 根据 ID 获取商品信息
99
+ if product_info is None:
100
+ product_list, _ = await get_db_product_info(user_id, product_id=product_id)
101
+ product_info = product_list[0]
102
+
103
+ heighlights_str = product_info.heighlights.replace(";", "、")
104
+ product_info_str = product_info_struct_template[0].replace("{name}", product_info.product_name)
105
+ product_info_str += product_info_struct_template[1].replace("{highlights}", heighlights_str)
106
+
107
+ # 生成商品文案 prompt
108
+ sales_doc_prompt = first_input_template.replace("{product_info}", product_info_str)
109
+
110
+ prompt = [{"role": "system", "content": system_str}, {"role": "user", "content": sales_doc_prompt}]
111
+ logger.info(prompt)
112
+
113
+ return prompt
114
+
115
+
116
+ async def get_agent_res(prompt, departure_place, delivery_company):
117
+ """调用 Agent 能力"""
118
+ agent_response = ""
119
+
120
+ if not SERVER_PLUGINS_INFO.agent_enabled:
121
+ # 如果不开启则直接返回空
122
+ return ""
123
+
124
+ GENERATE_AGENT_TEMPLATE = (
125
+ "这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # Agent prompt 模板
126
+ )
127
+ input_prompt = prompt[-1]["content"]
128
+ agent_response = get_agent_result(LLM_MODEL_HANDLER, input_prompt, departure_place, delivery_company)
129
+ if agent_response != "":
130
+ agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt)
131
+ logger.info(f"Agent response: {agent_response}")
132
+
133
+ return agent_response
134
+
135
+
136
+ async def get_llm_res(prompt):
137
+ """获取 LLM 推理返回
138
+
139
+ Args:
140
+ prompt (str): _description_
141
+
142
+ Returns:
143
+ _type_: _description_
144
+ """
145
+
146
+ logger.info(prompt)
147
+ model_name = LLM_MODEL_HANDLER.available_models[0]
148
+
149
+ res_data = ""
150
+ for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt):
151
+ res_data = item["choices"][0]["message"]["content"]
152
+
153
+ return res_data
154
+
155
+
156
+ @router.get("/gen_sales_doc", summary="生成主播文案接口")
157
+ async def get_product_info_api(streamer_id: int, product_id: int, user_id: int = Depends(get_current_user_info)):
158
+ """生成口播文案
159
+
160
+ Args:
161
+ streamer_id (int): 主播 ID,用于获取性格等信息
162
+ product_id (int): 商品 ID
163
+ """
164
+
165
+ prompt = await gen_poduct_base_prompt(user_id, streamer_id, product_id)
166
+
167
+ res_data = await get_llm_res(prompt)
168
+
169
+ return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
170
+
171
+
172
+ @router.get("/gen_product_info")
173
+ async def get_product_info_api(product_id: int, user_id: int = Depends(get_current_user_info)):
174
+ """TODO 根据说明书内容生成商品信息
175
+
176
+ Args:
177
+ gen_product_item (GenProductItem): _description_
178
+ """
179
+
180
+ raise NotImplemented()
181
+ instruction_str = ""
182
+ prompt = [{"system": "现在你是一个文档小助手,你可以从文档里面总结出我需要的信息", "input": ""}]
183
+
184
+ res_data = ""
185
+ model_name = LLM_MODEL_HANDLER.available_models[0]
186
+ for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt):
187
+ res_data += item
server/base/routers/products.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : products.py
5
+ @Time : 2024/08/30
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 商品信息接口
10
+ """
11
+
12
+ from pathlib import Path
13
+
14
+ from fastapi import APIRouter, Depends
15
+
16
+ from ...web_configs import WEB_CONFIGS
17
+ from ..database.product_db import (
18
+ create_or_update_db_product_by_id,
19
+ delete_product_id,
20
+ get_db_product_info,
21
+ )
22
+ from ..models.product_model import ProductInfo, ProductPageItem, ProductQueryItem
23
+ from ..modules.rag.rag_worker import rebuild_rag_db
24
+ from ..utils import ResultCode, make_return_data
25
+ from .users import get_current_user_info
26
+
27
+ router = APIRouter(
28
+ prefix="/products",
29
+ tags=["products"],
30
+ responses={404: {"description": "Not found"}},
31
+ )
32
+
33
+
34
+ @router.get("/list", summary="获取分页商品信息接口")
35
+ async def get_product_info_api(
36
+ currentPage: int = 1, pageSize: int = 5, productName: str | None = None, user_id: int = Depends(get_current_user_info)
37
+ ):
38
+ product_list, db_product_size = await get_db_product_info(
39
+ user_id=user_id,
40
+ current_page=currentPage,
41
+ page_size=pageSize,
42
+ product_name=productName,
43
+ )
44
+
45
+ res_data = ProductPageItem(product_list=product_list, currentPage=currentPage, pageSize=pageSize, totalSize=db_product_size)
46
+ return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
47
+
48
+
49
+ @router.get("/info/{productId}", summary="获取特定商品 ID 的详细信息接口")
50
+ async def get_product_id_info_api(productId: int, user_id: int = Depends(get_current_user_info)):
51
+ product_list, _ = await get_db_product_info(user_id=user_id, product_id=productId)
52
+
53
+ if len(product_list) == 1:
54
+ product_list = product_list[0]
55
+
56
+ return make_return_data(True, ResultCode.SUCCESS, "成功", product_list)
57
+
58
+
59
+ @router.post("/create", summary="新增商品接口")
60
+ async def upload_product_api(upload_product_item: ProductInfo, user_id: int = Depends(get_current_user_info)):
61
+
62
+ upload_product_item.user_id = user_id
63
+ upload_product_item.product_id = None
64
+
65
+ rebuild_rag_db_flag = create_or_update_db_product_by_id(0, upload_product_item)
66
+
67
+ if WEB_CONFIGS.ENABLE_RAG and rebuild_rag_db_flag:
68
+ # 重新生成 RAG 向量数据库
69
+ await rebuild_rag_db(user_id)
70
+
71
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
72
+
73
+
74
+ @router.put("/edit/{product_id}", summary="编辑商品接口")
75
+ async def upload_product_api(product_id: int, upload_product_item: ProductInfo, user_id: int = Depends(get_current_user_info)):
76
+
77
+ rebuild_rag_db_flag = create_or_update_db_product_by_id(product_id, upload_product_item, user_id)
78
+
79
+ if WEB_CONFIGS.ENABLE_RAG and rebuild_rag_db_flag:
80
+ # 重新生成 RAG 向量数据库
81
+ await rebuild_rag_db(user_id)
82
+
83
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
84
+
85
+
86
+ @router.delete("/delete/{productId}", summary="删除特定商品 ID 接口")
87
+ async def upload_product_api(productId: int, user_id: int = Depends(get_current_user_info)):
88
+
89
+ process_success_flag = await delete_product_id(productId, user_id)
90
+
91
+ if not process_success_flag:
92
+ return make_return_data(False, ResultCode.FAIL, "失败", "")
93
+
94
+ if WEB_CONFIGS.ENABLE_RAG:
95
+ # 重新生成 RAG 向量数据库
96
+ await rebuild_rag_db(user_id)
97
+
98
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
99
+
100
+
101
+ @router.post("/instruction", summary="获取对应商品的说明书内容接口", dependencies=[Depends(get_current_user_info)])
102
+ async def get_product_instruction_info_api(instruction_path: ProductQueryItem):
103
+ """获取对应商品的说明书
104
+
105
+ Args:
106
+ instruction_path (ProductInstructionItem): 说明书路径
107
+
108
+ """
109
+ # TODO 后续改为前端 axios 直接获取
110
+ loacl_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(
111
+ WEB_CONFIGS.PRODUCT_FILE_DIR, WEB_CONFIGS.INSTRUCTIONS_DIR, Path(instruction_path.instructionPath).name
112
+ )
113
+ if not loacl_path.exists():
114
+ return make_return_data(False, ResultCode.FAIL, "文件不存在", "")
115
+
116
+ with open(loacl_path, "r") as f:
117
+ instruction_content = f.read()
118
+
119
+ return make_return_data(True, ResultCode.SUCCESS, "成功", instruction_content)
server/base/routers/streamer_info.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : streamer_info.py
5
+ @Time : 2024/08/10
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 主播管理信息页面接口
10
+ """
11
+
12
+ from typing import Tuple
13
+ import uuid
14
+ from pathlib import Path
15
+
16
+ import requests
17
+ from fastapi import APIRouter, Depends
18
+ from loguru import logger
19
+
20
+ from ...web_configs import API_CONFIG, WEB_CONFIGS
21
+ from ..database.streamer_info_db import create_or_update_db_streamer_by_id, delete_streamer_id, get_db_streamer_info
22
+ from ..models.streamer_info_model import StreamerInfo
23
+ from ..utils import ResultCode, make_poster_by_video_first_frame, make_return_data
24
+ from .users import get_current_user_info
25
+
26
+ router = APIRouter(
27
+ prefix="/streamer",
28
+ tags=["streamer"],
29
+ responses={404: {"description": "Not found"}},
30
+ )
31
+
32
+
33
+ async def gen_digital_human(user_id, streamer_id: int, new_streamer_info: StreamerInfo) -> Tuple[str, str]:
34
+ """生成数字人视频
35
+
36
+ Args:
37
+ user_id (int): 用户 ID
38
+ streamer_id (int): 主播 ID
39
+ new_streamer_info (StreamerInfo): 新的主播信息
40
+
41
+ Returns:
42
+ str: 数字人视频地址
43
+ str: 数字人头像/海报地址
44
+ """
45
+
46
+ streamer_info_db = await get_db_streamer_info(user_id, streamer_id)
47
+ streamer_info_db = streamer_info_db[0]
48
+
49
+ new_base_mp4_path = new_streamer_info.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
50
+ if streamer_info_db.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "") == new_base_mp4_path:
51
+ # 数字人视频没更新,跳过
52
+ return streamer_info_db.base_mp4_path, streamer_info_db.poster_image
53
+
54
+ # 调取接口生成进行数字人预处理
55
+
56
+ # new_streamer_info.base_mp4_path 是 服务器地址,需要进行转换
57
+ video_local_dir = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(
58
+ WEB_CONFIGS.STREAMER_FILE_DIR, WEB_CONFIGS.STREAMER_INFO_FILES_DIR
59
+ )
60
+
61
+ digital_human_gen_info = {
62
+ "user_id": str(user_id),
63
+ "request_id": str(uuid.uuid1()),
64
+ "streamer_id": str(new_streamer_info.streamer_id),
65
+ "video_path": str(video_local_dir.joinpath(Path(new_streamer_info.base_mp4_path).name)),
66
+ }
67
+ logger.info(f"Getting digital human preprocessing: {new_streamer_info.streamer_id}")
68
+ _ = requests.post(API_CONFIG.DIGITAL_HUMAN_PREPROCESS_URL, json=digital_human_gen_info)
69
+
70
+ # 根据视频第一帧生成头图
71
+ poster_save_name = Path(new_streamer_info.base_mp4_path).stem + ".png"
72
+ make_poster_by_video_first_frame(str(video_local_dir.joinpath(Path(new_streamer_info.base_mp4_path).name)), poster_save_name)
73
+
74
+ # 生成头图服务器地址
75
+ poster_server_url = str(Path(new_streamer_info.base_mp4_path).parent.joinpath(poster_save_name))
76
+ if "http://" not in poster_server_url and "http:/" in poster_server_url:
77
+ poster_server_url = poster_server_url.replace("http:/", "http://")
78
+
79
+ return new_streamer_info.base_mp4_path, poster_server_url
80
+
81
+
82
+ @router.get("/list", summary="获取所有主播信息接口,用于用户进行主播的选择")
83
+ async def get_streamer_info_api(user_id: int = Depends(get_current_user_info)):
84
+ """获取所有主播信息,用于用户进行主播的选择"""
85
+ streamer_list = await get_db_streamer_info(user_id)
86
+ return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_list)
87
+
88
+
89
+ @router.get("/info/{streamerId}", summary="用于获取特定主播的信息接口")
90
+ async def get_streamer_info_api(streamerId: int, user_id: int = Depends(get_current_user_info)):
91
+ """用于获取特定主播的信息"""
92
+
93
+ streamer_list = await get_db_streamer_info(user_id, streamerId)
94
+ if len(streamer_list) == 1:
95
+ streamer_list = streamer_list[0]
96
+
97
+ return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_list)
98
+
99
+
100
+ @router.post("/create", summary="新增主播信息接口")
101
+ async def create_streamer_info_api(streamerItem: StreamerInfo, user_id: int = Depends(get_current_user_info)):
102
+ """新增主播信息"""
103
+ streamer_info = streamerItem
104
+ streamer_info.user_id = user_id
105
+ streamer_info.streamer_id = None
106
+
107
+ poster_image = streamer_info.poster_image
108
+ base_mp4_path = streamer_info.base_mp4_path
109
+
110
+ streamer_info.poster_image = ""
111
+ streamer_info.base_mp4_path = ""
112
+
113
+ # 更新数据库,才能拿到 stream_id
114
+ streamer_id = create_or_update_db_streamer_by_id(0, streamer_info, user_id)
115
+
116
+ streamer_info.poster_image = poster_image
117
+ streamer_info.base_mp4_path = base_mp4_path
118
+ streamer_info.streamer_id = streamer_id
119
+
120
+ # 数字人视频对其进行初始化,同时生成头图
121
+ video_info = await gen_digital_human(user_id, streamer_id, streamer_info)
122
+
123
+ streamer_info.base_mp4_path = video_info[0]
124
+ streamer_info.poster_image = video_info[1]
125
+ streamer_info.avatar = video_info[1]
126
+
127
+ create_or_update_db_streamer_by_id(streamer_id, streamer_info, user_id)
128
+ return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_id)
129
+
130
+
131
+ @router.put("/edit/{streamer_id}", summary="修改主播信息接口")
132
+ async def edit_streamer_info_api(streamer_id: int, streamer_info: StreamerInfo, user_id: int = Depends(get_current_user_info)):
133
+ """修改主播信息"""
134
+
135
+ # 如果更新了数字人视频对其进行初始化,同时生成头图
136
+ video_info = await gen_digital_human(user_id, streamer_id, streamer_info)
137
+
138
+ streamer_info.base_mp4_path = video_info[0]
139
+ streamer_info.poster_image = video_info[1]
140
+ streamer_info.avatar = video_info[1]
141
+
142
+ # 更新数据库
143
+ create_or_update_db_streamer_by_id(streamer_id, streamer_info, user_id)
144
+
145
+ return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_id)
146
+
147
+
148
+ @router.delete("/delete/{streamerId}", summary="删除主播接口")
149
+ async def upload_product_api(streamerId: int, user_id: int = Depends(get_current_user_info)):
150
+
151
+ process_success_flag = await delete_streamer_id(streamerId, user_id)
152
+
153
+ if not process_success_flag:
154
+ return make_return_data(False, ResultCode.FAIL, "失败", "")
155
+
156
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
server/base/routers/streaming_room.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : streaming_room.py
5
+ @Time : 2024/08/31
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 主播间信息交互接口
10
+ """
11
+
12
+ import uuid
13
+ from copy import deepcopy
14
+ from pathlib import Path
15
+
16
+ import requests
17
+ from fastapi import APIRouter, Depends
18
+ from loguru import logger
19
+
20
+ from ...web_configs import API_CONFIG, WEB_CONFIGS
21
+ from ..database.product_db import get_db_product_info
22
+ from ..database.streamer_room_db import (
23
+ create_or_update_db_room_by_id,
24
+ get_live_room_info,
25
+ get_message_list,
26
+ update_db_room_status,
27
+ delete_room_id,
28
+ get_db_streaming_room_info,
29
+ update_message_info,
30
+ update_room_video_path,
31
+ )
32
+ from ..models.product_model import ProductInfo
33
+ from ..models.streamer_room_model import OnAirRoomStatusItem, RoomChatItem, SalesDocAndVideoInfo, StreamRoomInfo
34
+ from ..modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
35
+ from ..routers.users import get_current_user_info
36
+ from ..server_info import SERVER_PLUGINS_INFO
37
+ from ..utils import ResultCode, make_return_data
38
+ from .digital_human import gen_tts_and_digital_human_video_app
39
+ from .llm import combine_history, gen_poduct_base_prompt, get_agent_res, get_llm_res
40
+
41
+ router = APIRouter(
42
+ prefix="/streaming-room",
43
+ tags=["streaming-room"],
44
+ responses={404: {"description": "Not found"}},
45
+ )
46
+
47
+
48
+ @router.get("/list", summary="获取所有直播间信息接口")
49
+ async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
50
+ """获取所有直播间信息"""
51
+ # 加载直播间数据
52
+ streaming_room_list = await get_db_streaming_room_info(user_id)
53
+
54
+ for i in range(len(streaming_room_list)):
55
+ # 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
56
+ streaming_room_list[i] = dict(streaming_room_list[i])
57
+
58
+ return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
59
+
60
+
61
+ @router.get("/info/{roomId}", summary="获取特定直播间信息接口")
62
+ async def get_streaming_room_id_api(
63
+ roomId: int, currentPage: int = 1, pageSize: int = 5, user_id: int = Depends(get_current_user_info)
64
+ ):
65
+ """获取特定直播间信息"""
66
+ # 加载直播间配置文件
67
+ assert roomId != 0
68
+
69
+ # TODO 加入分页
70
+
71
+ # 加载直播间数据
72
+ streaming_room_list = await get_db_streaming_room_info(user_id, room_id=roomId)
73
+
74
+ if len(streaming_room_list) == 1:
75
+ # 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
76
+ format_product_list = []
77
+ for db_product in streaming_room_list[0].product_list:
78
+
79
+ product_dict = dict(db_product)
80
+ # 将 start_video 改为服务器地址
81
+ if product_dict["start_video"] != "":
82
+ product_dict["start_video"] = API_CONFIG.REQUEST_FILES_URL + product_dict["start_video"]
83
+
84
+ format_product_list.append(product_dict)
85
+ streaming_room_list = dict(streaming_room_list[0])
86
+ streaming_room_list["product_list"] = format_product_list
87
+ else:
88
+ streaming_room_list = []
89
+
90
+ return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
91
+
92
+
93
+ @router.get("/product-edit-list/{roomId}", summary="获取直播间商品编辑列表,含有已选中的标识")
94
+ async def get_streaming_room_product_list_api(
95
+ roomId: int, currentPage: int = 1, pageSize: int = 0, user_id: int = Depends(get_current_user_info)
96
+ ):
97
+ """获取直播间商品编辑列表,含有已选中的标识"""
98
+
99
+ # 获取目前直播间商品列表
100
+ if roomId == 0:
101
+ # 新的直播间
102
+ merge_list = []
103
+ exclude_list = []
104
+ else:
105
+ streaming_room_info = await get_db_streaming_room_info(user_id, roomId)
106
+
107
+ if len(streaming_room_info) == 0:
108
+ raise "401"
109
+
110
+ streaming_room_info = streaming_room_info[0]
111
+ # 获取未被选中的商品
112
+ exclude_list = [product.product_id for product in streaming_room_info.product_list]
113
+ merge_list = deepcopy(streaming_room_info.product_list)
114
+
115
+ # 获取未选中的商品信息
116
+ not_select_product_list, db_product_size = await get_db_product_info(user_id=user_id, exclude_list=exclude_list)
117
+
118
+ # 合并商品信息
119
+ for not_select_product in not_select_product_list:
120
+ merge_list.append(
121
+ SalesDocAndVideoInfo(
122
+ product_id=not_select_product.product_id,
123
+ product_info=ProductInfo(**dict(not_select_product)),
124
+ selected=False,
125
+ )
126
+ )
127
+
128
+ # TODO 懒加载分页
129
+
130
+ # 格式化
131
+ format_merge_list = []
132
+ for product in merge_list:
133
+ # 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
134
+ dict_info = dict(product)
135
+ if "stream_room" in dict_info:
136
+ dict_info.pop("stream_room")
137
+ format_merge_list.append(dict_info)
138
+
139
+ page_info = dict(
140
+ product_list=format_merge_list,
141
+ current=currentPage,
142
+ pageSize=db_product_size,
143
+ totalSize=db_product_size,
144
+ )
145
+ logger.info(page_info)
146
+ return make_return_data(True, ResultCode.SUCCESS, "成功", page_info)
147
+
148
+
149
+ @router.post("/create", summary="新增直播间接口")
150
+ async def streaming_room_edit_api(edit_item: dict, user_id: int = Depends(get_current_user_info)):
151
+ product_list = edit_item.pop("product_list")
152
+ status = edit_item.pop("status")
153
+ edit_item.pop("streamer_info")
154
+ edit_item.pop("room_id")
155
+
156
+ if "status_id" in edit_item:
157
+ edit_item.pop("status_id")
158
+
159
+ formate_product_list = []
160
+ for product in product_list:
161
+ if not product["selected"]:
162
+ continue
163
+ product.pop("product_info")
164
+ product_item = SalesDocAndVideoInfo(**product)
165
+ formate_product_list.append(product_item)
166
+
167
+ edit_item["user_id"] = user_id
168
+ formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
169
+ room_id = create_or_update_db_room_by_id(0, formate_info, user_id)
170
+ return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
171
+
172
+
173
+ @router.put("/edit/{room_id}", summary="编辑直播间接口")
174
+ async def streaming_room_edit_api(room_id: int, edit_item: dict, user_id: int = Depends(get_current_user_info)):
175
+ """编辑直播间接口
176
+
177
+ Args:
178
+ edit_item (StreamRoomInfo): _description_
179
+ """
180
+
181
+ product_list = edit_item.pop("product_list")
182
+ status = edit_item.pop("status")
183
+ edit_item.pop("streamer_info")
184
+
185
+ formate_product_list = []
186
+ for product in product_list:
187
+ if not product["selected"]:
188
+ continue
189
+ product.pop("product_info")
190
+ product_item = SalesDocAndVideoInfo(**product)
191
+ formate_product_list.append(product_item)
192
+
193
+ formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
194
+ create_or_update_db_room_by_id(room_id, formate_info, user_id)
195
+ return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
196
+
197
+
198
+ @router.delete("/delete/{roomId}", summary="删除直播间接口")
199
+ async def delete_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
200
+
201
+ process_success_flag = await delete_room_id(roomId, user_id)
202
+
203
+ if not process_success_flag:
204
+ return make_return_data(False, ResultCode.FAIL, "失败", "")
205
+
206
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
207
+
208
+
209
+ # ============================================================
210
+ # 开播接口
211
+ # ============================================================
212
+
213
+
214
+ @router.post("/online/{roomId}", summary="直播间开播接口")
215
+ async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
216
+
217
+ update_db_room_status(roomId, user_id, "online")
218
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
219
+
220
+
221
+ @router.put("/offline/{roomId}", summary="直播间下播接口")
222
+ async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
223
+
224
+ update_db_room_status(roomId, user_id, "offline")
225
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
226
+
227
+
228
+ @router.post("/next-product/{roomId}", summary="直播间进行下一个商品讲解接口")
229
+ async def on_air_live_room_next_product_api(roomId: int, user_id: int = Depends(get_current_user_info)):
230
+ """直播间进行下一个商品讲解
231
+
232
+ Args:
233
+ roomId (int): 直播间 ID
234
+ """
235
+
236
+ update_db_room_status(roomId, user_id, "next-product")
237
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
238
+
239
+
240
+ @router.get("/live-info/{roomId}", summary="获取正在直播的直播间信息接口")
241
+ async def get_on_air_live_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
242
+ """获取正在直播的直播间信息
243
+
244
+ 1. 主播视频地址
245
+ 2. 商品信息,显示在右下角的商品缩略图
246
+ 3. 对话记录 conversation_list
247
+
248
+ Args:
249
+ roomId (int): 直播间 ID
250
+ """
251
+
252
+ res_data = await get_live_room_info(user_id, roomId)
253
+
254
+ return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
255
+
256
+
257
+ @router.put("/chat", summary="直播间对话接口")
258
+ async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
259
+ # 根据直播间 ID 获取信息
260
+ streaming_room_info = await get_db_streaming_room_info(user_id, room_chat.roomId)
261
+ streaming_room_info = streaming_room_info[0]
262
+
263
+ # 商品索引
264
+ product_detail = streaming_room_info.product_list[streaming_room_info.status.current_product_index].product_info
265
+
266
+ # 销售 ID
267
+ sales_info_id = streaming_room_info.product_list[streaming_room_info.status.current_product_index].sales_info_id
268
+
269
+ # 更新对话记录
270
+ update_message_info(sales_info_id, user_id, role="user", message=room_chat.message)
271
+
272
+ # 获取最新的对话记录
273
+ conversation_list = get_message_list(sales_info_id)
274
+
275
+ # 根据对话记录生成 prompt
276
+ prompt = await gen_poduct_base_prompt(
277
+ user_id,
278
+ streamer_info=streaming_room_info.streamer_info,
279
+ product_info=product_detail,
280
+ ) # system + 获取商品文案prompt
281
+
282
+ prompt = combine_history(prompt, conversation_list)
283
+
284
+ # ====================== Agent ======================
285
+ # 调取 Agent
286
+ agent_response = await get_agent_res(prompt, product_detail.departure_place, product_detail.delivery_company)
287
+ if agent_response != "":
288
+ logger.info("Agent 执行成功,不执行 RAG")
289
+ prompt[-1]["content"] = agent_response
290
+
291
+ # ====================== RAG ======================
292
+ # 调取 rag
293
+ elif SERVER_PLUGINS_INFO.rag_enabled:
294
+ logger.info("Agent 未执行 or 未开启,调用 RAG")
295
+ # agent 失败,调取 rag, chat_item.plugins.rag 为 True,则使用 RAG 查询数据库
296
+ rag_res = build_rag_prompt(RAG_RETRIEVER, product_detail.product_name, prompt[-1]["content"])
297
+ if rag_res != "":
298
+ prompt[-1]["content"] = rag_res
299
+
300
+ # 调取 LLM
301
+ streamer_res = await get_llm_res(prompt)
302
+
303
+ # 生成数字人视频
304
+ server_video_path = await gen_tts_and_digital_human_video_app(streaming_room_info.streamer_info.streamer_id, streamer_res)
305
+
306
+ # 更新直播间数字人视频信息
307
+ update_room_video_path(streaming_room_info.status_id, server_video_path)
308
+
309
+ # 更新对话记录
310
+ update_message_info(sales_info_id, streaming_room_info.streamer_info.streamer_id, role="streamer", message=streamer_res)
311
+
312
+ return make_return_data(True, ResultCode.SUCCESS, "成功", "")
313
+
314
+
315
+ @router.post("/asr", summary="直播间调取 ASR 语音转文字 接口")
316
+ async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
317
+
318
+ # room_chat.asr_file 是 服务器地址,需要进行转换
319
+ asr_local_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(WEB_CONFIGS.ASR_FILE_DIR, Path(room_chat.asrFileUrl).name)
320
+
321
+ # 获取 ASR 结果
322
+ req_data = {
323
+ "user_id": user_id,
324
+ "request_id": str(uuid.uuid1()),
325
+ "wav_path": str(asr_local_path),
326
+ }
327
+ logger.info(req_data)
328
+
329
+ res = requests.post(API_CONFIG.ASR_URL, json=req_data).json()
330
+ asr_str = res["result"]
331
+ logger.info(f"ASR res = {asr_str}")
332
+
333
+ # 删除过程文件
334
+ asr_local_path.unlink()
335
+ return make_return_data(True, ResultCode.SUCCESS, "成功", asr_str)
server/base/routers/users.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : users.py
5
+ @Time : 2024/08/30
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 用户登录和 Token 认证接口
10
+ """
11
+
12
+ from datetime import datetime, timedelta, timezone
13
+
14
+ import jwt
15
+ from fastapi import APIRouter, Depends, HTTPException
16
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
17
+ from loguru import logger
18
+ from passlib.context import CryptContext
19
+
20
+ from ...web_configs import WEB_CONFIGS
21
+ from ..database.user_db import get_db_user_info
22
+ from ..models.user_model import TokenItem
23
+ from ..utils import ResultCode, make_return_data
24
+
25
+ router = APIRouter(
26
+ prefix="/user",
27
+ tags=["user"],
28
+ responses={404: {"description": "Not found"}},
29
+ )
30
+
31
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login")
32
+
33
+ # 密码加解密
34
+ PWD_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
35
+
36
+
37
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
38
+ """密码校验
39
+
40
+ Args:
41
+ plain_password (str): 明文密码
42
+ hashed_password (str): 加密后的密码,用于对比
43
+
44
+ Returns:
45
+ bool: 校验时候通过
46
+ """
47
+ logger.info(f"expect password = {PWD_CONTEXT.hash('123456')}")
48
+ return PWD_CONTEXT.verify(plain_password, hashed_password)
49
+
50
+
51
+ def get_password_hash(password: str) -> str:
52
+ """生成哈希密码
53
+
54
+ Args:
55
+ password (str): 明文密码
56
+
57
+ Returns:
58
+ str: 加密后的哈希密码
59
+ """
60
+ return PWD_CONTEXT.hash(password)
61
+
62
+
63
+ def authenticate_user(username: str, password: str) -> bool:
64
+ """对用户名和密码进行校验
65
+
66
+ Args:
67
+ username (str): 用户名
68
+ password (str): 密码
69
+
70
+ Returns:
71
+ bool: 是否检验通过
72
+ """
73
+
74
+ # 获取用户信息
75
+ user_info = get_db_user_info(username=username, all_info=True)
76
+ if not user_info:
77
+ # 没有找到用户名
78
+ logger.info(f"Cannot find username = {username}")
79
+ return False
80
+
81
+ # 校验密码
82
+ if not verify_password(password, user_info.hashed_password):
83
+ logger.info(f"verify_password fail")
84
+ # 密码校验失败
85
+ return False
86
+
87
+ return user_info
88
+
89
+
90
+ def get_current_user_info(token: str = Depends(oauth2_scheme)):
91
+ """在 token 中提取 user id
92
+
93
+ Args:
94
+ token (str, optional): token. Defaults to Depends(oauth2_scheme).
95
+
96
+ Raises:
97
+ HTTPException: 401 获取失败
98
+
99
+ Returns:
100
+ int: 用户 ID
101
+ """
102
+ logger.info(token)
103
+ try:
104
+ token_data = jwt.decode(token, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithms=WEB_CONFIGS.TOKEN_JWT_ALGORITHM)
105
+ logger.info(token_data)
106
+ user_id = token_data.get("user_id", None)
107
+ except Exception as e:
108
+ logger.error(e)
109
+ raise HTTPException(status_code=401, detail="Could not validate credentials")
110
+
111
+ if not user_id:
112
+ logger.error(f"can not get user_id: {user_id}")
113
+ raise HTTPException(status_code=401, detail="Could not validate credentials")
114
+
115
+ # TODO 超时强制重新登录
116
+
117
+ logger.info(f"Got user_id: {user_id}")
118
+ return user_id
119
+
120
+
121
+ @router.post("/login", summary="登录接口")
122
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
123
+
124
+ # 校验用户名和密码
125
+ user_info = authenticate_user(form_data.username, form_data.password)
126
+
127
+ if not user_info:
128
+ raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
129
+
130
+ # 过期时间
131
+ token_expires = datetime.now(timezone.utc) + timedelta(days=7)
132
+
133
+ # token 生成包含内容,记录 IP 的原因是防止被其他人拿到用户的 token 进行假冒访问
134
+ token_data = {
135
+ "user_id": user_info.user_id,
136
+ "username": user_info.username,
137
+ "exp": int(token_expires.timestamp()),
138
+ "ip": user_info.ip_address,
139
+ "login_time": int(datetime.now(timezone.utc).timestamp()),
140
+ }
141
+ logger.info(f"token_data = {token_data}")
142
+
143
+ # 生成 token
144
+ token = jwt.encode(token_data, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithm=WEB_CONFIGS.TOKEN_JWT_ALGORITHM)
145
+
146
+ # 返回
147
+ res_json = TokenItem(access_token=token, token_type="bearer")
148
+ logger.info(f"Got token info = {res_json}")
149
+ # return make_return_data(True, ResultCode.SUCCESS, "成功", content)
150
+ return res_json
151
+
152
+
153
+ @router.get("/me", summary="获取用户信息")
154
+ async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
155
+ """获取用户信息"""
156
+ user_info = get_db_user_info(id=user_id)
157
+ return make_return_data(True, ResultCode.SUCCESS, "成功", user_info)
server/base/server_info.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : server_info.py
5
+ @Time : 2024/09/02
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 组件信息获取逻辑
10
+ """
11
+
12
+
13
+ import random
14
+ import requests
15
+ from loguru import logger
16
+
17
+ from ..web_configs import API_CONFIG, WEB_CONFIGS
18
+
19
+
20
+ class ServerPluginsInfo:
21
+
22
+ def __init__(self) -> None:
23
+ self.update_info()
24
+
25
+ def update_info(self):
26
+
27
+ self.tts_server_enabled = self._check_server(API_CONFIG.TTS_URL + "/check")
28
+ self.digital_human_server_enabled = self._check_server(API_CONFIG.DIGITAL_HUMAN_CHECK_URL)
29
+ self.asr_server_enabled = self._check_server(API_CONFIG.ASR_URL + "/check")
30
+ self.llm_enabled = self._check_server(API_CONFIG.LLM_URL)
31
+
32
+ if WEB_CONFIGS.AGENT_DELIVERY_TIME_API_KEY is None or WEB_CONFIGS.AGENT_WEATHER_API_KEY is None:
33
+ self.agent_enabled = False
34
+ else:
35
+ self.agent_enabled = True
36
+
37
+ self.rag_enabled = WEB_CONFIGS.ENABLE_RAG
38
+
39
+ logger.info(
40
+ "\nself check plugins info : \n"
41
+ f"| llm | {self.llm_enabled} |\n"
42
+ f"| rag | {self.rag_enabled} |\n"
43
+ f"| tts | {self.tts_server_enabled} |\n"
44
+ f"| digital hunam | {self.digital_human_server_enabled} |\n"
45
+ f"| asr | {self.asr_server_enabled} |\n"
46
+ f"| agent | {self.agent_enabled} |\n"
47
+ )
48
+
49
+ @staticmethod
50
+ def _check_server(url):
51
+
52
+ try:
53
+ res = requests.get(url)
54
+ except requests.exceptions.ConnectionError:
55
+ return False
56
+
57
+ if res.status_code == 200:
58
+ return True
59
+ else:
60
+ return False
61
+
62
+ @staticmethod
63
+ def _make_color_list(color_num):
64
+
65
+ color_list = [
66
+ "#FF3838",
67
+ "#FF9D97",
68
+ "#FF701F",
69
+ "#FFB21D",
70
+ "#CFD231",
71
+ "#48F90A",
72
+ "#92CC17",
73
+ "#3DDB86",
74
+ "#1A9334",
75
+ "#00D4BB",
76
+ "#2C99A8",
77
+ "#00C2FF",
78
+ "#344593",
79
+ "#6473FF",
80
+ "#0018EC",
81
+ "#8438FF",
82
+ "#520085",
83
+ "#CB38FF",
84
+ "#FF95C8",
85
+ "#FF37C7",
86
+ ]
87
+
88
+ return random.sample(color_list, color_num)
89
+
90
+ def get_status(self):
91
+ self.update_info()
92
+
93
+ info_list = [
94
+ {
95
+ "plugin_name": "LLM",
96
+ "describe": "大语言模型,用于根据客户历史对话,生成对话信息",
97
+ "enabled": self.llm_enabled,
98
+ },
99
+ {
100
+ "plugin_name": "RAG",
101
+ "describe": "用于调用知识库实时更新信息",
102
+ "enabled": self.rag_enabled,
103
+ },
104
+ {
105
+ "plugin_name": "TTS",
106
+ "describe": "文字转语音,让主播的文字也能听到",
107
+ "enabled": self.tts_server_enabled,
108
+ },
109
+ {
110
+ "plugin_name": "数字人",
111
+ "describe": "数字人服务,用于生成数字人,需要和 TTS 一起开启才有效果",
112
+ "enabled": self.digital_human_server_enabled,
113
+ },
114
+ {
115
+ "plugin_name": "Agent",
116
+ "describe": "用于根据用户对话,获取网络的实时信息",
117
+ "enabled": self.agent_enabled,
118
+ },
119
+ {
120
+ "plugin_name": "ASR",
121
+ "describe": "语音转文字,让用户无需打字就可以和主播进行对话",
122
+ "enabled": self.asr_server_enabled,
123
+ },
124
+ ]
125
+
126
+ # 生成图标背景色
127
+ color_list = self._make_color_list(len(info_list))
128
+ for idx, color in enumerate(color_list):
129
+ info_list[idx].update({"avatar_color": color})
130
+
131
+ return info_list
132
+
133
+
134
+ SERVER_PLUGINS_INFO = ServerPluginsInfo()
server/base/utils.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : utils.py
5
+ @Time : 2024/09/02
6
+ @Project : https://github.com/PeterH0323/Streamer-Sales
7
+ @Author : HinGwenWong
8
+ @Version : 1.0
9
+ @Desc : 工具集合文件
10
+ """
11
+
12
+
13
+ import asyncio
14
+ from ipaddress import IPv4Address
15
+ import json
16
+ import random
17
+ import wave
18
+ from dataclasses import dataclass
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Dict, List
22
+
23
+ import cv2
24
+ from lmdeploy.serve.openai.api_client import APIClient
25
+ from loguru import logger
26
+ from pydantic import BaseModel
27
+ from sqlmodel import Session, select
28
+ from tqdm import tqdm
29
+
30
+ from server.base.models.user_model import UserInfo
31
+
32
+ from ..tts.tools import SYMBOL_SPLITS, make_text_chunk
33
+ from ..web_configs import API_CONFIG, WEB_CONFIGS
34
+ from .database.init_db import DB_ENGINE
35
+ from .models.product_model import ProductInfo
36
+ from .models.streamer_info_model import StreamerInfo
37
+ from .models.streamer_room_model import OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo
38
+
39
+ from .modules.agent.agent_worker import get_agent_result
40
+ from .modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
41
+ from .queue_thread import DIGITAL_HUMAN_QUENE, TTS_TEXT_QUENE
42
+ from .server_info import SERVER_PLUGINS_INFO
43
+
44
+
45
+ class ChatGenConfig(BaseModel):
46
+ # LLM 推理配置
47
+ top_p: float = 0.8
48
+ temperature: float = 0.7
49
+ repetition_penalty: float = 1.005
50
+
51
+
52
+ class ProductInfoItem(BaseModel):
53
+ name: str
54
+ heighlights: str
55
+ introduce: str # 生成商品文案 prompt
56
+
57
+ image_path: str
58
+ departure_place: str
59
+ delivery_company_name: str
60
+
61
+
62
+ class PluginsInfo(BaseModel):
63
+ rag: bool = True
64
+ agent: bool = True
65
+ tts: bool = True
66
+ digital_human: bool = True
67
+
68
+
69
+ class ChatItem(BaseModel):
70
+ user_id: str # User 识别号,用于区分不用的用户调用
71
+ request_id: str # 请求 ID,用于生成 TTS & 数字人
72
+ prompt: List[Dict[str, str]] # 本次的 prompt
73
+ product_info: ProductInfoItem # 商品信息
74
+ plugins: PluginsInfo = PluginsInfo() # 插件信息
75
+ chat_config: ChatGenConfig = ChatGenConfig()
76
+
77
+
78
+ # 加载 LLM 模型
79
+ LLM_MODEL_HANDLER = APIClient(API_CONFIG.LLM_URL)
80
+
81
+
82
+ async def streamer_sales_process(chat_item: ChatItem):
83
+
84
+ # ====================== Agent ======================
85
+ # 调取 Agent
86
+ agent_response = ""
87
+ if chat_item.plugins.agent and SERVER_PLUGINS_INFO.agent_enabled:
88
+ GENERATE_AGENT_TEMPLATE = (
89
+ "这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # Agent prompt 模板
90
+ )
91
+ input_prompt = chat_item.prompt[-1]["content"]
92
+ agent_response = get_agent_result(
93
+ LLM_MODEL_HANDLER, input_prompt, chat_item.product_info.departure_place, chat_item.product_info.delivery_company_name
94
+ )
95
+ if agent_response != "":
96
+ agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt)
97
+ print(f"Agent response: {agent_response}")
98
+ chat_item.prompt[-1]["content"] = agent_response
99
+
100
+ # ====================== RAG ======================
101
+ # 调取 rag
102
+ if chat_item.plugins.rag and agent_response == "":
103
+ # 如果 Agent 没有执行,则使用 RAG 查询数据库
104
+ rag_prompt = chat_item.prompt[-1]["content"]
105
+ prompt_pro = build_rag_prompt(RAG_RETRIEVER, chat_item.product_info.name, rag_prompt)
106
+
107
+ if prompt_pro != "":
108
+ chat_item.prompt[-1]["content"] = prompt_pro
109
+
110
+ # llm 推理流返回
111
+ logger.info(chat_item.prompt)
112
+
113
+ current_predict = ""
114
+ idx = 0
115
+ last_text_index = 0
116
+ sentence_id = 0
117
+ model_name = LLM_MODEL_HANDLER.available_models[0]
118
+ for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=chat_item.prompt, stream=True):
119
+ logger.debug(f"LLM predict: {item}")
120
+ if "content" not in item["choices"][0]["delta"]:
121
+ continue
122
+ current_res = item["choices"][0]["delta"]["content"]
123
+
124
+ if "~" in current_res:
125
+ current_res = current_res.replace("~", "。").replace("。。", "。")
126
+
127
+ current_predict += current_res
128
+ idx += 1
129
+
130
+ if chat_item.plugins.tts and SERVER_PLUGINS_INFO.tts_server_enabled:
131
+ # 切句子
132
+ sentence = ""
133
+ for symbol in SYMBOL_SPLITS:
134
+ if symbol in current_res:
135
+ last_text_index, sentence = make_text_chunk(current_predict, last_text_index)
136
+ if len(sentence) <= 3:
137
+ # 文字太短的情况,不做生成
138
+ sentence = ""
139
+ break
140
+
141
+ if sentence != "":
142
+ sentence_id += 1
143
+ logger.info(f"get sentence: {sentence}")
144
+ tts_request_dict = {
145
+ "user_id": chat_item.user_id,
146
+ "request_id": chat_item.request_id,
147
+ "sentence": sentence,
148
+ "chunk_id": sentence_id,
149
+ # "wav_save_name": chat_item.request_id + f"{str(sentence_id).zfill(8)}.wav",
150
+ }
151
+
152
+ TTS_TEXT_QUENE.put(tts_request_dict)
153
+ await asyncio.sleep(0.01)
154
+
155
+ yield json.dumps(
156
+ {
157
+ "event": "message",
158
+ "retry": 100,
159
+ "id": idx,
160
+ "data": current_predict,
161
+ "step": "llm",
162
+ "end_flag": False,
163
+ },
164
+ ensure_ascii=False,
165
+ )
166
+ await asyncio.sleep(0.01) # 加个延时避免无法发出 event stream
167
+
168
+ if chat_item.plugins.digital_human and SERVER_PLUGINS_INFO.digital_human_server_enabled:
169
+
170
+ wav_list = [
171
+ Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + f"-{str(i).zfill(8)}.wav")
172
+ for i in range(1, sentence_id + 1)
173
+ ]
174
+ while True:
175
+ # 等待 TTS 生成完成
176
+ not_exist_count = 0
177
+ for tts_wav in wav_list:
178
+ if not tts_wav.exists():
179
+ not_exist_count += 1
180
+
181
+ logger.info(f"still need to wait for {not_exist_count}/{sentence_id} wav generating...")
182
+ if not_exist_count == 0:
183
+ break
184
+
185
+ yield json.dumps(
186
+ {
187
+ "event": "message",
188
+ "retry": 100,
189
+ "id": idx,
190
+ "data": current_predict,
191
+ "step": "tts",
192
+ "end_flag": False,
193
+ },
194
+ ensure_ascii=False,
195
+ )
196
+ await asyncio.sleep(1) # 加个延时避免无法发出 event stream
197
+
198
+ # 合并 tts
199
+ tts_save_path = Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + ".wav")
200
+ all_tts_data = []
201
+
202
+ for wav_file in tqdm(wav_list):
203
+ logger.info(f"Reading wav file {wav_file}...")
204
+ with wave.open(str(wav_file), "rb") as wf:
205
+ all_tts_data.append([wf.getparams(), wf.readframes(wf.getnframes())])
206
+
207
+ logger.info(f"Merging wav file to {tts_save_path}...")
208
+ tts_params = max([tts_data[0] for tts_data in all_tts_data])
209
+ with wave.open(str(tts_save_path), "wb") as wf:
210
+ wf.setparams(tts_params) # 使用第一个音频参数
211
+
212
+ for wf_data in all_tts_data:
213
+ wf.writeframes(wf_data[1])
214
+ logger.info(f"Merged wav file to {tts_save_path} !")
215
+
216
+ # 生成数字人视频
217
+ tts_request_dict = {
218
+ "user_id": chat_item.user_id,
219
+ "request_id": chat_item.request_id,
220
+ "chunk_id": 0,
221
+ "tts_path": str(tts_save_path),
222
+ }
223
+
224
+ logger.info(f"Generating digital human...")
225
+ DIGITAL_HUMAN_QUENE.put(tts_request_dict)
226
+ while True:
227
+ if (
228
+ Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH)
229
+ .joinpath(Path(tts_save_path).stem + ".mp4")
230
+ .with_suffix(".txt")
231
+ .exists()
232
+ ):
233
+ break
234
+ yield json.dumps(
235
+ {
236
+ "event": "message",
237
+ "retry": 100,
238
+ "id": idx,
239
+ "data": current_predict,
240
+ "step": "dg",
241
+ "end_flag": False,
242
+ },
243
+ ensure_ascii=False,
244
+ )
245
+ await asyncio.sleep(1) # 加个延时避免无法发出 event stream
246
+
247
+ # 删除过程文件
248
+ for wav_file in wav_list:
249
+ wav_file.unlink()
250
+
251
+ yield json.dumps(
252
+ {
253
+ "event": "message",
254
+ "retry": 100,
255
+ "id": idx,
256
+ "data": current_predict,
257
+ "step": "all",
258
+ "end_flag": True,
259
+ },
260
+ ensure_ascii=False,
261
+ )
262
+
263
+
264
+ def make_poster_by_video_first_frame(video_path: str, image_output_name: str):
265
+ """根据视频第一帧生成缩略图
266
+
267
+ Args:
268
+ video_path (str): 视频文件路径
269
+
270
+ Returns:
271
+ str: 第一帧保存的图片路径
272
+ """
273
+
274
+ # 打开视频文件
275
+ cap = cv2.VideoCapture(video_path)
276
+
277
+ # 读取第一帧
278
+ ret, frame = cap.read()
279
+
280
+ # 检查是否成功读取
281
+ poster_save_path = str(Path(video_path).parent.joinpath(image_output_name))
282
+ if ret:
283
+ # 保存图像到文件
284
+ cv2.imwrite(poster_save_path, frame)
285
+ logger.info(f"第一帧已保存为 {poster_save_path}")
286
+ else:
287
+ logger.error("无法读取视频帧")
288
+
289
+ # 释放视频捕获对象
290
+ cap.release()
291
+
292
+ return poster_save_path
293
+
294
+
295
+ @dataclass
296
+ class ResultCode:
297
+ SUCCESS: int = 0000 # 成功
298
+ FAIL: int = 1000 # 失败
299
+
300
+
301
+ def make_return_data(success_flag: bool, code: ResultCode, message: str, data: dict):
302
+ return {
303
+ "success": success_flag,
304
+ "code": code,
305
+ "message": message,
306
+ "data": data,
307
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
308
+ }
309
+
310
+
311
+ def gen_default_data():
312
+ """生成默认数据,包括:
313
+ - 商品数据
314
+ - 主播数据
315
+ - 直播间信息以及关联表
316
+ """
317
+
318
+ def create_default_user():
319
+ """创建默认用户"""
320
+ admin_user = UserInfo(
321
+ username="hingwen.wong",
322
+ ip_address=IPv4Address("127.0.0.1"),
323
+ email="[email protected]",
324
+ hashed_password="$2b$12$zXXveodjipHZMoSxJz5ODul7Z9YeRJd0GeSBjpwHdqEtBbAFvEdre", # 123456 -> 用 get_password_hash 加密后的字符串
325
+ avatar="/user/user-avatar.png",
326
+ )
327
+
328
+ with Session(DB_ENGINE) as session:
329
+ session.add(admin_user)
330
+ session.commit()
331
+
332
+ def init_user() -> bool:
333
+ """判断是否需要创建默认用户
334
+
335
+ Returns:
336
+ bool: 是否执行创建默认用户
337
+ """
338
+ with Session(DB_ENGINE) as session:
339
+ results = session.exec(select(UserInfo).where(UserInfo.user_id == 1)).first()
340
+
341
+ if results is None:
342
+ # 如果数据库为空,创建初始用户
343
+ create_default_user()
344
+ logger.info("created default user info")
345
+ return True
346
+
347
+ return False
348
+
349
+ def create_default_product_item():
350
+ """生成商品默认数据库"""
351
+ delivery_company_list = ["京东", "顺丰", "韵达", "圆通", "中通"]
352
+ departure_place_list = ["广州", "北京", "武汉", "杭州", "上海", "深圳", "成都"]
353
+ default_product_list = {
354
+ "beef": {
355
+ "product_name": "进口和牛羽下肉",
356
+ "heighlights": "富含铁质;营养价值高;肌肉纤维好;红白相间纹理;适合烧烤炖煮;草食动物来源",
357
+ "product_class": "食品",
358
+ },
359
+ "elec_toothblush": {
360
+ "product_name": "声波电动牙刷",
361
+ "heighlights": "高效清洁;减少手动压力;定时提醒;智能模式调节;无线充电;噪音低",
362
+ "product_class": "电子",
363
+ },
364
+ "lip_stick": {
365
+ "product_name": "唇膏",
366
+ "heighlights": "丰富色号;滋润保湿;显色度高;持久不脱色;易于涂抹;便携包装",
367
+ "product_class": "美妆",
368
+ },
369
+ "mask": {
370
+ "product_name": "光感润颜面膜",
371
+ "heighlights": "密集滋养;深层补水;急救修复;快速见效;定期护理;多种类型选择",
372
+ "product_class": "美妆",
373
+ },
374
+ "oled_tv": {
375
+ "product_name": "65英寸OLED电视",
376
+ "heighlights": "色彩鲜艳;对比度极高;响应速度快;无背光眩光;厚度较薄;自发光无需额外照明",
377
+ "product_class": "家电",
378
+ },
379
+ "pad": {
380
+ "product_name": "14英寸平板电脑",
381
+ "heighlights": "轻薄;触控操作;电池续航好;移动办公便利;娱乐性强;适合儿童学习",
382
+ "product_class": "电子",
383
+ },
384
+ "pants": {
385
+ "product_name": "速干运动裤",
386
+ "heighlights": "快干;伸缩自如;吸湿排汗;防风保暖;高腰设计;多口袋实用",
387
+ "product_class": "衣服",
388
+ },
389
+ "pen": {
390
+ "product_name": "墨水钢笔",
391
+ "heighlights": "耐用性;可书写性;不同颜色和类型;轻便设计;环保材料;易于携带",
392
+ "product_class": "文具",
393
+ },
394
+ "perfume": {
395
+ "product_name": "薰衣草淡香氛",
396
+ "heighlights": "浪漫优雅;花香调为主;情感表达;适合各种年龄;瓶身设计精致;提升女性魅力",
397
+ "product_class": "家居用品",
398
+ },
399
+ "shampoo": {
400
+ "product_name": "本草精华洗发露",
401
+ "heighlights": "温和配方;深层清洁;滋养头皮;丰富泡沫;易冲洗;适合各种发质",
402
+ "product_class": "日用品",
403
+ },
404
+ "wok": {
405
+ "product_name": "不粘煎炒锅",
406
+ "heighlights": "不粘涂层;耐磨耐用;导热快;易清洗;多种烹饪方式;设计人性化",
407
+ "product_class": "厨具",
408
+ },
409
+ "yoga_mat": {
410
+ "product_name": "瑜伽垫",
411
+ "heighlights": "防滑材质;吸湿排汗;厚度适中;耐用易清洁;各种瑜伽动作适用;轻巧便携",
412
+ "product_class": "运动",
413
+ },
414
+ }
415
+
416
+ with Session(DB_ENGINE) as session:
417
+ for product_key, product_info in default_product_list.items():
418
+ add_item = ProductInfo(
419
+ **product_info,
420
+ image_path=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.IMAGES_DIR}/{product_key}.png",
421
+ instruction=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.INSTRUCTIONS_DIR}/{product_key}.md",
422
+ departure_place=random.choice(departure_place_list),
423
+ delivery_company=random.choice(delivery_company_list),
424
+ selling_price=round(random.uniform(66.6, 1999.9), 2),
425
+ amount=random.randint(999, 9999),
426
+ user_id=1,
427
+ )
428
+ session.add(add_item)
429
+ session.commit()
430
+
431
+ logger.info("created default product info done!")
432
+
433
+ def create_default_streamer():
434
+
435
+ with Session(DB_ENGINE) as session:
436
+ streamer_item = StreamerInfo(
437
+ name="乐乐喵",
438
+ character="甜美;可爱;熟练使用各种网络热门梗造句;称呼客户为[家人们]",
439
+ avatar=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.png",
440
+ base_mp4_path=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.mp4",
441
+ poster_image=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.png",
442
+ tts_reference_audio=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.wav",
443
+ tts_reference_sentence="列车巡游银河,我不一定都能帮上忙,但只要是花钱能解决的事,尽管和我说吧。",
444
+ tts_weight_tag="艾丝妲",
445
+ user_id=1,
446
+ )
447
+ session.add(streamer_item)
448
+ session.commit()
449
+
450
+ def create_default_room():
451
+
452
+ with Session(DB_ENGINE) as session:
453
+
454
+ product_list = session.exec(
455
+ select(ProductInfo).where(ProductInfo.user_id == 1).order_by(ProductInfo.product_id)
456
+ ).all()
457
+
458
+ on_air_status = OnAirRoomStatusItem(user_id=1)
459
+ session.add(on_air_status)
460
+ session.commit()
461
+ session.refresh(on_air_status)
462
+
463
+ stream_item = StreamRoomInfo(
464
+ name="001",
465
+ user_id=1,
466
+ status_id=on_air_status.status_id,
467
+ streamer_id=1,
468
+ )
469
+ session.add(stream_item)
470
+ session.commit()
471
+ session.refresh(stream_item)
472
+
473
+ random_list = random.choices(product_list, k=3)
474
+ for product_random in random_list:
475
+ add_sales_info = SalesDocAndVideoInfo(product_id=product_random.product_id, room_id=stream_item.room_id)
476
+ session.add(add_sales_info)
477
+ session.commit()
478
+ session.refresh(add_sales_info)
479
+
480
+ # 主要逻辑
481
+ created = init_user()
482
+ if created:
483
+ create_default_product_item() # 商品信息
484
+ create_default_streamer() # 主播信息
485
+ create_default_room() # 直播间信息
server/digital_human/digital_human_server.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.exceptions import RequestValidationError
3
+ from fastapi.responses import PlainTextResponse
4
+ from loguru import logger
5
+ from pydantic import BaseModel
6
+
7
+
8
+ from .modules.digital_human_worker import gen_digital_human_video_app, preprocess_digital_human_app
9
+
10
+
11
+ app = FastAPI()
12
+
13
+
14
+ class DigitalHumanItem(BaseModel):
15
+ user_id: str # User 识别号,用于区分不用的用户调用
16
+ request_id: str # 请求 ID,用于生成 TTS & 数字人
17
+ streamer_id: str # 数字人 ID
18
+ tts_path: str = "" # 文本
19
+ chunk_id: int = 0 # 句子 ID
20
+
21
+
22
+ class DigitalHumanPreprocessItem(BaseModel):
23
+ user_id: str # User 识别号,用于区分不用的用户调用
24
+ request_id: str # 请求 ID,用于生成 TTS & 数字人
25
+ streamer_id: str # 数字人 ID
26
+ video_path: str # 数字人视频
27
+
28
+
29
+ @app.post("/digital_human/gen")
30
+ async def get_digital_human(dg_item: DigitalHumanItem):
31
+ """生成数字人视频"""
32
+ save_tag = (
33
+ dg_item.request_id + ".mp4" if dg_item.chunk_id == 0 else dg_item.request_id + f"-{str(dg_item.chunk_id).zfill(8)}.mp4"
34
+ )
35
+ mp4_path = await gen_digital_human_video_app(dg_item.streamer_id, dg_item.tts_path, save_tag)
36
+ logger.info(f"digital human mp4 path = {mp4_path}")
37
+ return {"user_id": dg_item.user_id, "request_id": dg_item.request_id, "digital_human_mp4_path": mp4_path}
38
+
39
+
40
+ @app.post("/digital_human/preprocess")
41
+ async def preprocess_digital_human(preprocess_item: DigitalHumanPreprocessItem):
42
+ """数字人视频预处理,用于新增数字人"""
43
+
44
+ _ = await preprocess_digital_human_app(str(preprocess_item.streamer_id), preprocess_item.video_path)
45
+
46
+ logger.info(f"digital human process for {preprocess_item.streamer_id} done")
47
+ return {"user_id": preprocess_item.user_id, "request_id": preprocess_item.request_id}
48
+
49
+
50
+ @app.exception_handler(RequestValidationError)
51
+ async def validation_exception_handler(request, exc):
52
+ """调 API 入参错误的回调接口
53
+
54
+ Args:
55
+ request (_type_): _description_
56
+ exc (_type_): _description_
57
+
58
+ Returns:
59
+ _type_: _description_
60
+ """
61
+ logger.info(request)
62
+ logger.info(exc)
63
+ return PlainTextResponse(str(exc), status_code=400)
64
+
65
+
66
+ @app.get("/digital_human/check")
67
+ async def check_server():
68
+ return {"message": "server enabled"}
server/digital_human/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from torch import hub
2
+ from ...web_configs import WEB_CONFIGS
3
+ from pathlib import Path
4
+
5
+ # 部分模型会使用 torch download 下载,需要设置路径
6
+ hub.set_dir(str(Path(WEB_CONFIGS.DIGITAL_HUMAN_MODEL_DIR).joinpath("face-alignment")))
server/digital_human/modules/digital_human_worker.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from .realtime_inference import DIGITAL_HUMAN_HANDLER, gen_digital_human_preprocess, gen_digital_human_video
3
+ from ...web_configs import WEB_CONFIGS
4
+
5
+
6
+ async def gen_digital_human_video_app(stream_id, audio_path, save_tag):
7
+ if DIGITAL_HUMAN_HANDLER is None:
8
+ return None
9
+
10
+ save_path = gen_digital_human_video(
11
+ DIGITAL_HUMAN_HANDLER,
12
+ stream_id,
13
+ audio_path,
14
+ work_dir=str(Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH).absolute()),
15
+ video_path=save_tag,
16
+ fps=DIGITAL_HUMAN_HANDLER.fps,
17
+ )
18
+
19
+ return save_path
20
+
21
+
22
+ async def preprocess_digital_human_app(stream_id, video_path):
23
+ if DIGITAL_HUMAN_HANDLER is None:
24
+ return None
25
+
26
+ res = gen_digital_human_preprocess(
27
+ DIGITAL_HUMAN_HANDLER,
28
+ stream_id,
29
+ work_dir=str(Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH).absolute()),
30
+ video_path=video_path,
31
+ )
32
+
33
+ return res
server/digital_human/modules/musetalk/models/unet.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import json
5
+
6
+ from diffusers import UNet2DConditionModel
7
+
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_model=384, max_len=5000):
10
+ super(PositionalEncoding, self).__init__()
11
+ pe = torch.zeros(max_len, d_model)
12
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
13
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
14
+ pe[:, 0::2] = torch.sin(position * div_term)
15
+ pe[:, 1::2] = torch.cos(position * div_term)
16
+ pe = pe.unsqueeze(0)
17
+ self.register_buffer('pe', pe)
18
+
19
+ def forward(self, x):
20
+ b, seq_len, d_model = x.size()
21
+ pe = self.pe[:, :seq_len, :]
22
+ x = x + pe.to(x.device)
23
+ return x
24
+
25
+ class UNet():
26
+ def __init__(self,
27
+ unet_config,
28
+ model_path,
29
+ use_float16=False,
30
+ ):
31
+ with open(unet_config, 'r') as f:
32
+ unet_config = json.load(f)
33
+ self.model = UNet2DConditionModel(**unet_config)
34
+ self.pe = PositionalEncoding(d_model=384)
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
37
+ self.model.load_state_dict(weights)
38
+ if use_float16:
39
+ self.model = self.model.half()
40
+ self.model.to(self.device)
41
+
42
+ if __name__ == "__main__":
43
+ unet = UNet()
server/digital_human/modules/musetalk/models/vae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as transforms
8
+ from diffusers import AutoencoderKL
9
+
10
+
11
+ class VAE():
12
+ """
13
+ VAE (Variational Autoencoder) class for image processing.
14
+ """
15
+
16
+ def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
17
+ """
18
+ Initialize the VAE instance.
19
+
20
+ :param model_path: Path to the trained model.
21
+ :param resized_img: The size to which images are resized.
22
+ :param use_float16: Whether to use float16 precision.
23
+ """
24
+ self.model_path = model_path
25
+ self.vae = AutoencoderKL.from_pretrained(self.model_path)
26
+
27
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ self.vae.to(self.device)
29
+
30
+ if use_float16:
31
+ self.vae = self.vae.half()
32
+ self._use_float16 = True
33
+ else:
34
+ self._use_float16 = False
35
+
36
+ self.scaling_factor = self.vae.config.scaling_factor
37
+ self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
38
+ self._resized_img = resized_img
39
+ self._mask_tensor = self.get_mask_tensor()
40
+
41
+ def get_mask_tensor(self):
42
+ """
43
+ Creates a mask tensor for image processing.
44
+ :return: A mask tensor.
45
+ """
46
+ mask_tensor = torch.zeros((self._resized_img,self._resized_img))
47
+ mask_tensor[:self._resized_img//2,:] = 1
48
+ mask_tensor[mask_tensor< 0.5] = 0
49
+ mask_tensor[mask_tensor>= 0.5] = 1
50
+ return mask_tensor
51
+
52
+ def preprocess_img(self,img_name,half_mask=False):
53
+ """
54
+ Preprocess an image for the VAE.
55
+
56
+ :param img_name: The image file path or a list of image file paths.
57
+ :param half_mask: Whether to apply a half mask to the image.
58
+ :return: A preprocessed image tensor.
59
+ """
60
+ window = []
61
+ if isinstance(img_name, str):
62
+ window_fnames = [img_name]
63
+ for fname in window_fnames:
64
+ img = cv2.imread(fname)
65
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
66
+ img = cv2.resize(img, (self._resized_img, self._resized_img),
67
+ interpolation=cv2.INTER_LANCZOS4)
68
+ window.append(img)
69
+ else:
70
+ img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
71
+ window.append(img)
72
+
73
+ x = np.asarray(window) / 255.
74
+ x = np.transpose(x, (3, 0, 1, 2))
75
+ x = torch.squeeze(torch.FloatTensor(x))
76
+ if half_mask:
77
+ x = x * (self._mask_tensor>0.5)
78
+ x = self.transform(x)
79
+
80
+ x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
81
+ x = x.to(self.vae.device)
82
+
83
+ return x
84
+
85
+ def encode_latents(self,image):
86
+ """
87
+ Encode an image into latent variables.
88
+
89
+ :param image: The image tensor to encode.
90
+ :return: The encoded latent variables.
91
+ """
92
+ with torch.no_grad():
93
+ init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
94
+ init_latents = self.scaling_factor * init_latent_dist.sample()
95
+ return init_latents
96
+
97
+ def decode_latents(self, latents):
98
+ """
99
+ Decode latent variables back into an image.
100
+ :param latents: The latent variables to decode.
101
+ :return: A NumPy array representing the decoded image.
102
+ """
103
+ latents = (1/ self.scaling_factor) * latents
104
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
105
+ image = (image / 2 + 0.5).clamp(0, 1)
106
+ image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
107
+ image = (image * 255).round().astype("uint8")
108
+ image = image[...,::-1] # RGB to BGR
109
+ return image
110
+
111
+ def get_latents_for_unet(self,img):
112
+ """
113
+ Prepare latent variables for a U-Net model.
114
+ :param img: The image to process.
115
+ :return: A concatenated tensor of latents for U-Net input.
116
+ """
117
+
118
+ ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
119
+ masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
120
+ ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
121
+ ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
122
+ latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
123
+ return latent_model_input
124
+
125
+ if __name__ == "__main__":
126
+ vae_mode_path = "./models/sd-vae-ft-mse/"
127
+ vae = VAE(model_path = vae_mode_path,use_float16=False)
128
+ img_path = "./results/sun001_crop/00000.png"
129
+
130
+ crop_imgs_path = "./results/sun001_crop/"
131
+ latents_out_path = "./results/latents/"
132
+ if not os.path.exists(latents_out_path):
133
+ os.mkdir(latents_out_path)
134
+
135
+ files = os.listdir(crop_imgs_path)
136
+ files.sort()
137
+ files = [file for file in files if file.split(".")[-1] == "png"]
138
+
139
+ for file in files:
140
+ index = file.split(".")[0]
141
+ img_path = crop_imgs_path + file
142
+ latents = vae.get_latents_for_unet(img_path)
143
+ print(img_path,"latents",latents.size())
144
+ #torch.save(latents,os.path.join(latents_out_path,index+".pt"))
145
+ #reload_tensor = torch.load('tensor.pt')
146
+ #print(reload_tensor.size())
147
+
148
+
149
+
server/digital_human/modules/musetalk/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import sys
2
+ from os.path import abspath, dirname
3
+ current_dir = dirname(abspath(__file__))
4
+ parent_dir = dirname(current_dir)
5
+ sys.path.append(parent_dir+'/utils')
server/digital_human/modules/musetalk/utils/blending.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2
4
+ from face_parsing import FaceParsing
5
+
6
+
7
+ def init_face_parsing_model(
8
+ resnet_path="./models/face-parse-bisent/resnet18-5c106cde.pth", face_model_pth="./models/face-parse-bisent/79999_iter.pth"
9
+ ):
10
+ fp_model = FaceParsing(resnet_path, face_model_pth)
11
+ return fp_model
12
+
13
+
14
+ def get_crop_box(box, expand):
15
+ x, y, x1, y1 = box
16
+ x_c, y_c = (x + x1) // 2, (y + y1) // 2
17
+ w, h = x1 - x, y1 - y
18
+ s = int(max(w, h) // 2 * expand)
19
+ crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
20
+ return crop_box, s
21
+
22
+
23
+ def face_seg(image, fp_model):
24
+ seg_image = fp_model(image)
25
+ if seg_image is None:
26
+ print("error, no person_segment")
27
+ return None
28
+
29
+ seg_image = seg_image.resize(image.size)
30
+ return seg_image
31
+
32
+
33
+ def get_image(image, face, face_box, fp_model, upper_boundary_ratio=0.5, expand=1.2):
34
+ # print(image.shape)
35
+ # print(face.shape)
36
+
37
+ body = Image.fromarray(image[:, :, ::-1])
38
+ face = Image.fromarray(face[:, :, ::-1])
39
+
40
+ x, y, x1, y1 = face_box
41
+ # print(x1-x,y1-y)
42
+ crop_box, s = get_crop_box(face_box, expand)
43
+ x_s, y_s, x_e, y_e = crop_box
44
+ face_position = (x, y)
45
+
46
+ face_large = body.crop(crop_box)
47
+ ori_shape = face_large.size
48
+
49
+ mask_image = face_seg(face_large, fp_model)
50
+ mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
51
+ mask_image = Image.new("L", ori_shape, 0)
52
+ mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
53
+
54
+ # keep upper_boundary_ratio of talking area
55
+ width, height = mask_image.size
56
+ top_boundary = int(height * upper_boundary_ratio)
57
+ modified_mask_image = Image.new("L", ori_shape, 0)
58
+ modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
59
+
60
+ blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
61
+ mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
62
+ mask_image = Image.fromarray(mask_array)
63
+
64
+ face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
65
+ body.paste(face_large, crop_box[:2], mask_image)
66
+ body = np.array(body)
67
+ return body[:, :, ::-1]
68
+
69
+
70
+ def get_image_prepare_material(image, face_box, fp_model, upper_boundary_ratio=0.5, expand=1.2):
71
+ body = Image.fromarray(image[:, :, ::-1])
72
+
73
+ x, y, x1, y1 = face_box
74
+ # print(x1-x,y1-y)
75
+ crop_box, s = get_crop_box(face_box, expand)
76
+ x_s, y_s, x_e, y_e = crop_box
77
+
78
+ face_large = body.crop(crop_box)
79
+ ori_shape = face_large.size
80
+
81
+ mask_image = face_seg(face_large, fp_model)
82
+ mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
83
+ mask_image = Image.new("L", ori_shape, 0)
84
+ mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
85
+
86
+ # keep upper_boundary_ratio of talking area
87
+ width, height = mask_image.size
88
+ top_boundary = int(height * upper_boundary_ratio)
89
+ modified_mask_image = Image.new("L", ori_shape, 0)
90
+ modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
91
+
92
+ blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
93
+ mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
94
+ return mask_array, crop_box
95
+
96
+
97
+ def get_image_blending(image, face, face_box, mask_array, crop_box):
98
+ body = Image.fromarray(image[:, :, ::-1])
99
+ face = Image.fromarray(face[:, :, ::-1])
100
+
101
+ x, y, x1, y1 = face_box
102
+ x_s, y_s, x_e, y_e = crop_box
103
+ face_large = body.crop(crop_box)
104
+
105
+ mask_image = Image.fromarray(mask_array)
106
+ mask_image = mask_image.convert("L")
107
+ face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
108
+ body.paste(face_large, crop_box[:2], mask_image)
109
+ body = np.array(body)
110
+ return body[:, :, ::-1]
server/digital_human/modules/musetalk/utils/dwpose/default_runtime.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_scope = 'mmpose'
2
+
3
+ # hooks
4
+ default_hooks = dict(
5
+ timer=dict(type='IterTimerHook'),
6
+ logger=dict(type='LoggerHook', interval=50),
7
+ param_scheduler=dict(type='ParamSchedulerHook'),
8
+ checkpoint=dict(type='CheckpointHook', interval=10),
9
+ sampler_seed=dict(type='DistSamplerSeedHook'),
10
+ visualization=dict(type='PoseVisualizationHook', enable=False),
11
+ badcase=dict(
12
+ type='BadCaseAnalysisHook',
13
+ enable=False,
14
+ out_dir='badcase',
15
+ metric_type='loss',
16
+ badcase_thr=5))
17
+
18
+ # custom hooks
19
+ custom_hooks = [
20
+ # Synchronize model buffers such as running_mean and running_var in BN
21
+ # at the end of each epoch
22
+ dict(type='SyncBuffersHook')
23
+ ]
24
+
25
+ # multi-processing backend
26
+ env_cfg = dict(
27
+ cudnn_benchmark=False,
28
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
29
+ dist_cfg=dict(backend='nccl'),
30
+ )
31
+
32
+ # visualizer
33
+ vis_backends = [
34
+ dict(type='LocalVisBackend'),
35
+ # dict(type='TensorboardVisBackend'),
36
+ # dict(type='WandbVisBackend'),
37
+ ]
38
+ visualizer = dict(
39
+ type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
40
+
41
+ # logger
42
+ log_processor = dict(
43
+ type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
44
+ log_level = 'INFO'
45
+ load_from = None
46
+ resume = False
47
+
48
+ # file I/O backend
49
+ backend_args = dict(backend='local')
50
+
51
+ # training/validation/testing progress
52
+ train_cfg = dict(by_epoch=True)
53
+ val_cfg = dict()
54
+ test_cfg = dict()
server/digital_human/modules/musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #_base_ = ['../../../_base_/default_runtime.py']
2
+ _base_ = ['default_runtime.py']
3
+
4
+ # runtime
5
+ max_epochs = 270
6
+ stage2_num_epochs = 30
7
+ base_lr = 4e-3
8
+ train_batch_size = 8
9
+ val_batch_size = 8
10
+
11
+ train_cfg = dict(max_epochs=max_epochs, val_interval=10)
12
+ randomness = dict(seed=21)
13
+
14
+ # optimizer
15
+ optim_wrapper = dict(
16
+ type='OptimWrapper',
17
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
18
+ paramwise_cfg=dict(
19
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
20
+
21
+ # learning rate
22
+ param_scheduler = [
23
+ dict(
24
+ type='LinearLR',
25
+ start_factor=1.0e-5,
26
+ by_epoch=False,
27
+ begin=0,
28
+ end=1000),
29
+ dict(
30
+ # use cosine lr from 150 to 300 epoch
31
+ type='CosineAnnealingLR',
32
+ eta_min=base_lr * 0.05,
33
+ begin=max_epochs // 2,
34
+ end=max_epochs,
35
+ T_max=max_epochs // 2,
36
+ by_epoch=True,
37
+ convert_to_iter_based=True),
38
+ ]
39
+
40
+ # automatically scaling LR based on the actual training batch size
41
+ auto_scale_lr = dict(base_batch_size=512)
42
+
43
+ # codec settings
44
+ codec = dict(
45
+ type='SimCCLabel',
46
+ input_size=(288, 384),
47
+ sigma=(6., 6.93),
48
+ simcc_split_ratio=2.0,
49
+ normalize=False,
50
+ use_dark=False)
51
+
52
+ # model settings
53
+ model = dict(
54
+ type='TopdownPoseEstimator',
55
+ data_preprocessor=dict(
56
+ type='PoseDataPreprocessor',
57
+ mean=[123.675, 116.28, 103.53],
58
+ std=[58.395, 57.12, 57.375],
59
+ bgr_to_rgb=True),
60
+ backbone=dict(
61
+ _scope_='mmdet',
62
+ type='CSPNeXt',
63
+ arch='P5',
64
+ expand_ratio=0.5,
65
+ deepen_factor=1.,
66
+ widen_factor=1.,
67
+ out_indices=(4, ),
68
+ channel_attention=True,
69
+ norm_cfg=dict(type='SyncBN'),
70
+ act_cfg=dict(type='SiLU'),
71
+ init_cfg=dict(
72
+ type='Pretrained',
73
+ prefix='backbone.',
74
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
75
+ 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
76
+ )),
77
+ head=dict(
78
+ type='RTMCCHead',
79
+ in_channels=1024,
80
+ out_channels=133,
81
+ input_size=codec['input_size'],
82
+ in_featuremap_size=(9, 12),
83
+ simcc_split_ratio=codec['simcc_split_ratio'],
84
+ final_layer_kernel_size=7,
85
+ gau_cfg=dict(
86
+ hidden_dims=256,
87
+ s=128,
88
+ expansion_factor=2,
89
+ dropout_rate=0.,
90
+ drop_path=0.,
91
+ act_fn='SiLU',
92
+ use_rel_bias=False,
93
+ pos_enc=False),
94
+ loss=dict(
95
+ type='KLDiscretLoss',
96
+ use_target_weight=True,
97
+ beta=10.,
98
+ label_softmax=True),
99
+ decoder=codec),
100
+ test_cfg=dict(flip_test=True, ))
101
+
102
+ # base dataset settings
103
+ dataset_type = 'UBody2dDataset'
104
+ data_mode = 'topdown'
105
+ data_root = 'data/UBody/'
106
+
107
+ backend_args = dict(backend='local')
108
+
109
+ scenes = [
110
+ 'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
111
+ 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
112
+ 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
113
+ ]
114
+
115
+ train_datasets = [
116
+ dict(
117
+ type='CocoWholeBodyDataset',
118
+ data_root='data/coco/',
119
+ data_mode=data_mode,
120
+ ann_file='annotations/coco_wholebody_train_v1.0.json',
121
+ data_prefix=dict(img='train2017/'),
122
+ pipeline=[])
123
+ ]
124
+
125
+ for scene in scenes:
126
+ train_dataset = dict(
127
+ type=dataset_type,
128
+ data_root=data_root,
129
+ data_mode=data_mode,
130
+ ann_file=f'annotations/{scene}/train_annotations.json',
131
+ data_prefix=dict(img='images/'),
132
+ pipeline=[],
133
+ sample_interval=10)
134
+ train_datasets.append(train_dataset)
135
+
136
+ # pipelines
137
+ train_pipeline = [
138
+ dict(type='LoadImage', backend_args=backend_args),
139
+ dict(type='GetBBoxCenterScale'),
140
+ dict(type='RandomFlip', direction='horizontal'),
141
+ dict(type='RandomHalfBody'),
142
+ dict(
143
+ type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
144
+ dict(type='TopdownAffine', input_size=codec['input_size']),
145
+ dict(type='mmdet.YOLOXHSVRandomAug'),
146
+ dict(
147
+ type='Albumentation',
148
+ transforms=[
149
+ dict(type='Blur', p=0.1),
150
+ dict(type='MedianBlur', p=0.1),
151
+ dict(
152
+ type='CoarseDropout',
153
+ max_holes=1,
154
+ max_height=0.4,
155
+ max_width=0.4,
156
+ min_holes=1,
157
+ min_height=0.2,
158
+ min_width=0.2,
159
+ p=1.0),
160
+ ]),
161
+ dict(type='GenerateTarget', encoder=codec),
162
+ dict(type='PackPoseInputs')
163
+ ]
164
+ val_pipeline = [
165
+ dict(type='LoadImage', backend_args=backend_args),
166
+ dict(type='GetBBoxCenterScale'),
167
+ dict(type='TopdownAffine', input_size=codec['input_size']),
168
+ dict(type='PackPoseInputs')
169
+ ]
170
+
171
+ train_pipeline_stage2 = [
172
+ dict(type='LoadImage', backend_args=backend_args),
173
+ dict(type='GetBBoxCenterScale'),
174
+ dict(type='RandomFlip', direction='horizontal'),
175
+ dict(type='RandomHalfBody'),
176
+ dict(
177
+ type='RandomBBoxTransform',
178
+ shift_factor=0.,
179
+ scale_factor=[0.5, 1.5],
180
+ rotate_factor=90),
181
+ dict(type='TopdownAffine', input_size=codec['input_size']),
182
+ dict(type='mmdet.YOLOXHSVRandomAug'),
183
+ dict(
184
+ type='Albumentation',
185
+ transforms=[
186
+ dict(type='Blur', p=0.1),
187
+ dict(type='MedianBlur', p=0.1),
188
+ dict(
189
+ type='CoarseDropout',
190
+ max_holes=1,
191
+ max_height=0.4,
192
+ max_width=0.4,
193
+ min_holes=1,
194
+ min_height=0.2,
195
+ min_width=0.2,
196
+ p=0.5),
197
+ ]),
198
+ dict(type='GenerateTarget', encoder=codec),
199
+ dict(type='PackPoseInputs')
200
+ ]
201
+
202
+ # data loaders
203
+ train_dataloader = dict(
204
+ batch_size=train_batch_size,
205
+ num_workers=10,
206
+ persistent_workers=True,
207
+ sampler=dict(type='DefaultSampler', shuffle=True),
208
+ dataset=dict(
209
+ type='CombinedDataset',
210
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
211
+ datasets=train_datasets,
212
+ pipeline=train_pipeline,
213
+ test_mode=False,
214
+ ))
215
+
216
+ val_dataloader = dict(
217
+ batch_size=val_batch_size,
218
+ num_workers=10,
219
+ persistent_workers=True,
220
+ drop_last=False,
221
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222
+ dataset=dict(
223
+ type='CocoWholeBodyDataset',
224
+ data_root=data_root,
225
+ data_mode=data_mode,
226
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
227
+ bbox_file='data/coco/person_detection_results/'
228
+ 'COCO_val2017_detections_AP_H_56_person.json',
229
+ data_prefix=dict(img='coco/val2017/'),
230
+ test_mode=True,
231
+ pipeline=val_pipeline,
232
+ ))
233
+ test_dataloader = val_dataloader
234
+
235
+ # hooks
236
+ default_hooks = dict(
237
+ checkpoint=dict(
238
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239
+
240
+ custom_hooks = [
241
+ dict(
242
+ type='EMAHook',
243
+ ema_type='ExpMomentumEMA',
244
+ momentum=0.0002,
245
+ update_buffers=True,
246
+ priority=49),
247
+ dict(
248
+ type='mmdet.PipelineSwitchHook',
249
+ switch_epoch=max_epochs - stage2_num_epochs,
250
+ switch_pipeline=train_pipeline_stage2)
251
+ ]
252
+
253
+ # evaluators
254
+ val_evaluator = dict(
255
+ type='CocoWholeBodyMetric',
256
+ ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
257
+ test_evaluator = val_evaluator
server/digital_human/modules/musetalk/utils/face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.