""" |
@File : utils.py |
@Time : 2024/09/02 |
@Project : https://github.com/PeterH0323/Streamer-Sales |
@Author : HinGwenWong |
@Version : 1.0 |
@Desc : 工具集合文件 |
""" |
import asyncio |
from ipaddress import IPv4Address |
import json |
import random |
import wave |
from dataclasses import dataclass |
from datetime import datetime |
from pathlib import Path |
from typing import Dict, List |
import cv2 |
from lmdeploy.serve.openai.api_client import APIClient |
from loguru import logger |
from pydantic import BaseModel |
from sqlmodel import Session, select |
from tqdm import tqdm |
from server.base.models.user_model import UserInfo |
from ..tts.tools import SYMBOL_SPLITS, make_text_chunk |
from ..web_configs import API_CONFIG, WEB_CONFIGS |
from .database.init_db import DB_ENGINE |
from .models.product_model import ProductInfo |
from .models.streamer_info_model import StreamerInfo |
from .models.streamer_room_model import OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo |
from .modules.agent.agent_worker import get_agent_result |
from .modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt |
from .queue_thread import DIGITAL_HUMAN_QUENE, TTS_TEXT_QUENE |
from .server_info import SERVER_PLUGINS_INFO |
class ChatGenConfig(BaseModel): |
top_p: float = 0.8 |
temperature: float = 0.7 |
repetition_penalty: float = 1.005 |
class ProductInfoItem(BaseModel): |
name: str |
heighlights: str |
introduce: str |
image_path: str |
departure_place: str |
delivery_company_name: str |
class PluginsInfo(BaseModel): |
rag: bool = True |
agent: bool = True |
tts: bool = True |
digital_human: bool = True |
class ChatItem(BaseModel): |
user_id: str |
request_id: str |
prompt: List[Dict[str, str]] |
product_info: ProductInfoItem |
plugins: PluginsInfo = PluginsInfo() |
chat_config: ChatGenConfig = ChatGenConfig() |
async def streamer_sales_process(chat_item: ChatItem): |
agent_response = "" |
if chat_item.plugins.agent and SERVER_PLUGINS_INFO.agent_enabled: |
"这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。" |
) |
input_prompt = chat_item.prompt[-1]["content"] |
agent_response = get_agent_result( |
LLM_MODEL_HANDLER, input_prompt, chat_item.product_info.departure_place, chat_item.product_info.delivery_company_name |
) |
if agent_response != "": |
agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt) |
print(f"Agent response: {agent_response}") |
chat_item.prompt[-1]["content"] = agent_response |
if chat_item.plugins.rag and agent_response == "": |
rag_prompt = chat_item.prompt[-1]["content"] |
prompt_pro = build_rag_prompt(RAG_RETRIEVER, chat_item.product_info.name, rag_prompt) |
if prompt_pro != "": |
chat_item.prompt[-1]["content"] = prompt_pro |
logger.info(chat_item.prompt) |
current_predict = "" |
idx = 0 |
last_text_index = 0 |
sentence_id = 0 |
model_name = LLM_MODEL_HANDLER.available_models[0] |
for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=chat_item.prompt, stream=True): |
logger.debug(f"LLM predict: {item}") |
if "content" not in item["choices"][0]["delta"]: |
continue |
current_res = item["choices"][0]["delta"]["content"] |
if "~" in current_res: |
current_res = current_res.replace("~", "。").replace("。。", "。") |
current_predict += current_res |
idx += 1 |
if chat_item.plugins.tts and SERVER_PLUGINS_INFO.tts_server_enabled: |
sentence = "" |
for symbol in SYMBOL_SPLITS: |
if symbol in current_res: |
last_text_index, sentence = make_text_chunk(current_predict, last_text_index) |
if len(sentence) <= 3: |
sentence = "" |
break |
if sentence != "": |
sentence_id += 1 |
logger.info(f"get sentence: {sentence}") |
tts_request_dict = { |
"user_id": chat_item.user_id, |
"request_id": chat_item.request_id, |
"sentence": sentence, |
"chunk_id": sentence_id, |
} |
TTS_TEXT_QUENE.put(tts_request_dict) |
await asyncio.sleep(0.01) |
yield json.dumps( |
{ |
"event": "message", |
"retry": 100, |
"id": idx, |
"data": current_predict, |
"step": "llm", |
"end_flag": False, |
}, |
ensure_ascii=False, |
) |
await asyncio.sleep(0.01) |
if chat_item.plugins.digital_human and SERVER_PLUGINS_INFO.digital_human_server_enabled: |
wav_list = [ |
Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + f"-{str(i).zfill(8)}.wav") |
for i in range(1, sentence_id + 1) |
] |
while True: |
not_exist_count = 0 |
for tts_wav in wav_list: |
if not tts_wav.exists(): |
not_exist_count += 1 |
logger.info(f"still need to wait for {not_exist_count}/{sentence_id} wav generating...") |
if not_exist_count == 0: |
break |
yield json.dumps( |
{ |
"event": "message", |
"retry": 100, |
"id": idx, |
"data": current_predict, |
"step": "tts", |
"end_flag": False, |
}, |
ensure_ascii=False, |
) |
await asyncio.sleep(1) |
tts_save_path = Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + ".wav") |
all_tts_data = [] |
for wav_file in tqdm(wav_list): |
logger.info(f"Reading wav file {wav_file}...") |
with wave.open(str(wav_file), "rb") as wf: |
all_tts_data.append([wf.getparams(), wf.readframes(wf.getnframes())]) |
logger.info(f"Merging wav file to {tts_save_path}...") |
tts_params = max([tts_data[0] for tts_data in all_tts_data]) |
with wave.open(str(tts_save_path), "wb") as wf: |
wf.setparams(tts_params) |
for wf_data in all_tts_data: |
wf.writeframes(wf_data[1]) |
logger.info(f"Merged wav file to {tts_save_path} !") |
tts_request_dict = { |
"user_id": chat_item.user_id, |
"request_id": chat_item.request_id, |
"chunk_id": 0, |
"tts_path": str(tts_save_path), |
} |
logger.info(f"Generating digital human...") |
DIGITAL_HUMAN_QUENE.put(tts_request_dict) |
while True: |
if ( |
.joinpath(Path(tts_save_path).stem + ".mp4") |
.with_suffix(".txt") |
.exists() |
): |
break |
yield json.dumps( |
{ |
"event": "message", |
"retry": 100, |
"id": idx, |
"data": current_predict, |
"step": "dg", |
"end_flag": False, |
}, |
ensure_ascii=False, |
) |
await asyncio.sleep(1) |
for wav_file in wav_list: |
wav_file.unlink() |
yield json.dumps( |
{ |
"event": "message", |
"retry": 100, |
"id": idx, |
"data": current_predict, |
"step": "all", |
"end_flag": True, |
}, |
ensure_ascii=False, |
) |
def make_poster_by_video_first_frame(video_path: str, image_output_name: str): |
"""根据视频第一帧生成缩略图 |
Args: |
video_path (str): 视频文件路径 |
Returns: |
str: 第一帧保存的图片路径 |
""" |
cap = cv2.VideoCapture(video_path) |
ret, frame = cap.read() |
poster_save_path = str(Path(video_path).parent.joinpath(image_output_name)) |
if ret: |
cv2.imwrite(poster_save_path, frame) |
logger.info(f"第一帧已保存为 {poster_save_path}") |
else: |
logger.error("无法读取视频帧") |
cap.release() |
return poster_save_path |
@dataclass |
class ResultCode: |
SUCCESS: int = 0000 |
FAIL: int = 1000 |
def make_return_data(success_flag: bool, code: ResultCode, message: str, data: dict): |
return { |
"success": success_flag, |
"code": code, |
"message": message, |
"data": data, |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
} |
def gen_default_data(): |
"""生成默认数据,包括: |
- 商品数据 |
- 主播数据 |
- 直播间信息以及关联表 |
""" |
def create_default_user(): |
"""创建默认用户""" |
admin_user = UserInfo( |
username="hingwen.wong", |
ip_address=IPv4Address(""), |
email="[email protected]", |
hashed_password="$2b$12$zXXveodjipHZMoSxJz5ODul7Z9YeRJd0GeSBjpwHdqEtBbAFvEdre", |
avatar="/user/user-avatar.png", |
) |
with Session(DB_ENGINE) as session: |
session.add(admin_user) |
session.commit() |
def init_user() -> bool: |
"""判断是否需要创建默认用户 |
Returns: |
bool: 是否执行创建默认用户 |
""" |
with Session(DB_ENGINE) as session: |
results = session.exec(select(UserInfo).where(UserInfo.user_id == 1)).first() |
if results is None: |
create_default_user() |
logger.info("created default user info") |
return True |
return False |
def create_default_product_item(): |
"""生成商品默认数据库""" |
delivery_company_list = ["京东", "顺丰", "韵达", "圆通", "中通"] |
departure_place_list = ["广州", "北京", "武汉", "杭州", "上海", "深圳", "成都"] |
default_product_list = { |
"beef": { |
"product_name": "进口和牛羽下肉", |
"heighlights": "富含铁质;营养价值高;肌肉纤维好;红白相间纹理;适合烧烤炖煮;草食动物来源", |
"product_class": "食品", |
}, |
"elec_toothblush": { |
"product_name": "声波电动牙刷", |
"heighlights": "高效清洁;减少手动压力;定时提醒;智能模式调节;无线充电;噪音低", |
"product_class": "电子", |
}, |
"lip_stick": { |
"product_name": "唇膏", |
"heighlights": "丰富色号;滋润保湿;显色度高;持久不脱色;易于涂抹;便携包装", |
"product_class": "美妆", |
}, |
"mask": { |
"product_name": "光感润颜面膜", |
"heighlights": "密集滋养;深层补水;急救修复;快速见效;定期护理;多种类型选择", |
"product_class": "美妆", |
}, |
"oled_tv": { |
"product_name": "65英寸OLED电视", |
"heighlights": "色彩鲜艳;对比度极高;响应速度快;无背光眩光;厚度较薄;自发光无需额外照明", |
"product_class": "家电", |
}, |
"pad": { |
"product_name": "14英寸平板电脑", |
"heighlights": "轻薄;触控操作;电池续航好;移动办公便利;娱乐性强;适合儿童学习", |
"product_class": "电子", |
}, |
"pants": { |
"product_name": "速干运动裤", |
"heighlights": "快干;伸缩自如;吸湿排汗;防风保暖;高腰设计;多口袋实用", |
"product_class": "衣服", |
}, |
"pen": { |
"product_name": "墨水钢笔", |
"heighlights": "耐用性;可书写性;不同颜色和类型;轻便设计;环保材料;易于携带", |
"product_class": "文具", |
}, |
"perfume": { |
"product_name": "薰衣草淡香氛", |
"heighlights": "浪漫优雅;花香调为主;情感表达;适合各种年龄;瓶身设计精致;提升女性魅力", |
"product_class": "家居用品", |
}, |
"shampoo": { |
"product_name": "本草精华洗发露", |
"heighlights": "温和配方;深层清洁;滋养头皮;丰富泡沫;易冲洗;适合各种发质", |
"product_class": "日用品", |
}, |
"wok": { |
"product_name": "不粘煎炒锅", |
"heighlights": "不粘涂层;耐磨耐用;导热快;易清洗;多种烹饪方式;设计人性化", |
"product_class": "厨具", |
}, |
"yoga_mat": { |
"product_name": "瑜伽垫", |
"heighlights": "防滑材质;吸湿排汗;厚度适中;耐用易清洁;各种瑜伽动作适用;轻巧便携", |
"product_class": "运动", |
}, |
} |
with Session(DB_ENGINE) as session: |
for product_key, product_info in default_product_list.items(): |
add_item = ProductInfo( |
**product_info, |
image_path=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.IMAGES_DIR}/{product_key}.png", |
departure_place=random.choice(departure_place_list), |
delivery_company=random.choice(delivery_company_list), |
selling_price=round(random.uniform(66.6, 1999.9), 2), |
amount=random.randint(999, 9999), |
user_id=1, |
) |
session.add(add_item) |
session.commit() |
logger.info("created default product info done!") |
def create_default_streamer(): |
with Session(DB_ENGINE) as session: |
streamer_item = StreamerInfo( |
name="乐乐喵", |
character="甜美;可爱;熟练使用各种网络热门梗造句;称呼客户为[家人们]", |
tts_reference_sentence="列车巡游银河,我不一定都能帮上忙,但只要是花钱能解决的事,尽管和我说吧。", |
tts_weight_tag="艾丝妲", |
user_id=1, |
) |
session.add(streamer_item) |
session.commit() |
def create_default_room(): |
with Session(DB_ENGINE) as session: |
product_list = session.exec( |
select(ProductInfo).where(ProductInfo.user_id == 1).order_by(ProductInfo.product_id) |
).all() |
on_air_status = OnAirRoomStatusItem(user_id=1) |
session.add(on_air_status) |
session.commit() |
session.refresh(on_air_status) |
stream_item = StreamRoomInfo( |
name="001", |
user_id=1, |
status_id=on_air_status.status_id, |
streamer_id=1, |
) |
session.add(stream_item) |
session.commit() |
session.refresh(stream_item) |
random_list = random.choices(product_list, k=3) |
for product_random in random_list: |
add_sales_info = SalesDocAndVideoInfo(product_id=product_random.product_id, room_id=stream_item.room_id) |
session.add(add_sales_info) |
session.commit() |
session.refresh(add_sales_info) |
created = init_user() |
if created: |
create_default_product_item() |
create_default_streamer() |
create_default_room() |