FAYO
commited on
Commit
·
1ef9436
1
Parent(s):
77b0e0f
model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- server/__init__.py +0 -0
- server/__pycache__/__init__.cpython-310.pyc +0 -0
- server/__pycache__/web_configs.cpython-310.pyc +0 -0
- server/asr/asr_server.py +58 -0
- server/asr/asr_worker.py +54 -0
- server/base/__init__.py +0 -0
- server/base/base_server.py +206 -0
- server/base/database/__init__.py +0 -0
- server/base/database/init_db.py +44 -0
- server/base/database/llm_db.py +22 -0
- server/base/database/product_db.py +186 -0
- server/base/database/streamer_info_db.py +152 -0
- server/base/database/streamer_room_db.py +415 -0
- server/base/database/user_db.py +48 -0
- server/base/models/__init__.py +0 -0
- server/base/models/llm_model.py +17 -0
- server/base/models/product_model.py +59 -0
- server/base/models/streamer_info_model.py +40 -0
- server/base/models/streamer_room_model.py +127 -0
- server/base/models/user_model.py +43 -0
- server/base/modules/__init__.py +0 -0
- server/base/modules/agent/__init__.py +0 -0
- server/base/modules/agent/agent_worker.py +200 -0
- server/base/modules/agent/delivery_time_query.py +300 -0
- server/base/modules/rag/__init__.py +0 -0
- server/base/modules/rag/feature_store.py +545 -0
- server/base/modules/rag/file_operation.py +228 -0
- server/base/modules/rag/rag_worker.py +122 -0
- server/base/modules/rag/retriever.py +244 -0
- server/base/modules/rag/test_queries.json +4 -0
- server/base/queue_thread.py +73 -0
- server/base/routers/__init__.py +0 -0
- server/base/routers/digital_human.py +85 -0
- server/base/routers/llm.py +187 -0
- server/base/routers/products.py +119 -0
- server/base/routers/streamer_info.py +156 -0
- server/base/routers/streaming_room.py +335 -0
- server/base/routers/users.py +157 -0
- server/base/server_info.py +134 -0
- server/base/utils.py +485 -0
- server/digital_human/digital_human_server.py +68 -0
- server/digital_human/modules/__init__.py +6 -0
- server/digital_human/modules/digital_human_worker.py +33 -0
- server/digital_human/modules/musetalk/models/unet.py +43 -0
- server/digital_human/modules/musetalk/models/vae.py +149 -0
- server/digital_human/modules/musetalk/utils/__init__.py +5 -0
- server/digital_human/modules/musetalk/utils/blending.py +110 -0
- server/digital_human/modules/musetalk/utils/dwpose/default_runtime.py +54 -0
- server/digital_human/modules/musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py +257 -0
- server/digital_human/modules/musetalk/utils/face_detection/README.md +1 -0
server/__init__.py
ADDED
File without changes
|
server/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (139 Bytes). View file
|
|
server/__pycache__/web_configs.cpython-310.pyc
ADDED
Binary file (4.19 kB). View file
|
|
server/asr/asr_server.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from fastapi.exceptions import RequestValidationError
|
3 |
+
from fastapi.responses import PlainTextResponse
|
4 |
+
from loguru import logger
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
from ..web_configs import WEB_CONFIGS
|
8 |
+
from .asr_worker import load_asr_model, process_asr
|
9 |
+
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
if WEB_CONFIGS.ENABLE_ASR:
|
13 |
+
ASR_HANDLER = load_asr_model()
|
14 |
+
else:
|
15 |
+
ASR_HANDLER = None
|
16 |
+
|
17 |
+
|
18 |
+
class ASRItem(BaseModel):
|
19 |
+
user_id: int # User 识别号,用于区分不用的用户调用
|
20 |
+
request_id: str # 请求 ID,用于生成 TTS & 数字人
|
21 |
+
wav_path: str # wav 文件路径
|
22 |
+
|
23 |
+
|
24 |
+
@app.post("/asr")
|
25 |
+
async def get_asr(asr_item: ASRItem):
|
26 |
+
# 语音转文字
|
27 |
+
result = ""
|
28 |
+
status = "success"
|
29 |
+
if ASR_HANDLER is None:
|
30 |
+
result = "ASR not enable in sever"
|
31 |
+
status = "fail"
|
32 |
+
logger.error(f"ASR not enable...")
|
33 |
+
else:
|
34 |
+
result = process_asr(ASR_HANDLER, asr_item.wav_path)
|
35 |
+
logger.info(f"ASR res for id {asr_item.request_id}, res = {result}")
|
36 |
+
|
37 |
+
return {"user_id": asr_item.user_id, "request_id": asr_item.request_id, "status": status, "result": result}
|
38 |
+
|
39 |
+
|
40 |
+
@app.get("/asr/check")
|
41 |
+
async def check_server():
|
42 |
+
return {"message": "server enabled"}
|
43 |
+
|
44 |
+
|
45 |
+
@app.exception_handler(RequestValidationError)
|
46 |
+
async def validation_exception_handler(request, exc):
|
47 |
+
"""调 API 入参错误的回调接口
|
48 |
+
|
49 |
+
Args:
|
50 |
+
request (_type_): _description_
|
51 |
+
exc (_type_): _description_
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
_type_: _description_
|
55 |
+
"""
|
56 |
+
logger.info(request)
|
57 |
+
logger.info(exc)
|
58 |
+
return PlainTextResponse(str(exc), status_code=400)
|
server/asr/asr_worker.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
|
3 |
+
from funasr import AutoModel
|
4 |
+
from funasr.download.name_maps_from_hub import name_maps_ms as NAME_MAPS_MS
|
5 |
+
from modelscope import snapshot_download
|
6 |
+
from modelscope.utils.constant import Invoke, ThirdParty
|
7 |
+
|
8 |
+
from ..web_configs import WEB_CONFIGS
|
9 |
+
|
10 |
+
|
11 |
+
def load_asr_model():
|
12 |
+
|
13 |
+
# 模型下载
|
14 |
+
model_path_info = dict()
|
15 |
+
for model_name in ["paraformer-zh", "fsmn-vad", "ct-punc"]:
|
16 |
+
print(f"downloading asr model : {NAME_MAPS_MS[model_name]}")
|
17 |
+
mode_dir = snapshot_download(
|
18 |
+
NAME_MAPS_MS[model_name],
|
19 |
+
revision="master",
|
20 |
+
user_agent={Invoke.KEY: Invoke.PIPELINE, ThirdParty.KEY: "funasr"},
|
21 |
+
cache_dir=WEB_CONFIGS.ASR_MODEL_DIR,
|
22 |
+
)
|
23 |
+
model_path_info[model_name] = mode_dir
|
24 |
+
NAME_MAPS_MS[model_name] = mode_dir # 更新权重路径环境变量
|
25 |
+
|
26 |
+
print(f"ASR model path info = {model_path_info}")
|
27 |
+
# paraformer-zh is a multi-functional asr model
|
28 |
+
# use vad, punc, spk or not as you need
|
29 |
+
model = AutoModel(
|
30 |
+
model="paraformer-zh", # 语音识别,带时间戳输出,非实时
|
31 |
+
vad_model="fsmn-vad", # 语音端点检测,实时
|
32 |
+
punc_model="ct-punc", # 标点恢复
|
33 |
+
# spk_model="cam++" # 说话人确认/分割
|
34 |
+
model_path=model_path_info["paraformer-zh"],
|
35 |
+
vad_kwargs={"model_path": model_path_info["fsmn-vad"]},
|
36 |
+
punc_kwargs={"model_path": model_path_info["ct-punc"]},
|
37 |
+
)
|
38 |
+
return model
|
39 |
+
|
40 |
+
|
41 |
+
def process_asr(model: AutoModel, wav_path):
|
42 |
+
# https://github.com/modelscope/FunASR/blob/main/README_zh.md#%E5%AE%9E%E6%97%B6%E8%AF%AD%E9%9F%B3%E8%AF%86%E5%88%AB
|
43 |
+
f_start_time = datetime.datetime.now()
|
44 |
+
res = model.generate(input=wav_path, batch_size_s=50, hotword="魔搭")
|
45 |
+
delta_time = datetime.datetime.now() - f_start_time
|
46 |
+
|
47 |
+
try:
|
48 |
+
print(f"ASR using time {delta_time}s, text: ", res[0]["text"])
|
49 |
+
res_str = res[0]["text"]
|
50 |
+
except Exception as e:
|
51 |
+
print("ASR 解析失败,无法获取到文字")
|
52 |
+
return ""
|
53 |
+
|
54 |
+
return res_str
|
server/base/__init__.py
ADDED
File without changes
|
server/base/base_server.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : base_server.py
|
5 |
+
@Time : 2024/09/02
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 中台服务入口文件
|
10 |
+
"""
|
11 |
+
|
12 |
+
import time
|
13 |
+
import uuid
|
14 |
+
from contextlib import asynccontextmanager
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile
|
18 |
+
from fastapi.exceptions import RequestValidationError
|
19 |
+
from fastapi.responses import PlainTextResponse
|
20 |
+
from fastapi.staticfiles import StaticFiles
|
21 |
+
from loguru import logger
|
22 |
+
|
23 |
+
from ..web_configs import API_CONFIG, WEB_CONFIGS
|
24 |
+
from .database.init_db import create_db_and_tables
|
25 |
+
from .routers import digital_human, llm, products, streamer_info, streaming_room, users
|
26 |
+
from .server_info import SERVER_PLUGINS_INFO
|
27 |
+
from .utils import ChatItem, ResultCode, gen_default_data, make_return_data, streamer_sales_process
|
28 |
+
|
29 |
+
swagger_description = """
|
30 |
+
|
31 |
+
## 项目地址
|
32 |
+
|
33 |
+
[销冠 —— 卖货主播大模型 && 后台管理系统](https://github.com/PeterH0323/Streamer-Sales)
|
34 |
+
|
35 |
+
## 功能点
|
36 |
+
|
37 |
+
1. 📜 **主播文案一键生成**
|
38 |
+
2. 🚀 KV cache + Turbomind **推理加速**
|
39 |
+
3. 📚 RAG **检索增强生成**
|
40 |
+
4. 🔊 TTS **文字转语音**
|
41 |
+
5. 🦸 **数字人生成**
|
42 |
+
6. 🌐 **Agent 网络查询**
|
43 |
+
7. 🎙️ **ASR 语音转文字**
|
44 |
+
8. 🍍 **Vue + pinia + element-plus **搭建的前端,可自由扩展快速开发
|
45 |
+
9. 🗝️ 后端采用 FastAPI + Uvicorn + PostgreSQL,**高性能,高效编码,生产可用,同时具有 JWT 身份验证**
|
46 |
+
10. 🐋 采用 Docker-compose 部署,**一键实现分布式部署**
|
47 |
+
|
48 |
+
"""
|
49 |
+
|
50 |
+
|
51 |
+
@asynccontextmanager
|
52 |
+
async def lifespan(app: FastAPI):
|
53 |
+
"""服务生命周期函数"""
|
54 |
+
# 启动
|
55 |
+
create_db_and_tables() # 创建数据库和数据表
|
56 |
+
|
57 |
+
# 新服务,生成默认数据,可以自行注释 or 修改
|
58 |
+
gen_default_data()
|
59 |
+
|
60 |
+
if WEB_CONFIGS.ENABLE_RAG:
|
61 |
+
from .modules.rag.rag_worker import load_rag_model
|
62 |
+
|
63 |
+
# 生成 rag 数据库
|
64 |
+
await load_rag_model(user_id=1)
|
65 |
+
|
66 |
+
yield
|
67 |
+
|
68 |
+
# 结束
|
69 |
+
logger.info("Base server stopped.")
|
70 |
+
|
71 |
+
|
72 |
+
app = FastAPI(
|
73 |
+
title="销冠 —— 卖货主播大模型 && 后台管理系统",
|
74 |
+
description=swagger_description,
|
75 |
+
summary="一个能够根据给定的商品特点从激发用户购买意愿角度出发进行商品解说的卖货主播大模型。",
|
76 |
+
version="1.0.0",
|
77 |
+
license_info={
|
78 |
+
"name": "AGPL-3.0 license",
|
79 |
+
"url": "https://github.com/PeterH0323/Streamer-Sales/blob/main/LICENSE",
|
80 |
+
},
|
81 |
+
root_path=API_CONFIG.API_V1_STR,
|
82 |
+
lifespan=lifespan,
|
83 |
+
)
|
84 |
+
|
85 |
+
# 注册路由
|
86 |
+
app.include_router(users.router)
|
87 |
+
app.include_router(products.router)
|
88 |
+
app.include_router(llm.router)
|
89 |
+
app.include_router(streamer_info.router)
|
90 |
+
app.include_router(streaming_room.router)
|
91 |
+
app.include_router(digital_human.router)
|
92 |
+
|
93 |
+
|
94 |
+
# 挂载静态文件目录,以便访问上传的文件
|
95 |
+
WEB_CONFIGS.SERVER_FILE_ROOT = str(Path(WEB_CONFIGS.SERVER_FILE_ROOT).absolute())
|
96 |
+
Path(WEB_CONFIGS.SERVER_FILE_ROOT).mkdir(parents=True, exist_ok=True)
|
97 |
+
logger.info(f"上传文件挂载路径: {WEB_CONFIGS.SERVER_FILE_ROOT}")
|
98 |
+
logger.info(f"上传文件访问 URL: {API_CONFIG.REQUEST_FILES_URL}")
|
99 |
+
app.mount(
|
100 |
+
f"/{API_CONFIG.REQUEST_FILES_URL.split('/')[-1]}",
|
101 |
+
StaticFiles(directory=WEB_CONFIGS.SERVER_FILE_ROOT),
|
102 |
+
name=API_CONFIG.REQUEST_FILES_URL.split("/")[-1],
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
@app.get("/")
|
107 |
+
async def hello():
|
108 |
+
return {"message": "Hello Streamer-Sales"}
|
109 |
+
|
110 |
+
|
111 |
+
@app.exception_handler(RequestValidationError)
|
112 |
+
async def validation_exception_handler(request, exc):
|
113 |
+
"""调 API 入参错误的回调接口
|
114 |
+
|
115 |
+
Args:
|
116 |
+
request (_type_): _description_
|
117 |
+
exc (_type_): _description_
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
_type_: _description_
|
121 |
+
"""
|
122 |
+
logger.info(request.headers)
|
123 |
+
logger.info(exc)
|
124 |
+
return PlainTextResponse(str(exc), status_code=400)
|
125 |
+
|
126 |
+
|
127 |
+
@app.get("/dashboard", tags=["base"], summary="获取主页信息接口")
|
128 |
+
async def get_dashboard_info():
|
129 |
+
"""首页展示数据"""
|
130 |
+
fake_dashboard_data = {
|
131 |
+
"registeredBrandNum": 98431, # 入驻品牌方
|
132 |
+
"productNum": 49132, # 商品数
|
133 |
+
"dailyActivity": 68431, # 日活
|
134 |
+
"todayOrder": 8461321, # 订单量
|
135 |
+
"totalSales": 245578131857, # 销售额
|
136 |
+
"conversionRate": 90.0, # 转化率
|
137 |
+
# 折线图
|
138 |
+
"orderNumList": [46813, 68461, 99561, 138131, 233812, 84613, 846122], # 订单量
|
139 |
+
"totalSalesList": [46813, 68461, 99561, 138131, 23383, 84613, 841213], # 销售额
|
140 |
+
"newUserList": [3215, 65131, 6513, 6815, 2338, 84614, 84213], # 新增用户
|
141 |
+
"activityUserList": [132, 684, 59431, 4618, 31354, 68431, 88431], # 活跃用户
|
142 |
+
# 柱状图
|
143 |
+
"knowledgeBasesNum": 12, # 知识库数量
|
144 |
+
"digitalHumanNum": 3, # 数字人数量
|
145 |
+
"LiveRoomNum": 5, # 直播间数量
|
146 |
+
}
|
147 |
+
|
148 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", fake_dashboard_data)
|
149 |
+
|
150 |
+
|
151 |
+
@app.get("/plugins_info", tags=["base"], summary="获取组件信息接口")
|
152 |
+
async def get_plugins_info():
|
153 |
+
|
154 |
+
plugins_info = SERVER_PLUGINS_INFO.get_status()
|
155 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", plugins_info)
|
156 |
+
|
157 |
+
|
158 |
+
@app.post("/upload/file", tags=["base"], summary="上传文件接口")
|
159 |
+
async def upload_product_api(file: UploadFile = File(...), user_id: int = Depends(users.get_current_user_info)):
|
160 |
+
|
161 |
+
file_type = file.filename.split(".")[-1] # eg. png
|
162 |
+
logger.info(f"upload file type = {file_type}")
|
163 |
+
|
164 |
+
sub_dir_name_map = {
|
165 |
+
"md": WEB_CONFIGS.INSTRUCTIONS_DIR,
|
166 |
+
"png": WEB_CONFIGS.IMAGES_DIR,
|
167 |
+
"jpg": WEB_CONFIGS.IMAGES_DIR,
|
168 |
+
"mp4": WEB_CONFIGS.STREAMER_INFO_FILES_DIR,
|
169 |
+
"wav": WEB_CONFIGS.STREAMER_INFO_FILES_DIR,
|
170 |
+
"webm": WEB_CONFIGS.ASR_FILE_DIR,
|
171 |
+
}
|
172 |
+
if file_type in ["wav", "mp4"]:
|
173 |
+
save_root = WEB_CONFIGS.STREAMER_FILE_DIR
|
174 |
+
elif file_type in ["webm"]:
|
175 |
+
save_root = ""
|
176 |
+
else:
|
177 |
+
save_root = WEB_CONFIGS.PRODUCT_FILE_DIR
|
178 |
+
|
179 |
+
upload_time = str(int(time.time())) + "__" + str(uuid.uuid4().hex)
|
180 |
+
|
181 |
+
sub_dir_name = sub_dir_name_map[file_type]
|
182 |
+
save_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(save_root, sub_dir_name, upload_time + "." + file_type)
|
183 |
+
save_path.parent.mkdir(exist_ok=True, parents=True)
|
184 |
+
logger.info(f"save path = {save_path}")
|
185 |
+
|
186 |
+
# 使用流式处理接收文件
|
187 |
+
with open(save_path, "wb") as buffer:
|
188 |
+
while chunk := await file.read(1024 * 1024 * 5): # 每次读取 5MB 的数据块
|
189 |
+
buffer.write(chunk)
|
190 |
+
|
191 |
+
split_dir_name = Path(WEB_CONFIGS.SERVER_FILE_ROOT).name # 保存文件夹根目录名字
|
192 |
+
file_url = f"{API_CONFIG.REQUEST_FILES_URL}{str(save_path).split(split_dir_name)[-1]}"
|
193 |
+
|
194 |
+
# TODO 文件归属记录表
|
195 |
+
|
196 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", file_url)
|
197 |
+
|
198 |
+
|
199 |
+
@app.post("/streamer-sales/chat", tags=["base"], summary="对话接口", deprecated=True)
|
200 |
+
async def streamer_sales_chat(chat_item: ChatItem, response: Response):
|
201 |
+
from sse_starlette import EventSourceResponse
|
202 |
+
|
203 |
+
# 对话总接口
|
204 |
+
response.headers["Content-Type"] = "text/event-stream"
|
205 |
+
response.headers["Cache-Control"] = "no-cache"
|
206 |
+
return EventSourceResponse(streamer_sales_process(chat_item))
|
server/base/database/__init__.py
ADDED
File without changes
|
server/base/database/init_db.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : init_db.py
|
5 |
+
@Time : 2024/09/06
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 数据库初始化
|
10 |
+
"""
|
11 |
+
|
12 |
+
from loguru import logger
|
13 |
+
from pydantic import PostgresDsn
|
14 |
+
from pydantic_core import MultiHostUrl
|
15 |
+
from sqlmodel import SQLModel, create_engine
|
16 |
+
|
17 |
+
from ...web_configs import WEB_CONFIGS
|
18 |
+
|
19 |
+
ECHO_DB_MESG = True # 数据库执行中是否回显,for debug
|
20 |
+
|
21 |
+
|
22 |
+
def sqlalchemy_db_url() -> PostgresDsn:
|
23 |
+
"""生成数据库 URL
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PostgresDsn: 数据库地址
|
27 |
+
"""
|
28 |
+
return MultiHostUrl.build(
|
29 |
+
scheme="postgresql+psycopg",
|
30 |
+
username=WEB_CONFIGS.POSTGRES_USER,
|
31 |
+
password=WEB_CONFIGS.POSTGRES_PASSWORD,
|
32 |
+
host=WEB_CONFIGS.POSTGRES_SERVER,
|
33 |
+
port=WEB_CONFIGS.POSTGRES_PORT,
|
34 |
+
path=WEB_CONFIGS.POSTGRES_DB,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
logger.info(f"connecting to db: {str(sqlalchemy_db_url())}")
|
39 |
+
DB_ENGINE = create_engine(str(sqlalchemy_db_url()), echo=ECHO_DB_MESG)
|
40 |
+
|
41 |
+
|
42 |
+
def create_db_and_tables():
|
43 |
+
"""创建所有数据库和对应的表,有则跳过"""
|
44 |
+
SQLModel.metadata.create_all(DB_ENGINE)
|
server/base/database/llm_db.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : llm_db.py
|
5 |
+
@Time : 2024/09/01
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 大模型对话数据库交互
|
10 |
+
"""
|
11 |
+
|
12 |
+
import yaml
|
13 |
+
|
14 |
+
from ...web_configs import WEB_CONFIGS
|
15 |
+
|
16 |
+
|
17 |
+
async def get_llm_product_prompt_base_info():
|
18 |
+
# 加载对话配置文件
|
19 |
+
with open(WEB_CONFIGS.CONVERSATION_CFG_YAML_PATH, "r", encoding="utf-8") as f:
|
20 |
+
dataset_yaml = yaml.safe_load(f)
|
21 |
+
|
22 |
+
return dataset_yaml
|
server/base/database/product_db.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : product_db.py
|
5 |
+
@Time : 2024/08/30
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 商品数据表文件读写
|
10 |
+
"""
|
11 |
+
|
12 |
+
from typing import List, Tuple
|
13 |
+
|
14 |
+
from loguru import logger
|
15 |
+
from sqlalchemy import func
|
16 |
+
from sqlmodel import Session, and_, not_, select
|
17 |
+
|
18 |
+
from ...web_configs import API_CONFIG
|
19 |
+
from ..models.product_model import ProductInfo
|
20 |
+
from .init_db import DB_ENGINE
|
21 |
+
|
22 |
+
|
23 |
+
async def get_db_product_info(
|
24 |
+
user_id: int,
|
25 |
+
current_page: int = -1,
|
26 |
+
page_size: int = 10,
|
27 |
+
product_name: str | None = None,
|
28 |
+
product_id: int | None = None,
|
29 |
+
exclude_list: List[int] | None = None,
|
30 |
+
) -> Tuple[List[ProductInfo], int]:
|
31 |
+
"""查询数据库中的商品信息
|
32 |
+
|
33 |
+
Args:
|
34 |
+
user_id (int): 用户 ID
|
35 |
+
current_page (int, optional): 页数. Defaults to -1.
|
36 |
+
page_size (int, optional): 每页的大小. Defaults to 10.
|
37 |
+
product_name (str | None, optional): 商品名称,模糊搜索. Defaults to None.
|
38 |
+
product_id (int | None, optional): 商品 ID,用户获取特定商品信息. Defaults to None.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
List[ProductInfo]: 商品信息
|
42 |
+
int : 该用户持有的总商品数,已剔除删除的
|
43 |
+
"""
|
44 |
+
|
45 |
+
assert current_page != 0
|
46 |
+
assert page_size != 0
|
47 |
+
|
48 |
+
# 查询条件
|
49 |
+
query_condiction = and_(ProductInfo.user_id == user_id, ProductInfo.delete == False)
|
50 |
+
|
51 |
+
# 获取总数
|
52 |
+
with Session(DB_ENGINE) as session:
|
53 |
+
# 获得该用户所有商品的总数
|
54 |
+
total_product_num = session.scalar(select(func.count(ProductInfo.product_id)).where(query_condiction))
|
55 |
+
|
56 |
+
if product_name is not None:
|
57 |
+
# 查询条件更改为商品名称模糊搜索
|
58 |
+
query_condiction = and_(
|
59 |
+
ProductInfo.user_id == user_id, ProductInfo.delete == False, ProductInfo.product_name.ilike(f"%{product_name}%")
|
60 |
+
)
|
61 |
+
|
62 |
+
elif product_id is not None:
|
63 |
+
# 查询条件更改为查找特定 ID
|
64 |
+
query_condiction = and_(
|
65 |
+
ProductInfo.user_id == user_id, ProductInfo.delete == False, ProductInfo.product_id == product_id
|
66 |
+
)
|
67 |
+
|
68 |
+
elif exclude_list is not None:
|
69 |
+
# 排除查询
|
70 |
+
query_condiction = and_(
|
71 |
+
ProductInfo.user_id == user_id, ProductInfo.delete == False, not_(ProductInfo.product_id.in_(exclude_list))
|
72 |
+
)
|
73 |
+
|
74 |
+
# 查询获取商品
|
75 |
+
if current_page < 0:
|
76 |
+
# 全部查询
|
77 |
+
product_list = session.exec(select(ProductInfo).where(query_condiction).order_by(ProductInfo.product_id)).all()
|
78 |
+
else:
|
79 |
+
# 分页查询
|
80 |
+
offset_idx = (current_page - 1) * page_size
|
81 |
+
product_list = session.exec(
|
82 |
+
select(ProductInfo).where(query_condiction).offset(offset_idx).limit(page_size).order_by(ProductInfo.product_id)
|
83 |
+
).all()
|
84 |
+
|
85 |
+
if product_list is None:
|
86 |
+
logger.warning("nothing to find in db...")
|
87 |
+
product_list = []
|
88 |
+
|
89 |
+
# 将路径换成服务器路径
|
90 |
+
for product in product_list:
|
91 |
+
product.image_path = API_CONFIG.REQUEST_FILES_URL + product.image_path
|
92 |
+
product.instruction = API_CONFIG.REQUEST_FILES_URL + product.instruction
|
93 |
+
|
94 |
+
logger.info(product_list)
|
95 |
+
logger.info(f"len {len(product_list)}")
|
96 |
+
|
97 |
+
return product_list, total_product_num
|
98 |
+
|
99 |
+
|
100 |
+
async def delete_product_id(product_id: int, user_id: int) -> bool:
|
101 |
+
"""删除特定的商品 ID
|
102 |
+
|
103 |
+
Args:
|
104 |
+
product_id (int): 商品 ID
|
105 |
+
user_id (int): 用户 ID,用于防止其他用户恶意删除
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
bool: 是否删除成功
|
109 |
+
"""
|
110 |
+
|
111 |
+
delete_success = True
|
112 |
+
|
113 |
+
try:
|
114 |
+
# 获取总数
|
115 |
+
with Session(DB_ENGINE) as session:
|
116 |
+
# 查找特定 ID
|
117 |
+
product_info = session.exec(
|
118 |
+
select(ProductInfo).where(and_(ProductInfo.product_id == product_id, ProductInfo.user_id == user_id))
|
119 |
+
).one()
|
120 |
+
|
121 |
+
if product_info is None:
|
122 |
+
logger.error("Delete by other ID !!!")
|
123 |
+
return False
|
124 |
+
|
125 |
+
product_info.delete = True # 设置为删除
|
126 |
+
session.add(product_info)
|
127 |
+
session.commit() # 提交
|
128 |
+
except Exception:
|
129 |
+
delete_success = False
|
130 |
+
|
131 |
+
return delete_success
|
132 |
+
|
133 |
+
|
134 |
+
def create_or_update_db_product_by_id(product_id: int, new_info: ProductInfo, user_id: int) -> bool:
|
135 |
+
"""新增 or 编辑商品信息
|
136 |
+
|
137 |
+
Args:
|
138 |
+
product_id (int): 商品 ID
|
139 |
+
new_info (ProductInfo): 新的信息
|
140 |
+
user_id (int): 用户 ID,用于防止其他用户恶意修改
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
bool: 说明书是否变化
|
144 |
+
"""
|
145 |
+
|
146 |
+
instruction_updated = False
|
147 |
+
|
148 |
+
# 去掉服务器地址
|
149 |
+
new_info.image_path = new_info.image_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
150 |
+
new_info.instruction = new_info.instruction.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
151 |
+
|
152 |
+
with Session(DB_ENGINE) as session:
|
153 |
+
|
154 |
+
if product_id > 0:
|
155 |
+
# 更新特定 ID
|
156 |
+
product_info = session.exec(
|
157 |
+
select(ProductInfo).where(and_(ProductInfo.product_id == product_id, ProductInfo.user_id == user_id))
|
158 |
+
).one()
|
159 |
+
|
160 |
+
if product_info is None:
|
161 |
+
logger.error("Edit by other ID !!!")
|
162 |
+
return False
|
163 |
+
|
164 |
+
if product_info.instruction != new_info.instruction:
|
165 |
+
# 判断说明书是否变化了
|
166 |
+
instruction_updated = True
|
167 |
+
|
168 |
+
# 更新对应的值
|
169 |
+
product_info.product_name = new_info.product_name
|
170 |
+
product_info.product_class = new_info.product_class
|
171 |
+
product_info.heighlights = new_info.heighlights
|
172 |
+
product_info.image_path = new_info.image_path
|
173 |
+
product_info.instruction = new_info.instruction
|
174 |
+
product_info.departure_place = new_info.departure_place
|
175 |
+
product_info.delivery_company = new_info.delivery_company
|
176 |
+
product_info.selling_price = new_info.selling_price
|
177 |
+
product_info.amount = new_info.amount
|
178 |
+
|
179 |
+
session.add(product_info)
|
180 |
+
else:
|
181 |
+
# 新增,直接添加即可
|
182 |
+
session.add(new_info)
|
183 |
+
instruction_updated = True
|
184 |
+
|
185 |
+
session.commit() # 提交
|
186 |
+
return instruction_updated
|
server/base/database/streamer_info_db.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : streamer_info_db.py
|
5 |
+
@Time : 2024/08/30
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 主播信息数据库操作
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
from loguru import logger
|
16 |
+
from sqlmodel import Session, and_, select
|
17 |
+
|
18 |
+
from ...web_configs import API_CONFIG
|
19 |
+
from ..models.streamer_info_model import StreamerInfo
|
20 |
+
from .init_db import DB_ENGINE
|
21 |
+
|
22 |
+
|
23 |
+
async def get_db_streamer_info(user_id: int, streamer_id: int | None = None) -> List[StreamerInfo] | None:
|
24 |
+
"""查询数据库中的主播信息
|
25 |
+
|
26 |
+
Args:
|
27 |
+
user_id (int): 用户 ID
|
28 |
+
streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
List[StreamerInfo] | StreamerInfo | None: 主播信息,如果获取全部则返回 list,如果获取单个则返回单个,如果查不到返回 None
|
32 |
+
"""
|
33 |
+
|
34 |
+
# 查询条件
|
35 |
+
query_condiction = and_(StreamerInfo.user_id == user_id, StreamerInfo.delete == False)
|
36 |
+
|
37 |
+
# 获取总数
|
38 |
+
with Session(DB_ENGINE) as session:
|
39 |
+
# 获得该用户所有主播的总数
|
40 |
+
# total_product_num = session.scalar(select(func.count(StreamerInfo.product_id)).where(query_condiction))
|
41 |
+
|
42 |
+
if streamer_id is not None:
|
43 |
+
# 查询条件更改为查找特定 ID
|
44 |
+
query_condiction = and_(
|
45 |
+
StreamerInfo.user_id == user_id, StreamerInfo.delete == False, StreamerInfo.streamer_id == streamer_id
|
46 |
+
)
|
47 |
+
|
48 |
+
# 查询主播商品,并根据 ID 进行排序
|
49 |
+
try:
|
50 |
+
streamer_list = session.exec(select(StreamerInfo).where(query_condiction).order_by(StreamerInfo.streamer_id)).all()
|
51 |
+
except Exception as e:
|
52 |
+
streamer_list = None
|
53 |
+
|
54 |
+
if streamer_list is None:
|
55 |
+
logger.warning("nothing to find in db...")
|
56 |
+
streamer_list = []
|
57 |
+
|
58 |
+
# 将路径换成服务器路径
|
59 |
+
for streamer in streamer_list:
|
60 |
+
streamer.avatar = API_CONFIG.REQUEST_FILES_URL + streamer.avatar
|
61 |
+
streamer.tts_reference_audio = API_CONFIG.REQUEST_FILES_URL + streamer.tts_reference_audio
|
62 |
+
streamer.poster_image = API_CONFIG.REQUEST_FILES_URL + streamer.poster_image
|
63 |
+
streamer.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + streamer.base_mp4_path
|
64 |
+
|
65 |
+
logger.info(streamer_list)
|
66 |
+
logger.info(f"len {len(streamer_list)}")
|
67 |
+
|
68 |
+
return streamer_list
|
69 |
+
|
70 |
+
|
71 |
+
async def delete_streamer_id(streamer_id: int, user_id: int) -> bool:
|
72 |
+
"""删除特定的主播 ID
|
73 |
+
|
74 |
+
Args:
|
75 |
+
streamer_id (int): 主播 ID
|
76 |
+
user_id (int): 用户 ID,用于防止其他用户恶意删除
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
bool: 是否删除成功
|
80 |
+
"""
|
81 |
+
|
82 |
+
delete_success = True
|
83 |
+
|
84 |
+
try:
|
85 |
+
# 获取总数
|
86 |
+
with Session(DB_ENGINE) as session:
|
87 |
+
# 查找特定 ID
|
88 |
+
streamer_info = session.exec(
|
89 |
+
select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
|
90 |
+
).one()
|
91 |
+
|
92 |
+
if streamer_info is None:
|
93 |
+
logger.error("Delete by other ID !!!")
|
94 |
+
return False
|
95 |
+
|
96 |
+
streamer_info.delete = True # 设置为删除
|
97 |
+
session.add(streamer_info)
|
98 |
+
session.commit() # 提交
|
99 |
+
except Exception:
|
100 |
+
delete_success = False
|
101 |
+
|
102 |
+
return delete_success
|
103 |
+
|
104 |
+
|
105 |
+
def create_or_update_db_streamer_by_id(streamer_id: int, new_info: StreamerInfo, user_id: int) -> int:
|
106 |
+
"""新增 or 编辑主播信息
|
107 |
+
|
108 |
+
Args:
|
109 |
+
product_id (int): 商品 ID
|
110 |
+
new_info (ProductInfo): 新的信息
|
111 |
+
user_id (int): 用户 ID,用于防止其他用户恶意修改
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
int: 主播 ID
|
115 |
+
"""
|
116 |
+
|
117 |
+
# 去掉服务器地址
|
118 |
+
new_info.avatar = new_info.avatar.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
119 |
+
new_info.tts_reference_audio = new_info.tts_reference_audio.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
120 |
+
new_info.poster_image = new_info.poster_image.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
121 |
+
new_info.base_mp4_path = new_info.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
122 |
+
|
123 |
+
with Session(DB_ENGINE) as session:
|
124 |
+
|
125 |
+
if streamer_id > 0:
|
126 |
+
# 更新特定 ID
|
127 |
+
streamer_info = session.exec(
|
128 |
+
select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
|
129 |
+
).one()
|
130 |
+
|
131 |
+
if streamer_info is None:
|
132 |
+
logger.error("Edit by other ID !!!")
|
133 |
+
return -1
|
134 |
+
else:
|
135 |
+
# 新增,直接添加即可
|
136 |
+
streamer_info = StreamerInfo(user_id=user_id)
|
137 |
+
|
138 |
+
# 更新对应的值
|
139 |
+
streamer_info.name = new_info.name
|
140 |
+
streamer_info.character = new_info.character
|
141 |
+
streamer_info.avatar = new_info.avatar
|
142 |
+
streamer_info.tts_weight_tag = new_info.tts_weight_tag
|
143 |
+
streamer_info.tts_reference_sentence = new_info.tts_reference_sentence
|
144 |
+
streamer_info.tts_reference_audio = new_info.tts_reference_audio
|
145 |
+
streamer_info.poster_image = new_info.poster_image
|
146 |
+
streamer_info.base_mp4_path = new_info.base_mp4_path
|
147 |
+
|
148 |
+
session.add(streamer_info)
|
149 |
+
session.commit() # 提交
|
150 |
+
session.refresh(streamer_info)
|
151 |
+
|
152 |
+
return int(streamer_info.streamer_id)
|
server/base/database/streamer_room_db.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : streamer_room_db.py
|
5 |
+
@Time : 2024/08/31
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 直播间信息数据库操作
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
from datetime import datetime
|
14 |
+
from typing import List
|
15 |
+
|
16 |
+
from loguru import logger
|
17 |
+
from sqlmodel import Session, and_, not_, select
|
18 |
+
|
19 |
+
from ...web_configs import API_CONFIG
|
20 |
+
from ..models.streamer_room_model import ChatMessageInfo, OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo
|
21 |
+
from .init_db import DB_ENGINE
|
22 |
+
|
23 |
+
|
24 |
+
async def get_db_streaming_room_info(user_id: int, room_id: int | None = None) -> List[StreamRoomInfo] | None:
|
25 |
+
"""查询数据库中的商品信息
|
26 |
+
|
27 |
+
Args:
|
28 |
+
user_id (int): 用户 ID
|
29 |
+
streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
List[StreamRoomInfo] | None: 直播间信息
|
33 |
+
"""
|
34 |
+
|
35 |
+
# 查询条件
|
36 |
+
query_condiction = and_(StreamRoomInfo.user_id == user_id, StreamRoomInfo.delete == False)
|
37 |
+
|
38 |
+
# 获取总数
|
39 |
+
with Session(DB_ENGINE) as session:
|
40 |
+
if room_id is not None:
|
41 |
+
# 查询条件更改为查找特定 ID
|
42 |
+
query_condiction = and_(
|
43 |
+
StreamRoomInfo.user_id == user_id, StreamRoomInfo.delete == False, StreamRoomInfo.room_id == room_id
|
44 |
+
)
|
45 |
+
|
46 |
+
# 查询获取直播间信息
|
47 |
+
stream_room_list = session.exec(select(StreamRoomInfo).where(query_condiction).order_by(StreamRoomInfo.room_id)).all()
|
48 |
+
|
49 |
+
if stream_room_list is None:
|
50 |
+
logger.warning("nothing to find in db...")
|
51 |
+
stream_room_list = []
|
52 |
+
|
53 |
+
# 将路径换成服务器路径
|
54 |
+
for stream_room in stream_room_list:
|
55 |
+
# 主播信息
|
56 |
+
stream_room.streamer_info.avatar = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.avatar
|
57 |
+
stream_room.streamer_info.tts_reference_audio = (
|
58 |
+
API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.tts_reference_audio
|
59 |
+
)
|
60 |
+
stream_room.streamer_info.poster_image = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.poster_image
|
61 |
+
stream_room.streamer_info.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.base_mp4_path
|
62 |
+
|
63 |
+
# 商品信息
|
64 |
+
for idx, product in enumerate(stream_room.product_list):
|
65 |
+
stream_room.product_list[idx].product_info.image_path = API_CONFIG.REQUEST_FILES_URL + product.product_info.image_path
|
66 |
+
stream_room.product_list[idx].product_info.instruction = (
|
67 |
+
API_CONFIG.REQUEST_FILES_URL + product.product_info.instruction
|
68 |
+
)
|
69 |
+
|
70 |
+
logger.info(stream_room_list)
|
71 |
+
logger.info(f"len {len(stream_room_list)}")
|
72 |
+
|
73 |
+
return stream_room_list
|
74 |
+
|
75 |
+
|
76 |
+
async def delete_room_id(room_id: int, user_id: int) -> bool:
|
77 |
+
"""删除特定的主播间 ID
|
78 |
+
|
79 |
+
Args:
|
80 |
+
room_id (int): 直播间 ID
|
81 |
+
user_id (int): 用户 ID,用于防止其他用户恶意删除
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
bool: 是否删除成功
|
85 |
+
"""
|
86 |
+
|
87 |
+
delete_success = True
|
88 |
+
|
89 |
+
try:
|
90 |
+
# 获取总数
|
91 |
+
with Session(DB_ENGINE) as session:
|
92 |
+
# 查找特定 ID
|
93 |
+
room_info = session.exec(
|
94 |
+
select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id))
|
95 |
+
).one()
|
96 |
+
|
97 |
+
if room_info is None:
|
98 |
+
logger.error("Delete by other ID !!!")
|
99 |
+
return False
|
100 |
+
|
101 |
+
room_info.delete = True # 设置为删除
|
102 |
+
session.add(room_info)
|
103 |
+
session.commit() # 提交
|
104 |
+
except Exception:
|
105 |
+
delete_success = False
|
106 |
+
|
107 |
+
return delete_success
|
108 |
+
|
109 |
+
|
110 |
+
def create_or_update_db_room_by_id(room_id: int, new_info: StreamRoomInfo, user_id: int):
|
111 |
+
"""新增 or 编辑直播间信息
|
112 |
+
|
113 |
+
Args:
|
114 |
+
room_id (int): 直播间 ID
|
115 |
+
new_info (StreamRoomInfo): 新的信息
|
116 |
+
user_id (int): 用户 ID,用于防止其他用户恶意修改
|
117 |
+
"""
|
118 |
+
|
119 |
+
with Session(DB_ENGINE) as session:
|
120 |
+
|
121 |
+
# 更新 status 内容
|
122 |
+
if new_info.status_id is not None:
|
123 |
+
status_info = session.exec(
|
124 |
+
select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == new_info.status_id)
|
125 |
+
).one()
|
126 |
+
else:
|
127 |
+
status_info = OnAirRoomStatusItem()
|
128 |
+
|
129 |
+
status_info.streaming_video_path = new_info.status.streaming_video_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
130 |
+
status_info.live_status = new_info.status.live_status
|
131 |
+
session.add(status_info)
|
132 |
+
session.commit()
|
133 |
+
session.refresh(status_info)
|
134 |
+
|
135 |
+
if room_id > 0:
|
136 |
+
|
137 |
+
# 更新主播间其他信息
|
138 |
+
room_info = session.exec(
|
139 |
+
select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id))
|
140 |
+
).one()
|
141 |
+
|
142 |
+
if room_info is None:
|
143 |
+
logger.error("Edit by other ID !!!")
|
144 |
+
return
|
145 |
+
|
146 |
+
else:
|
147 |
+
room_info = StreamRoomInfo(status_id=status_info.status_id, user_id=user_id)
|
148 |
+
|
149 |
+
# 更新直播间基础信息
|
150 |
+
room_info.name = new_info.name
|
151 |
+
room_info.prohibited_words_id = new_info.prohibited_words_id
|
152 |
+
room_info.room_poster = new_info.room_poster.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
153 |
+
room_info.background_image = new_info.background_image.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
154 |
+
room_info.streamer_id = new_info.streamer_id
|
155 |
+
|
156 |
+
session.add(room_info)
|
157 |
+
session.commit() # 提交
|
158 |
+
session.refresh(room_info)
|
159 |
+
|
160 |
+
# 更新商品信息
|
161 |
+
if len(new_info.product_list) > 0:
|
162 |
+
selected_id_list = [product.product_id for product in new_info.product_list]
|
163 |
+
for product in new_info.product_list:
|
164 |
+
if product.sales_info_id is not None:
|
165 |
+
# 更新
|
166 |
+
sales_info = session.exec(
|
167 |
+
select(SalesDocAndVideoInfo).where(
|
168 |
+
and_(
|
169 |
+
SalesDocAndVideoInfo.room_id == room_info.room_id,
|
170 |
+
SalesDocAndVideoInfo.product_id == product.product_id,
|
171 |
+
SalesDocAndVideoInfo.sales_info_id == product.sales_info_id,
|
172 |
+
)
|
173 |
+
)
|
174 |
+
).one()
|
175 |
+
else:
|
176 |
+
# 新建
|
177 |
+
sales_info = SalesDocAndVideoInfo()
|
178 |
+
|
179 |
+
sales_info.product_id = product.product_id
|
180 |
+
sales_info.sales_doc = product.sales_doc
|
181 |
+
sales_info.start_time = product.start_time
|
182 |
+
sales_info.start_video = product.start_video.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
183 |
+
sales_info.selected = True
|
184 |
+
sales_info.room_id = room_info.room_id
|
185 |
+
session.add(sales_info)
|
186 |
+
session.commit()
|
187 |
+
|
188 |
+
# 删除没选上的
|
189 |
+
if len(selected_id_list) > 0:
|
190 |
+
cancel_select_sales_info = session.exec(
|
191 |
+
select(SalesDocAndVideoInfo).where(
|
192 |
+
and_(
|
193 |
+
SalesDocAndVideoInfo.room_id == room_info.room_id,
|
194 |
+
not_(SalesDocAndVideoInfo.product_id.in_(selected_id_list)),
|
195 |
+
)
|
196 |
+
)
|
197 |
+
).all()
|
198 |
+
|
199 |
+
if cancel_select_sales_info is not None:
|
200 |
+
for cancel_select in cancel_select_sales_info:
|
201 |
+
session.delete(cancel_select)
|
202 |
+
session.commit()
|
203 |
+
|
204 |
+
return room_info.room_id
|
205 |
+
|
206 |
+
|
207 |
+
def init_conversation(db_session, sales_info_id: int, streamer_id: int, sales_doc: str):
|
208 |
+
"""新建直播间对话,一般触发于点击 开始直播 or 下一个商品
|
209 |
+
|
210 |
+
Args:
|
211 |
+
db_session (it): 数据库句柄
|
212 |
+
sales_info_id (int): 销售 ID
|
213 |
+
streamer_id (int): 主播 ID
|
214 |
+
sales_doc (str): 主播文案
|
215 |
+
"""
|
216 |
+
message_info = ChatMessageInfo(
|
217 |
+
sales_info_id=sales_info_id, streamer_id=streamer_id, role="streamer", message=sales_doc, send_time=datetime.now()
|
218 |
+
)
|
219 |
+
db_session.add(message_info)
|
220 |
+
|
221 |
+
|
222 |
+
def update_message_info(sales_info_id: int, role_id: int, role: str, message: str):
|
223 |
+
"""新增对话记录
|
224 |
+
|
225 |
+
Args:
|
226 |
+
sales_info_id (int): 销售 ID
|
227 |
+
role_id (int): 角色 ID
|
228 |
+
role (str): 角色类型:"streamer", "user"
|
229 |
+
message (str): 插入的消息
|
230 |
+
"""
|
231 |
+
|
232 |
+
assert role in ["streamer", "user"]
|
233 |
+
|
234 |
+
with Session(DB_ENGINE) as session:
|
235 |
+
|
236 |
+
role_key = "streamer_id" if role == "streamer" else "user_id"
|
237 |
+
role_id_info = {role_key: role_id}
|
238 |
+
|
239 |
+
message_info = ChatMessageInfo(
|
240 |
+
**role_id_info, sales_info_id=sales_info_id, role=role, message=message, send_time=datetime.now()
|
241 |
+
)
|
242 |
+
session.add(message_info)
|
243 |
+
session.commit()
|
244 |
+
|
245 |
+
|
246 |
+
def update_db_room_status(room_id: int, user_id: int, process_type: str):
|
247 |
+
"""编辑直播间状态信息
|
248 |
+
|
249 |
+
Args:
|
250 |
+
room_id (int): 直播间 ID
|
251 |
+
new_status_info (OnAirRoomStatusItem): 新的信息
|
252 |
+
user_id (int): 用户 ID,用于防止其他用户恶意修改
|
253 |
+
"""
|
254 |
+
|
255 |
+
with Session(DB_ENGINE) as session:
|
256 |
+
|
257 |
+
# 更新主播间其他信息
|
258 |
+
room_info = session.exec(
|
259 |
+
select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id))
|
260 |
+
).one()
|
261 |
+
|
262 |
+
if room_info is None:
|
263 |
+
logger.error("Edit by other ID !!!")
|
264 |
+
return
|
265 |
+
|
266 |
+
# 更新 status 内容
|
267 |
+
if room_info.status_id is not None:
|
268 |
+
status_info = session.exec(
|
269 |
+
select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == room_info.status_id)
|
270 |
+
).one()
|
271 |
+
|
272 |
+
if status_info is None:
|
273 |
+
logger.error("status_info is None !!!")
|
274 |
+
return
|
275 |
+
|
276 |
+
if process_type in ["online", "next-product"]:
|
277 |
+
|
278 |
+
if process_type == "online":
|
279 |
+
status_info.live_status = 1
|
280 |
+
status_info.start_time = datetime.now()
|
281 |
+
status_info.end_time = None
|
282 |
+
status_info.current_product_index = 0
|
283 |
+
|
284 |
+
elif process_type == "next-product":
|
285 |
+
status_info.current_product_index += 1
|
286 |
+
|
287 |
+
current_idx = status_info.current_product_index
|
288 |
+
|
289 |
+
status_info.streaming_video_path = room_info.product_list[current_idx].start_video
|
290 |
+
status_info.sales_info_id = room_info.product_list[current_idx].sales_info_id
|
291 |
+
|
292 |
+
sales_info = session.exec(
|
293 |
+
select(SalesDocAndVideoInfo).where(
|
294 |
+
SalesDocAndVideoInfo.sales_info_id == room_info.product_list[current_idx].sales_info_id
|
295 |
+
)
|
296 |
+
).one()
|
297 |
+
|
298 |
+
sales_info.start_time = datetime.now()
|
299 |
+
session.add(sales_info)
|
300 |
+
|
301 |
+
# 新建对话
|
302 |
+
init_conversation(
|
303 |
+
session, status_info.sales_info_id, room_info.streamer_id, room_info.product_list[current_idx].sales_doc
|
304 |
+
)
|
305 |
+
|
306 |
+
elif process_type == "offline":
|
307 |
+
status_info.streaming_video_path = ""
|
308 |
+
status_info.live_status = 2
|
309 |
+
status_info.end_time = datetime.now()
|
310 |
+
|
311 |
+
else:
|
312 |
+
raise NotImplemented("process type error !!")
|
313 |
+
|
314 |
+
session.add(status_info)
|
315 |
+
session.commit()
|
316 |
+
|
317 |
+
|
318 |
+
def get_message_list(sales_info_id: int) -> List[ChatMessageInfo]:
|
319 |
+
"""根据销售 ID 获取全部对话
|
320 |
+
|
321 |
+
Args:
|
322 |
+
sales_info_id (int): 销售 ID
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
List[ChatMessageInfo]: 对话列表
|
326 |
+
"""
|
327 |
+
with Session(DB_ENGINE) as session:
|
328 |
+
|
329 |
+
message_info = session.exec(
|
330 |
+
select(ChatMessageInfo)
|
331 |
+
.where(and_(ChatMessageInfo.sales_info_id == sales_info_id))
|
332 |
+
.order_by(ChatMessageInfo.message_id)
|
333 |
+
).all()
|
334 |
+
|
335 |
+
if message_info is None:
|
336 |
+
return []
|
337 |
+
|
338 |
+
formate_message_list = []
|
339 |
+
for message_ in message_info:
|
340 |
+
chat_item = {
|
341 |
+
"role": message_.role,
|
342 |
+
"avatar": message_.user_info.avatar if message_.role == "user" else message_.streamer_info.avatar,
|
343 |
+
"userName": message_.user_info.username if message_.role == "user" else message_.streamer_info.name,
|
344 |
+
"message": message_.message,
|
345 |
+
"datetime": message_.send_time,
|
346 |
+
}
|
347 |
+
|
348 |
+
chat_item["avatar"] = API_CONFIG.REQUEST_FILES_URL + chat_item["avatar"]
|
349 |
+
formate_message_list.append(chat_item)
|
350 |
+
|
351 |
+
return formate_message_list
|
352 |
+
|
353 |
+
|
354 |
+
def update_room_video_path(status_id: int, news_video_server_path: str):
|
355 |
+
"""数据库更新 status 主播视频
|
356 |
+
|
357 |
+
Args:
|
358 |
+
status_id (int): 主播间 status ID
|
359 |
+
news_video_server_path (str): 需要更新的主播视频 服务器地址
|
360 |
+
|
361 |
+
"""
|
362 |
+
with Session(DB_ENGINE) as session:
|
363 |
+
# 更新 status 内容
|
364 |
+
status_info = session.exec(select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == status_id)).one()
|
365 |
+
|
366 |
+
status_info.streaming_video_path = news_video_server_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
367 |
+
session.add(status_info)
|
368 |
+
session.commit()
|
369 |
+
|
370 |
+
|
371 |
+
async def get_live_room_info(user_id: int, room_id: int):
|
372 |
+
"""获取直播间的开播实时信息
|
373 |
+
|
374 |
+
Args:
|
375 |
+
user_id (int): 用户 ID
|
376 |
+
room_id (int): 直播间 ID
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
dict: 直播间实时信息
|
380 |
+
"""
|
381 |
+
|
382 |
+
# 根据直播间 ID 获取信息
|
383 |
+
streaming_room_info = await get_db_streaming_room_info(user_id, room_id)
|
384 |
+
streaming_room_info = streaming_room_info[0]
|
385 |
+
|
386 |
+
# 主播信息
|
387 |
+
streamer_info = streaming_room_info.streamer_info
|
388 |
+
|
389 |
+
# 商品索引
|
390 |
+
prodcut_index = streaming_room_info.status.current_product_index
|
391 |
+
|
392 |
+
# 是否为最后的商品
|
393 |
+
final_procut = True if len(streaming_room_info.product_list) - 1 == prodcut_index else False
|
394 |
+
|
395 |
+
# 对话信息
|
396 |
+
conversation_list = get_message_list(streaming_room_info.status.sales_info_id)
|
397 |
+
|
398 |
+
# 视频转换为服务器地址
|
399 |
+
video_path = API_CONFIG.REQUEST_FILES_URL + streaming_room_info.status.streaming_video_path
|
400 |
+
|
401 |
+
# 返回报文
|
402 |
+
res_data = {
|
403 |
+
"streamerInfo": streamer_info,
|
404 |
+
"conversation": conversation_list,
|
405 |
+
"currentProductInfo": streaming_room_info.product_list[prodcut_index].product_info,
|
406 |
+
"currentStreamerVideo": video_path,
|
407 |
+
"currentProductIndex": streaming_room_info.status.current_product_index,
|
408 |
+
"startTime": streaming_room_info.status.start_time,
|
409 |
+
"currentPoductStartTime": streaming_room_info.product_list[prodcut_index].start_time,
|
410 |
+
"finalProduct": final_procut,
|
411 |
+
}
|
412 |
+
|
413 |
+
logger.info(res_data)
|
414 |
+
|
415 |
+
return res_data
|
server/base/database/user_db.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : user_db.py
|
5 |
+
@Time : 2024/08/31
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 用户信息数据库操作
|
10 |
+
"""
|
11 |
+
|
12 |
+
from sqlmodel import Session, select
|
13 |
+
|
14 |
+
from ...web_configs import API_CONFIG
|
15 |
+
from ..models.user_model import UserBaseInfo, UserInfo
|
16 |
+
from .init_db import DB_ENGINE
|
17 |
+
|
18 |
+
|
19 |
+
def get_db_user_info(id: int = -1, username: str = "", all_info: bool = False) -> UserBaseInfo | UserInfo | None:
|
20 |
+
"""查询数据库获取用户信息
|
21 |
+
|
22 |
+
Args:
|
23 |
+
id (int): 用户 ID
|
24 |
+
username (str): 用户名
|
25 |
+
all_info (bool): 是否返回含有密码串的敏感信息
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
UserInfo | None: 用户信息,没有查到返回 None
|
29 |
+
"""
|
30 |
+
|
31 |
+
if username == "":
|
32 |
+
# 使用 ID 的方式进行查询
|
33 |
+
query = select(UserInfo).where(UserInfo.user_id == id)
|
34 |
+
else:
|
35 |
+
query = select(UserInfo).where(UserInfo.username == username)
|
36 |
+
|
37 |
+
# 查询数据库
|
38 |
+
with Session(DB_ENGINE) as session:
|
39 |
+
results = session.exec(query).first()
|
40 |
+
|
41 |
+
# 返回服务器地址
|
42 |
+
results.avatar = API_CONFIG.REQUEST_FILES_URL + results.avatar
|
43 |
+
|
44 |
+
if results is not None and all_info is False:
|
45 |
+
# 返回不含用户敏感信息的基本信息
|
46 |
+
results = UserBaseInfo(**results.model_dump())
|
47 |
+
|
48 |
+
return results
|
server/base/models/__init__.py
ADDED
File without changes
|
server/base/models/llm_model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : llm_model.py
|
5 |
+
@Time : 2024/09/01
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 大模型对话数据结构
|
10 |
+
"""
|
11 |
+
|
12 |
+
from pydantic import BaseModel
|
13 |
+
|
14 |
+
|
15 |
+
class GenProductItem(BaseModel):
|
16 |
+
gen_type: str
|
17 |
+
instruction: str
|
server/base/models/product_model.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : product_model.py
|
5 |
+
@Time : 2024/08/30
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 商品数据类型定义
|
10 |
+
"""
|
11 |
+
|
12 |
+
from datetime import datetime
|
13 |
+
from typing import List
|
14 |
+
from pydantic import BaseModel
|
15 |
+
from sqlmodel import Field, Relationship, SQLModel
|
16 |
+
|
17 |
+
|
18 |
+
# =======================================================
|
19 |
+
# 数据库模型
|
20 |
+
# =======================================================
|
21 |
+
|
22 |
+
|
23 |
+
class ProductInfo(SQLModel, table=True):
|
24 |
+
"""商品信息"""
|
25 |
+
|
26 |
+
__tablename__ = "product_info"
|
27 |
+
|
28 |
+
product_id: int | None = Field(default=None, primary_key=True, unique=True)
|
29 |
+
product_name: str = Field(index=True, unique=True)
|
30 |
+
product_class: str
|
31 |
+
heighlights: str
|
32 |
+
image_path: str
|
33 |
+
instruction: str
|
34 |
+
departure_place: str
|
35 |
+
delivery_company: str
|
36 |
+
selling_price: float
|
37 |
+
amount: int
|
38 |
+
upload_date: datetime = datetime.now()
|
39 |
+
delete: bool = False
|
40 |
+
|
41 |
+
user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
|
42 |
+
|
43 |
+
sales_info: list["SalesDocAndVideoInfo"] = Relationship(back_populates="product_info")
|
44 |
+
|
45 |
+
|
46 |
+
# =======================================================
|
47 |
+
# 基本模型
|
48 |
+
# =======================================================
|
49 |
+
|
50 |
+
|
51 |
+
class ProductPageItem(BaseModel):
|
52 |
+
product_list: List[ProductInfo] = []
|
53 |
+
currentPage: int = 0 # 当前页数
|
54 |
+
pageSize: int = 0 # 页面的组件数量
|
55 |
+
totalSize: int = 0 # 总大小
|
56 |
+
|
57 |
+
|
58 |
+
class ProductQueryItem(BaseModel):
|
59 |
+
instructionPath: str = "" # 商品说明书路径,用于获取说明书内容
|
server/base/models/streamer_info_model.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : streamer_info_model.py
|
5 |
+
@Time : 2024/08/30
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 主播信息数据结构
|
10 |
+
"""
|
11 |
+
|
12 |
+
from typing import Optional
|
13 |
+
from sqlmodel import Field, Relationship, SQLModel
|
14 |
+
|
15 |
+
|
16 |
+
# =======================================================
|
17 |
+
# 数据库模型
|
18 |
+
# =======================================================
|
19 |
+
class StreamerInfo(SQLModel, table=True):
|
20 |
+
__tablename__ = "streamer_info"
|
21 |
+
|
22 |
+
streamer_id: int | None = Field(default=None, primary_key=True, unique=True)
|
23 |
+
name: str = Field(index=True, unique=True)
|
24 |
+
character: str = ""
|
25 |
+
avatar: str = "" # 头像
|
26 |
+
|
27 |
+
tts_weight_tag: str = "" # 艾丝妲
|
28 |
+
tts_reference_sentence: str = ""
|
29 |
+
tts_reference_audio: str = ""
|
30 |
+
|
31 |
+
poster_image: str = ""
|
32 |
+
base_mp4_path: str = ""
|
33 |
+
|
34 |
+
delete: bool = False
|
35 |
+
|
36 |
+
user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
|
37 |
+
|
38 |
+
room_info: Optional["StreamRoomInfo"] | None = Relationship(
|
39 |
+
back_populates="streamer_info", sa_relationship_kwargs={"lazy": "selectin"}
|
40 |
+
)
|
server/base/models/streamer_room_model.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : streamer_room_model.py
|
5 |
+
@Time : 2024/08/31
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 直播间信息数据结构定义
|
10 |
+
"""
|
11 |
+
|
12 |
+
from datetime import datetime
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
from pydantic import BaseModel
|
16 |
+
from sqlmodel import Field, Relationship, SQLModel
|
17 |
+
|
18 |
+
from ..models.user_model import UserInfo
|
19 |
+
from ..models.product_model import ProductInfo
|
20 |
+
from ..models.streamer_info_model import StreamerInfo
|
21 |
+
|
22 |
+
|
23 |
+
class RoomChatItem(BaseModel):
|
24 |
+
roomId: int
|
25 |
+
message: str = ""
|
26 |
+
asrFileUrl: str = ""
|
27 |
+
|
28 |
+
|
29 |
+
# =======================================================
|
30 |
+
# 直播间数据库模型
|
31 |
+
# =======================================================
|
32 |
+
|
33 |
+
|
34 |
+
class SalesDocAndVideoInfo(SQLModel, table=True):
|
35 |
+
"""直播间 文案 和 数字人介绍视频数据结构"""
|
36 |
+
|
37 |
+
__tablename__ = "sales_doc_and_video_info"
|
38 |
+
|
39 |
+
sales_info_id: int | None = Field(default=None, primary_key=True, unique=True)
|
40 |
+
|
41 |
+
sales_doc: str = "" # 讲解文案
|
42 |
+
start_video: str = "" # 开播时候第一个讲解视频
|
43 |
+
start_time: datetime | None = None # 当前商品开始时间
|
44 |
+
selected: bool = True
|
45 |
+
|
46 |
+
product_id: int | None = Field(default=None, foreign_key="product_info.product_id")
|
47 |
+
product_info: ProductInfo | None = Relationship(back_populates="sales_info", sa_relationship_kwargs={"lazy": "selectin"})
|
48 |
+
|
49 |
+
room_id: int | None = Field(default=None, foreign_key="stream_room_info.room_id")
|
50 |
+
stream_room: Optional["StreamRoomInfo"] | None = Relationship(back_populates="product_list")
|
51 |
+
|
52 |
+
|
53 |
+
class OnAirRoomStatusItem(SQLModel, table=True):
|
54 |
+
"""直播间状态信息"""
|
55 |
+
|
56 |
+
__tablename__ = "on_air_room_status_item"
|
57 |
+
|
58 |
+
status_id: int | None = Field(default=None, primary_key=True, unique=True) # 直播间 ID
|
59 |
+
|
60 |
+
sales_info_id: int | None = Field(default=None, foreign_key="sales_doc_and_video_info.sales_info_id")
|
61 |
+
|
62 |
+
current_product_index: int = 0 # 目前讲解的商品列表索引
|
63 |
+
streaming_video_path: str = "" # 目前介绍使用的视频
|
64 |
+
|
65 |
+
live_status: int = 0 # 直播间状态 0 未开播,1 正在直播,2 下播了
|
66 |
+
start_time: datetime | None = None # 直播开始时间
|
67 |
+
end_time: datetime | None = None # 直播下播时间
|
68 |
+
|
69 |
+
room_info: Optional["StreamRoomInfo"] | None = Relationship(
|
70 |
+
back_populates="status", sa_relationship_kwargs={"lazy": "selectin"}
|
71 |
+
)
|
72 |
+
|
73 |
+
"""直播间信息,数据库保存时的数据结构"""
|
74 |
+
|
75 |
+
|
76 |
+
class StreamRoomInfo(SQLModel, table=True):
|
77 |
+
|
78 |
+
__tablename__ = "stream_room_info"
|
79 |
+
|
80 |
+
room_id: int | None = Field(default=None, primary_key=True, unique=True) # 直播间 ID
|
81 |
+
|
82 |
+
name: str = "" # 直播间名字
|
83 |
+
|
84 |
+
product_list: list[SalesDocAndVideoInfo] = Relationship(
|
85 |
+
back_populates="stream_room",
|
86 |
+
sa_relationship_kwargs={"lazy": "selectin", "order_by": "asc(SalesDocAndVideoInfo.product_id)"},
|
87 |
+
) # 商品列表,查找的时候加上 order_by 自动排序,desc -> 降序; asc -> 升序
|
88 |
+
|
89 |
+
prohibited_words_id: int = 0 # 违禁词表 ID
|
90 |
+
room_poster: str = "" # 海报图
|
91 |
+
background_image: str = "" # 主播背景图
|
92 |
+
|
93 |
+
delete: bool = False # 是否删除
|
94 |
+
|
95 |
+
status_id: int | None = Field(default=None, foreign_key="on_air_room_status_item.status_id")
|
96 |
+
status: OnAirRoomStatusItem | None = Relationship(back_populates="room_info", sa_relationship_kwargs={"lazy": "selectin"})
|
97 |
+
|
98 |
+
streamer_id: int | None = Field(default=None, foreign_key="streamer_info.streamer_id") # 主播 ID
|
99 |
+
streamer_info: StreamerInfo | None = Relationship(back_populates="room_info", sa_relationship_kwargs={"lazy": "selectin"})
|
100 |
+
|
101 |
+
user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
|
102 |
+
|
103 |
+
|
104 |
+
# =======================================================
|
105 |
+
# 直播对话数据库模型
|
106 |
+
# =======================================================
|
107 |
+
|
108 |
+
|
109 |
+
class ChatMessageInfo(SQLModel, table=True):
|
110 |
+
"""直播页面对话数据结构"""
|
111 |
+
|
112 |
+
__tablename__ = "chat_message_info"
|
113 |
+
|
114 |
+
message_id: int | None = Field(default=None, primary_key=True, unique=True) # 消息 ID
|
115 |
+
|
116 |
+
sales_info_id: int | None = Field(default=None, foreign_key="sales_doc_and_video_info.sales_info_id")
|
117 |
+
sales_info: SalesDocAndVideoInfo | None = Relationship(sa_relationship_kwargs={"lazy": "selectin"})
|
118 |
+
|
119 |
+
user_id: int | None = Field(default=None, foreign_key="user_info.user_id")
|
120 |
+
user_info: UserInfo | None = Relationship(sa_relationship_kwargs={"lazy": "selectin"})
|
121 |
+
|
122 |
+
streamer_id: int | None = Field(default=None, foreign_key="streamer_info.streamer_id")
|
123 |
+
streamer_info: StreamerInfo | None = Relationship(sa_relationship_kwargs={"lazy": "selectin"})
|
124 |
+
|
125 |
+
role: str
|
126 |
+
message: str
|
127 |
+
send_time: datetime | None = None
|
server/base/models/user_model.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : user_model.py
|
5 |
+
@Time : 2024/08/31
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 用户信息数据结构
|
10 |
+
"""
|
11 |
+
|
12 |
+
from datetime import datetime
|
13 |
+
from ipaddress import IPv4Address
|
14 |
+
from pydantic import BaseModel
|
15 |
+
from sqlmodel import Field, SQLModel
|
16 |
+
|
17 |
+
|
18 |
+
# =======================================================
|
19 |
+
# 基本模型
|
20 |
+
# =======================================================
|
21 |
+
class TokenItem(BaseModel):
|
22 |
+
access_token: str
|
23 |
+
token_type: str
|
24 |
+
|
25 |
+
|
26 |
+
class UserBaseInfo(BaseModel):
|
27 |
+
user_id: int | None = Field(default=None, primary_key=True, unique=True)
|
28 |
+
username: str = Field(index=True, unique=True)
|
29 |
+
email: str | None = None
|
30 |
+
avatar: str | None = None
|
31 |
+
create_time: datetime = datetime.now()
|
32 |
+
|
33 |
+
|
34 |
+
# =======================================================
|
35 |
+
# 数据库模型
|
36 |
+
# =======================================================
|
37 |
+
class UserInfo(UserBaseInfo, SQLModel, table=True):
|
38 |
+
|
39 |
+
__tablename__ = "user_info"
|
40 |
+
|
41 |
+
hashed_password: str
|
42 |
+
ip_address: IPv4Address | None = None
|
43 |
+
delete: bool = False
|
server/base/modules/__init__.py
ADDED
File without changes
|
server/base/modules/agent/__init__.py
ADDED
File without changes
|
server/base/modules/agent/agent_worker.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from lagent.actions import ActionExecutor
|
5 |
+
from lagent.agents.internlm2_agent import Internlm2Protocol
|
6 |
+
from lagent.schema import ActionReturn, AgentReturn
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
from .delivery_time_query import DeliveryTimeQueryAction
|
10 |
+
|
11 |
+
|
12 |
+
def init_handlers(departure_place, delivery_company_name):
|
13 |
+
META_CN = "当开启工具以及代码时,根据需求选择合适的工具进行调用"
|
14 |
+
|
15 |
+
INTERPRETER_CN = (
|
16 |
+
"你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。"
|
17 |
+
"当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。"
|
18 |
+
"这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),"
|
19 |
+
"复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),"
|
20 |
+
"文本处理和分析(比如文本解析和自然语言处理),"
|
21 |
+
"机器学习和数据科学(用于展示模型训练和数据可视化),"
|
22 |
+
"以及文件操作和数据导入(处理CSV、JSON等格式的文件)。"
|
23 |
+
)
|
24 |
+
|
25 |
+
PLUGIN_CN = (
|
26 |
+
"你可以使用如下工具:"
|
27 |
+
"\n{prompt}\n"
|
28 |
+
"如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! "
|
29 |
+
"同时注意你可以使用的工具,不要随意捏造!"
|
30 |
+
)
|
31 |
+
|
32 |
+
protocol_handler = Internlm2Protocol(
|
33 |
+
meta_prompt=META_CN,
|
34 |
+
interpreter_prompt=INTERPRETER_CN,
|
35 |
+
plugin_prompt=PLUGIN_CN,
|
36 |
+
tool=dict(
|
37 |
+
begin="{start_token}{name}\n",
|
38 |
+
start_token="<|action_start|>",
|
39 |
+
name_map=dict(plugin="<|plugin|>", interpreter="<|interpreter|>"),
|
40 |
+
belong="assistant",
|
41 |
+
end="<|action_end|>\n",
|
42 |
+
),
|
43 |
+
)
|
44 |
+
action_list = [
|
45 |
+
DeliveryTimeQueryAction(
|
46 |
+
departure_place=departure_place,
|
47 |
+
delivery_company_name=delivery_company_name,
|
48 |
+
),
|
49 |
+
]
|
50 |
+
plugin_map = {action.name: action for action in action_list}
|
51 |
+
plugin_name = [action.name for action in action_list]
|
52 |
+
plugin_action = [plugin_map[name] for name in plugin_name]
|
53 |
+
action_executor = ActionExecutor(actions=plugin_action)
|
54 |
+
|
55 |
+
return action_executor, protocol_handler
|
56 |
+
|
57 |
+
|
58 |
+
def get_agent_result(llm_model_handler, prompt_input, departure_place, delivery_company_name):
|
59 |
+
|
60 |
+
action_executor, protocol_handler = init_handlers(departure_place, delivery_company_name)
|
61 |
+
|
62 |
+
# 第一次将 prompt 生成 agent 形式的 prompt
|
63 |
+
# [{'role': 'system', 'content': '当开启工具以及代码时,根据需求选择合适的工具进行调用'},
|
64 |
+
# {'role': 'system', 'content': '你可以使用如下工具:\n[\n {\n "name": "ArxivSearch.get_arxiv_article_information",\n
|
65 |
+
# "description": "This is the subfunction for tool \'ArxivSearch\', you can use this tool. The description of this function is: \\nRun Arxiv search and get the article meta information.",\n
|
66 |
+
# "parameters": [\n {\n "name": "query",\n "type": "STRING",\n "description": "the content of search query"\n }\n ],\n "required": [\n "query"\n ],\n "return_data": [\n {\n "name": "content",\n "description": "a list of 3 arxiv search papers",\n "type": "STRING"\n }\n ],\n "parameter_description": "If you call this tool, you must pass arguments in the JSON format {key: value}, where the key is the parameter name."\n }\n]\n
|
67 |
+
# 如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! 同时注意你可以使用的工具,不要随意捏造!',
|
68 |
+
# 'name': 'plugin'},
|
69 |
+
# {'role': 'user', 'content': '帮我搜索 InternLM2 Technical Report'}]
|
70 |
+
|
71 |
+
# 推理得出:'<|action_start|><|plugin|>\n{"name": "ArxivSearch.get_arxiv_article_information", "parameters": {"query": "InternLM2 Technical Report"}}<|action_end|>\n'
|
72 |
+
# 放入 assient 中
|
73 |
+
|
74 |
+
# 使用 ArxivSearch.get_arxiv_article_information 方法得出结果,放到 envrinment 里面,结果是:
|
75 |
+
# [{'role': 'system', 'content': '当开启工具以及代码时,根据需求选择合适的工具进行调用'},
|
76 |
+
# {'role': 'system', 'content': '你可以使用如下工具:\n[\n {\n "name": "ArxivSearch.get_arxiv_article_information",\n "description": "This is the subfunction for tool \'ArxivSearch\', you can use this tool. The description of this function is: \\nRun Arxiv search and get the article meta information.",\n "parameters": [\n {\n "name": "query",\n "type": "STRING",\n "description": "the content of search query"\n }\n ],\n "required": [\n "query"\n ],\n "return_data": [\n {\n "name": "content",\n "description": "a list of 3 arxiv search papers",\n "type": "STRING"\n }\n ],\n "parameter_description": "If you call this tool, you must pass arguments in the JSON format {key: value}, where the key is the parameter name."\n }\n]\n如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! 同时注意你可以使用的工具,不要随意捏造!', 'name': 'plugin'},
|
77 |
+
# {'role': 'user', 'content': '帮我搜索 InternLM2 Technical Report'},
|
78 |
+
# {'role': 'assistant', 'content': '<|action_start|><|plugin|>\n{"name": "ArxivSearch.get_arxiv_article_information", "parameters": {"query": "InternLM2 Technical Report"}}<|action_end|>\n'},
|
79 |
+
# {'role': 'environment', 'content': '{"content": "Published: 2024-03-26\\nTitle: InternLM2 Technical Report\\nAuthors: Zheng Cai, Maosong Cao, Haojiong Chen, Kai Chen, Keyu Chen, Xin Chen, Xun Chen, Zehui Chen, Zhi Chen, Pei Chu, Xiaoyi Dong, Haodong Duan, Qi Fan, Zhaoye Fei, Yang Gao, Jiaye Ge, Chenya Gu, Yuzhe Gu, Tao Gui, Aijia Guo, Qipeng Guo, Conghui He, Yingfan Hu, Ting Huang, Tao Jiang, Penglong Jiao, Zhenjiang Jin, Zhikai Lei, Jiaxing Li, Jingwen Li, Linyang Li, Shuaibin Li, Wei Li, Yining Li, Hongwei Liu, Jiangning Liu, Jiawei Hong, Kaiwen Liu, Kuikun Liu, Xiaoran Liu, Chengqi Lv, Haijun Lv, Kai Lv, Li Ma, Runyuan Ma, Zerun Ma, Wenchang Ning, Linke Ouyang, Jiantao Qiu, Yuan Qu, Fukai Shang, Yunfan Shao, Demin Song, Zifan Song, Zhihao Sui, Peng Sun, Yu Sun, Huanze Tang, Bin Wang, Guoteng Wang, Jiaqi Wang, Jiayu Wang, Rui Wang, Yudong Wang, Ziyi Wang, Xingjian Wei, Qizhen Weng, Fan Wu, Yingtong Xiong, Chao Xu, Ruiliang Xu, Hang Yan, Yirong Yan, Xiaogui Yang, Haochen Ye, Huaiyuan Ying, Jia Yu, Jing Yu, Yuhang Zang, Chuyu Zhang, Li Zhang, Pan Zhang, Peng Zhang, Ruijie Zhang, Shuo Zhang, Songyang Zhang, Wenjian Zhang, Wenwei Zhang, Xingcheng Zhang, Xinyue Zhang, Hui Zhao, Qian Zhao, Xiaomeng Zhao, Fengzhe Zhou, Zaida Zhou, Jingming Zhuo, Yicheng Zou, Xipeng Qiu, Yu Qiao, Dahua Lin\\nSummary: The evolution of Large Language Models (LLMs) like ChatGPT and GPT-4 has\\nsparked discussions on the advent of Artificial General Intelligence (AGI).\\nHowever, replicating such advancements in open-source models has been\\nchallenging. This paper introduces InternLM2, an open-source LLM that\\noutperforms its predecessors in comprehensive evaluations across 6 dimensions\\nand 30 benchmarks, long-context modeling, and open-ended subjective evaluations\\nthrough innovative pre-training and optimization techniques. The pre-training\\nprocess of InternLM2 is meticulously detailed, highlighting the preparation of\\ndiverse data types including text, code, and long-context data. InternLM2\\nefficiently captures long-term dependencies, initially trained on 4k tokens\\nbefore advancing to 32k tokens in pre-training and fine-tuning stages,\\nexhibiting remarkable performance on the 200k ``Needle-in-a-Haystack\\" test.\\nInternLM2 is further aligned using Supervised Fine-Tuning (SFT) and a novel\\nConditional Online Reinforcement Learning from Human Feedback (COOL RLHF)\\nstrategy that addresses conflicting human preferences and reward hacking. By\\nreleasing InternLM2 models in different training stages and model sizes, we\\nprovide the community with insights into the model\'s evolution.\\n\\nPublished: 2017-07-27\\nTitle: Cumulative Reports of the SoNDe Project July 2017\\nAuthors: Sebastian Jaksch, Ralf Engels, Günter Kemmerling, Codin Gheorghe, Philip Pahlsson, Sylvain Désert, Frederic Ott\\nSummary: This are the cumulated reports of the SoNDe detector Project as of July 2017.\\nThe contained reports are: - Report on the 1x1 module technical demonstrator -\\nReport on used materials - Report on radiation hardness of components - Report\\non potential additional applications - Report on the 2x2 module technical\\ndemonstrator - Report on test results of the 2x2 technical demonstrator\\n\\nPublished: 2023-03-12\\nTitle: Banach Couples. I. Elementary Theory\\nAuthors: Jaak Peetre, Per Nilsson\\nSummary: This note is an (exact) copy of the report of Jaak Peetre, \\"Banach Couples.\\nI. Elementary Theory\\". Published as Technical Report, Lund (1971). Some more\\nrecent general references have been added and some references updated though"}', 'name': 'plugin'}]
|
80 |
+
|
81 |
+
# 然后调用大模型推理总结,stream 输出
|
82 |
+
|
83 |
+
# 判断 name is None ,跳出循环
|
84 |
+
inner_history = [{"role": "user", "content": prompt_input}]
|
85 |
+
interpreter_executor = None
|
86 |
+
max_turn = 7
|
87 |
+
for _ in range(max_turn):
|
88 |
+
|
89 |
+
prompt = protocol_handler.format( # 生成 agent prompt
|
90 |
+
inner_step=inner_history,
|
91 |
+
plugin_executor=action_executor,
|
92 |
+
interpreter_executor=interpreter_executor,
|
93 |
+
)
|
94 |
+
cur_response = ""
|
95 |
+
|
96 |
+
agent_return = AgentReturn()
|
97 |
+
|
98 |
+
# 根据 tokenizer_config.json 中查找到特殊的 token :
|
99 |
+
# token_map = {
|
100 |
+
# 92538: "<|plugin|>",
|
101 |
+
# 92539: "<|interpreter|>",
|
102 |
+
# 92540: "<|action_end|>",
|
103 |
+
# 92541: "<|action_start|>",
|
104 |
+
# }
|
105 |
+
|
106 |
+
# 将 prompt 给模型
|
107 |
+
# [{'role': 'system', 'content': '当开启工具以及代码时,根据需求选择合适的工具进行调用'},
|
108 |
+
# {'role': 'system', 'content': '你可以使用如下工具:\n[\n {\n "name": "ArxivSearch.get_arxiv_article_information",\n
|
109 |
+
# "description": "This is the subfunction for tool \'ArxivSearch\', you can use this tool. The description of this function is: \\nRun Arxiv search and get the article meta information.",\n
|
110 |
+
# "parameters": [\n {\n "name": "query",\n "type": "STRING",\n "description": "the content of search query"\n }\n ],\n "required": [\n "query"\n ],\n "return_data": [\n {\n "name": "content",\n "description": "a list of 3 arxiv search papers",\n "type": "STRING"\n }\n ],\n "parameter_description": "If you call this tool, you must pass arguments in the JSON format {key: value}, where the key is the parameter name."\n }\n]\n
|
111 |
+
# 如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! 同时注意你可以使用的工具,不要随意捏造!',
|
112 |
+
# 'name': 'plugin'},
|
113 |
+
# {'role': 'user', 'content': '帮我搜索 InternLM2 Technical Report'}]
|
114 |
+
|
115 |
+
# skip_special_tokens = False 输出 <|action_start|> <|plugin|> 等特殊字符
|
116 |
+
# for item in model_pipe.stream_infer(prompt, gen_config=prepare_generation_config(skip_special_tokens=False)):
|
117 |
+
|
118 |
+
logger.info(f"agent input for llm: {prompt}")
|
119 |
+
|
120 |
+
model_name = llm_model_handler.available_models[0]
|
121 |
+
for item in llm_model_handler.chat_completions_v1(
|
122 |
+
model=model_name, messages=prompt, stream=True, skip_special_tokens=False
|
123 |
+
):
|
124 |
+
# 从 prompt 推理结果例子:
|
125 |
+
# '<|action_start|><|plugin|>\n{"name": "ArxivSearch.get_arxiv_article_information", "parameters": {"query": "InternLM2 Technical Report"}}<|action_end|>\n'
|
126 |
+
|
127 |
+
logger.info(f"agent return = {item}")
|
128 |
+
if "content" not in item["choices"][0]["delta"]:
|
129 |
+
continue
|
130 |
+
current_res = item["choices"][0]["delta"]["content"]
|
131 |
+
|
132 |
+
if "~" in current_res:
|
133 |
+
current_res = item.text.replace("~", "。").replace("。。", "。")
|
134 |
+
|
135 |
+
cur_response += current_res
|
136 |
+
|
137 |
+
logger.info(f"agent return = {item}")
|
138 |
+
|
139 |
+
name, language, action = protocol_handler.parse(
|
140 |
+
message=cur_response,
|
141 |
+
plugin_executor=action_executor,
|
142 |
+
interpreter_executor=interpreter_executor,
|
143 |
+
)
|
144 |
+
if name: # "plugin"
|
145 |
+
if name == "plugin":
|
146 |
+
if action_executor:
|
147 |
+
executor = action_executor
|
148 |
+
else:
|
149 |
+
logging.info(msg="No plugin is instantiated!")
|
150 |
+
continue
|
151 |
+
try:
|
152 |
+
action = json.loads(action)
|
153 |
+
except Exception as e:
|
154 |
+
logging.info(msg=f"Invaild action {e}")
|
155 |
+
continue
|
156 |
+
elif name == "interpreter":
|
157 |
+
if interpreter_executor:
|
158 |
+
executor = interpreter_executor
|
159 |
+
else:
|
160 |
+
logging.info(msg="No interpreter is instantiated!")
|
161 |
+
continue
|
162 |
+
# agent_return.state = agent_state
|
163 |
+
agent_return.response = action
|
164 |
+
|
165 |
+
print(f"Agent response: {cur_response}")
|
166 |
+
|
167 |
+
if name:
|
168 |
+
print(f"Agent action: {action}")
|
169 |
+
action_return: ActionReturn = executor(action["name"], action["parameters"])
|
170 |
+
# action_return.thought = language
|
171 |
+
# agent_return.actions.append(action_return)
|
172 |
+
try:
|
173 |
+
return_str = action_return.result[0]["content"]
|
174 |
+
return return_str
|
175 |
+
except Exception as e:
|
176 |
+
return ""
|
177 |
+
|
178 |
+
# agent_return_list.append(dict(role='assistant', name=name, content=action))
|
179 |
+
# agent_return_list.append(protocol_handler.format_response(action_return, name=name))
|
180 |
+
|
181 |
+
# inner_history.append(dict(role="language", content=language))
|
182 |
+
|
183 |
+
if not name:
|
184 |
+
agent_return.response = language
|
185 |
+
break
|
186 |
+
# elif action_return.type == executor.finish_action.name:
|
187 |
+
# try:
|
188 |
+
# response = action_return.args["text"]["response"]
|
189 |
+
# except Exception:
|
190 |
+
# logging.info(msg="Unable to parse FinishAction.")
|
191 |
+
# response = ""
|
192 |
+
# agent_return.response = response
|
193 |
+
# break
|
194 |
+
# else:
|
195 |
+
# inner_history.append(dict(role="tool", content=action, name=name))
|
196 |
+
# inner_history.append(protocol_handler.format_response(action_return, name=name))
|
197 |
+
# # agent_state += 1
|
198 |
+
# # agent_return.state = agent_state
|
199 |
+
# # yield agent_return
|
200 |
+
return ""
|
server/base/modules/agent/delivery_time_query.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from datetime import datetime
|
3 |
+
import hashlib
|
4 |
+
import json
|
5 |
+
from typing import Optional, Type
|
6 |
+
|
7 |
+
import jionlp as jio
|
8 |
+
import requests
|
9 |
+
from lagent.actions.base_action import BaseAction, tool_api
|
10 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
11 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
12 |
+
|
13 |
+
from ....web_configs import WEB_CONFIGS
|
14 |
+
|
15 |
+
|
16 |
+
class DeliveryTimeQueryAction(BaseAction):
|
17 |
+
"""快递时效查询插件,用于根据用户提出的收货地址查询到达期限"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
departure_place: str,
|
22 |
+
delivery_company_name: str,
|
23 |
+
description: Optional[dict] = None,
|
24 |
+
parser: Type[BaseParser] = JsonParser,
|
25 |
+
enable: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__(description, parser, enable)
|
28 |
+
self.departure_place = departure_place # 发货地
|
29 |
+
|
30 |
+
# 天气查询
|
31 |
+
self.weather_query_handler = WeatherQuery(departure_place, WEB_CONFIGS.AGENT_WEATHER_API_KEY)
|
32 |
+
self.delivery_time_handler = DeliveryTimeQuery(delivery_company_name, WEB_CONFIGS.AGENT_DELIVERY_TIME_API_KEY)
|
33 |
+
|
34 |
+
@tool_api
|
35 |
+
def run(self, query: str) -> ActionReturn:
|
36 |
+
"""一个到货时间查询API。可以根据城市名查询到货时间信息。
|
37 |
+
|
38 |
+
Args:
|
39 |
+
query (:class:`str`): 需要查询的城市名。
|
40 |
+
"""
|
41 |
+
|
42 |
+
# 获取文本中收货地,发货地后台设置
|
43 |
+
# 防止 LLM 将城市识别错误,进行兜底
|
44 |
+
city_info = jio.parse_location(query, town_village=True)
|
45 |
+
city_name = city_info["city"]
|
46 |
+
|
47 |
+
# 获取收货地代号 -> 天气
|
48 |
+
destination_weather = self.weather_query_handler(city_name)
|
49 |
+
|
50 |
+
# 获取发货地代号 -> 天气
|
51 |
+
departure_weather = self.weather_query_handler(self.departure_place)
|
52 |
+
|
53 |
+
# 获取到达时间
|
54 |
+
delivery_time = self.delivery_time_handler(self.departure_place, city_name)
|
55 |
+
|
56 |
+
final_str = (
|
57 |
+
f"今天日期:{datetime.now().strftime('%m月%d日')}\n"
|
58 |
+
f"收货地天气:{destination_weather.result[0]['content']}\n"
|
59 |
+
f"发货地天气:{departure_weather.result[0]['content']}\n"
|
60 |
+
f"物流信息:{delivery_time.result[0]['content']}\n"
|
61 |
+
"回答突出“预计送达时间”和“收货地天气”,如果收货地或者发货地遇到暴雨暴雪等极端天气,须告知用户快递到达时间会有所增加。"
|
62 |
+
)
|
63 |
+
|
64 |
+
tool_return = ActionReturn(type=self.name)
|
65 |
+
tool_return.result = [dict(type="text", content=final_str)]
|
66 |
+
return tool_return
|
67 |
+
|
68 |
+
|
69 |
+
class WeatherQuery:
|
70 |
+
"""快递时效查询插件,用于根据用户提出的收货地址查询到达期限"""
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
departure_place: str,
|
75 |
+
api_key: Optional[str] = None,
|
76 |
+
) -> None:
|
77 |
+
self.departure_place = departure_place # 发货地
|
78 |
+
|
79 |
+
# 天气查询
|
80 |
+
# api_key = os.environ.get("WEATHER_API_KEY", key)
|
81 |
+
if api_key is None:
|
82 |
+
raise ValueError("Please set Weather API key either in the environment as WEATHER_API_KEY")
|
83 |
+
self.api_key = api_key
|
84 |
+
self.location_query_url = "https://geoapi.qweather.com/v2/city/lookup"
|
85 |
+
self.weather_query_url = "https://devapi.qweather.com/v7/weather/now"
|
86 |
+
|
87 |
+
def parse_results(self, city_name: str, results: dict) -> str:
|
88 |
+
"""解析 API 返回的信息
|
89 |
+
|
90 |
+
Args:
|
91 |
+
results (dict): JSON 格式的 API 报文。
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
str: 解析后的结果。
|
95 |
+
"""
|
96 |
+
now = results["now"]
|
97 |
+
data = (
|
98 |
+
# f'数据观测时间: {now["obsTime"]};'
|
99 |
+
f"城市名: {city_name};"
|
100 |
+
f'温度: {now["temp"]}°C;'
|
101 |
+
f'体感温度: {now["feelsLike"]}°C;'
|
102 |
+
f'天气: {now["text"]};'
|
103 |
+
# f'风向: {now["windDir"]},角度为 {now["wind360"]}°;'
|
104 |
+
f'风力等级: {now["windScale"]},风速为 {now["windSpeed"]} km/h;'
|
105 |
+
f'相对湿度: {now["humidity"]};'
|
106 |
+
f'当前小时累计降水量: {now["precip"]} mm;'
|
107 |
+
# f'大气压强: {now["pressure"]} 百帕;'
|
108 |
+
f'能见度: {now["vis"]} km。'
|
109 |
+
)
|
110 |
+
return data
|
111 |
+
|
112 |
+
def __call__(self, query):
|
113 |
+
tool_return = ActionReturn()
|
114 |
+
status_code, response = self.search_weather_with_city(query)
|
115 |
+
if status_code == -1:
|
116 |
+
tool_return.errmsg = response
|
117 |
+
tool_return.state = ActionStatusCode.HTTP_ERROR
|
118 |
+
elif status_code == 200:
|
119 |
+
parsed_res = self.parse_results(query, response)
|
120 |
+
tool_return.result = [dict(type="text", content=str(parsed_res))]
|
121 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
122 |
+
else:
|
123 |
+
tool_return.errmsg = str(status_code)
|
124 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
125 |
+
return tool_return
|
126 |
+
|
127 |
+
def search_weather_with_city(self, query: str):
|
128 |
+
"""根据城市名获取城市代号,然后进行天气���询
|
129 |
+
|
130 |
+
Args:
|
131 |
+
query (str): 城市名
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
int: 天气接口调用状态码
|
135 |
+
dict: 天气接口返回信息
|
136 |
+
"""
|
137 |
+
|
138 |
+
# 获取城市代号
|
139 |
+
try:
|
140 |
+
city_code_response = requests.get(self.location_query_url, params={"key": self.api_key, "location": query})
|
141 |
+
except Exception as e:
|
142 |
+
return -1, str(e)
|
143 |
+
|
144 |
+
if city_code_response.status_code != 200:
|
145 |
+
return city_code_response.status_code, city_code_response.json()
|
146 |
+
city_code_response = city_code_response.json()
|
147 |
+
if len(city_code_response["location"]) == 0:
|
148 |
+
return -1, "未查询到城市"
|
149 |
+
city_code = city_code_response["location"][0]["id"]
|
150 |
+
|
151 |
+
# 获取天气
|
152 |
+
try:
|
153 |
+
weather_response = requests.get(self.weather_query_url, params={"key": self.api_key, "location": city_code})
|
154 |
+
except Exception as e:
|
155 |
+
return -1, str(e)
|
156 |
+
return weather_response.status_code, weather_response.json()
|
157 |
+
|
158 |
+
|
159 |
+
class DeliveryTimeQuery:
|
160 |
+
def __init__(
|
161 |
+
self,
|
162 |
+
delivery_company_name: Optional[str] = "中通",
|
163 |
+
api_key: Optional[str] = None,
|
164 |
+
) -> None:
|
165 |
+
|
166 |
+
# 快递时效查询
|
167 |
+
# api_key = os.environ.get("DELIVERY_TIME_API_KEY", key)
|
168 |
+
if api_key is None or "," not in api_key:
|
169 |
+
raise ValueError(
|
170 |
+
'Please set Delivery time API key either in the environment as DELIVERY_TIME_API_KEY="${e_business_id},${api_key}"'
|
171 |
+
)
|
172 |
+
self.e_business_id = api_key.split(",")[0]
|
173 |
+
self.api_key = api_key.split(",")[1]
|
174 |
+
self.api_url = "http://api.kdniao.com/api/dist" # 快递鸟
|
175 |
+
self.china_location = jio.china_location_loader()
|
176 |
+
# 快递鸟对应的
|
177 |
+
DELIVERY_COMPANY_MAP = {
|
178 |
+
"德邦": "DBL",
|
179 |
+
"邮政": "EMS",
|
180 |
+
"京东": "JD",
|
181 |
+
"极兔速递": "JTSD",
|
182 |
+
"顺丰": "SF",
|
183 |
+
"申通": "STO",
|
184 |
+
"韵达": "YD",
|
185 |
+
"圆通": "YTO",
|
186 |
+
"中通": "ZTO",
|
187 |
+
}
|
188 |
+
self.delivery_company_name = delivery_company_name
|
189 |
+
self.delivery_company_id = DELIVERY_COMPANY_MAP[delivery_company_name]
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def data_md5(n):
|
193 |
+
# md5加密
|
194 |
+
md5 = hashlib.md5()
|
195 |
+
md5.update(str(n).encode("utf-8"))
|
196 |
+
return md5.hexdigest()
|
197 |
+
|
198 |
+
def get_data_sign(self, n):
|
199 |
+
# 签名
|
200 |
+
md5Data = self.data_md5(json.dumps(n) + self.api_key)
|
201 |
+
res = str(base64.b64encode(md5Data.encode("utf-8")), "utf-8")
|
202 |
+
return res
|
203 |
+
|
204 |
+
def get_city_detail(self, name):
|
205 |
+
# 如果是城市名,使用第一个区名
|
206 |
+
city_info = jio.parse_location(name, town_village=True)
|
207 |
+
# china_location = jio.china_location_loader()
|
208 |
+
|
209 |
+
county_name = ""
|
210 |
+
for i in self.china_location[city_info["province"]][city_info["city"]].keys():
|
211 |
+
if "区" == i[-1]:
|
212 |
+
county_name = i
|
213 |
+
break
|
214 |
+
|
215 |
+
return {
|
216 |
+
"province": city_info["province"],
|
217 |
+
"city": city_info["city"],
|
218 |
+
"county": county_name,
|
219 |
+
}
|
220 |
+
|
221 |
+
def get_params(self, send_city, receive_city):
|
222 |
+
|
223 |
+
# 根据市查出省份和区名称
|
224 |
+
send_city_info = self.get_city_detail(send_city)
|
225 |
+
receive_city_info = self.get_city_detail(receive_city)
|
226 |
+
|
227 |
+
# 预计送达时间接口文档;https://www.yuque.com/kdnjishuzhichi/dfcrg1/ynkmts0e5owsnpvu
|
228 |
+
# 请求接口指令
|
229 |
+
RequestType = "6004"
|
230 |
+
# 组装应用级参数
|
231 |
+
RequestData = {
|
232 |
+
"ShipperCode": self.delivery_company_id,
|
233 |
+
"ReceiveArea": receive_city_info["county"],
|
234 |
+
"ReceiveCity": receive_city_info["city"],
|
235 |
+
"ReceiveProvince": receive_city_info["province"],
|
236 |
+
"SendArea": send_city_info["county"],
|
237 |
+
"SendCity": send_city_info["city"],
|
238 |
+
"SendProvince": send_city_info["province"],
|
239 |
+
}
|
240 |
+
# 组装系统级参数
|
241 |
+
data = {
|
242 |
+
"RequestData": json.dumps(RequestData),
|
243 |
+
"RequestType": RequestType,
|
244 |
+
"EBusinessID": self.e_business_id,
|
245 |
+
"DataSign": self.get_data_sign(RequestData),
|
246 |
+
"DataType": 2,
|
247 |
+
}
|
248 |
+
return data
|
249 |
+
|
250 |
+
def parse_results(self, response):
|
251 |
+
|
252 |
+
# 返回例子:
|
253 |
+
# {
|
254 |
+
# "EBusinessID" : "1000000",
|
255 |
+
# "Data" : {
|
256 |
+
# "DeliveryTime" : "06月15日下午可达",
|
257 |
+
# "SendAddress" : null,
|
258 |
+
# "ReceiveArea" : "芙蓉区",
|
259 |
+
# "SendProvince" : "广东省",
|
260 |
+
# "ReceiveProvince" : "湖南省",
|
261 |
+
# "ShipperCode" : "DBL",
|
262 |
+
# "Hour" : "52h",
|
263 |
+
# "SendArea" : "白云区",
|
264 |
+
# "ReceiveAddress" : null,
|
265 |
+
# "SendCity" : "广州市",
|
266 |
+
# "ReceiveCity" : "长沙市"
|
267 |
+
# },
|
268 |
+
# "ResultCode" : "100",
|
269 |
+
# "Success" : true
|
270 |
+
# }
|
271 |
+
|
272 |
+
response = response["Data"]
|
273 |
+
data = (
|
274 |
+
f'发货地点: {response["SendProvince"]} {response["SendCity"]};'
|
275 |
+
f'收货地点: {response["ReceiveProvince"]} {response["ReceiveCity"]};'
|
276 |
+
f'预计送达时间: {response["DeliveryTime"]};'
|
277 |
+
f"快递公司: {self.delivery_company_name};"
|
278 |
+
f'预计时效: {response["Hour"]}。'
|
279 |
+
)
|
280 |
+
return data
|
281 |
+
|
282 |
+
def __call__(self, send_city, receive_city):
|
283 |
+
tool_return = ActionReturn()
|
284 |
+
try:
|
285 |
+
res = requests.post(self.api_url, self.get_params(send_city, receive_city))
|
286 |
+
status_code = res.status_code
|
287 |
+
response = res.json()
|
288 |
+
except Exception as e:
|
289 |
+
tool_return.errmsg = str(e)
|
290 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
291 |
+
return tool_return
|
292 |
+
|
293 |
+
if status_code == 200:
|
294 |
+
parsed_res = self.parse_results(response)
|
295 |
+
tool_return.result = [dict(type="text", content=str(parsed_res))]
|
296 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
297 |
+
else:
|
298 |
+
tool_return.errmsg = str(status_code)
|
299 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
300 |
+
return tool_return
|
server/base/modules/rag/__init__.py
ADDED
File without changes
|
server/base/modules/rag/feature_store.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""extract feature and search with user query."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import shutil
|
8 |
+
from multiprocessing import Pool
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, List, Optional
|
11 |
+
|
12 |
+
import yaml
|
13 |
+
|
14 |
+
# 解决 Warning:huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks…
|
15 |
+
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
16 |
+
|
17 |
+
from BCEmbedding.tools.langchain import BCERerank
|
18 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
19 |
+
from langchain.text_splitter import (MarkdownHeaderTextSplitter,
|
20 |
+
MarkdownTextSplitter,
|
21 |
+
RecursiveCharacterTextSplitter)
|
22 |
+
from langchain.vectorstores.faiss import FAISS as Vectorstore
|
23 |
+
from langchain_core.documents import Document
|
24 |
+
from loguru import logger
|
25 |
+
from torch.cuda import empty_cache
|
26 |
+
|
27 |
+
from .file_operation import FileName, FileOperation
|
28 |
+
from .retriever import CacheRetriever, Retriever
|
29 |
+
|
30 |
+
|
31 |
+
def read_and_save(file: FileName):
|
32 |
+
if os.path.exists(file.copypath):
|
33 |
+
# already exists, return
|
34 |
+
logger.info("already exist, skip load")
|
35 |
+
return
|
36 |
+
file_opr = FileOperation()
|
37 |
+
logger.info("reading {}, would save to {}".format(file.origin, file.copypath))
|
38 |
+
content, error = file_opr.read(file.origin)
|
39 |
+
if error is not None:
|
40 |
+
logger.error("{} load error: {}".format(file.origin, str(error)))
|
41 |
+
return
|
42 |
+
|
43 |
+
if content is None or len(content) < 1:
|
44 |
+
logger.warning("{} empty, skip save".format(file.origin))
|
45 |
+
return
|
46 |
+
|
47 |
+
with open(file.copypath, "w") as f:
|
48 |
+
f.write(content)
|
49 |
+
|
50 |
+
|
51 |
+
def _split_text_with_regex_from_end(text: str, separator: str, keep_separator: bool) -> List[str]:
|
52 |
+
# Now that we have the separator, split the text
|
53 |
+
if separator:
|
54 |
+
if keep_separator:
|
55 |
+
# The parentheses in the pattern keep the delimiters in the result.
|
56 |
+
_splits = re.split(f"({separator})", text)
|
57 |
+
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
|
58 |
+
if len(_splits) % 2 == 1:
|
59 |
+
splits += _splits[-1:]
|
60 |
+
# splits = [_splits[0]] + splits
|
61 |
+
else:
|
62 |
+
splits = re.split(separator, text)
|
63 |
+
else:
|
64 |
+
splits = list(text)
|
65 |
+
return [s for s in splits if s != ""]
|
66 |
+
|
67 |
+
|
68 |
+
# copy from https://github.com/chatchat-space/Langchain-Chatchat/blob/master/text_splitter/chinese_recursive_text_splitter.py
|
69 |
+
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
separators: Optional[List[str]] = None,
|
74 |
+
keep_separator: bool = True,
|
75 |
+
is_separator_regex: bool = True,
|
76 |
+
**kwargs: Any,
|
77 |
+
) -> None:
|
78 |
+
"""Create a new TextSplitter."""
|
79 |
+
super().__init__(keep_separator=keep_separator, **kwargs)
|
80 |
+
self._separators = separators or ["\n\n", "\n", "。|!|?", "\.\s|\!\s|\?\s", ";|;\s", ",|,\s"]
|
81 |
+
self._is_separator_regex = is_separator_regex
|
82 |
+
|
83 |
+
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
84 |
+
"""Split incoming text and return chunks."""
|
85 |
+
final_chunks = []
|
86 |
+
# Get appropriate separator to use
|
87 |
+
separator = separators[-1]
|
88 |
+
new_separators = []
|
89 |
+
for i, _s in enumerate(separators):
|
90 |
+
_separator = _s if self._is_separator_regex else re.escape(_s)
|
91 |
+
if _s == "":
|
92 |
+
separator = _s
|
93 |
+
break
|
94 |
+
if re.search(_separator, text):
|
95 |
+
separator = _s
|
96 |
+
new_separators = separators[i + 1 :]
|
97 |
+
break
|
98 |
+
|
99 |
+
_separator = separator if self._is_separator_regex else re.escape(separator)
|
100 |
+
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
|
101 |
+
|
102 |
+
# Now go merging things, recursively splitting longer texts.
|
103 |
+
_good_splits = []
|
104 |
+
_separator = "" if self._keep_separator else separator
|
105 |
+
for s in splits:
|
106 |
+
if self._length_function(s) < self._chunk_size:
|
107 |
+
_good_splits.append(s)
|
108 |
+
else:
|
109 |
+
if _good_splits:
|
110 |
+
merged_text = self._merge_splits(_good_splits, _separator)
|
111 |
+
final_chunks.extend(merged_text)
|
112 |
+
_good_splits = []
|
113 |
+
if not new_separators:
|
114 |
+
final_chunks.append(s)
|
115 |
+
else:
|
116 |
+
other_info = self._split_text(s, new_separators)
|
117 |
+
final_chunks.extend(other_info)
|
118 |
+
if _good_splits:
|
119 |
+
merged_text = self._merge_splits(_good_splits, _separator)
|
120 |
+
final_chunks.extend(merged_text)
|
121 |
+
return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip() != ""]
|
122 |
+
|
123 |
+
|
124 |
+
class FeatureStore:
|
125 |
+
"""Tokenize and extract features from the project's documents, for use in
|
126 |
+
the reject pipeline and response pipeline."""
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self, embeddings: HuggingFaceEmbeddings, reranker: BCERerank, config_path: str = "config.ini", language: str = "zh"
|
130 |
+
) -> None:
|
131 |
+
"""Init with model device type and config."""
|
132 |
+
self.config_path = config_path
|
133 |
+
self.reject_throttle = -1
|
134 |
+
self.language = language
|
135 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
136 |
+
config = yaml.safe_load(f)["feature_store"]
|
137 |
+
self.reject_throttle = config["reject_throttle"]
|
138 |
+
|
139 |
+
logger.warning(
|
140 |
+
"!!! If your feature generated by `text2vec-large-chinese` before 20240208, please rerun `python3 -m huixiangdou.service.feature_store`" # noqa E501
|
141 |
+
)
|
142 |
+
|
143 |
+
logger.debug("loading text2vec model..")
|
144 |
+
self.embeddings = embeddings
|
145 |
+
self.reranker = reranker
|
146 |
+
self.compression_retriever = None
|
147 |
+
self.rejecter = None
|
148 |
+
self.retriever = None
|
149 |
+
self.md_splitter = MarkdownTextSplitter(chunk_size=768, chunk_overlap=32)
|
150 |
+
|
151 |
+
if language == "zh":
|
152 |
+
self.text_splitter = ChineseRecursiveTextSplitter(
|
153 |
+
keep_separator=True, is_separator_regex=True, chunk_size=768, chunk_overlap=32
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=768, chunk_overlap=32)
|
157 |
+
|
158 |
+
self.head_splitter = MarkdownHeaderTextSplitter(
|
159 |
+
headers_to_split_on=[
|
160 |
+
("#", "Header 1"),
|
161 |
+
("##", "Header 2"),
|
162 |
+
("###", "Header 3"),
|
163 |
+
]
|
164 |
+
)
|
165 |
+
|
166 |
+
def split_md(self, text: str, source: None):
|
167 |
+
"""Split the markdown document in a nested way, first extracting the
|
168 |
+
header.
|
169 |
+
|
170 |
+
If the extraction result exceeds 1024, split it again according to
|
171 |
+
length.
|
172 |
+
"""
|
173 |
+
docs = self.head_splitter.split_text(text)
|
174 |
+
|
175 |
+
final = []
|
176 |
+
for doc in docs:
|
177 |
+
header = ""
|
178 |
+
if len(doc.metadata) > 0:
|
179 |
+
if "Header 1" in doc.metadata:
|
180 |
+
header += doc.metadata["Header 1"]
|
181 |
+
if "Header 2" in doc.metadata:
|
182 |
+
header += " "
|
183 |
+
header += doc.metadata["Header 2"]
|
184 |
+
if "Header 3" in doc.metadata:
|
185 |
+
header += " "
|
186 |
+
header += doc.metadata["Header 3"]
|
187 |
+
|
188 |
+
if len(doc.page_content) >= 1024:
|
189 |
+
subdocs = self.md_splitter.create_documents([doc.page_content])
|
190 |
+
for subdoc in subdocs:
|
191 |
+
if len(subdoc.page_content) >= 10:
|
192 |
+
final.append("{} {}".format(header, subdoc.page_content.lower()))
|
193 |
+
elif len(doc.page_content) >= 10:
|
194 |
+
final.append("{} {}".format(header, doc.page_content.lower())) # noqa E501
|
195 |
+
|
196 |
+
for item in final:
|
197 |
+
if len(item) >= 1024:
|
198 |
+
logger.debug("source {} split length {}".format(source, len(item)))
|
199 |
+
return final
|
200 |
+
|
201 |
+
def clean_md(self, text: str):
|
202 |
+
"""Remove parts of the markdown document that do not contain the key
|
203 |
+
question words, such as code blocks, URL links, etc."""
|
204 |
+
# remove ref
|
205 |
+
pattern_ref = r"\[(.*?)\]\(.*?\)"
|
206 |
+
new_text = re.sub(pattern_ref, r"\1", text)
|
207 |
+
|
208 |
+
# remove code block
|
209 |
+
pattern_code = r"```.*?```"
|
210 |
+
new_text = re.sub(pattern_code, "", new_text, flags=re.DOTALL)
|
211 |
+
|
212 |
+
# remove underline
|
213 |
+
new_text = re.sub("_{5,}", "", new_text)
|
214 |
+
|
215 |
+
# remove table
|
216 |
+
# new_text = re.sub('\|.*?\|\n\| *\:.*\: *\|.*\n(\|.*\|.*\n)*', '', new_text, flags=re.DOTALL) # noqa E501
|
217 |
+
|
218 |
+
# use lower
|
219 |
+
new_text = new_text.lower()
|
220 |
+
return new_text
|
221 |
+
|
222 |
+
def get_md_documents(self, file: FileName):
|
223 |
+
documents = []
|
224 |
+
length = 0
|
225 |
+
text = ""
|
226 |
+
with open(file.copypath, encoding="utf8") as f:
|
227 |
+
text = f.read()
|
228 |
+
text = file.prefix + "\n" + self.clean_md(text)
|
229 |
+
if len(text) <= 1:
|
230 |
+
return [], length
|
231 |
+
|
232 |
+
chunks = self.split_md(text=text, source=os.path.abspath(file.copypath))
|
233 |
+
for chunk in chunks:
|
234 |
+
new_doc = Document(page_content=chunk, metadata={"source": file.basename, "read": file.copypath})
|
235 |
+
length += len(chunk)
|
236 |
+
documents.append(new_doc)
|
237 |
+
return documents, length
|
238 |
+
|
239 |
+
def get_text_documents(self, text: str, file: FileName):
|
240 |
+
if len(text) <= 1:
|
241 |
+
return []
|
242 |
+
chunks = self.text_splitter.create_documents([text])
|
243 |
+
documents = []
|
244 |
+
for chunk in chunks:
|
245 |
+
# `source` is for return references
|
246 |
+
# `read` is for LLM response
|
247 |
+
chunk.metadata = {"source": file.basename, "read": file.copypath}
|
248 |
+
documents.append(chunk)
|
249 |
+
return documents
|
250 |
+
|
251 |
+
def ingress_response(self, files: list, work_dir: str):
|
252 |
+
"""Extract the features required for the response pipeline based on the
|
253 |
+
document."""
|
254 |
+
feature_dir = os.path.join(work_dir, "db_response")
|
255 |
+
if not os.path.exists(feature_dir):
|
256 |
+
os.makedirs(feature_dir)
|
257 |
+
|
258 |
+
# logger.info('glob {} in dir {}'.format(files, file_dir))
|
259 |
+
file_opr = FileOperation()
|
260 |
+
documents = []
|
261 |
+
|
262 |
+
for i, file in enumerate(files):
|
263 |
+
logger.debug("{}/{}.. {}".format(i + 1, len(files), file.basename))
|
264 |
+
if not file.state:
|
265 |
+
continue
|
266 |
+
|
267 |
+
if file._type == "md":
|
268 |
+
md_documents, md_length = self.get_md_documents(file)
|
269 |
+
documents += md_documents
|
270 |
+
logger.info("{} content length {}".format(file._type, md_length))
|
271 |
+
file.reason = str(md_length)
|
272 |
+
|
273 |
+
else:
|
274 |
+
# now read pdf/word/excel/ppt text
|
275 |
+
text, error = file_opr.read(file.copypath)
|
276 |
+
if error is not None:
|
277 |
+
file.state = False
|
278 |
+
file.reason = str(error)
|
279 |
+
continue
|
280 |
+
file.reason = str(len(text))
|
281 |
+
logger.info("{} content length {}".format(file._type, len(text)))
|
282 |
+
text = file.prefix + text
|
283 |
+
documents += self.get_text_documents(text, file)
|
284 |
+
|
285 |
+
if len(documents) < 1:
|
286 |
+
return
|
287 |
+
vs = Vectorstore.from_documents(documents, self.embeddings)
|
288 |
+
vs.save_local(feature_dir)
|
289 |
+
|
290 |
+
def ingress_reject(self, files: list, work_dir: str):
|
291 |
+
"""Extract the features required for the reject pipeline based on
|
292 |
+
documents."""
|
293 |
+
feature_dir = os.path.join(work_dir, "db_reject")
|
294 |
+
if not os.path.exists(feature_dir):
|
295 |
+
os.makedirs(feature_dir)
|
296 |
+
|
297 |
+
documents = []
|
298 |
+
file_opr = FileOperation()
|
299 |
+
|
300 |
+
logger.debug("ingress reject..")
|
301 |
+
for i, file in enumerate(files):
|
302 |
+
if not file.state:
|
303 |
+
continue
|
304 |
+
|
305 |
+
if file._type == "md":
|
306 |
+
# reject base not clean md
|
307 |
+
text = file.basename + "\n"
|
308 |
+
with open(file.copypath, encoding="utf8") as f:
|
309 |
+
text += f.read()
|
310 |
+
if len(text) <= 1:
|
311 |
+
continue
|
312 |
+
|
313 |
+
chunks = self.split_md(text=text, source=os.path.abspath(file.copypath))
|
314 |
+
for chunk in chunks:
|
315 |
+
new_doc = Document(page_content=chunk, metadata={"source": file.basename, "read": file.copypath})
|
316 |
+
documents.append(new_doc)
|
317 |
+
|
318 |
+
else:
|
319 |
+
text, error = file_opr.read(file.copypath)
|
320 |
+
if error is not None:
|
321 |
+
continue
|
322 |
+
text = file.basename + text
|
323 |
+
documents += self.get_text_documents(text, file)
|
324 |
+
|
325 |
+
if len(documents) < 1:
|
326 |
+
return
|
327 |
+
vs = Vectorstore.from_documents(documents, self.embeddings)
|
328 |
+
vs.save_local(feature_dir)
|
329 |
+
|
330 |
+
def preprocess(self, files: list, work_dir: str):
|
331 |
+
"""Preprocesses files in a given directory. Copies each file to
|
332 |
+
'preprocess' with new name formed by joining all subdirectories with
|
333 |
+
'_'.
|
334 |
+
|
335 |
+
Args:
|
336 |
+
files (list): original file list.
|
337 |
+
work_dir (str): Working directory where preprocessed files will be stored. # noqa E501
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
str: Path to the directory where preprocessed markdown files are saved.
|
341 |
+
|
342 |
+
Raises:
|
343 |
+
Exception: Raise an exception if no markdown files are found in the provided repository directory. # noqa E501
|
344 |
+
"""
|
345 |
+
preproc_dir = os.path.join(work_dir, "preprocess")
|
346 |
+
if not os.path.exists(preproc_dir):
|
347 |
+
os.makedirs(preproc_dir)
|
348 |
+
|
349 |
+
pool = Pool(processes=16)
|
350 |
+
file_opr = FileOperation()
|
351 |
+
for idx, file in enumerate(files):
|
352 |
+
if not os.path.exists(file.origin):
|
353 |
+
file.state = False
|
354 |
+
file.reason = "skip not exist"
|
355 |
+
continue
|
356 |
+
|
357 |
+
if file._type == "image":
|
358 |
+
file.state = False
|
359 |
+
file.reason = "skip image"
|
360 |
+
|
361 |
+
elif file._type in ["pdf", "word", "excel", "ppt", "html"]:
|
362 |
+
# read pdf/word/excel file and save to text format
|
363 |
+
md5 = file_opr.md5(file.origin)
|
364 |
+
file.copypath = os.path.join(preproc_dir, "{}.text".format(md5))
|
365 |
+
pool.apply_async(read_and_save, (file,))
|
366 |
+
|
367 |
+
elif file._type in ["md", "text"]:
|
368 |
+
# rename text files to new dir
|
369 |
+
md5 = file_opr.md5(file.origin)
|
370 |
+
file.copypath = os.path.join(preproc_dir, file.origin.replace("/", "_")[-84:])
|
371 |
+
try:
|
372 |
+
shutil.copy(file.origin, file.copypath)
|
373 |
+
file.state = True
|
374 |
+
file.reason = "preprocessed"
|
375 |
+
except Exception as e:
|
376 |
+
file.state = False
|
377 |
+
file.reason = str(e)
|
378 |
+
|
379 |
+
else:
|
380 |
+
file.state = False
|
381 |
+
file.reason = "skip unknown format"
|
382 |
+
pool.close()
|
383 |
+
logger.debug("waiting for preprocess read finish..")
|
384 |
+
pool.join()
|
385 |
+
|
386 |
+
# check process result
|
387 |
+
for file in files:
|
388 |
+
if file._type in ["pdf", "word", "excel"]:
|
389 |
+
if os.path.exists(file.copypath):
|
390 |
+
file.state = True
|
391 |
+
file.reason = "preprocessed"
|
392 |
+
else:
|
393 |
+
file.state = False
|
394 |
+
file.reason = "read error"
|
395 |
+
|
396 |
+
def initialize(self, files: list, work_dir: str):
|
397 |
+
"""Initializes response and reject feature store.
|
398 |
+
|
399 |
+
Only needs to be called once. Also calculates the optimal threshold
|
400 |
+
based on provided good and bad question examples, and saves it in the
|
401 |
+
configuration file.
|
402 |
+
"""
|
403 |
+
logger.info("initialize response and reject feature store, you only need call this once.") # noqa E501
|
404 |
+
self.preprocess(files=files, work_dir=work_dir)
|
405 |
+
self.ingress_response(files=files, work_dir=work_dir)
|
406 |
+
self.ingress_reject(files=files, work_dir=work_dir)
|
407 |
+
|
408 |
+
|
409 |
+
def parse_args():
|
410 |
+
"""Parse command-line arguments."""
|
411 |
+
parser = argparse.ArgumentParser(description="Feature store for processing directories.")
|
412 |
+
parser.add_argument("--work_dir", type=str, default="work_dir", help="Working directory.")
|
413 |
+
parser.add_argument("--repo_dir", type=str, default="repodir", help="Root directory where the repositories are located.")
|
414 |
+
parser.add_argument(
|
415 |
+
"--config_path", default="config.ini", help="Feature store configuration path. Default value is config.ini"
|
416 |
+
)
|
417 |
+
parser.add_argument(
|
418 |
+
"--good_questions",
|
419 |
+
default="resource/good_questions.json",
|
420 |
+
help="Positive examples in the dataset. Default value is resource/good_questions.json", # noqa E251 # noqa E501
|
421 |
+
)
|
422 |
+
parser.add_argument(
|
423 |
+
"--bad_questions",
|
424 |
+
default="resource/bad_questions.json",
|
425 |
+
help="Negative examples json path. Default value is resource/bad_questions.json", # noqa E251 # noqa E501
|
426 |
+
)
|
427 |
+
parser.add_argument("--sample", help="Input an json file, save reject and search output.")
|
428 |
+
args = parser.parse_args()
|
429 |
+
return args
|
430 |
+
|
431 |
+
|
432 |
+
def test_reject(retriever: Retriever, sample: str = None):
|
433 |
+
"""Simple test reject pipeline."""
|
434 |
+
if sample is None:
|
435 |
+
real_questions = [
|
436 |
+
"SAM 10个T 的训练集,怎么比比较公平呢~?速度上还有缺陷吧?",
|
437 |
+
"想问下,如果只是推理的话,amp的fp16是不会省显存么,我看parameter仍然是float32,开和不开推理的显存占用都是一样的。能不能直接用把数据和model都 .half() 代替呢,相比之下amp好在哪里", # noqa E501
|
438 |
+
"mmdeploy支持ncnn vulkan部署么,我只找到了ncnn cpu 版本",
|
439 |
+
"大佬们,如果我想在高空检测安全帽,我应该用 mmdetection 还是 mmrotate",
|
440 |
+
"请问 ncnn 全称是什么",
|
441 |
+
"有啥中文的 text to speech 模型吗?",
|
442 |
+
"今天中午吃什么?",
|
443 |
+
"huixiangdou 是什么?",
|
444 |
+
"mmpose 如何安装?",
|
445 |
+
"使用科研仪器需要注意什么?",
|
446 |
+
]
|
447 |
+
else:
|
448 |
+
with open(sample) as f:
|
449 |
+
real_questions = json.load(f)
|
450 |
+
|
451 |
+
for example in real_questions:
|
452 |
+
reject, _ = retriever.is_reject(example)
|
453 |
+
|
454 |
+
if reject:
|
455 |
+
logger.error(f"reject query: {example}")
|
456 |
+
else:
|
457 |
+
logger.warning(f"process query: {example}")
|
458 |
+
|
459 |
+
if sample is not None:
|
460 |
+
if reject:
|
461 |
+
with open("workdir/negative.txt", "a+") as f:
|
462 |
+
f.write(example)
|
463 |
+
f.write("\n")
|
464 |
+
else:
|
465 |
+
with open("workdir/positive.txt", "a+") as f:
|
466 |
+
f.write(example)
|
467 |
+
f.write("\n")
|
468 |
+
|
469 |
+
empty_cache()
|
470 |
+
|
471 |
+
|
472 |
+
def test_query(retriever: Retriever, sample: str = None):
|
473 |
+
"""Simple test response pipeline."""
|
474 |
+
if sample is not None:
|
475 |
+
with open(sample) as f:
|
476 |
+
real_questions = json.load(f)
|
477 |
+
logger.add("logs/feature_store_query.log", rotation="4MB")
|
478 |
+
else:
|
479 |
+
real_questions = ["mmpose installation", "how to use std::vector ?"]
|
480 |
+
|
481 |
+
for example in real_questions:
|
482 |
+
example = example[0:400]
|
483 |
+
print(retriever.query(example))
|
484 |
+
empty_cache()
|
485 |
+
|
486 |
+
empty_cache()
|
487 |
+
|
488 |
+
|
489 |
+
def fix_system_error():
|
490 |
+
"""
|
491 |
+
Fix `No module named 'faiss.swigfaiss_avx2`
|
492 |
+
"""
|
493 |
+
import os
|
494 |
+
from pathlib import Path
|
495 |
+
|
496 |
+
import faiss
|
497 |
+
|
498 |
+
if Path(faiss.__file__).parent.joinpath("swigfaiss_avx2.py").exists():
|
499 |
+
return
|
500 |
+
|
501 |
+
print("Fixing faiss error...")
|
502 |
+
os.system(f"cd {Path(faiss.__file__).parent} && ln -s swigfaiss.py swigfaiss_avx2.py")
|
503 |
+
|
504 |
+
|
505 |
+
def gen_vector_db(config_path, source_dir, work_dir, test_mode=False, update_reject=False):
|
506 |
+
|
507 |
+
# 解决 faiss 导入问题
|
508 |
+
fix_system_error()
|
509 |
+
|
510 |
+
# 必须是绝对路径,否则加载会有问题
|
511 |
+
work_dir = str(Path(work_dir).absolute())
|
512 |
+
|
513 |
+
cache = CacheRetriever(config_path=config_path)
|
514 |
+
|
515 |
+
# 生成向量数据库
|
516 |
+
fs_init = FeatureStore(embeddings=cache.embeddings, reranker=cache.reranker, config_path=config_path)
|
517 |
+
|
518 |
+
# walk all files in repo dir
|
519 |
+
file_opr = FileOperation()
|
520 |
+
files = file_opr.scan_dir(repo_dir=source_dir)
|
521 |
+
fs_init.initialize(files=files, work_dir=work_dir)
|
522 |
+
file_opr.summarize(files)
|
523 |
+
del fs_init
|
524 |
+
|
525 |
+
# update reject throttle
|
526 |
+
if update_reject:
|
527 |
+
retriever = cache.get(config_path=config_path, work_dir=work_dir)
|
528 |
+
with open(os.path.join("resource", "good_questions.json")) as f:
|
529 |
+
good_questions = json.load(f)
|
530 |
+
with open(os.path.join("resource", "bad_questions.json")) as f:
|
531 |
+
bad_questions = json.load(f)
|
532 |
+
retriever.update_throttle(config_path=config_path, good_questions=good_questions, bad_questions=bad_questions)
|
533 |
+
|
534 |
+
cache.pop("default")
|
535 |
+
|
536 |
+
if test_mode:
|
537 |
+
# test
|
538 |
+
retriever = cache.get(config_path=config_path, work_dir=work_dir)
|
539 |
+
# test_reject(retriever, args.sample)
|
540 |
+
test_query(retriever, args.sample)
|
541 |
+
|
542 |
+
|
543 |
+
if __name__ == "__main__":
|
544 |
+
args = parse_args()
|
545 |
+
gen_vector_db(args.config_path, args.repo_dir, args.work_dir, test_mode=True)
|
server/base/modules/rag/file_operation.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
|
9 |
+
class FileName:
|
10 |
+
"""Record file original name, state and copied filepath with text
|
11 |
+
format."""
|
12 |
+
|
13 |
+
def __init__(self, root: str, filename: str, _type: str):
|
14 |
+
self.root = root
|
15 |
+
self.prefix = filename.replace("/", "_")
|
16 |
+
self.basename = os.path.basename(filename)
|
17 |
+
self.origin = os.path.join(root, filename)
|
18 |
+
self.copypath = ""
|
19 |
+
self._type = _type
|
20 |
+
self.state = True
|
21 |
+
self.reason = ""
|
22 |
+
|
23 |
+
def __str__(self):
|
24 |
+
return "{},{},{},{}\n".format(self.basename, self.copypath, self.state, self.reason)
|
25 |
+
|
26 |
+
|
27 |
+
class FileOperation:
|
28 |
+
"""Encapsulate all file reading operations."""
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
self.image_suffix = [".jpg", ".jpeg", ".png", ".bmp"]
|
32 |
+
self.md_suffix = ".md"
|
33 |
+
self.text_suffix = [".txt", ".text"]
|
34 |
+
self.excel_suffix = [".xlsx", ".xls", ".csv"]
|
35 |
+
self.pdf_suffix = ".pdf"
|
36 |
+
self.ppt_suffix = ".pptx"
|
37 |
+
self.html_suffix = [".html", ".htm", ".shtml", ".xhtml"]
|
38 |
+
self.word_suffix = [".docx", ".doc"]
|
39 |
+
self.normal_suffix = (
|
40 |
+
[self.md_suffix]
|
41 |
+
+ self.text_suffix
|
42 |
+
+ self.excel_suffix
|
43 |
+
+ [self.pdf_suffix]
|
44 |
+
+ self.word_suffix
|
45 |
+
+ [self.ppt_suffix]
|
46 |
+
+ self.html_suffix
|
47 |
+
)
|
48 |
+
|
49 |
+
def get_type(self, filepath: str):
|
50 |
+
filepath = filepath.lower()
|
51 |
+
if filepath.endswith(self.pdf_suffix):
|
52 |
+
return "pdf"
|
53 |
+
|
54 |
+
if filepath.endswith(self.md_suffix):
|
55 |
+
return "md"
|
56 |
+
|
57 |
+
if filepath.endswith(self.ppt_suffix):
|
58 |
+
return "ppt"
|
59 |
+
|
60 |
+
for suffix in self.image_suffix:
|
61 |
+
if filepath.endswith(suffix):
|
62 |
+
return "image"
|
63 |
+
|
64 |
+
for suffix in self.text_suffix:
|
65 |
+
if filepath.endswith(suffix):
|
66 |
+
return "text"
|
67 |
+
|
68 |
+
for suffix in self.word_suffix:
|
69 |
+
if filepath.endswith(suffix):
|
70 |
+
return "word"
|
71 |
+
|
72 |
+
for suffix in self.excel_suffix:
|
73 |
+
if filepath.endswith(suffix):
|
74 |
+
return "excel"
|
75 |
+
|
76 |
+
for suffix in self.html_suffix:
|
77 |
+
if filepath.endswith(suffix):
|
78 |
+
return "html"
|
79 |
+
return None
|
80 |
+
|
81 |
+
def md5(self, filepath: str):
|
82 |
+
hash_object = hashlib.sha256()
|
83 |
+
with open(filepath, "rb") as file:
|
84 |
+
chunk_size = 8192
|
85 |
+
while chunk := file.read(chunk_size):
|
86 |
+
hash_object.update(chunk)
|
87 |
+
|
88 |
+
return hash_object.hexdigest()[0:8]
|
89 |
+
|
90 |
+
def summarize(self, files: list):
|
91 |
+
success = 0
|
92 |
+
skip = 0
|
93 |
+
failed = 0
|
94 |
+
|
95 |
+
for file in files:
|
96 |
+
if file.state:
|
97 |
+
success += 1
|
98 |
+
elif file.reason == "skip":
|
99 |
+
skip += 1
|
100 |
+
else:
|
101 |
+
logger.info("{} {}".format(file.origin, file.reason))
|
102 |
+
failed += 1
|
103 |
+
|
104 |
+
logger.info("{} {}".format(file.reason, file.copypath))
|
105 |
+
logger.info("累计{}文件,成功{}个,跳过{}个,异常{}个".format(len(files), success, skip, failed))
|
106 |
+
|
107 |
+
def scan_dir(self, repo_dir: str):
|
108 |
+
files = []
|
109 |
+
for root, _, filenames in os.walk(repo_dir):
|
110 |
+
for filename in filenames:
|
111 |
+
_type = self.get_type(filename)
|
112 |
+
if _type is not None:
|
113 |
+
files.append(FileName(root=root, filename=filename, _type=_type))
|
114 |
+
return files
|
115 |
+
|
116 |
+
def read_pdf(self, filepath: str):
|
117 |
+
# load pdf and serialize table
|
118 |
+
|
119 |
+
# TODO fitz 安装有些不兼容,后续按需完善
|
120 |
+
import fitz
|
121 |
+
|
122 |
+
text = ""
|
123 |
+
with fitz.open(filepath) as pages:
|
124 |
+
for page in pages:
|
125 |
+
text += page.get_text()
|
126 |
+
tables = page.find_tables()
|
127 |
+
for table in tables:
|
128 |
+
tablename = "_".join(filter(lambda x: x is not None and "Col" not in x, table.header.names))
|
129 |
+
pan = table.to_pandas()
|
130 |
+
json_text = pan.dropna(axis=1).to_json(force_ascii=False)
|
131 |
+
text += tablename
|
132 |
+
text += "\n"
|
133 |
+
text += json_text
|
134 |
+
text += "\n"
|
135 |
+
return text
|
136 |
+
|
137 |
+
def read_excel(self, filepath: str):
|
138 |
+
table = None
|
139 |
+
if filepath.endswith(".csv"):
|
140 |
+
table = pd.read_csv(filepath)
|
141 |
+
else:
|
142 |
+
table = pd.read_excel(filepath)
|
143 |
+
if table is None:
|
144 |
+
return ""
|
145 |
+
json_text = table.dropna(axis=1).to_json(force_ascii=False)
|
146 |
+
return json_text
|
147 |
+
|
148 |
+
def read(self, filepath: str):
|
149 |
+
file_type = self.get_type(filepath)
|
150 |
+
|
151 |
+
text = ""
|
152 |
+
|
153 |
+
if not os.path.exists(filepath):
|
154 |
+
return text, None
|
155 |
+
|
156 |
+
try:
|
157 |
+
|
158 |
+
if file_type == "md" or file_type == "text":
|
159 |
+
with open(filepath) as f:
|
160 |
+
text = f.read()
|
161 |
+
|
162 |
+
elif file_type == "pdf":
|
163 |
+
text += self.read_pdf(filepath)
|
164 |
+
|
165 |
+
elif file_type == "excel":
|
166 |
+
text += self.read_excel(filepath)
|
167 |
+
|
168 |
+
elif file_type == "word" or file_type == "ppt":
|
169 |
+
# https://stackoverflow.com/questions/36001482/read-doc-file-with-python
|
170 |
+
# https://textract.readthedocs.io/en/latest/installation.html
|
171 |
+
|
172 |
+
# TODO textract 在 pip 高于 24.1 后安装不了,因为其库自身原因,后续按需进行完善
|
173 |
+
# 可自行安装 pip install textract==1.6.5
|
174 |
+
import textract # for word and ppt
|
175 |
+
|
176 |
+
text = textract.process(filepath).decode("utf8")
|
177 |
+
if file_type == "ppt":
|
178 |
+
text = text.replace("\n", " ")
|
179 |
+
|
180 |
+
elif file_type == "html":
|
181 |
+
with open(filepath) as f:
|
182 |
+
soup = BeautifulSoup(f.read(), "html.parser")
|
183 |
+
text += soup.text
|
184 |
+
|
185 |
+
except Exception as e:
|
186 |
+
logger.error((filepath, str(e)))
|
187 |
+
return "", e
|
188 |
+
text = text.replace("\n\n", "\n")
|
189 |
+
text = text.replace("\n\n", "\n")
|
190 |
+
text = text.replace("\n\n", "\n")
|
191 |
+
text = text.replace(" ", " ")
|
192 |
+
text = text.replace(" ", " ")
|
193 |
+
text = text.replace(" ", " ")
|
194 |
+
return text, None
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
|
199 |
+
def get_pdf_files(directory):
|
200 |
+
pdf_files = []
|
201 |
+
# 遍历目录
|
202 |
+
for root, dirs, files in os.walk(directory):
|
203 |
+
for file in files:
|
204 |
+
# 检查文件扩展名是否为.pdf
|
205 |
+
if file.lower().endswith(".pdf"):
|
206 |
+
# 将完整路径添加到列表中
|
207 |
+
pdf_files.append(os.path.abspath(os.path.join(root, file)))
|
208 |
+
return pdf_files
|
209 |
+
|
210 |
+
# 将你想要搜索的目录替换为下面的路径
|
211 |
+
pdf_list = get_pdf_files("/home/khj/huixiangdou-web-online-data/hxd-bad-file")
|
212 |
+
|
213 |
+
# 打印所有找到的PDF文件的绝对路径
|
214 |
+
|
215 |
+
opr = FileOperation()
|
216 |
+
for pdf_path in pdf_list:
|
217 |
+
text, error = opr.read(pdf_path)
|
218 |
+
print("processing {}".format(pdf_path))
|
219 |
+
if error is not None:
|
220 |
+
# pdb.set_trace()
|
221 |
+
print("")
|
222 |
+
|
223 |
+
else:
|
224 |
+
if text is not None:
|
225 |
+
print(len(text))
|
226 |
+
else:
|
227 |
+
# pdb.set_trace()
|
228 |
+
print("")
|
server/base/modules/rag/rag_worker.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from ....web_configs import WEB_CONFIGS
|
8 |
+
from ...database.product_db import get_db_product_info
|
9 |
+
from .feature_store import gen_vector_db
|
10 |
+
from .retriever import CacheRetriever
|
11 |
+
|
12 |
+
# 基础配置
|
13 |
+
CONTEXT_MAX_LENGTH = 3000 # 上下文最大长度
|
14 |
+
GENERATE_TEMPLATE = "这是说明书:“{}”\n 客户的问题:“{}” \n 请阅读说明并运用你的性格进行解答。" # RAG prompt 模板
|
15 |
+
|
16 |
+
# RAG 实例句柄
|
17 |
+
RAG_RETRIEVER = None
|
18 |
+
|
19 |
+
|
20 |
+
def build_rag_prompt(rag_retriever: CacheRetriever, product_name, prompt):
|
21 |
+
|
22 |
+
real_retriever = rag_retriever.get(fs_id="default")
|
23 |
+
|
24 |
+
if isinstance(real_retriever, tuple):
|
25 |
+
logger.info(f" @@@ GOT real_retriever == tuple : {real_retriever}")
|
26 |
+
return ""
|
27 |
+
|
28 |
+
chunk, db_context, references = real_retriever.query(
|
29 |
+
f"商品名:{product_name}。{prompt}", context_max_length=CONTEXT_MAX_LENGTH - 2 * len(GENERATE_TEMPLATE)
|
30 |
+
)
|
31 |
+
logger.info(f"db_context = {db_context}")
|
32 |
+
|
33 |
+
if db_context is not None and len(db_context) > 1:
|
34 |
+
prompt_rag = GENERATE_TEMPLATE.format(db_context, prompt)
|
35 |
+
else:
|
36 |
+
logger.info("db_context get error")
|
37 |
+
prompt_rag = prompt
|
38 |
+
|
39 |
+
logger.info(f"RAG reference = {references}")
|
40 |
+
logger.info("=" * 20)
|
41 |
+
|
42 |
+
return prompt_rag
|
43 |
+
|
44 |
+
|
45 |
+
def init_rag_retriever(rag_config: str, db_path: str):
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
|
48 |
+
retriever = CacheRetriever(config_path=rag_config)
|
49 |
+
|
50 |
+
# 初始化
|
51 |
+
retriever.get(fs_id="default", config_path=rag_config, work_dir=db_path)
|
52 |
+
|
53 |
+
return retriever
|
54 |
+
|
55 |
+
|
56 |
+
async def gen_rag_db(user_id, force_gen=False):
|
57 |
+
"""
|
58 |
+
生成向量数据库。
|
59 |
+
|
60 |
+
参数:
|
61 |
+
force_gen - 布尔值,当设置为 True 时,即使数据库已存在也会重新生成数据库。
|
62 |
+
"""
|
63 |
+
|
64 |
+
# 检查数据库目录是否存在,如果存在且force_gen为False,则不执行生成操作
|
65 |
+
if Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists() and not force_gen:
|
66 |
+
return
|
67 |
+
|
68 |
+
if force_gen and Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists():
|
69 |
+
shutil.rmtree(WEB_CONFIGS.RAG_VECTOR_DB_DIR)
|
70 |
+
|
71 |
+
# 仅仅遍历 instructions 字段里面的文件
|
72 |
+
if Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).exists():
|
73 |
+
shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP)
|
74 |
+
Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).mkdir(exist_ok=True, parents=True)
|
75 |
+
|
76 |
+
# 读取 yaml 文件,获取所有说明书路径,并移动到 tmp 目录
|
77 |
+
product_list, _ = await get_db_product_info(user_id)
|
78 |
+
|
79 |
+
for info in product_list:
|
80 |
+
|
81 |
+
shutil.copyfile(
|
82 |
+
Path(
|
83 |
+
WEB_CONFIGS.SERVER_FILE_ROOT,
|
84 |
+
WEB_CONFIGS.PRODUCT_FILE_DIR,
|
85 |
+
WEB_CONFIGS.INSTRUCTIONS_DIR,
|
86 |
+
Path(info.instruction).name,
|
87 |
+
),
|
88 |
+
Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).joinpath(Path(info.instruction).name),
|
89 |
+
)
|
90 |
+
|
91 |
+
logger.info("Generating rag database, pls wait ...")
|
92 |
+
# 调用函数生成向量数据库
|
93 |
+
gen_vector_db(
|
94 |
+
WEB_CONFIGS.RAG_CONFIG_PATH,
|
95 |
+
str(Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).absolute()),
|
96 |
+
WEB_CONFIGS.RAG_VECTOR_DB_DIR,
|
97 |
+
)
|
98 |
+
|
99 |
+
# 删除过程文件
|
100 |
+
shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP)
|
101 |
+
|
102 |
+
|
103 |
+
async def load_rag_model(user_id):
|
104 |
+
|
105 |
+
global RAG_RETRIEVER
|
106 |
+
|
107 |
+
# 重新生成 RAG 向量数据库
|
108 |
+
await gen_rag_db(user_id)
|
109 |
+
|
110 |
+
# 加载 rag 模型
|
111 |
+
RAG_RETRIEVER = init_rag_retriever(rag_config=WEB_CONFIGS.RAG_CONFIG_PATH, db_path=WEB_CONFIGS.RAG_VECTOR_DB_DIR)
|
112 |
+
logger.info("load rag model done !...")
|
113 |
+
|
114 |
+
|
115 |
+
async def rebuild_rag_db(user_id, db_name="default"):
|
116 |
+
|
117 |
+
# 重新生成 RAG 向量数据库
|
118 |
+
await gen_rag_db(user_id, force_gen=True)
|
119 |
+
|
120 |
+
# 重新加载 retriever
|
121 |
+
RAG_RETRIEVER.pop(db_name)
|
122 |
+
RAG_RETRIEVER.get(fs_id=db_name, config_path=WEB_CONFIGS.RAG_CONFIG_PATH, work_dir=WEB_CONFIGS.RAG_VECTOR_DB_DIR)
|
server/base/modules/rag/retriever.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""extract feature and search with user query."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import yaml
|
8 |
+
from BCEmbedding.tools.langchain import BCERerank
|
9 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
10 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
11 |
+
from langchain.vectorstores.faiss import FAISS as Vectorstore
|
12 |
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
13 |
+
from loguru import logger
|
14 |
+
from modelscope import snapshot_download
|
15 |
+
from sklearn.metrics import precision_recall_curve
|
16 |
+
|
17 |
+
from ....web_configs import WEB_CONFIGS
|
18 |
+
from .file_operation import FileOperation
|
19 |
+
|
20 |
+
|
21 |
+
class Retriever:
|
22 |
+
"""Tokenize and extract features from the project's documents, for use in
|
23 |
+
the reject pipeline and response pipeline."""
|
24 |
+
|
25 |
+
def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None:
|
26 |
+
"""Init with model device type and config."""
|
27 |
+
self.reject_throttle = reject_throttle
|
28 |
+
self.rejecter = Vectorstore.load_local(
|
29 |
+
os.path.join(work_dir, "db_reject"), embeddings=embeddings, allow_dangerous_deserialization=True
|
30 |
+
)
|
31 |
+
self.retriever = Vectorstore.load_local(
|
32 |
+
os.path.join(work_dir, "db_response"),
|
33 |
+
embeddings=embeddings,
|
34 |
+
allow_dangerous_deserialization=True,
|
35 |
+
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
36 |
+
).as_retriever(search_type="similarity", search_kwargs={"score_threshold": 0.15, "k": 30})
|
37 |
+
self.compression_retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=self.retriever)
|
38 |
+
|
39 |
+
def is_reject(self, question, k=30, disable_throttle=False):
|
40 |
+
"""If no search results below the threshold can be found from the
|
41 |
+
database, reject this query."""
|
42 |
+
if disable_throttle:
|
43 |
+
# for searching throttle during update sample
|
44 |
+
docs_with_score = self.rejecter.similarity_search_with_relevance_scores(question, k=1)
|
45 |
+
if len(docs_with_score) < 1:
|
46 |
+
return True, docs_with_score
|
47 |
+
return False, docs_with_score
|
48 |
+
else:
|
49 |
+
# for retrieve result
|
50 |
+
# if no chunk passed the throttle, give the max
|
51 |
+
docs_with_score = self.rejecter.similarity_search_with_relevance_scores(question, k=k)
|
52 |
+
ret = []
|
53 |
+
max_score = -1
|
54 |
+
top1 = None
|
55 |
+
for doc, score in docs_with_score:
|
56 |
+
if score >= self.reject_throttle:
|
57 |
+
ret.append(doc)
|
58 |
+
if score > max_score:
|
59 |
+
max_score = score
|
60 |
+
top1 = (doc, score)
|
61 |
+
reject = False if len(ret) > 0 else True
|
62 |
+
return reject, [top1]
|
63 |
+
|
64 |
+
def update_throttle(self, config_path: str = "config.yaml", good_questions=[], bad_questions=[]):
|
65 |
+
"""Update reject throttle based on positive and negative examples."""
|
66 |
+
|
67 |
+
if len(good_questions) == 0 or len(bad_questions) == 0:
|
68 |
+
raise Exception("good and bad question examples cat not be empty.")
|
69 |
+
questions = good_questions + bad_questions
|
70 |
+
predictions = []
|
71 |
+
for question in questions:
|
72 |
+
self.reject_throttle = -1
|
73 |
+
_, docs = self.is_reject(question=question, disable_throttle=True)
|
74 |
+
score = docs[0][1]
|
75 |
+
predictions.append(max(0, score))
|
76 |
+
|
77 |
+
labels = [1 for _ in range(len(good_questions))] + [0 for _ in range(len(bad_questions))]
|
78 |
+
precision, recall, thresholds = precision_recall_curve(labels, predictions)
|
79 |
+
|
80 |
+
# get the best index for sum(precision, recall)
|
81 |
+
sum_precision_recall = precision[:-1] + recall[:-1]
|
82 |
+
index_max = np.argmax(sum_precision_recall)
|
83 |
+
optimal_threshold = max(thresholds[index_max], 0.0)
|
84 |
+
|
85 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
86 |
+
config = yaml.safe_load(f)
|
87 |
+
config["feature_store"]["reject_throttle"] = float(optimal_threshold)
|
88 |
+
with open(config_path, "w", encoding="utf8") as f:
|
89 |
+
yaml.dump(config, f)
|
90 |
+
logger.info(f"The optimal threshold is: {optimal_threshold}, saved it to {config_path}") # noqa E501
|
91 |
+
|
92 |
+
def query(self, question: str, context_max_length: int = 16000): # , tracker: QueryTracker = None):
|
93 |
+
"""Processes a query and returns the best match from the vector store
|
94 |
+
database. If the question is rejected, returns None.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
question (str): The question asked by the user.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
str: The best matching chunk, or None.
|
101 |
+
str: The best matching text, or None
|
102 |
+
"""
|
103 |
+
print(f"DEBUG -1: enter query")
|
104 |
+
|
105 |
+
if question is None or len(question) < 1:
|
106 |
+
print(f"DEBUG 0: len error")
|
107 |
+
|
108 |
+
return None, None, []
|
109 |
+
|
110 |
+
if len(question) > 512:
|
111 |
+
logger.warning("input too long, truncate to 512")
|
112 |
+
question = question[0:512]
|
113 |
+
|
114 |
+
# reject, docs = self.is_reject(question=question)
|
115 |
+
# assert (len(docs) > 0)
|
116 |
+
# if reject:
|
117 |
+
# return None, None, [docs[0][0].metadata['source']]
|
118 |
+
|
119 |
+
docs = self.compression_retriever.get_relevant_documents(question)
|
120 |
+
|
121 |
+
print(f"DEBUG 1: {docs}")
|
122 |
+
|
123 |
+
# if tracker is not None:
|
124 |
+
# tracker.log('retrieve', [doc.metadata['source'] for doc in docs])
|
125 |
+
chunks = []
|
126 |
+
context = ""
|
127 |
+
references = []
|
128 |
+
|
129 |
+
# add file text to context, until exceed `context_max_length`
|
130 |
+
|
131 |
+
file_opr = FileOperation()
|
132 |
+
for idx, doc in enumerate(docs):
|
133 |
+
chunk = doc.page_content
|
134 |
+
chunks.append(chunk)
|
135 |
+
|
136 |
+
if "read" not in doc.metadata:
|
137 |
+
logger.error(
|
138 |
+
"If you are using the version before 20240319, please rerun `python3 -m huixiangdou.service.feature_store`"
|
139 |
+
)
|
140 |
+
raise Exception("huixiangdou version mismatch")
|
141 |
+
file_text, error = file_opr.read(doc.metadata["read"])
|
142 |
+
if error is not None:
|
143 |
+
# read file failed, skip
|
144 |
+
print(f"DEBUG 2: error")
|
145 |
+
|
146 |
+
continue
|
147 |
+
|
148 |
+
source = doc.metadata["source"]
|
149 |
+
logger.info("target {} file length {}".format(source, len(file_text)))
|
150 |
+
|
151 |
+
print(f"DEBUG 3: target {source}, file length {len(file_text)}")
|
152 |
+
|
153 |
+
if len(file_text) + len(context) > context_max_length:
|
154 |
+
if source in references:
|
155 |
+
continue
|
156 |
+
references.append(source)
|
157 |
+
# add and break
|
158 |
+
add_len = context_max_length - len(context)
|
159 |
+
if add_len <= 0:
|
160 |
+
break
|
161 |
+
chunk_index = file_text.find(chunk)
|
162 |
+
if chunk_index == -1:
|
163 |
+
# chunk not in file_text
|
164 |
+
context += chunk
|
165 |
+
context += "\n"
|
166 |
+
context += file_text[0 : add_len - len(chunk) - 1]
|
167 |
+
else:
|
168 |
+
start_index = max(0, chunk_index - (add_len - len(chunk)))
|
169 |
+
context += file_text[start_index : start_index + add_len]
|
170 |
+
break
|
171 |
+
|
172 |
+
if source not in references:
|
173 |
+
context += file_text
|
174 |
+
context += "\n"
|
175 |
+
references.append(source)
|
176 |
+
|
177 |
+
context = context[0:context_max_length]
|
178 |
+
logger.debug("query:{} top1 file:{}".format(question, references[0]))
|
179 |
+
return "\n".join(chunks), context, [os.path.basename(r) for r in references]
|
180 |
+
|
181 |
+
|
182 |
+
class CacheRetriever:
|
183 |
+
|
184 |
+
def __init__(self, config_path: str, max_len: int = 4):
|
185 |
+
self.cache = dict()
|
186 |
+
self.max_len = max_len
|
187 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
188 |
+
config = yaml.safe_load(f)["feature_store"]
|
189 |
+
embedding_model_path = config["embedding_model_path"]
|
190 |
+
reranker_model_path = config["reranker_model_path"]
|
191 |
+
|
192 |
+
embedding_model_path = snapshot_download(embedding_model_path, cache_dir=WEB_CONFIGS.RAG_MODEL_DIR)
|
193 |
+
reranker_model_path = snapshot_download(reranker_model_path, cache_dir=WEB_CONFIGS.RAG_MODEL_DIR)
|
194 |
+
|
195 |
+
# load text2vec and rerank model
|
196 |
+
logger.info("loading test2vec and rerank models")
|
197 |
+
self.embeddings = HuggingFaceEmbeddings(
|
198 |
+
model_name=embedding_model_path,
|
199 |
+
model_kwargs={"device": "cuda"},
|
200 |
+
encode_kwargs={"batch_size": 1, "normalize_embeddings": True},
|
201 |
+
)
|
202 |
+
self.embeddings.client = self.embeddings.client.half()
|
203 |
+
reranker_args = {"model": reranker_model_path, "top_n": 7, "device": "cuda", "use_fp16": True}
|
204 |
+
self.reranker = BCERerank(**reranker_args)
|
205 |
+
|
206 |
+
def get(self, fs_id: str = "default", config_path="config.yaml", work_dir="workdir"):
|
207 |
+
if fs_id in self.cache:
|
208 |
+
self.cache[fs_id]["time"] = time.time()
|
209 |
+
return self.cache[fs_id]["retriever"]
|
210 |
+
|
211 |
+
if not os.path.exists(work_dir) or not os.path.exists(config_path):
|
212 |
+
return None, "workdir or config.yaml not exist"
|
213 |
+
|
214 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
215 |
+
reject_throttle = yaml.safe_load(f)["feature_store"]["reject_throttle"]
|
216 |
+
|
217 |
+
if len(self.cache) >= self.max_len:
|
218 |
+
# drop the oldest one
|
219 |
+
del_key = None
|
220 |
+
min_time = time.time()
|
221 |
+
for key, value in self.cache.items():
|
222 |
+
cur_time = value["time"]
|
223 |
+
if cur_time < min_time:
|
224 |
+
min_time = cur_time
|
225 |
+
del_key = key
|
226 |
+
|
227 |
+
if del_key is not None:
|
228 |
+
del_value = self.cache[del_key]
|
229 |
+
self.cache.pop(del_key)
|
230 |
+
del del_value["retriever"]
|
231 |
+
|
232 |
+
retriever = Retriever(
|
233 |
+
embeddings=self.embeddings, reranker=self.reranker, work_dir=work_dir, reject_throttle=reject_throttle
|
234 |
+
)
|
235 |
+
self.cache[fs_id] = {"retriever": retriever, "time": time.time()}
|
236 |
+
return retriever
|
237 |
+
|
238 |
+
def pop(self, fs_id: str):
|
239 |
+
if fs_id not in self.cache:
|
240 |
+
return
|
241 |
+
del_value = self.cache[fs_id]
|
242 |
+
self.cache.pop(fs_id)
|
243 |
+
# manually free memory
|
244 |
+
del del_value
|
server/base/modules/rag/test_queries.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"我的商品是牛肉。饲养天数",
|
3 |
+
"我的商品是唇膏。净含量是多少"
|
4 |
+
]
|
server/base/queue_thread.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : queue_thread.py
|
5 |
+
@Time : 2024/09/02
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 队列调取相关逻辑(半废弃状态)
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
from loguru import logger
|
14 |
+
import requests
|
15 |
+
import multiprocessing
|
16 |
+
|
17 |
+
from ..web_configs import API_CONFIG
|
18 |
+
from .server_info import SERVER_PLUGINS_INFO
|
19 |
+
|
20 |
+
|
21 |
+
def process_tts(tts_text_queue):
|
22 |
+
|
23 |
+
while True:
|
24 |
+
try:
|
25 |
+
text_chunk = tts_text_queue.get(block=True, timeout=1)
|
26 |
+
except Exception as e:
|
27 |
+
# logger.info(f"### {e}")
|
28 |
+
continue
|
29 |
+
logger.info(f"Get tts quene: {type(text_chunk)} , {text_chunk}")
|
30 |
+
res = requests.post(API_CONFIG.TTS_URL, json=text_chunk)
|
31 |
+
|
32 |
+
# # tts 推理成功,放入数字人队列进行推理
|
33 |
+
# res_json = res.json()
|
34 |
+
# tts_request_dict = {
|
35 |
+
# "user_id": "123",
|
36 |
+
# "request_id": text_chunk["request_id"],
|
37 |
+
# "chunk_id": text_chunk["chunk_id"],
|
38 |
+
# "tts_path": res_json["wav_path"],
|
39 |
+
# }
|
40 |
+
|
41 |
+
# DIGITAL_HUMAN_QUENE.put(tts_request_dict)
|
42 |
+
|
43 |
+
logger.info(f"tts res = {res}")
|
44 |
+
|
45 |
+
|
46 |
+
def process_digital_human(digital_human_queue):
|
47 |
+
|
48 |
+
while True:
|
49 |
+
try:
|
50 |
+
text_chunk = digital_human_queue.get(block=True, timeout=1)
|
51 |
+
except Exception as e:
|
52 |
+
# logger.info(f"### {e}")
|
53 |
+
continue
|
54 |
+
logger.info(f"Get digital human quene: {type(text_chunk)} , {text_chunk}")
|
55 |
+
res = requests.post(API_CONFIG.DIGITAL_HUMAN_URL, json=text_chunk)
|
56 |
+
logger.info(f"digital human res = {res}")
|
57 |
+
|
58 |
+
|
59 |
+
if SERVER_PLUGINS_INFO.tts_server_enabled:
|
60 |
+
TTS_TEXT_QUENE = multiprocessing.Queue(maxsize=100)
|
61 |
+
tts_thread = multiprocessing.Process(target=process_tts, args=(TTS_TEXT_QUENE,), name="tts_processer")
|
62 |
+
tts_thread.start()
|
63 |
+
else:
|
64 |
+
TTS_TEXT_QUENE = None
|
65 |
+
|
66 |
+
if SERVER_PLUGINS_INFO.digital_human_server_enabled:
|
67 |
+
DIGITAL_HUMAN_QUENE = multiprocessing.Queue(maxsize=100)
|
68 |
+
digital_human_thread = multiprocessing.Process(
|
69 |
+
target=process_digital_human, args=(DIGITAL_HUMAN_QUENE,), name="digital_human_processer"
|
70 |
+
)
|
71 |
+
digital_human_thread.start()
|
72 |
+
else:
|
73 |
+
DIGITAL_HUMAN_QUENE = None
|
server/base/routers/__init__.py
ADDED
File without changes
|
server/base/routers/digital_human.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : digital_human.py
|
5 |
+
@Time : 2024/09/02
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 数字人接口
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
from pathlib import Path
|
14 |
+
import uuid
|
15 |
+
import requests
|
16 |
+
from fastapi import APIRouter
|
17 |
+
from loguru import logger
|
18 |
+
from pydantic import BaseModel
|
19 |
+
|
20 |
+
from ...web_configs import API_CONFIG, WEB_CONFIGS
|
21 |
+
from ..utils import ResultCode, make_return_data
|
22 |
+
|
23 |
+
router = APIRouter(
|
24 |
+
prefix="/digital-human",
|
25 |
+
tags=["digital-human"],
|
26 |
+
responses={404: {"description": "Not found"}},
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class GenDigitalHumanVideoItem(BaseModel):
|
31 |
+
streamerId: int
|
32 |
+
salesDoc: str
|
33 |
+
|
34 |
+
|
35 |
+
async def gen_tts_and_digital_human_video_app(streamer_id: int, sales_doc: str):
|
36 |
+
logger.info(sales_doc)
|
37 |
+
|
38 |
+
request_id = str(uuid.uuid1())
|
39 |
+
sentence_id = 1 # 直接推理,所以设置成 1
|
40 |
+
user_id = "123"
|
41 |
+
|
42 |
+
# 生成 TTS wav
|
43 |
+
tts_json = {
|
44 |
+
"user_id": user_id,
|
45 |
+
"request_id": request_id,
|
46 |
+
"sentence": sales_doc,
|
47 |
+
"chunk_id": sentence_id,
|
48 |
+
# "wav_save_name": chat_item.request_id + f"{str(sentence_id).zfill(8)}.wav",
|
49 |
+
}
|
50 |
+
tts_save_path = Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, request_id + f"-{str(1).zfill(8)}.wav")
|
51 |
+
logger.info(f"waiting for wav generating done: {tts_save_path}")
|
52 |
+
_ = requests.post(API_CONFIG.TTS_URL, json=tts_json)
|
53 |
+
|
54 |
+
# 生成数字人视频
|
55 |
+
digital_human_gen_info = {
|
56 |
+
"user_id": user_id,
|
57 |
+
"request_id": request_id,
|
58 |
+
"chunk_id": 0,
|
59 |
+
"tts_path": str(tts_save_path),
|
60 |
+
"streamer_id": str(streamer_id),
|
61 |
+
}
|
62 |
+
video_path = Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH).joinpath(request_id + ".mp4")
|
63 |
+
logger.info(f"Generating digital human: {video_path}")
|
64 |
+
_ = requests.post(API_CONFIG.DIGITAL_HUMAN_URL, json=digital_human_gen_info)
|
65 |
+
|
66 |
+
# 删除过程文件
|
67 |
+
tts_save_path.unlink()
|
68 |
+
|
69 |
+
server_video_path = f"{API_CONFIG.REQUEST_FILES_URL}/{WEB_CONFIGS.STREAMER_FILE_DIR}/vid_output/{request_id}.mp4"
|
70 |
+
logger.info(server_video_path)
|
71 |
+
|
72 |
+
return server_video_path
|
73 |
+
|
74 |
+
|
75 |
+
@router.post("/gen")
|
76 |
+
async def get_digital_human_according_doc_api(gen_item: GenDigitalHumanVideoItem):
|
77 |
+
"""根据口播文案生成数字人介绍视频
|
78 |
+
|
79 |
+
Args:
|
80 |
+
gen_item (GenDigitalHumanVideoItem): _description_
|
81 |
+
|
82 |
+
"""
|
83 |
+
server_video_path = await gen_tts_and_digital_human_video_app(gen_item.streamerId, gen_item.salesDoc)
|
84 |
+
|
85 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", server_video_path)
|
server/base/routers/llm.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : llm.py
|
5 |
+
@Time : 2024/09/02
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 大模型接口
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
from typing import Dict, List
|
14 |
+
|
15 |
+
from fastapi import APIRouter, Depends
|
16 |
+
from loguru import logger
|
17 |
+
|
18 |
+
from ..database.llm_db import get_llm_product_prompt_base_info
|
19 |
+
from ..database.product_db import get_db_product_info
|
20 |
+
from ..database.streamer_info_db import get_db_streamer_info
|
21 |
+
from ..models.product_model import ProductInfo
|
22 |
+
from ..models.streamer_info_model import StreamerInfo
|
23 |
+
from ..modules.agent.agent_worker import get_agent_result
|
24 |
+
from ..server_info import SERVER_PLUGINS_INFO
|
25 |
+
from ..utils import LLM_MODEL_HANDLER, ResultCode, make_return_data
|
26 |
+
from .users import get_current_user_info
|
27 |
+
|
28 |
+
router = APIRouter(
|
29 |
+
prefix="/llm",
|
30 |
+
tags=["llm"],
|
31 |
+
responses={404: {"description": "Not found"}},
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def combine_history(prompt: list, history_msg: list):
|
36 |
+
"""生成对话历史 prompt
|
37 |
+
|
38 |
+
Args:
|
39 |
+
prompt (_type_): _description_
|
40 |
+
history_msg (_type_): _description_. Defaults to None.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
_type_: _description_
|
44 |
+
"""
|
45 |
+
# 角色映射表
|
46 |
+
role_map = {"streamer": "assistant", "user": "user"}
|
47 |
+
|
48 |
+
# 生成历史对话信息
|
49 |
+
for message in history_msg:
|
50 |
+
prompt.append({"role": role_map[message["role"]], "content": message["message"]})
|
51 |
+
|
52 |
+
return prompt
|
53 |
+
|
54 |
+
|
55 |
+
async def gen_poduct_base_prompt(
|
56 |
+
user_id: int,
|
57 |
+
streamer_id: int = -1,
|
58 |
+
product_id: int = -1,
|
59 |
+
streamer_info: StreamerInfo | None = None,
|
60 |
+
product_info: ProductInfo | None = None,
|
61 |
+
) -> List[Dict[str, str]]:
|
62 |
+
"""生成商品介绍的 prompt
|
63 |
+
|
64 |
+
Args:
|
65 |
+
user_id (int): 用户 ID
|
66 |
+
streamer_id (int): 主播 ID
|
67 |
+
product_id (int): 商品 ID
|
68 |
+
streamer_info (StreamerInfo, optional): 主播信息,如果为空则根据 streamer_id 查表
|
69 |
+
product_info (ProductInfo, optional): 商品信息,如果为空则根据 product_id 查表
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
List[Dict[str,str]]: 生成的 promot
|
73 |
+
"""
|
74 |
+
|
75 |
+
assert (streamer_id == -1 and streamer_info is not None) or (streamer_id != -1 and streamer_info is None)
|
76 |
+
assert (product_id == -1 and product_info is not None) or (product_id != -1 and product_info is None)
|
77 |
+
|
78 |
+
# 加载对话配置文件
|
79 |
+
dataset_yaml = await get_llm_product_prompt_base_info()
|
80 |
+
|
81 |
+
# 从配置中提取对话设置相关的信息
|
82 |
+
# system_str: 系统词,针对销售角色定制
|
83 |
+
# first_input_template: 对话开始时的第一个输入模板
|
84 |
+
# product_info_struct_template: 产品信息结构模板
|
85 |
+
system = dataset_yaml["conversation_setting"]["system"]
|
86 |
+
first_input_template = dataset_yaml["conversation_setting"]["first_input"]
|
87 |
+
product_info_struct_template = dataset_yaml["product_info_struct"]
|
88 |
+
|
89 |
+
# 根据 ID 获取主播信息
|
90 |
+
if streamer_info is None:
|
91 |
+
streamer_info = await get_db_streamer_info(user_id, streamer_id)
|
92 |
+
streamer_info = streamer_info[0]
|
93 |
+
|
94 |
+
# 将销售角色名和角色信息插入到 system prompt
|
95 |
+
character_str = streamer_info.character.replace(";", "、")
|
96 |
+
system_str = system.replace("{role_type}", streamer_info.name).replace("{character}", character_str)
|
97 |
+
|
98 |
+
# 根据 ID 获取商品信息
|
99 |
+
if product_info is None:
|
100 |
+
product_list, _ = await get_db_product_info(user_id, product_id=product_id)
|
101 |
+
product_info = product_list[0]
|
102 |
+
|
103 |
+
heighlights_str = product_info.heighlights.replace(";", "、")
|
104 |
+
product_info_str = product_info_struct_template[0].replace("{name}", product_info.product_name)
|
105 |
+
product_info_str += product_info_struct_template[1].replace("{highlights}", heighlights_str)
|
106 |
+
|
107 |
+
# 生成商品文案 prompt
|
108 |
+
sales_doc_prompt = first_input_template.replace("{product_info}", product_info_str)
|
109 |
+
|
110 |
+
prompt = [{"role": "system", "content": system_str}, {"role": "user", "content": sales_doc_prompt}]
|
111 |
+
logger.info(prompt)
|
112 |
+
|
113 |
+
return prompt
|
114 |
+
|
115 |
+
|
116 |
+
async def get_agent_res(prompt, departure_place, delivery_company):
|
117 |
+
"""调用 Agent 能力"""
|
118 |
+
agent_response = ""
|
119 |
+
|
120 |
+
if not SERVER_PLUGINS_INFO.agent_enabled:
|
121 |
+
# 如果不开启则直接返回空
|
122 |
+
return ""
|
123 |
+
|
124 |
+
GENERATE_AGENT_TEMPLATE = (
|
125 |
+
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # Agent prompt 模板
|
126 |
+
)
|
127 |
+
input_prompt = prompt[-1]["content"]
|
128 |
+
agent_response = get_agent_result(LLM_MODEL_HANDLER, input_prompt, departure_place, delivery_company)
|
129 |
+
if agent_response != "":
|
130 |
+
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt)
|
131 |
+
logger.info(f"Agent response: {agent_response}")
|
132 |
+
|
133 |
+
return agent_response
|
134 |
+
|
135 |
+
|
136 |
+
async def get_llm_res(prompt):
|
137 |
+
"""获取 LLM 推理返回
|
138 |
+
|
139 |
+
Args:
|
140 |
+
prompt (str): _description_
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
_type_: _description_
|
144 |
+
"""
|
145 |
+
|
146 |
+
logger.info(prompt)
|
147 |
+
model_name = LLM_MODEL_HANDLER.available_models[0]
|
148 |
+
|
149 |
+
res_data = ""
|
150 |
+
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt):
|
151 |
+
res_data = item["choices"][0]["message"]["content"]
|
152 |
+
|
153 |
+
return res_data
|
154 |
+
|
155 |
+
|
156 |
+
@router.get("/gen_sales_doc", summary="生成主播文案接口")
|
157 |
+
async def get_product_info_api(streamer_id: int, product_id: int, user_id: int = Depends(get_current_user_info)):
|
158 |
+
"""生成口播文案
|
159 |
+
|
160 |
+
Args:
|
161 |
+
streamer_id (int): 主播 ID,用于获取性格等信息
|
162 |
+
product_id (int): 商品 ID
|
163 |
+
"""
|
164 |
+
|
165 |
+
prompt = await gen_poduct_base_prompt(user_id, streamer_id, product_id)
|
166 |
+
|
167 |
+
res_data = await get_llm_res(prompt)
|
168 |
+
|
169 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
|
170 |
+
|
171 |
+
|
172 |
+
@router.get("/gen_product_info")
|
173 |
+
async def get_product_info_api(product_id: int, user_id: int = Depends(get_current_user_info)):
|
174 |
+
"""TODO 根据说明书内容生成商品信息
|
175 |
+
|
176 |
+
Args:
|
177 |
+
gen_product_item (GenProductItem): _description_
|
178 |
+
"""
|
179 |
+
|
180 |
+
raise NotImplemented()
|
181 |
+
instruction_str = ""
|
182 |
+
prompt = [{"system": "现在你是一个文档小助手,你可以从文档里面总结出我需要的信息", "input": ""}]
|
183 |
+
|
184 |
+
res_data = ""
|
185 |
+
model_name = LLM_MODEL_HANDLER.available_models[0]
|
186 |
+
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=prompt):
|
187 |
+
res_data += item
|
server/base/routers/products.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : products.py
|
5 |
+
@Time : 2024/08/30
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 商品信息接口
|
10 |
+
"""
|
11 |
+
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
from fastapi import APIRouter, Depends
|
15 |
+
|
16 |
+
from ...web_configs import WEB_CONFIGS
|
17 |
+
from ..database.product_db import (
|
18 |
+
create_or_update_db_product_by_id,
|
19 |
+
delete_product_id,
|
20 |
+
get_db_product_info,
|
21 |
+
)
|
22 |
+
from ..models.product_model import ProductInfo, ProductPageItem, ProductQueryItem
|
23 |
+
from ..modules.rag.rag_worker import rebuild_rag_db
|
24 |
+
from ..utils import ResultCode, make_return_data
|
25 |
+
from .users import get_current_user_info
|
26 |
+
|
27 |
+
router = APIRouter(
|
28 |
+
prefix="/products",
|
29 |
+
tags=["products"],
|
30 |
+
responses={404: {"description": "Not found"}},
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
@router.get("/list", summary="获取分页商品信息接口")
|
35 |
+
async def get_product_info_api(
|
36 |
+
currentPage: int = 1, pageSize: int = 5, productName: str | None = None, user_id: int = Depends(get_current_user_info)
|
37 |
+
):
|
38 |
+
product_list, db_product_size = await get_db_product_info(
|
39 |
+
user_id=user_id,
|
40 |
+
current_page=currentPage,
|
41 |
+
page_size=pageSize,
|
42 |
+
product_name=productName,
|
43 |
+
)
|
44 |
+
|
45 |
+
res_data = ProductPageItem(product_list=product_list, currentPage=currentPage, pageSize=pageSize, totalSize=db_product_size)
|
46 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
|
47 |
+
|
48 |
+
|
49 |
+
@router.get("/info/{productId}", summary="获取特定商品 ID 的详细信息接口")
|
50 |
+
async def get_product_id_info_api(productId: int, user_id: int = Depends(get_current_user_info)):
|
51 |
+
product_list, _ = await get_db_product_info(user_id=user_id, product_id=productId)
|
52 |
+
|
53 |
+
if len(product_list) == 1:
|
54 |
+
product_list = product_list[0]
|
55 |
+
|
56 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", product_list)
|
57 |
+
|
58 |
+
|
59 |
+
@router.post("/create", summary="新增商品接口")
|
60 |
+
async def upload_product_api(upload_product_item: ProductInfo, user_id: int = Depends(get_current_user_info)):
|
61 |
+
|
62 |
+
upload_product_item.user_id = user_id
|
63 |
+
upload_product_item.product_id = None
|
64 |
+
|
65 |
+
rebuild_rag_db_flag = create_or_update_db_product_by_id(0, upload_product_item)
|
66 |
+
|
67 |
+
if WEB_CONFIGS.ENABLE_RAG and rebuild_rag_db_flag:
|
68 |
+
# 重新生成 RAG 向量数据库
|
69 |
+
await rebuild_rag_db(user_id)
|
70 |
+
|
71 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
72 |
+
|
73 |
+
|
74 |
+
@router.put("/edit/{product_id}", summary="编辑商品接口")
|
75 |
+
async def upload_product_api(product_id: int, upload_product_item: ProductInfo, user_id: int = Depends(get_current_user_info)):
|
76 |
+
|
77 |
+
rebuild_rag_db_flag = create_or_update_db_product_by_id(product_id, upload_product_item, user_id)
|
78 |
+
|
79 |
+
if WEB_CONFIGS.ENABLE_RAG and rebuild_rag_db_flag:
|
80 |
+
# 重新生成 RAG 向量数据库
|
81 |
+
await rebuild_rag_db(user_id)
|
82 |
+
|
83 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
84 |
+
|
85 |
+
|
86 |
+
@router.delete("/delete/{productId}", summary="删除特定商品 ID 接口")
|
87 |
+
async def upload_product_api(productId: int, user_id: int = Depends(get_current_user_info)):
|
88 |
+
|
89 |
+
process_success_flag = await delete_product_id(productId, user_id)
|
90 |
+
|
91 |
+
if not process_success_flag:
|
92 |
+
return make_return_data(False, ResultCode.FAIL, "失败", "")
|
93 |
+
|
94 |
+
if WEB_CONFIGS.ENABLE_RAG:
|
95 |
+
# 重新生成 RAG 向量数据库
|
96 |
+
await rebuild_rag_db(user_id)
|
97 |
+
|
98 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
99 |
+
|
100 |
+
|
101 |
+
@router.post("/instruction", summary="获取对应商品的说明书内容接口", dependencies=[Depends(get_current_user_info)])
|
102 |
+
async def get_product_instruction_info_api(instruction_path: ProductQueryItem):
|
103 |
+
"""获取对应商品的说明书
|
104 |
+
|
105 |
+
Args:
|
106 |
+
instruction_path (ProductInstructionItem): 说明书路径
|
107 |
+
|
108 |
+
"""
|
109 |
+
# TODO 后续改为前端 axios 直接获取
|
110 |
+
loacl_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(
|
111 |
+
WEB_CONFIGS.PRODUCT_FILE_DIR, WEB_CONFIGS.INSTRUCTIONS_DIR, Path(instruction_path.instructionPath).name
|
112 |
+
)
|
113 |
+
if not loacl_path.exists():
|
114 |
+
return make_return_data(False, ResultCode.FAIL, "文件不存在", "")
|
115 |
+
|
116 |
+
with open(loacl_path, "r") as f:
|
117 |
+
instruction_content = f.read()
|
118 |
+
|
119 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", instruction_content)
|
server/base/routers/streamer_info.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : streamer_info.py
|
5 |
+
@Time : 2024/08/10
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 主播管理信息页面接口
|
10 |
+
"""
|
11 |
+
|
12 |
+
from typing import Tuple
|
13 |
+
import uuid
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
import requests
|
17 |
+
from fastapi import APIRouter, Depends
|
18 |
+
from loguru import logger
|
19 |
+
|
20 |
+
from ...web_configs import API_CONFIG, WEB_CONFIGS
|
21 |
+
from ..database.streamer_info_db import create_or_update_db_streamer_by_id, delete_streamer_id, get_db_streamer_info
|
22 |
+
from ..models.streamer_info_model import StreamerInfo
|
23 |
+
from ..utils import ResultCode, make_poster_by_video_first_frame, make_return_data
|
24 |
+
from .users import get_current_user_info
|
25 |
+
|
26 |
+
router = APIRouter(
|
27 |
+
prefix="/streamer",
|
28 |
+
tags=["streamer"],
|
29 |
+
responses={404: {"description": "Not found"}},
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
async def gen_digital_human(user_id, streamer_id: int, new_streamer_info: StreamerInfo) -> Tuple[str, str]:
|
34 |
+
"""生成数字人视频
|
35 |
+
|
36 |
+
Args:
|
37 |
+
user_id (int): 用户 ID
|
38 |
+
streamer_id (int): 主播 ID
|
39 |
+
new_streamer_info (StreamerInfo): 新的主播信息
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
str: 数字人视频地址
|
43 |
+
str: 数字人头像/海报地址
|
44 |
+
"""
|
45 |
+
|
46 |
+
streamer_info_db = await get_db_streamer_info(user_id, streamer_id)
|
47 |
+
streamer_info_db = streamer_info_db[0]
|
48 |
+
|
49 |
+
new_base_mp4_path = new_streamer_info.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
|
50 |
+
if streamer_info_db.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "") == new_base_mp4_path:
|
51 |
+
# 数字人视频没更新,跳过
|
52 |
+
return streamer_info_db.base_mp4_path, streamer_info_db.poster_image
|
53 |
+
|
54 |
+
# 调取接口生成进行数字人预处理
|
55 |
+
|
56 |
+
# new_streamer_info.base_mp4_path 是 服务器地址,需要进行转换
|
57 |
+
video_local_dir = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(
|
58 |
+
WEB_CONFIGS.STREAMER_FILE_DIR, WEB_CONFIGS.STREAMER_INFO_FILES_DIR
|
59 |
+
)
|
60 |
+
|
61 |
+
digital_human_gen_info = {
|
62 |
+
"user_id": str(user_id),
|
63 |
+
"request_id": str(uuid.uuid1()),
|
64 |
+
"streamer_id": str(new_streamer_info.streamer_id),
|
65 |
+
"video_path": str(video_local_dir.joinpath(Path(new_streamer_info.base_mp4_path).name)),
|
66 |
+
}
|
67 |
+
logger.info(f"Getting digital human preprocessing: {new_streamer_info.streamer_id}")
|
68 |
+
_ = requests.post(API_CONFIG.DIGITAL_HUMAN_PREPROCESS_URL, json=digital_human_gen_info)
|
69 |
+
|
70 |
+
# 根据视频第一帧生成头图
|
71 |
+
poster_save_name = Path(new_streamer_info.base_mp4_path).stem + ".png"
|
72 |
+
make_poster_by_video_first_frame(str(video_local_dir.joinpath(Path(new_streamer_info.base_mp4_path).name)), poster_save_name)
|
73 |
+
|
74 |
+
# 生成头图服务器地址
|
75 |
+
poster_server_url = str(Path(new_streamer_info.base_mp4_path).parent.joinpath(poster_save_name))
|
76 |
+
if "http://" not in poster_server_url and "http:/" in poster_server_url:
|
77 |
+
poster_server_url = poster_server_url.replace("http:/", "http://")
|
78 |
+
|
79 |
+
return new_streamer_info.base_mp4_path, poster_server_url
|
80 |
+
|
81 |
+
|
82 |
+
@router.get("/list", summary="获取所有主播信息接口,用于用户进行主播的选择")
|
83 |
+
async def get_streamer_info_api(user_id: int = Depends(get_current_user_info)):
|
84 |
+
"""获取所有主播信息,用于用户进行主播的选择"""
|
85 |
+
streamer_list = await get_db_streamer_info(user_id)
|
86 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_list)
|
87 |
+
|
88 |
+
|
89 |
+
@router.get("/info/{streamerId}", summary="用于获取特定主播的信息接口")
|
90 |
+
async def get_streamer_info_api(streamerId: int, user_id: int = Depends(get_current_user_info)):
|
91 |
+
"""用于获取特定主播的信息"""
|
92 |
+
|
93 |
+
streamer_list = await get_db_streamer_info(user_id, streamerId)
|
94 |
+
if len(streamer_list) == 1:
|
95 |
+
streamer_list = streamer_list[0]
|
96 |
+
|
97 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_list)
|
98 |
+
|
99 |
+
|
100 |
+
@router.post("/create", summary="新增主播信息接口")
|
101 |
+
async def create_streamer_info_api(streamerItem: StreamerInfo, user_id: int = Depends(get_current_user_info)):
|
102 |
+
"""新增主播信息"""
|
103 |
+
streamer_info = streamerItem
|
104 |
+
streamer_info.user_id = user_id
|
105 |
+
streamer_info.streamer_id = None
|
106 |
+
|
107 |
+
poster_image = streamer_info.poster_image
|
108 |
+
base_mp4_path = streamer_info.base_mp4_path
|
109 |
+
|
110 |
+
streamer_info.poster_image = ""
|
111 |
+
streamer_info.base_mp4_path = ""
|
112 |
+
|
113 |
+
# 更新数据库,才能拿到 stream_id
|
114 |
+
streamer_id = create_or_update_db_streamer_by_id(0, streamer_info, user_id)
|
115 |
+
|
116 |
+
streamer_info.poster_image = poster_image
|
117 |
+
streamer_info.base_mp4_path = base_mp4_path
|
118 |
+
streamer_info.streamer_id = streamer_id
|
119 |
+
|
120 |
+
# 数字人视频对其进行初始化,同时生成头图
|
121 |
+
video_info = await gen_digital_human(user_id, streamer_id, streamer_info)
|
122 |
+
|
123 |
+
streamer_info.base_mp4_path = video_info[0]
|
124 |
+
streamer_info.poster_image = video_info[1]
|
125 |
+
streamer_info.avatar = video_info[1]
|
126 |
+
|
127 |
+
create_or_update_db_streamer_by_id(streamer_id, streamer_info, user_id)
|
128 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_id)
|
129 |
+
|
130 |
+
|
131 |
+
@router.put("/edit/{streamer_id}", summary="修改主播信息接口")
|
132 |
+
async def edit_streamer_info_api(streamer_id: int, streamer_info: StreamerInfo, user_id: int = Depends(get_current_user_info)):
|
133 |
+
"""修改主播信息"""
|
134 |
+
|
135 |
+
# 如果更新了数字人视频对其进行初始化,同时生成头图
|
136 |
+
video_info = await gen_digital_human(user_id, streamer_id, streamer_info)
|
137 |
+
|
138 |
+
streamer_info.base_mp4_path = video_info[0]
|
139 |
+
streamer_info.poster_image = video_info[1]
|
140 |
+
streamer_info.avatar = video_info[1]
|
141 |
+
|
142 |
+
# 更新数据库
|
143 |
+
create_or_update_db_streamer_by_id(streamer_id, streamer_info, user_id)
|
144 |
+
|
145 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", streamer_id)
|
146 |
+
|
147 |
+
|
148 |
+
@router.delete("/delete/{streamerId}", summary="删除主播接口")
|
149 |
+
async def upload_product_api(streamerId: int, user_id: int = Depends(get_current_user_info)):
|
150 |
+
|
151 |
+
process_success_flag = await delete_streamer_id(streamerId, user_id)
|
152 |
+
|
153 |
+
if not process_success_flag:
|
154 |
+
return make_return_data(False, ResultCode.FAIL, "失败", "")
|
155 |
+
|
156 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
server/base/routers/streaming_room.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : streaming_room.py
|
5 |
+
@Time : 2024/08/31
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 主播间信息交互接口
|
10 |
+
"""
|
11 |
+
|
12 |
+
import uuid
|
13 |
+
from copy import deepcopy
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
import requests
|
17 |
+
from fastapi import APIRouter, Depends
|
18 |
+
from loguru import logger
|
19 |
+
|
20 |
+
from ...web_configs import API_CONFIG, WEB_CONFIGS
|
21 |
+
from ..database.product_db import get_db_product_info
|
22 |
+
from ..database.streamer_room_db import (
|
23 |
+
create_or_update_db_room_by_id,
|
24 |
+
get_live_room_info,
|
25 |
+
get_message_list,
|
26 |
+
update_db_room_status,
|
27 |
+
delete_room_id,
|
28 |
+
get_db_streaming_room_info,
|
29 |
+
update_message_info,
|
30 |
+
update_room_video_path,
|
31 |
+
)
|
32 |
+
from ..models.product_model import ProductInfo
|
33 |
+
from ..models.streamer_room_model import OnAirRoomStatusItem, RoomChatItem, SalesDocAndVideoInfo, StreamRoomInfo
|
34 |
+
from ..modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
|
35 |
+
from ..routers.users import get_current_user_info
|
36 |
+
from ..server_info import SERVER_PLUGINS_INFO
|
37 |
+
from ..utils import ResultCode, make_return_data
|
38 |
+
from .digital_human import gen_tts_and_digital_human_video_app
|
39 |
+
from .llm import combine_history, gen_poduct_base_prompt, get_agent_res, get_llm_res
|
40 |
+
|
41 |
+
router = APIRouter(
|
42 |
+
prefix="/streaming-room",
|
43 |
+
tags=["streaming-room"],
|
44 |
+
responses={404: {"description": "Not found"}},
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
@router.get("/list", summary="获取所有直播间信息接口")
|
49 |
+
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
|
50 |
+
"""获取所有直播间信息"""
|
51 |
+
# 加载直播间数据
|
52 |
+
streaming_room_list = await get_db_streaming_room_info(user_id)
|
53 |
+
|
54 |
+
for i in range(len(streaming_room_list)):
|
55 |
+
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
|
56 |
+
streaming_room_list[i] = dict(streaming_room_list[i])
|
57 |
+
|
58 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
|
59 |
+
|
60 |
+
|
61 |
+
@router.get("/info/{roomId}", summary="获取特定直播间信息接口")
|
62 |
+
async def get_streaming_room_id_api(
|
63 |
+
roomId: int, currentPage: int = 1, pageSize: int = 5, user_id: int = Depends(get_current_user_info)
|
64 |
+
):
|
65 |
+
"""获取特定直播间信息"""
|
66 |
+
# 加载直播间配置文件
|
67 |
+
assert roomId != 0
|
68 |
+
|
69 |
+
# TODO 加入分页
|
70 |
+
|
71 |
+
# 加载直播间数据
|
72 |
+
streaming_room_list = await get_db_streaming_room_info(user_id, room_id=roomId)
|
73 |
+
|
74 |
+
if len(streaming_room_list) == 1:
|
75 |
+
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
|
76 |
+
format_product_list = []
|
77 |
+
for db_product in streaming_room_list[0].product_list:
|
78 |
+
|
79 |
+
product_dict = dict(db_product)
|
80 |
+
# 将 start_video 改为服务器地址
|
81 |
+
if product_dict["start_video"] != "":
|
82 |
+
product_dict["start_video"] = API_CONFIG.REQUEST_FILES_URL + product_dict["start_video"]
|
83 |
+
|
84 |
+
format_product_list.append(product_dict)
|
85 |
+
streaming_room_list = dict(streaming_room_list[0])
|
86 |
+
streaming_room_list["product_list"] = format_product_list
|
87 |
+
else:
|
88 |
+
streaming_room_list = []
|
89 |
+
|
90 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
|
91 |
+
|
92 |
+
|
93 |
+
@router.get("/product-edit-list/{roomId}", summary="获取直播间商品编辑列表,含有已选中的标识")
|
94 |
+
async def get_streaming_room_product_list_api(
|
95 |
+
roomId: int, currentPage: int = 1, pageSize: int = 0, user_id: int = Depends(get_current_user_info)
|
96 |
+
):
|
97 |
+
"""获取直播间商品编辑列表,含有已选中的标识"""
|
98 |
+
|
99 |
+
# 获取目前直播间商品列表
|
100 |
+
if roomId == 0:
|
101 |
+
# 新的直播间
|
102 |
+
merge_list = []
|
103 |
+
exclude_list = []
|
104 |
+
else:
|
105 |
+
streaming_room_info = await get_db_streaming_room_info(user_id, roomId)
|
106 |
+
|
107 |
+
if len(streaming_room_info) == 0:
|
108 |
+
raise "401"
|
109 |
+
|
110 |
+
streaming_room_info = streaming_room_info[0]
|
111 |
+
# 获取未被选中的商品
|
112 |
+
exclude_list = [product.product_id for product in streaming_room_info.product_list]
|
113 |
+
merge_list = deepcopy(streaming_room_info.product_list)
|
114 |
+
|
115 |
+
# 获取未选中的商品信息
|
116 |
+
not_select_product_list, db_product_size = await get_db_product_info(user_id=user_id, exclude_list=exclude_list)
|
117 |
+
|
118 |
+
# 合并商品信息
|
119 |
+
for not_select_product in not_select_product_list:
|
120 |
+
merge_list.append(
|
121 |
+
SalesDocAndVideoInfo(
|
122 |
+
product_id=not_select_product.product_id,
|
123 |
+
product_info=ProductInfo(**dict(not_select_product)),
|
124 |
+
selected=False,
|
125 |
+
)
|
126 |
+
)
|
127 |
+
|
128 |
+
# TODO 懒加载分页
|
129 |
+
|
130 |
+
# 格式化
|
131 |
+
format_merge_list = []
|
132 |
+
for product in merge_list:
|
133 |
+
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
|
134 |
+
dict_info = dict(product)
|
135 |
+
if "stream_room" in dict_info:
|
136 |
+
dict_info.pop("stream_room")
|
137 |
+
format_merge_list.append(dict_info)
|
138 |
+
|
139 |
+
page_info = dict(
|
140 |
+
product_list=format_merge_list,
|
141 |
+
current=currentPage,
|
142 |
+
pageSize=db_product_size,
|
143 |
+
totalSize=db_product_size,
|
144 |
+
)
|
145 |
+
logger.info(page_info)
|
146 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", page_info)
|
147 |
+
|
148 |
+
|
149 |
+
@router.post("/create", summary="新增直播间接口")
|
150 |
+
async def streaming_room_edit_api(edit_item: dict, user_id: int = Depends(get_current_user_info)):
|
151 |
+
product_list = edit_item.pop("product_list")
|
152 |
+
status = edit_item.pop("status")
|
153 |
+
edit_item.pop("streamer_info")
|
154 |
+
edit_item.pop("room_id")
|
155 |
+
|
156 |
+
if "status_id" in edit_item:
|
157 |
+
edit_item.pop("status_id")
|
158 |
+
|
159 |
+
formate_product_list = []
|
160 |
+
for product in product_list:
|
161 |
+
if not product["selected"]:
|
162 |
+
continue
|
163 |
+
product.pop("product_info")
|
164 |
+
product_item = SalesDocAndVideoInfo(**product)
|
165 |
+
formate_product_list.append(product_item)
|
166 |
+
|
167 |
+
edit_item["user_id"] = user_id
|
168 |
+
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
|
169 |
+
room_id = create_or_update_db_room_by_id(0, formate_info, user_id)
|
170 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
|
171 |
+
|
172 |
+
|
173 |
+
@router.put("/edit/{room_id}", summary="编辑直播间接口")
|
174 |
+
async def streaming_room_edit_api(room_id: int, edit_item: dict, user_id: int = Depends(get_current_user_info)):
|
175 |
+
"""编辑直播间接口
|
176 |
+
|
177 |
+
Args:
|
178 |
+
edit_item (StreamRoomInfo): _description_
|
179 |
+
"""
|
180 |
+
|
181 |
+
product_list = edit_item.pop("product_list")
|
182 |
+
status = edit_item.pop("status")
|
183 |
+
edit_item.pop("streamer_info")
|
184 |
+
|
185 |
+
formate_product_list = []
|
186 |
+
for product in product_list:
|
187 |
+
if not product["selected"]:
|
188 |
+
continue
|
189 |
+
product.pop("product_info")
|
190 |
+
product_item = SalesDocAndVideoInfo(**product)
|
191 |
+
formate_product_list.append(product_item)
|
192 |
+
|
193 |
+
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
|
194 |
+
create_or_update_db_room_by_id(room_id, formate_info, user_id)
|
195 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
|
196 |
+
|
197 |
+
|
198 |
+
@router.delete("/delete/{roomId}", summary="删除直播间接口")
|
199 |
+
async def delete_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
|
200 |
+
|
201 |
+
process_success_flag = await delete_room_id(roomId, user_id)
|
202 |
+
|
203 |
+
if not process_success_flag:
|
204 |
+
return make_return_data(False, ResultCode.FAIL, "失败", "")
|
205 |
+
|
206 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
207 |
+
|
208 |
+
|
209 |
+
# ============================================================
|
210 |
+
# 开播接口
|
211 |
+
# ============================================================
|
212 |
+
|
213 |
+
|
214 |
+
@router.post("/online/{roomId}", summary="直播间开播接口")
|
215 |
+
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
|
216 |
+
|
217 |
+
update_db_room_status(roomId, user_id, "online")
|
218 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
219 |
+
|
220 |
+
|
221 |
+
@router.put("/offline/{roomId}", summary="直播间下播接口")
|
222 |
+
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
|
223 |
+
|
224 |
+
update_db_room_status(roomId, user_id, "offline")
|
225 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
226 |
+
|
227 |
+
|
228 |
+
@router.post("/next-product/{roomId}", summary="直播间进行下一个商品讲解接口")
|
229 |
+
async def on_air_live_room_next_product_api(roomId: int, user_id: int = Depends(get_current_user_info)):
|
230 |
+
"""直播间进行下一个商品讲解
|
231 |
+
|
232 |
+
Args:
|
233 |
+
roomId (int): 直播间 ID
|
234 |
+
"""
|
235 |
+
|
236 |
+
update_db_room_status(roomId, user_id, "next-product")
|
237 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
238 |
+
|
239 |
+
|
240 |
+
@router.get("/live-info/{roomId}", summary="获取正在直播的直播间信息接口")
|
241 |
+
async def get_on_air_live_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
|
242 |
+
"""获取正在直播的直播间信息
|
243 |
+
|
244 |
+
1. 主播视频地址
|
245 |
+
2. 商品信息,显示在右下角的商品缩略图
|
246 |
+
3. 对话记录 conversation_list
|
247 |
+
|
248 |
+
Args:
|
249 |
+
roomId (int): 直播间 ID
|
250 |
+
"""
|
251 |
+
|
252 |
+
res_data = await get_live_room_info(user_id, roomId)
|
253 |
+
|
254 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
|
255 |
+
|
256 |
+
|
257 |
+
@router.put("/chat", summary="直播间对话接口")
|
258 |
+
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
|
259 |
+
# 根据直播间 ID 获取信息
|
260 |
+
streaming_room_info = await get_db_streaming_room_info(user_id, room_chat.roomId)
|
261 |
+
streaming_room_info = streaming_room_info[0]
|
262 |
+
|
263 |
+
# 商品索引
|
264 |
+
product_detail = streaming_room_info.product_list[streaming_room_info.status.current_product_index].product_info
|
265 |
+
|
266 |
+
# 销售 ID
|
267 |
+
sales_info_id = streaming_room_info.product_list[streaming_room_info.status.current_product_index].sales_info_id
|
268 |
+
|
269 |
+
# 更新对话记录
|
270 |
+
update_message_info(sales_info_id, user_id, role="user", message=room_chat.message)
|
271 |
+
|
272 |
+
# 获取最新的对话记录
|
273 |
+
conversation_list = get_message_list(sales_info_id)
|
274 |
+
|
275 |
+
# 根据对话记录生成 prompt
|
276 |
+
prompt = await gen_poduct_base_prompt(
|
277 |
+
user_id,
|
278 |
+
streamer_info=streaming_room_info.streamer_info,
|
279 |
+
product_info=product_detail,
|
280 |
+
) # system + 获取商品文案prompt
|
281 |
+
|
282 |
+
prompt = combine_history(prompt, conversation_list)
|
283 |
+
|
284 |
+
# ====================== Agent ======================
|
285 |
+
# 调取 Agent
|
286 |
+
agent_response = await get_agent_res(prompt, product_detail.departure_place, product_detail.delivery_company)
|
287 |
+
if agent_response != "":
|
288 |
+
logger.info("Agent 执行成功,不执行 RAG")
|
289 |
+
prompt[-1]["content"] = agent_response
|
290 |
+
|
291 |
+
# ====================== RAG ======================
|
292 |
+
# 调取 rag
|
293 |
+
elif SERVER_PLUGINS_INFO.rag_enabled:
|
294 |
+
logger.info("Agent 未执行 or 未开启,调用 RAG")
|
295 |
+
# agent 失败,调取 rag, chat_item.plugins.rag 为 True,则使用 RAG 查询数据库
|
296 |
+
rag_res = build_rag_prompt(RAG_RETRIEVER, product_detail.product_name, prompt[-1]["content"])
|
297 |
+
if rag_res != "":
|
298 |
+
prompt[-1]["content"] = rag_res
|
299 |
+
|
300 |
+
# 调取 LLM
|
301 |
+
streamer_res = await get_llm_res(prompt)
|
302 |
+
|
303 |
+
# 生成数字人视频
|
304 |
+
server_video_path = await gen_tts_and_digital_human_video_app(streaming_room_info.streamer_info.streamer_id, streamer_res)
|
305 |
+
|
306 |
+
# 更新直播间数字人视频信息
|
307 |
+
update_room_video_path(streaming_room_info.status_id, server_video_path)
|
308 |
+
|
309 |
+
# 更新对话记录
|
310 |
+
update_message_info(sales_info_id, streaming_room_info.streamer_info.streamer_id, role="streamer", message=streamer_res)
|
311 |
+
|
312 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
|
313 |
+
|
314 |
+
|
315 |
+
@router.post("/asr", summary="直播间调取 ASR 语音转文字 接口")
|
316 |
+
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
|
317 |
+
|
318 |
+
# room_chat.asr_file 是 服务器地址,需要进行转换
|
319 |
+
asr_local_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(WEB_CONFIGS.ASR_FILE_DIR, Path(room_chat.asrFileUrl).name)
|
320 |
+
|
321 |
+
# 获取 ASR 结果
|
322 |
+
req_data = {
|
323 |
+
"user_id": user_id,
|
324 |
+
"request_id": str(uuid.uuid1()),
|
325 |
+
"wav_path": str(asr_local_path),
|
326 |
+
}
|
327 |
+
logger.info(req_data)
|
328 |
+
|
329 |
+
res = requests.post(API_CONFIG.ASR_URL, json=req_data).json()
|
330 |
+
asr_str = res["result"]
|
331 |
+
logger.info(f"ASR res = {asr_str}")
|
332 |
+
|
333 |
+
# 删除过程文件
|
334 |
+
asr_local_path.unlink()
|
335 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", asr_str)
|
server/base/routers/users.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : users.py
|
5 |
+
@Time : 2024/08/30
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 用户登录和 Token 认证接口
|
10 |
+
"""
|
11 |
+
|
12 |
+
from datetime import datetime, timedelta, timezone
|
13 |
+
|
14 |
+
import jwt
|
15 |
+
from fastapi import APIRouter, Depends, HTTPException
|
16 |
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
17 |
+
from loguru import logger
|
18 |
+
from passlib.context import CryptContext
|
19 |
+
|
20 |
+
from ...web_configs import WEB_CONFIGS
|
21 |
+
from ..database.user_db import get_db_user_info
|
22 |
+
from ..models.user_model import TokenItem
|
23 |
+
from ..utils import ResultCode, make_return_data
|
24 |
+
|
25 |
+
router = APIRouter(
|
26 |
+
prefix="/user",
|
27 |
+
tags=["user"],
|
28 |
+
responses={404: {"description": "Not found"}},
|
29 |
+
)
|
30 |
+
|
31 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/user/login")
|
32 |
+
|
33 |
+
# 密码加解密
|
34 |
+
PWD_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
35 |
+
|
36 |
+
|
37 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
38 |
+
"""密码校验
|
39 |
+
|
40 |
+
Args:
|
41 |
+
plain_password (str): 明文密码
|
42 |
+
hashed_password (str): 加密后的密码,用于对比
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
bool: 校验时候通过
|
46 |
+
"""
|
47 |
+
logger.info(f"expect password = {PWD_CONTEXT.hash('123456')}")
|
48 |
+
return PWD_CONTEXT.verify(plain_password, hashed_password)
|
49 |
+
|
50 |
+
|
51 |
+
def get_password_hash(password: str) -> str:
|
52 |
+
"""生成哈希密码
|
53 |
+
|
54 |
+
Args:
|
55 |
+
password (str): 明文密码
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
str: 加密后的哈希密码
|
59 |
+
"""
|
60 |
+
return PWD_CONTEXT.hash(password)
|
61 |
+
|
62 |
+
|
63 |
+
def authenticate_user(username: str, password: str) -> bool:
|
64 |
+
"""对用户名和密码进行校验
|
65 |
+
|
66 |
+
Args:
|
67 |
+
username (str): 用户名
|
68 |
+
password (str): 密码
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
bool: 是否检验通过
|
72 |
+
"""
|
73 |
+
|
74 |
+
# 获取用户信息
|
75 |
+
user_info = get_db_user_info(username=username, all_info=True)
|
76 |
+
if not user_info:
|
77 |
+
# 没有找到用户名
|
78 |
+
logger.info(f"Cannot find username = {username}")
|
79 |
+
return False
|
80 |
+
|
81 |
+
# 校验密码
|
82 |
+
if not verify_password(password, user_info.hashed_password):
|
83 |
+
logger.info(f"verify_password fail")
|
84 |
+
# 密码校验失败
|
85 |
+
return False
|
86 |
+
|
87 |
+
return user_info
|
88 |
+
|
89 |
+
|
90 |
+
def get_current_user_info(token: str = Depends(oauth2_scheme)):
|
91 |
+
"""在 token 中提取 user id
|
92 |
+
|
93 |
+
Args:
|
94 |
+
token (str, optional): token. Defaults to Depends(oauth2_scheme).
|
95 |
+
|
96 |
+
Raises:
|
97 |
+
HTTPException: 401 获取失败
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
int: 用户 ID
|
101 |
+
"""
|
102 |
+
logger.info(token)
|
103 |
+
try:
|
104 |
+
token_data = jwt.decode(token, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithms=WEB_CONFIGS.TOKEN_JWT_ALGORITHM)
|
105 |
+
logger.info(token_data)
|
106 |
+
user_id = token_data.get("user_id", None)
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(e)
|
109 |
+
raise HTTPException(status_code=401, detail="Could not validate credentials")
|
110 |
+
|
111 |
+
if not user_id:
|
112 |
+
logger.error(f"can not get user_id: {user_id}")
|
113 |
+
raise HTTPException(status_code=401, detail="Could not validate credentials")
|
114 |
+
|
115 |
+
# TODO 超时强制重新登录
|
116 |
+
|
117 |
+
logger.info(f"Got user_id: {user_id}")
|
118 |
+
return user_id
|
119 |
+
|
120 |
+
|
121 |
+
@router.post("/login", summary="登录接口")
|
122 |
+
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
123 |
+
|
124 |
+
# 校验用户名和密码
|
125 |
+
user_info = authenticate_user(form_data.username, form_data.password)
|
126 |
+
|
127 |
+
if not user_info:
|
128 |
+
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
|
129 |
+
|
130 |
+
# 过期时间
|
131 |
+
token_expires = datetime.now(timezone.utc) + timedelta(days=7)
|
132 |
+
|
133 |
+
# token 生成包含内容,记录 IP 的原因是防止被其他人拿到用户的 token 进行假冒访问
|
134 |
+
token_data = {
|
135 |
+
"user_id": user_info.user_id,
|
136 |
+
"username": user_info.username,
|
137 |
+
"exp": int(token_expires.timestamp()),
|
138 |
+
"ip": user_info.ip_address,
|
139 |
+
"login_time": int(datetime.now(timezone.utc).timestamp()),
|
140 |
+
}
|
141 |
+
logger.info(f"token_data = {token_data}")
|
142 |
+
|
143 |
+
# 生成 token
|
144 |
+
token = jwt.encode(token_data, WEB_CONFIGS.TOKEN_JWT_SECURITY_KEY, algorithm=WEB_CONFIGS.TOKEN_JWT_ALGORITHM)
|
145 |
+
|
146 |
+
# 返回
|
147 |
+
res_json = TokenItem(access_token=token, token_type="bearer")
|
148 |
+
logger.info(f"Got token info = {res_json}")
|
149 |
+
# return make_return_data(True, ResultCode.SUCCESS, "成功", content)
|
150 |
+
return res_json
|
151 |
+
|
152 |
+
|
153 |
+
@router.get("/me", summary="获取用户信息")
|
154 |
+
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
|
155 |
+
"""获取用户信息"""
|
156 |
+
user_info = get_db_user_info(id=user_id)
|
157 |
+
return make_return_data(True, ResultCode.SUCCESS, "成功", user_info)
|
server/base/server_info.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : server_info.py
|
5 |
+
@Time : 2024/09/02
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 组件信息获取逻辑
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
import random
|
14 |
+
import requests
|
15 |
+
from loguru import logger
|
16 |
+
|
17 |
+
from ..web_configs import API_CONFIG, WEB_CONFIGS
|
18 |
+
|
19 |
+
|
20 |
+
class ServerPluginsInfo:
|
21 |
+
|
22 |
+
def __init__(self) -> None:
|
23 |
+
self.update_info()
|
24 |
+
|
25 |
+
def update_info(self):
|
26 |
+
|
27 |
+
self.tts_server_enabled = self._check_server(API_CONFIG.TTS_URL + "/check")
|
28 |
+
self.digital_human_server_enabled = self._check_server(API_CONFIG.DIGITAL_HUMAN_CHECK_URL)
|
29 |
+
self.asr_server_enabled = self._check_server(API_CONFIG.ASR_URL + "/check")
|
30 |
+
self.llm_enabled = self._check_server(API_CONFIG.LLM_URL)
|
31 |
+
|
32 |
+
if WEB_CONFIGS.AGENT_DELIVERY_TIME_API_KEY is None or WEB_CONFIGS.AGENT_WEATHER_API_KEY is None:
|
33 |
+
self.agent_enabled = False
|
34 |
+
else:
|
35 |
+
self.agent_enabled = True
|
36 |
+
|
37 |
+
self.rag_enabled = WEB_CONFIGS.ENABLE_RAG
|
38 |
+
|
39 |
+
logger.info(
|
40 |
+
"\nself check plugins info : \n"
|
41 |
+
f"| llm | {self.llm_enabled} |\n"
|
42 |
+
f"| rag | {self.rag_enabled} |\n"
|
43 |
+
f"| tts | {self.tts_server_enabled} |\n"
|
44 |
+
f"| digital hunam | {self.digital_human_server_enabled} |\n"
|
45 |
+
f"| asr | {self.asr_server_enabled} |\n"
|
46 |
+
f"| agent | {self.agent_enabled} |\n"
|
47 |
+
)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _check_server(url):
|
51 |
+
|
52 |
+
try:
|
53 |
+
res = requests.get(url)
|
54 |
+
except requests.exceptions.ConnectionError:
|
55 |
+
return False
|
56 |
+
|
57 |
+
if res.status_code == 200:
|
58 |
+
return True
|
59 |
+
else:
|
60 |
+
return False
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def _make_color_list(color_num):
|
64 |
+
|
65 |
+
color_list = [
|
66 |
+
"#FF3838",
|
67 |
+
"#FF9D97",
|
68 |
+
"#FF701F",
|
69 |
+
"#FFB21D",
|
70 |
+
"#CFD231",
|
71 |
+
"#48F90A",
|
72 |
+
"#92CC17",
|
73 |
+
"#3DDB86",
|
74 |
+
"#1A9334",
|
75 |
+
"#00D4BB",
|
76 |
+
"#2C99A8",
|
77 |
+
"#00C2FF",
|
78 |
+
"#344593",
|
79 |
+
"#6473FF",
|
80 |
+
"#0018EC",
|
81 |
+
"#8438FF",
|
82 |
+
"#520085",
|
83 |
+
"#CB38FF",
|
84 |
+
"#FF95C8",
|
85 |
+
"#FF37C7",
|
86 |
+
]
|
87 |
+
|
88 |
+
return random.sample(color_list, color_num)
|
89 |
+
|
90 |
+
def get_status(self):
|
91 |
+
self.update_info()
|
92 |
+
|
93 |
+
info_list = [
|
94 |
+
{
|
95 |
+
"plugin_name": "LLM",
|
96 |
+
"describe": "大语言模型,用于根据客户历史对话,生成对话信息",
|
97 |
+
"enabled": self.llm_enabled,
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"plugin_name": "RAG",
|
101 |
+
"describe": "用于调用知识库实时更新信息",
|
102 |
+
"enabled": self.rag_enabled,
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"plugin_name": "TTS",
|
106 |
+
"describe": "文字转语音,让主播的文字也能听到",
|
107 |
+
"enabled": self.tts_server_enabled,
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"plugin_name": "数字人",
|
111 |
+
"describe": "数字人服务,用于生成数字人,需要和 TTS 一起开启才有效果",
|
112 |
+
"enabled": self.digital_human_server_enabled,
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"plugin_name": "Agent",
|
116 |
+
"describe": "用于根据用户对话,获取网络的实时信息",
|
117 |
+
"enabled": self.agent_enabled,
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"plugin_name": "ASR",
|
121 |
+
"describe": "语音转文字,让用户无需打字就可以和主播进行对话",
|
122 |
+
"enabled": self.asr_server_enabled,
|
123 |
+
},
|
124 |
+
]
|
125 |
+
|
126 |
+
# 生成图标背景色
|
127 |
+
color_list = self._make_color_list(len(info_list))
|
128 |
+
for idx, color in enumerate(color_list):
|
129 |
+
info_list[idx].update({"avatar_color": color})
|
130 |
+
|
131 |
+
return info_list
|
132 |
+
|
133 |
+
|
134 |
+
SERVER_PLUGINS_INFO = ServerPluginsInfo()
|
server/base/utils.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : utils.py
|
5 |
+
@Time : 2024/09/02
|
6 |
+
@Project : https://github.com/PeterH0323/Streamer-Sales
|
7 |
+
@Author : HinGwenWong
|
8 |
+
@Version : 1.0
|
9 |
+
@Desc : 工具集合文件
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
import asyncio
|
14 |
+
from ipaddress import IPv4Address
|
15 |
+
import json
|
16 |
+
import random
|
17 |
+
import wave
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from datetime import datetime
|
20 |
+
from pathlib import Path
|
21 |
+
from typing import Dict, List
|
22 |
+
|
23 |
+
import cv2
|
24 |
+
from lmdeploy.serve.openai.api_client import APIClient
|
25 |
+
from loguru import logger
|
26 |
+
from pydantic import BaseModel
|
27 |
+
from sqlmodel import Session, select
|
28 |
+
from tqdm import tqdm
|
29 |
+
|
30 |
+
from server.base.models.user_model import UserInfo
|
31 |
+
|
32 |
+
from ..tts.tools import SYMBOL_SPLITS, make_text_chunk
|
33 |
+
from ..web_configs import API_CONFIG, WEB_CONFIGS
|
34 |
+
from .database.init_db import DB_ENGINE
|
35 |
+
from .models.product_model import ProductInfo
|
36 |
+
from .models.streamer_info_model import StreamerInfo
|
37 |
+
from .models.streamer_room_model import OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo
|
38 |
+
|
39 |
+
from .modules.agent.agent_worker import get_agent_result
|
40 |
+
from .modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
|
41 |
+
from .queue_thread import DIGITAL_HUMAN_QUENE, TTS_TEXT_QUENE
|
42 |
+
from .server_info import SERVER_PLUGINS_INFO
|
43 |
+
|
44 |
+
|
45 |
+
class ChatGenConfig(BaseModel):
|
46 |
+
# LLM 推理配置
|
47 |
+
top_p: float = 0.8
|
48 |
+
temperature: float = 0.7
|
49 |
+
repetition_penalty: float = 1.005
|
50 |
+
|
51 |
+
|
52 |
+
class ProductInfoItem(BaseModel):
|
53 |
+
name: str
|
54 |
+
heighlights: str
|
55 |
+
introduce: str # 生成商品文案 prompt
|
56 |
+
|
57 |
+
image_path: str
|
58 |
+
departure_place: str
|
59 |
+
delivery_company_name: str
|
60 |
+
|
61 |
+
|
62 |
+
class PluginsInfo(BaseModel):
|
63 |
+
rag: bool = True
|
64 |
+
agent: bool = True
|
65 |
+
tts: bool = True
|
66 |
+
digital_human: bool = True
|
67 |
+
|
68 |
+
|
69 |
+
class ChatItem(BaseModel):
|
70 |
+
user_id: str # User 识别号,用于区分不用的用户调用
|
71 |
+
request_id: str # 请求 ID,用于生成 TTS & 数字人
|
72 |
+
prompt: List[Dict[str, str]] # 本次的 prompt
|
73 |
+
product_info: ProductInfoItem # 商品信息
|
74 |
+
plugins: PluginsInfo = PluginsInfo() # 插件信息
|
75 |
+
chat_config: ChatGenConfig = ChatGenConfig()
|
76 |
+
|
77 |
+
|
78 |
+
# 加载 LLM 模型
|
79 |
+
LLM_MODEL_HANDLER = APIClient(API_CONFIG.LLM_URL)
|
80 |
+
|
81 |
+
|
82 |
+
async def streamer_sales_process(chat_item: ChatItem):
|
83 |
+
|
84 |
+
# ====================== Agent ======================
|
85 |
+
# 调取 Agent
|
86 |
+
agent_response = ""
|
87 |
+
if chat_item.plugins.agent and SERVER_PLUGINS_INFO.agent_enabled:
|
88 |
+
GENERATE_AGENT_TEMPLATE = (
|
89 |
+
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" # Agent prompt 模板
|
90 |
+
)
|
91 |
+
input_prompt = chat_item.prompt[-1]["content"]
|
92 |
+
agent_response = get_agent_result(
|
93 |
+
LLM_MODEL_HANDLER, input_prompt, chat_item.product_info.departure_place, chat_item.product_info.delivery_company_name
|
94 |
+
)
|
95 |
+
if agent_response != "":
|
96 |
+
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt)
|
97 |
+
print(f"Agent response: {agent_response}")
|
98 |
+
chat_item.prompt[-1]["content"] = agent_response
|
99 |
+
|
100 |
+
# ====================== RAG ======================
|
101 |
+
# 调取 rag
|
102 |
+
if chat_item.plugins.rag and agent_response == "":
|
103 |
+
# 如果 Agent 没有执行,则使用 RAG 查询数据库
|
104 |
+
rag_prompt = chat_item.prompt[-1]["content"]
|
105 |
+
prompt_pro = build_rag_prompt(RAG_RETRIEVER, chat_item.product_info.name, rag_prompt)
|
106 |
+
|
107 |
+
if prompt_pro != "":
|
108 |
+
chat_item.prompt[-1]["content"] = prompt_pro
|
109 |
+
|
110 |
+
# llm 推理流返回
|
111 |
+
logger.info(chat_item.prompt)
|
112 |
+
|
113 |
+
current_predict = ""
|
114 |
+
idx = 0
|
115 |
+
last_text_index = 0
|
116 |
+
sentence_id = 0
|
117 |
+
model_name = LLM_MODEL_HANDLER.available_models[0]
|
118 |
+
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=chat_item.prompt, stream=True):
|
119 |
+
logger.debug(f"LLM predict: {item}")
|
120 |
+
if "content" not in item["choices"][0]["delta"]:
|
121 |
+
continue
|
122 |
+
current_res = item["choices"][0]["delta"]["content"]
|
123 |
+
|
124 |
+
if "~" in current_res:
|
125 |
+
current_res = current_res.replace("~", "。").replace("。。", "。")
|
126 |
+
|
127 |
+
current_predict += current_res
|
128 |
+
idx += 1
|
129 |
+
|
130 |
+
if chat_item.plugins.tts and SERVER_PLUGINS_INFO.tts_server_enabled:
|
131 |
+
# 切句子
|
132 |
+
sentence = ""
|
133 |
+
for symbol in SYMBOL_SPLITS:
|
134 |
+
if symbol in current_res:
|
135 |
+
last_text_index, sentence = make_text_chunk(current_predict, last_text_index)
|
136 |
+
if len(sentence) <= 3:
|
137 |
+
# 文字太短的情况,不做生成
|
138 |
+
sentence = ""
|
139 |
+
break
|
140 |
+
|
141 |
+
if sentence != "":
|
142 |
+
sentence_id += 1
|
143 |
+
logger.info(f"get sentence: {sentence}")
|
144 |
+
tts_request_dict = {
|
145 |
+
"user_id": chat_item.user_id,
|
146 |
+
"request_id": chat_item.request_id,
|
147 |
+
"sentence": sentence,
|
148 |
+
"chunk_id": sentence_id,
|
149 |
+
# "wav_save_name": chat_item.request_id + f"{str(sentence_id).zfill(8)}.wav",
|
150 |
+
}
|
151 |
+
|
152 |
+
TTS_TEXT_QUENE.put(tts_request_dict)
|
153 |
+
await asyncio.sleep(0.01)
|
154 |
+
|
155 |
+
yield json.dumps(
|
156 |
+
{
|
157 |
+
"event": "message",
|
158 |
+
"retry": 100,
|
159 |
+
"id": idx,
|
160 |
+
"data": current_predict,
|
161 |
+
"step": "llm",
|
162 |
+
"end_flag": False,
|
163 |
+
},
|
164 |
+
ensure_ascii=False,
|
165 |
+
)
|
166 |
+
await asyncio.sleep(0.01) # 加个延时避免无法发出 event stream
|
167 |
+
|
168 |
+
if chat_item.plugins.digital_human and SERVER_PLUGINS_INFO.digital_human_server_enabled:
|
169 |
+
|
170 |
+
wav_list = [
|
171 |
+
Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + f"-{str(i).zfill(8)}.wav")
|
172 |
+
for i in range(1, sentence_id + 1)
|
173 |
+
]
|
174 |
+
while True:
|
175 |
+
# 等待 TTS 生成完成
|
176 |
+
not_exist_count = 0
|
177 |
+
for tts_wav in wav_list:
|
178 |
+
if not tts_wav.exists():
|
179 |
+
not_exist_count += 1
|
180 |
+
|
181 |
+
logger.info(f"still need to wait for {not_exist_count}/{sentence_id} wav generating...")
|
182 |
+
if not_exist_count == 0:
|
183 |
+
break
|
184 |
+
|
185 |
+
yield json.dumps(
|
186 |
+
{
|
187 |
+
"event": "message",
|
188 |
+
"retry": 100,
|
189 |
+
"id": idx,
|
190 |
+
"data": current_predict,
|
191 |
+
"step": "tts",
|
192 |
+
"end_flag": False,
|
193 |
+
},
|
194 |
+
ensure_ascii=False,
|
195 |
+
)
|
196 |
+
await asyncio.sleep(1) # 加个延时避免无法发出 event stream
|
197 |
+
|
198 |
+
# 合并 tts
|
199 |
+
tts_save_path = Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + ".wav")
|
200 |
+
all_tts_data = []
|
201 |
+
|
202 |
+
for wav_file in tqdm(wav_list):
|
203 |
+
logger.info(f"Reading wav file {wav_file}...")
|
204 |
+
with wave.open(str(wav_file), "rb") as wf:
|
205 |
+
all_tts_data.append([wf.getparams(), wf.readframes(wf.getnframes())])
|
206 |
+
|
207 |
+
logger.info(f"Merging wav file to {tts_save_path}...")
|
208 |
+
tts_params = max([tts_data[0] for tts_data in all_tts_data])
|
209 |
+
with wave.open(str(tts_save_path), "wb") as wf:
|
210 |
+
wf.setparams(tts_params) # 使用第一个音频参数
|
211 |
+
|
212 |
+
for wf_data in all_tts_data:
|
213 |
+
wf.writeframes(wf_data[1])
|
214 |
+
logger.info(f"Merged wav file to {tts_save_path} !")
|
215 |
+
|
216 |
+
# 生成数字人视频
|
217 |
+
tts_request_dict = {
|
218 |
+
"user_id": chat_item.user_id,
|
219 |
+
"request_id": chat_item.request_id,
|
220 |
+
"chunk_id": 0,
|
221 |
+
"tts_path": str(tts_save_path),
|
222 |
+
}
|
223 |
+
|
224 |
+
logger.info(f"Generating digital human...")
|
225 |
+
DIGITAL_HUMAN_QUENE.put(tts_request_dict)
|
226 |
+
while True:
|
227 |
+
if (
|
228 |
+
Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH)
|
229 |
+
.joinpath(Path(tts_save_path).stem + ".mp4")
|
230 |
+
.with_suffix(".txt")
|
231 |
+
.exists()
|
232 |
+
):
|
233 |
+
break
|
234 |
+
yield json.dumps(
|
235 |
+
{
|
236 |
+
"event": "message",
|
237 |
+
"retry": 100,
|
238 |
+
"id": idx,
|
239 |
+
"data": current_predict,
|
240 |
+
"step": "dg",
|
241 |
+
"end_flag": False,
|
242 |
+
},
|
243 |
+
ensure_ascii=False,
|
244 |
+
)
|
245 |
+
await asyncio.sleep(1) # 加个延时避免无法发出 event stream
|
246 |
+
|
247 |
+
# 删除过程文件
|
248 |
+
for wav_file in wav_list:
|
249 |
+
wav_file.unlink()
|
250 |
+
|
251 |
+
yield json.dumps(
|
252 |
+
{
|
253 |
+
"event": "message",
|
254 |
+
"retry": 100,
|
255 |
+
"id": idx,
|
256 |
+
"data": current_predict,
|
257 |
+
"step": "all",
|
258 |
+
"end_flag": True,
|
259 |
+
},
|
260 |
+
ensure_ascii=False,
|
261 |
+
)
|
262 |
+
|
263 |
+
|
264 |
+
def make_poster_by_video_first_frame(video_path: str, image_output_name: str):
|
265 |
+
"""根据视频第一帧生成缩略图
|
266 |
+
|
267 |
+
Args:
|
268 |
+
video_path (str): 视频文件路径
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
str: 第一帧保存的图片路径
|
272 |
+
"""
|
273 |
+
|
274 |
+
# 打开视频文件
|
275 |
+
cap = cv2.VideoCapture(video_path)
|
276 |
+
|
277 |
+
# 读取第一帧
|
278 |
+
ret, frame = cap.read()
|
279 |
+
|
280 |
+
# 检查是否成功读取
|
281 |
+
poster_save_path = str(Path(video_path).parent.joinpath(image_output_name))
|
282 |
+
if ret:
|
283 |
+
# 保存图像到文件
|
284 |
+
cv2.imwrite(poster_save_path, frame)
|
285 |
+
logger.info(f"第一帧已保存为 {poster_save_path}")
|
286 |
+
else:
|
287 |
+
logger.error("无法读取视频帧")
|
288 |
+
|
289 |
+
# 释放视频捕获对象
|
290 |
+
cap.release()
|
291 |
+
|
292 |
+
return poster_save_path
|
293 |
+
|
294 |
+
|
295 |
+
@dataclass
|
296 |
+
class ResultCode:
|
297 |
+
SUCCESS: int = 0000 # 成功
|
298 |
+
FAIL: int = 1000 # 失败
|
299 |
+
|
300 |
+
|
301 |
+
def make_return_data(success_flag: bool, code: ResultCode, message: str, data: dict):
|
302 |
+
return {
|
303 |
+
"success": success_flag,
|
304 |
+
"code": code,
|
305 |
+
"message": message,
|
306 |
+
"data": data,
|
307 |
+
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
308 |
+
}
|
309 |
+
|
310 |
+
|
311 |
+
def gen_default_data():
|
312 |
+
"""生成默认数据,包括:
|
313 |
+
- 商品数据
|
314 |
+
- 主播数据
|
315 |
+
- 直播间信息以及关联表
|
316 |
+
"""
|
317 |
+
|
318 |
+
def create_default_user():
|
319 |
+
"""创建默认用户"""
|
320 |
+
admin_user = UserInfo(
|
321 |
+
username="hingwen.wong",
|
322 |
+
ip_address=IPv4Address("127.0.0.1"),
|
323 |
+
email="[email protected]",
|
324 |
+
hashed_password="$2b$12$zXXveodjipHZMoSxJz5ODul7Z9YeRJd0GeSBjpwHdqEtBbAFvEdre", # 123456 -> 用 get_password_hash 加密后的字符串
|
325 |
+
avatar="/user/user-avatar.png",
|
326 |
+
)
|
327 |
+
|
328 |
+
with Session(DB_ENGINE) as session:
|
329 |
+
session.add(admin_user)
|
330 |
+
session.commit()
|
331 |
+
|
332 |
+
def init_user() -> bool:
|
333 |
+
"""判断是否需要创建默认用户
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
bool: 是否执行创建默认用户
|
337 |
+
"""
|
338 |
+
with Session(DB_ENGINE) as session:
|
339 |
+
results = session.exec(select(UserInfo).where(UserInfo.user_id == 1)).first()
|
340 |
+
|
341 |
+
if results is None:
|
342 |
+
# 如果数据库为空,创建初始用户
|
343 |
+
create_default_user()
|
344 |
+
logger.info("created default user info")
|
345 |
+
return True
|
346 |
+
|
347 |
+
return False
|
348 |
+
|
349 |
+
def create_default_product_item():
|
350 |
+
"""生成商品默认数据库"""
|
351 |
+
delivery_company_list = ["京东", "顺丰", "韵达", "圆通", "中通"]
|
352 |
+
departure_place_list = ["广州", "北京", "武汉", "杭州", "上海", "深圳", "成都"]
|
353 |
+
default_product_list = {
|
354 |
+
"beef": {
|
355 |
+
"product_name": "进口和牛羽下肉",
|
356 |
+
"heighlights": "富含铁质;营养价值高;肌肉纤维好;红白相间纹理;适合烧烤炖煮;草食动物来源",
|
357 |
+
"product_class": "食品",
|
358 |
+
},
|
359 |
+
"elec_toothblush": {
|
360 |
+
"product_name": "声波电动牙刷",
|
361 |
+
"heighlights": "高效清洁;减少手动压力;定时提醒;智能模式调节;无线充电;噪音低",
|
362 |
+
"product_class": "电子",
|
363 |
+
},
|
364 |
+
"lip_stick": {
|
365 |
+
"product_name": "唇膏",
|
366 |
+
"heighlights": "丰富色号;滋润保湿;显色度高;持久不脱色;易于涂抹;便携包装",
|
367 |
+
"product_class": "美妆",
|
368 |
+
},
|
369 |
+
"mask": {
|
370 |
+
"product_name": "光感润颜面膜",
|
371 |
+
"heighlights": "密集滋养;深层补水;急救修复;快速见效;定期护理;多种类型选择",
|
372 |
+
"product_class": "美妆",
|
373 |
+
},
|
374 |
+
"oled_tv": {
|
375 |
+
"product_name": "65英寸OLED电视",
|
376 |
+
"heighlights": "色彩鲜艳;对比度极高;响应速度快;无背光眩光;厚度较薄;自发光无需额外照明",
|
377 |
+
"product_class": "家电",
|
378 |
+
},
|
379 |
+
"pad": {
|
380 |
+
"product_name": "14英寸平板电脑",
|
381 |
+
"heighlights": "轻薄;触控操作;电池续航好;移动办公便利;娱乐性强;适合儿童学习",
|
382 |
+
"product_class": "电子",
|
383 |
+
},
|
384 |
+
"pants": {
|
385 |
+
"product_name": "速干运动裤",
|
386 |
+
"heighlights": "快干;伸缩自如;吸湿排汗;防风保暖;高腰设计;多口袋实用",
|
387 |
+
"product_class": "衣服",
|
388 |
+
},
|
389 |
+
"pen": {
|
390 |
+
"product_name": "墨水钢笔",
|
391 |
+
"heighlights": "耐用性;可书写性;不同颜色和类型;轻便设计;环保材料;易于携带",
|
392 |
+
"product_class": "文具",
|
393 |
+
},
|
394 |
+
"perfume": {
|
395 |
+
"product_name": "薰衣草淡香氛",
|
396 |
+
"heighlights": "浪漫优雅;花香调为主;情感表达;适合各种年龄;瓶身设计精致;提升女性魅力",
|
397 |
+
"product_class": "家居用品",
|
398 |
+
},
|
399 |
+
"shampoo": {
|
400 |
+
"product_name": "本草精华洗发露",
|
401 |
+
"heighlights": "温和配方;深层清洁;滋养头皮;丰富泡沫;易冲洗;适合各种发质",
|
402 |
+
"product_class": "日用品",
|
403 |
+
},
|
404 |
+
"wok": {
|
405 |
+
"product_name": "不粘煎炒锅",
|
406 |
+
"heighlights": "不粘涂层;耐磨耐用;导热快;易清洗;多种烹饪方式;设计人性化",
|
407 |
+
"product_class": "厨具",
|
408 |
+
},
|
409 |
+
"yoga_mat": {
|
410 |
+
"product_name": "瑜伽垫",
|
411 |
+
"heighlights": "防滑材质;吸湿排汗;厚度适中;耐用易清洁;各种瑜伽动作适用;轻巧便携",
|
412 |
+
"product_class": "运动",
|
413 |
+
},
|
414 |
+
}
|
415 |
+
|
416 |
+
with Session(DB_ENGINE) as session:
|
417 |
+
for product_key, product_info in default_product_list.items():
|
418 |
+
add_item = ProductInfo(
|
419 |
+
**product_info,
|
420 |
+
image_path=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.IMAGES_DIR}/{product_key}.png",
|
421 |
+
instruction=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.INSTRUCTIONS_DIR}/{product_key}.md",
|
422 |
+
departure_place=random.choice(departure_place_list),
|
423 |
+
delivery_company=random.choice(delivery_company_list),
|
424 |
+
selling_price=round(random.uniform(66.6, 1999.9), 2),
|
425 |
+
amount=random.randint(999, 9999),
|
426 |
+
user_id=1,
|
427 |
+
)
|
428 |
+
session.add(add_item)
|
429 |
+
session.commit()
|
430 |
+
|
431 |
+
logger.info("created default product info done!")
|
432 |
+
|
433 |
+
def create_default_streamer():
|
434 |
+
|
435 |
+
with Session(DB_ENGINE) as session:
|
436 |
+
streamer_item = StreamerInfo(
|
437 |
+
name="乐乐喵",
|
438 |
+
character="甜美;可爱;熟练使用各种网络热门梗造句;称呼客户为[家人们]",
|
439 |
+
avatar=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.png",
|
440 |
+
base_mp4_path=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.mp4",
|
441 |
+
poster_image=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.png",
|
442 |
+
tts_reference_audio=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.wav",
|
443 |
+
tts_reference_sentence="列车巡游银河,我不一定都能帮上忙,但只要是花钱能解决的事,尽管和我说吧。",
|
444 |
+
tts_weight_tag="艾丝妲",
|
445 |
+
user_id=1,
|
446 |
+
)
|
447 |
+
session.add(streamer_item)
|
448 |
+
session.commit()
|
449 |
+
|
450 |
+
def create_default_room():
|
451 |
+
|
452 |
+
with Session(DB_ENGINE) as session:
|
453 |
+
|
454 |
+
product_list = session.exec(
|
455 |
+
select(ProductInfo).where(ProductInfo.user_id == 1).order_by(ProductInfo.product_id)
|
456 |
+
).all()
|
457 |
+
|
458 |
+
on_air_status = OnAirRoomStatusItem(user_id=1)
|
459 |
+
session.add(on_air_status)
|
460 |
+
session.commit()
|
461 |
+
session.refresh(on_air_status)
|
462 |
+
|
463 |
+
stream_item = StreamRoomInfo(
|
464 |
+
name="001",
|
465 |
+
user_id=1,
|
466 |
+
status_id=on_air_status.status_id,
|
467 |
+
streamer_id=1,
|
468 |
+
)
|
469 |
+
session.add(stream_item)
|
470 |
+
session.commit()
|
471 |
+
session.refresh(stream_item)
|
472 |
+
|
473 |
+
random_list = random.choices(product_list, k=3)
|
474 |
+
for product_random in random_list:
|
475 |
+
add_sales_info = SalesDocAndVideoInfo(product_id=product_random.product_id, room_id=stream_item.room_id)
|
476 |
+
session.add(add_sales_info)
|
477 |
+
session.commit()
|
478 |
+
session.refresh(add_sales_info)
|
479 |
+
|
480 |
+
# 主要逻辑
|
481 |
+
created = init_user()
|
482 |
+
if created:
|
483 |
+
create_default_product_item() # 商品信息
|
484 |
+
create_default_streamer() # 主播信息
|
485 |
+
create_default_room() # 直播间信息
|
server/digital_human/digital_human_server.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from fastapi.exceptions import RequestValidationError
|
3 |
+
from fastapi.responses import PlainTextResponse
|
4 |
+
from loguru import logger
|
5 |
+
from pydantic import BaseModel
|
6 |
+
|
7 |
+
|
8 |
+
from .modules.digital_human_worker import gen_digital_human_video_app, preprocess_digital_human_app
|
9 |
+
|
10 |
+
|
11 |
+
app = FastAPI()
|
12 |
+
|
13 |
+
|
14 |
+
class DigitalHumanItem(BaseModel):
|
15 |
+
user_id: str # User 识别号,用于区分不用的用户调用
|
16 |
+
request_id: str # 请求 ID,用于生成 TTS & 数字人
|
17 |
+
streamer_id: str # 数字人 ID
|
18 |
+
tts_path: str = "" # 文本
|
19 |
+
chunk_id: int = 0 # 句子 ID
|
20 |
+
|
21 |
+
|
22 |
+
class DigitalHumanPreprocessItem(BaseModel):
|
23 |
+
user_id: str # User 识别号,用于区分不用的用户调用
|
24 |
+
request_id: str # 请求 ID,用于生成 TTS & 数字人
|
25 |
+
streamer_id: str # 数字人 ID
|
26 |
+
video_path: str # 数字人视频
|
27 |
+
|
28 |
+
|
29 |
+
@app.post("/digital_human/gen")
|
30 |
+
async def get_digital_human(dg_item: DigitalHumanItem):
|
31 |
+
"""生成数字人视频"""
|
32 |
+
save_tag = (
|
33 |
+
dg_item.request_id + ".mp4" if dg_item.chunk_id == 0 else dg_item.request_id + f"-{str(dg_item.chunk_id).zfill(8)}.mp4"
|
34 |
+
)
|
35 |
+
mp4_path = await gen_digital_human_video_app(dg_item.streamer_id, dg_item.tts_path, save_tag)
|
36 |
+
logger.info(f"digital human mp4 path = {mp4_path}")
|
37 |
+
return {"user_id": dg_item.user_id, "request_id": dg_item.request_id, "digital_human_mp4_path": mp4_path}
|
38 |
+
|
39 |
+
|
40 |
+
@app.post("/digital_human/preprocess")
|
41 |
+
async def preprocess_digital_human(preprocess_item: DigitalHumanPreprocessItem):
|
42 |
+
"""数字人视频预处理,用于新增数字人"""
|
43 |
+
|
44 |
+
_ = await preprocess_digital_human_app(str(preprocess_item.streamer_id), preprocess_item.video_path)
|
45 |
+
|
46 |
+
logger.info(f"digital human process for {preprocess_item.streamer_id} done")
|
47 |
+
return {"user_id": preprocess_item.user_id, "request_id": preprocess_item.request_id}
|
48 |
+
|
49 |
+
|
50 |
+
@app.exception_handler(RequestValidationError)
|
51 |
+
async def validation_exception_handler(request, exc):
|
52 |
+
"""调 API 入参错误的回调接口
|
53 |
+
|
54 |
+
Args:
|
55 |
+
request (_type_): _description_
|
56 |
+
exc (_type_): _description_
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
_type_: _description_
|
60 |
+
"""
|
61 |
+
logger.info(request)
|
62 |
+
logger.info(exc)
|
63 |
+
return PlainTextResponse(str(exc), status_code=400)
|
64 |
+
|
65 |
+
|
66 |
+
@app.get("/digital_human/check")
|
67 |
+
async def check_server():
|
68 |
+
return {"message": "server enabled"}
|
server/digital_human/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import hub
|
2 |
+
from ...web_configs import WEB_CONFIGS
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# 部分模型会使用 torch download 下载,需要设置路径
|
6 |
+
hub.set_dir(str(Path(WEB_CONFIGS.DIGITAL_HUMAN_MODEL_DIR).joinpath("face-alignment")))
|
server/digital_human/modules/digital_human_worker.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from .realtime_inference import DIGITAL_HUMAN_HANDLER, gen_digital_human_preprocess, gen_digital_human_video
|
3 |
+
from ...web_configs import WEB_CONFIGS
|
4 |
+
|
5 |
+
|
6 |
+
async def gen_digital_human_video_app(stream_id, audio_path, save_tag):
|
7 |
+
if DIGITAL_HUMAN_HANDLER is None:
|
8 |
+
return None
|
9 |
+
|
10 |
+
save_path = gen_digital_human_video(
|
11 |
+
DIGITAL_HUMAN_HANDLER,
|
12 |
+
stream_id,
|
13 |
+
audio_path,
|
14 |
+
work_dir=str(Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH).absolute()),
|
15 |
+
video_path=save_tag,
|
16 |
+
fps=DIGITAL_HUMAN_HANDLER.fps,
|
17 |
+
)
|
18 |
+
|
19 |
+
return save_path
|
20 |
+
|
21 |
+
|
22 |
+
async def preprocess_digital_human_app(stream_id, video_path):
|
23 |
+
if DIGITAL_HUMAN_HANDLER is None:
|
24 |
+
return None
|
25 |
+
|
26 |
+
res = gen_digital_human_preprocess(
|
27 |
+
DIGITAL_HUMAN_HANDLER,
|
28 |
+
stream_id,
|
29 |
+
work_dir=str(Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH).absolute()),
|
30 |
+
video_path=video_path,
|
31 |
+
)
|
32 |
+
|
33 |
+
return res
|
server/digital_human/modules/musetalk/models/unet.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
|
6 |
+
from diffusers import UNet2DConditionModel
|
7 |
+
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model=384, max_len=5000):
|
10 |
+
super(PositionalEncoding, self).__init__()
|
11 |
+
pe = torch.zeros(max_len, d_model)
|
12 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
13 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
14 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
15 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
16 |
+
pe = pe.unsqueeze(0)
|
17 |
+
self.register_buffer('pe', pe)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
b, seq_len, d_model = x.size()
|
21 |
+
pe = self.pe[:, :seq_len, :]
|
22 |
+
x = x + pe.to(x.device)
|
23 |
+
return x
|
24 |
+
|
25 |
+
class UNet():
|
26 |
+
def __init__(self,
|
27 |
+
unet_config,
|
28 |
+
model_path,
|
29 |
+
use_float16=False,
|
30 |
+
):
|
31 |
+
with open(unet_config, 'r') as f:
|
32 |
+
unet_config = json.load(f)
|
33 |
+
self.model = UNet2DConditionModel(**unet_config)
|
34 |
+
self.pe = PositionalEncoding(d_model=384)
|
35 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
+
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
37 |
+
self.model.load_state_dict(weights)
|
38 |
+
if use_float16:
|
39 |
+
self.model = self.model.half()
|
40 |
+
self.model.to(self.device)
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
unet = UNet()
|
server/digital_human/modules/musetalk/models/vae.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from diffusers import AutoencoderKL
|
9 |
+
|
10 |
+
|
11 |
+
class VAE():
|
12 |
+
"""
|
13 |
+
VAE (Variational Autoencoder) class for image processing.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
|
17 |
+
"""
|
18 |
+
Initialize the VAE instance.
|
19 |
+
|
20 |
+
:param model_path: Path to the trained model.
|
21 |
+
:param resized_img: The size to which images are resized.
|
22 |
+
:param use_float16: Whether to use float16 precision.
|
23 |
+
"""
|
24 |
+
self.model_path = model_path
|
25 |
+
self.vae = AutoencoderKL.from_pretrained(self.model_path)
|
26 |
+
|
27 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
self.vae.to(self.device)
|
29 |
+
|
30 |
+
if use_float16:
|
31 |
+
self.vae = self.vae.half()
|
32 |
+
self._use_float16 = True
|
33 |
+
else:
|
34 |
+
self._use_float16 = False
|
35 |
+
|
36 |
+
self.scaling_factor = self.vae.config.scaling_factor
|
37 |
+
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
38 |
+
self._resized_img = resized_img
|
39 |
+
self._mask_tensor = self.get_mask_tensor()
|
40 |
+
|
41 |
+
def get_mask_tensor(self):
|
42 |
+
"""
|
43 |
+
Creates a mask tensor for image processing.
|
44 |
+
:return: A mask tensor.
|
45 |
+
"""
|
46 |
+
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
|
47 |
+
mask_tensor[:self._resized_img//2,:] = 1
|
48 |
+
mask_tensor[mask_tensor< 0.5] = 0
|
49 |
+
mask_tensor[mask_tensor>= 0.5] = 1
|
50 |
+
return mask_tensor
|
51 |
+
|
52 |
+
def preprocess_img(self,img_name,half_mask=False):
|
53 |
+
"""
|
54 |
+
Preprocess an image for the VAE.
|
55 |
+
|
56 |
+
:param img_name: The image file path or a list of image file paths.
|
57 |
+
:param half_mask: Whether to apply a half mask to the image.
|
58 |
+
:return: A preprocessed image tensor.
|
59 |
+
"""
|
60 |
+
window = []
|
61 |
+
if isinstance(img_name, str):
|
62 |
+
window_fnames = [img_name]
|
63 |
+
for fname in window_fnames:
|
64 |
+
img = cv2.imread(fname)
|
65 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
66 |
+
img = cv2.resize(img, (self._resized_img, self._resized_img),
|
67 |
+
interpolation=cv2.INTER_LANCZOS4)
|
68 |
+
window.append(img)
|
69 |
+
else:
|
70 |
+
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
71 |
+
window.append(img)
|
72 |
+
|
73 |
+
x = np.asarray(window) / 255.
|
74 |
+
x = np.transpose(x, (3, 0, 1, 2))
|
75 |
+
x = torch.squeeze(torch.FloatTensor(x))
|
76 |
+
if half_mask:
|
77 |
+
x = x * (self._mask_tensor>0.5)
|
78 |
+
x = self.transform(x)
|
79 |
+
|
80 |
+
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
81 |
+
x = x.to(self.vae.device)
|
82 |
+
|
83 |
+
return x
|
84 |
+
|
85 |
+
def encode_latents(self,image):
|
86 |
+
"""
|
87 |
+
Encode an image into latent variables.
|
88 |
+
|
89 |
+
:param image: The image tensor to encode.
|
90 |
+
:return: The encoded latent variables.
|
91 |
+
"""
|
92 |
+
with torch.no_grad():
|
93 |
+
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
|
94 |
+
init_latents = self.scaling_factor * init_latent_dist.sample()
|
95 |
+
return init_latents
|
96 |
+
|
97 |
+
def decode_latents(self, latents):
|
98 |
+
"""
|
99 |
+
Decode latent variables back into an image.
|
100 |
+
:param latents: The latent variables to decode.
|
101 |
+
:return: A NumPy array representing the decoded image.
|
102 |
+
"""
|
103 |
+
latents = (1/ self.scaling_factor) * latents
|
104 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
105 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
106 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
107 |
+
image = (image * 255).round().astype("uint8")
|
108 |
+
image = image[...,::-1] # RGB to BGR
|
109 |
+
return image
|
110 |
+
|
111 |
+
def get_latents_for_unet(self,img):
|
112 |
+
"""
|
113 |
+
Prepare latent variables for a U-Net model.
|
114 |
+
:param img: The image to process.
|
115 |
+
:return: A concatenated tensor of latents for U-Net input.
|
116 |
+
"""
|
117 |
+
|
118 |
+
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
119 |
+
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
120 |
+
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
121 |
+
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
122 |
+
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
123 |
+
return latent_model_input
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
vae_mode_path = "./models/sd-vae-ft-mse/"
|
127 |
+
vae = VAE(model_path = vae_mode_path,use_float16=False)
|
128 |
+
img_path = "./results/sun001_crop/00000.png"
|
129 |
+
|
130 |
+
crop_imgs_path = "./results/sun001_crop/"
|
131 |
+
latents_out_path = "./results/latents/"
|
132 |
+
if not os.path.exists(latents_out_path):
|
133 |
+
os.mkdir(latents_out_path)
|
134 |
+
|
135 |
+
files = os.listdir(crop_imgs_path)
|
136 |
+
files.sort()
|
137 |
+
files = [file for file in files if file.split(".")[-1] == "png"]
|
138 |
+
|
139 |
+
for file in files:
|
140 |
+
index = file.split(".")[0]
|
141 |
+
img_path = crop_imgs_path + file
|
142 |
+
latents = vae.get_latents_for_unet(img_path)
|
143 |
+
print(img_path,"latents",latents.size())
|
144 |
+
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
|
145 |
+
#reload_tensor = torch.load('tensor.pt')
|
146 |
+
#print(reload_tensor.size())
|
147 |
+
|
148 |
+
|
149 |
+
|
server/digital_human/modules/musetalk/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from os.path import abspath, dirname
|
3 |
+
current_dir = dirname(abspath(__file__))
|
4 |
+
parent_dir = dirname(current_dir)
|
5 |
+
sys.path.append(parent_dir+'/utils')
|
server/digital_human/modules/musetalk/utils/blending.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from face_parsing import FaceParsing
|
5 |
+
|
6 |
+
|
7 |
+
def init_face_parsing_model(
|
8 |
+
resnet_path="./models/face-parse-bisent/resnet18-5c106cde.pth", face_model_pth="./models/face-parse-bisent/79999_iter.pth"
|
9 |
+
):
|
10 |
+
fp_model = FaceParsing(resnet_path, face_model_pth)
|
11 |
+
return fp_model
|
12 |
+
|
13 |
+
|
14 |
+
def get_crop_box(box, expand):
|
15 |
+
x, y, x1, y1 = box
|
16 |
+
x_c, y_c = (x + x1) // 2, (y + y1) // 2
|
17 |
+
w, h = x1 - x, y1 - y
|
18 |
+
s = int(max(w, h) // 2 * expand)
|
19 |
+
crop_box = [x_c - s, y_c - s, x_c + s, y_c + s]
|
20 |
+
return crop_box, s
|
21 |
+
|
22 |
+
|
23 |
+
def face_seg(image, fp_model):
|
24 |
+
seg_image = fp_model(image)
|
25 |
+
if seg_image is None:
|
26 |
+
print("error, no person_segment")
|
27 |
+
return None
|
28 |
+
|
29 |
+
seg_image = seg_image.resize(image.size)
|
30 |
+
return seg_image
|
31 |
+
|
32 |
+
|
33 |
+
def get_image(image, face, face_box, fp_model, upper_boundary_ratio=0.5, expand=1.2):
|
34 |
+
# print(image.shape)
|
35 |
+
# print(face.shape)
|
36 |
+
|
37 |
+
body = Image.fromarray(image[:, :, ::-1])
|
38 |
+
face = Image.fromarray(face[:, :, ::-1])
|
39 |
+
|
40 |
+
x, y, x1, y1 = face_box
|
41 |
+
# print(x1-x,y1-y)
|
42 |
+
crop_box, s = get_crop_box(face_box, expand)
|
43 |
+
x_s, y_s, x_e, y_e = crop_box
|
44 |
+
face_position = (x, y)
|
45 |
+
|
46 |
+
face_large = body.crop(crop_box)
|
47 |
+
ori_shape = face_large.size
|
48 |
+
|
49 |
+
mask_image = face_seg(face_large, fp_model)
|
50 |
+
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
51 |
+
mask_image = Image.new("L", ori_shape, 0)
|
52 |
+
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
53 |
+
|
54 |
+
# keep upper_boundary_ratio of talking area
|
55 |
+
width, height = mask_image.size
|
56 |
+
top_boundary = int(height * upper_boundary_ratio)
|
57 |
+
modified_mask_image = Image.new("L", ori_shape, 0)
|
58 |
+
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
59 |
+
|
60 |
+
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
61 |
+
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
62 |
+
mask_image = Image.fromarray(mask_array)
|
63 |
+
|
64 |
+
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
65 |
+
body.paste(face_large, crop_box[:2], mask_image)
|
66 |
+
body = np.array(body)
|
67 |
+
return body[:, :, ::-1]
|
68 |
+
|
69 |
+
|
70 |
+
def get_image_prepare_material(image, face_box, fp_model, upper_boundary_ratio=0.5, expand=1.2):
|
71 |
+
body = Image.fromarray(image[:, :, ::-1])
|
72 |
+
|
73 |
+
x, y, x1, y1 = face_box
|
74 |
+
# print(x1-x,y1-y)
|
75 |
+
crop_box, s = get_crop_box(face_box, expand)
|
76 |
+
x_s, y_s, x_e, y_e = crop_box
|
77 |
+
|
78 |
+
face_large = body.crop(crop_box)
|
79 |
+
ori_shape = face_large.size
|
80 |
+
|
81 |
+
mask_image = face_seg(face_large, fp_model)
|
82 |
+
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
83 |
+
mask_image = Image.new("L", ori_shape, 0)
|
84 |
+
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
85 |
+
|
86 |
+
# keep upper_boundary_ratio of talking area
|
87 |
+
width, height = mask_image.size
|
88 |
+
top_boundary = int(height * upper_boundary_ratio)
|
89 |
+
modified_mask_image = Image.new("L", ori_shape, 0)
|
90 |
+
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
91 |
+
|
92 |
+
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
93 |
+
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
94 |
+
return mask_array, crop_box
|
95 |
+
|
96 |
+
|
97 |
+
def get_image_blending(image, face, face_box, mask_array, crop_box):
|
98 |
+
body = Image.fromarray(image[:, :, ::-1])
|
99 |
+
face = Image.fromarray(face[:, :, ::-1])
|
100 |
+
|
101 |
+
x, y, x1, y1 = face_box
|
102 |
+
x_s, y_s, x_e, y_e = crop_box
|
103 |
+
face_large = body.crop(crop_box)
|
104 |
+
|
105 |
+
mask_image = Image.fromarray(mask_array)
|
106 |
+
mask_image = mask_image.convert("L")
|
107 |
+
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
108 |
+
body.paste(face_large, crop_box[:2], mask_image)
|
109 |
+
body = np.array(body)
|
110 |
+
return body[:, :, ::-1]
|
server/digital_human/modules/musetalk/utils/dwpose/default_runtime.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_scope = 'mmpose'
|
2 |
+
|
3 |
+
# hooks
|
4 |
+
default_hooks = dict(
|
5 |
+
timer=dict(type='IterTimerHook'),
|
6 |
+
logger=dict(type='LoggerHook', interval=50),
|
7 |
+
param_scheduler=dict(type='ParamSchedulerHook'),
|
8 |
+
checkpoint=dict(type='CheckpointHook', interval=10),
|
9 |
+
sampler_seed=dict(type='DistSamplerSeedHook'),
|
10 |
+
visualization=dict(type='PoseVisualizationHook', enable=False),
|
11 |
+
badcase=dict(
|
12 |
+
type='BadCaseAnalysisHook',
|
13 |
+
enable=False,
|
14 |
+
out_dir='badcase',
|
15 |
+
metric_type='loss',
|
16 |
+
badcase_thr=5))
|
17 |
+
|
18 |
+
# custom hooks
|
19 |
+
custom_hooks = [
|
20 |
+
# Synchronize model buffers such as running_mean and running_var in BN
|
21 |
+
# at the end of each epoch
|
22 |
+
dict(type='SyncBuffersHook')
|
23 |
+
]
|
24 |
+
|
25 |
+
# multi-processing backend
|
26 |
+
env_cfg = dict(
|
27 |
+
cudnn_benchmark=False,
|
28 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
29 |
+
dist_cfg=dict(backend='nccl'),
|
30 |
+
)
|
31 |
+
|
32 |
+
# visualizer
|
33 |
+
vis_backends = [
|
34 |
+
dict(type='LocalVisBackend'),
|
35 |
+
# dict(type='TensorboardVisBackend'),
|
36 |
+
# dict(type='WandbVisBackend'),
|
37 |
+
]
|
38 |
+
visualizer = dict(
|
39 |
+
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
40 |
+
|
41 |
+
# logger
|
42 |
+
log_processor = dict(
|
43 |
+
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
44 |
+
log_level = 'INFO'
|
45 |
+
load_from = None
|
46 |
+
resume = False
|
47 |
+
|
48 |
+
# file I/O backend
|
49 |
+
backend_args = dict(backend='local')
|
50 |
+
|
51 |
+
# training/validation/testing progress
|
52 |
+
train_cfg = dict(by_epoch=True)
|
53 |
+
val_cfg = dict()
|
54 |
+
test_cfg = dict()
|
server/digital_human/modules/musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#_base_ = ['../../../_base_/default_runtime.py']
|
2 |
+
_base_ = ['default_runtime.py']
|
3 |
+
|
4 |
+
# runtime
|
5 |
+
max_epochs = 270
|
6 |
+
stage2_num_epochs = 30
|
7 |
+
base_lr = 4e-3
|
8 |
+
train_batch_size = 8
|
9 |
+
val_batch_size = 8
|
10 |
+
|
11 |
+
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
12 |
+
randomness = dict(seed=21)
|
13 |
+
|
14 |
+
# optimizer
|
15 |
+
optim_wrapper = dict(
|
16 |
+
type='OptimWrapper',
|
17 |
+
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
18 |
+
paramwise_cfg=dict(
|
19 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
20 |
+
|
21 |
+
# learning rate
|
22 |
+
param_scheduler = [
|
23 |
+
dict(
|
24 |
+
type='LinearLR',
|
25 |
+
start_factor=1.0e-5,
|
26 |
+
by_epoch=False,
|
27 |
+
begin=0,
|
28 |
+
end=1000),
|
29 |
+
dict(
|
30 |
+
# use cosine lr from 150 to 300 epoch
|
31 |
+
type='CosineAnnealingLR',
|
32 |
+
eta_min=base_lr * 0.05,
|
33 |
+
begin=max_epochs // 2,
|
34 |
+
end=max_epochs,
|
35 |
+
T_max=max_epochs // 2,
|
36 |
+
by_epoch=True,
|
37 |
+
convert_to_iter_based=True),
|
38 |
+
]
|
39 |
+
|
40 |
+
# automatically scaling LR based on the actual training batch size
|
41 |
+
auto_scale_lr = dict(base_batch_size=512)
|
42 |
+
|
43 |
+
# codec settings
|
44 |
+
codec = dict(
|
45 |
+
type='SimCCLabel',
|
46 |
+
input_size=(288, 384),
|
47 |
+
sigma=(6., 6.93),
|
48 |
+
simcc_split_ratio=2.0,
|
49 |
+
normalize=False,
|
50 |
+
use_dark=False)
|
51 |
+
|
52 |
+
# model settings
|
53 |
+
model = dict(
|
54 |
+
type='TopdownPoseEstimator',
|
55 |
+
data_preprocessor=dict(
|
56 |
+
type='PoseDataPreprocessor',
|
57 |
+
mean=[123.675, 116.28, 103.53],
|
58 |
+
std=[58.395, 57.12, 57.375],
|
59 |
+
bgr_to_rgb=True),
|
60 |
+
backbone=dict(
|
61 |
+
_scope_='mmdet',
|
62 |
+
type='CSPNeXt',
|
63 |
+
arch='P5',
|
64 |
+
expand_ratio=0.5,
|
65 |
+
deepen_factor=1.,
|
66 |
+
widen_factor=1.,
|
67 |
+
out_indices=(4, ),
|
68 |
+
channel_attention=True,
|
69 |
+
norm_cfg=dict(type='SyncBN'),
|
70 |
+
act_cfg=dict(type='SiLU'),
|
71 |
+
init_cfg=dict(
|
72 |
+
type='Pretrained',
|
73 |
+
prefix='backbone.',
|
74 |
+
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
75 |
+
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
|
76 |
+
)),
|
77 |
+
head=dict(
|
78 |
+
type='RTMCCHead',
|
79 |
+
in_channels=1024,
|
80 |
+
out_channels=133,
|
81 |
+
input_size=codec['input_size'],
|
82 |
+
in_featuremap_size=(9, 12),
|
83 |
+
simcc_split_ratio=codec['simcc_split_ratio'],
|
84 |
+
final_layer_kernel_size=7,
|
85 |
+
gau_cfg=dict(
|
86 |
+
hidden_dims=256,
|
87 |
+
s=128,
|
88 |
+
expansion_factor=2,
|
89 |
+
dropout_rate=0.,
|
90 |
+
drop_path=0.,
|
91 |
+
act_fn='SiLU',
|
92 |
+
use_rel_bias=False,
|
93 |
+
pos_enc=False),
|
94 |
+
loss=dict(
|
95 |
+
type='KLDiscretLoss',
|
96 |
+
use_target_weight=True,
|
97 |
+
beta=10.,
|
98 |
+
label_softmax=True),
|
99 |
+
decoder=codec),
|
100 |
+
test_cfg=dict(flip_test=True, ))
|
101 |
+
|
102 |
+
# base dataset settings
|
103 |
+
dataset_type = 'UBody2dDataset'
|
104 |
+
data_mode = 'topdown'
|
105 |
+
data_root = 'data/UBody/'
|
106 |
+
|
107 |
+
backend_args = dict(backend='local')
|
108 |
+
|
109 |
+
scenes = [
|
110 |
+
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
|
111 |
+
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
|
112 |
+
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
|
113 |
+
]
|
114 |
+
|
115 |
+
train_datasets = [
|
116 |
+
dict(
|
117 |
+
type='CocoWholeBodyDataset',
|
118 |
+
data_root='data/coco/',
|
119 |
+
data_mode=data_mode,
|
120 |
+
ann_file='annotations/coco_wholebody_train_v1.0.json',
|
121 |
+
data_prefix=dict(img='train2017/'),
|
122 |
+
pipeline=[])
|
123 |
+
]
|
124 |
+
|
125 |
+
for scene in scenes:
|
126 |
+
train_dataset = dict(
|
127 |
+
type=dataset_type,
|
128 |
+
data_root=data_root,
|
129 |
+
data_mode=data_mode,
|
130 |
+
ann_file=f'annotations/{scene}/train_annotations.json',
|
131 |
+
data_prefix=dict(img='images/'),
|
132 |
+
pipeline=[],
|
133 |
+
sample_interval=10)
|
134 |
+
train_datasets.append(train_dataset)
|
135 |
+
|
136 |
+
# pipelines
|
137 |
+
train_pipeline = [
|
138 |
+
dict(type='LoadImage', backend_args=backend_args),
|
139 |
+
dict(type='GetBBoxCenterScale'),
|
140 |
+
dict(type='RandomFlip', direction='horizontal'),
|
141 |
+
dict(type='RandomHalfBody'),
|
142 |
+
dict(
|
143 |
+
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
|
144 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
145 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
146 |
+
dict(
|
147 |
+
type='Albumentation',
|
148 |
+
transforms=[
|
149 |
+
dict(type='Blur', p=0.1),
|
150 |
+
dict(type='MedianBlur', p=0.1),
|
151 |
+
dict(
|
152 |
+
type='CoarseDropout',
|
153 |
+
max_holes=1,
|
154 |
+
max_height=0.4,
|
155 |
+
max_width=0.4,
|
156 |
+
min_holes=1,
|
157 |
+
min_height=0.2,
|
158 |
+
min_width=0.2,
|
159 |
+
p=1.0),
|
160 |
+
]),
|
161 |
+
dict(type='GenerateTarget', encoder=codec),
|
162 |
+
dict(type='PackPoseInputs')
|
163 |
+
]
|
164 |
+
val_pipeline = [
|
165 |
+
dict(type='LoadImage', backend_args=backend_args),
|
166 |
+
dict(type='GetBBoxCenterScale'),
|
167 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
168 |
+
dict(type='PackPoseInputs')
|
169 |
+
]
|
170 |
+
|
171 |
+
train_pipeline_stage2 = [
|
172 |
+
dict(type='LoadImage', backend_args=backend_args),
|
173 |
+
dict(type='GetBBoxCenterScale'),
|
174 |
+
dict(type='RandomFlip', direction='horizontal'),
|
175 |
+
dict(type='RandomHalfBody'),
|
176 |
+
dict(
|
177 |
+
type='RandomBBoxTransform',
|
178 |
+
shift_factor=0.,
|
179 |
+
scale_factor=[0.5, 1.5],
|
180 |
+
rotate_factor=90),
|
181 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
182 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
183 |
+
dict(
|
184 |
+
type='Albumentation',
|
185 |
+
transforms=[
|
186 |
+
dict(type='Blur', p=0.1),
|
187 |
+
dict(type='MedianBlur', p=0.1),
|
188 |
+
dict(
|
189 |
+
type='CoarseDropout',
|
190 |
+
max_holes=1,
|
191 |
+
max_height=0.4,
|
192 |
+
max_width=0.4,
|
193 |
+
min_holes=1,
|
194 |
+
min_height=0.2,
|
195 |
+
min_width=0.2,
|
196 |
+
p=0.5),
|
197 |
+
]),
|
198 |
+
dict(type='GenerateTarget', encoder=codec),
|
199 |
+
dict(type='PackPoseInputs')
|
200 |
+
]
|
201 |
+
|
202 |
+
# data loaders
|
203 |
+
train_dataloader = dict(
|
204 |
+
batch_size=train_batch_size,
|
205 |
+
num_workers=10,
|
206 |
+
persistent_workers=True,
|
207 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
208 |
+
dataset=dict(
|
209 |
+
type='CombinedDataset',
|
210 |
+
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
211 |
+
datasets=train_datasets,
|
212 |
+
pipeline=train_pipeline,
|
213 |
+
test_mode=False,
|
214 |
+
))
|
215 |
+
|
216 |
+
val_dataloader = dict(
|
217 |
+
batch_size=val_batch_size,
|
218 |
+
num_workers=10,
|
219 |
+
persistent_workers=True,
|
220 |
+
drop_last=False,
|
221 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
222 |
+
dataset=dict(
|
223 |
+
type='CocoWholeBodyDataset',
|
224 |
+
data_root=data_root,
|
225 |
+
data_mode=data_mode,
|
226 |
+
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
|
227 |
+
bbox_file='data/coco/person_detection_results/'
|
228 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
229 |
+
data_prefix=dict(img='coco/val2017/'),
|
230 |
+
test_mode=True,
|
231 |
+
pipeline=val_pipeline,
|
232 |
+
))
|
233 |
+
test_dataloader = val_dataloader
|
234 |
+
|
235 |
+
# hooks
|
236 |
+
default_hooks = dict(
|
237 |
+
checkpoint=dict(
|
238 |
+
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
239 |
+
|
240 |
+
custom_hooks = [
|
241 |
+
dict(
|
242 |
+
type='EMAHook',
|
243 |
+
ema_type='ExpMomentumEMA',
|
244 |
+
momentum=0.0002,
|
245 |
+
update_buffers=True,
|
246 |
+
priority=49),
|
247 |
+
dict(
|
248 |
+
type='mmdet.PipelineSwitchHook',
|
249 |
+
switch_epoch=max_epochs - stage2_num_epochs,
|
250 |
+
switch_pipeline=train_pipeline_stage2)
|
251 |
+
]
|
252 |
+
|
253 |
+
# evaluators
|
254 |
+
val_evaluator = dict(
|
255 |
+
type='CocoWholeBodyMetric',
|
256 |
+
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
|
257 |
+
test_evaluator = val_evaluator
|
server/digital_human/modules/musetalk/utils/face_detection/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|