FAYO
model
1ef9436
raw
history blame
12.4 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : streaming_room.py
@Time : 2024/08/31
@Project : https://github.com/PeterH0323/Streamer-Sales
@Author : HinGwenWong
@Version : 1.0
@Desc : 主播间信息交互接口
"""
import uuid
from copy import deepcopy
from pathlib import Path
import requests
from fastapi import APIRouter, Depends
from loguru import logger
from ...web_configs import API_CONFIG, WEB_CONFIGS
from ..database.product_db import get_db_product_info
from ..database.streamer_room_db import (
create_or_update_db_room_by_id,
get_live_room_info,
get_message_list,
update_db_room_status,
delete_room_id,
get_db_streaming_room_info,
update_message_info,
update_room_video_path,
)
from ..models.product_model import ProductInfo
from ..models.streamer_room_model import OnAirRoomStatusItem, RoomChatItem, SalesDocAndVideoInfo, StreamRoomInfo
from ..modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
from ..routers.users import get_current_user_info
from ..server_info import SERVER_PLUGINS_INFO
from ..utils import ResultCode, make_return_data
from .digital_human import gen_tts_and_digital_human_video_app
from .llm import combine_history, gen_poduct_base_prompt, get_agent_res, get_llm_res
router = APIRouter(
prefix="/streaming-room",
tags=["streaming-room"],
responses={404: {"description": "Not found"}},
)
@router.get("/list", summary="获取所有直播间信息接口")
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)):
"""获取所有直播间信息"""
# 加载直播间数据
streaming_room_list = await get_db_streaming_room_info(user_id)
for i in range(len(streaming_room_list)):
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
streaming_room_list[i] = dict(streaming_room_list[i])
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
@router.get("/info/{roomId}", summary="获取特定直播间信息接口")
async def get_streaming_room_id_api(
roomId: int, currentPage: int = 1, pageSize: int = 5, user_id: int = Depends(get_current_user_info)
):
"""获取特定直播间信息"""
# 加载直播间配置文件
assert roomId != 0
# TODO 加入分页
# 加载直播间数据
streaming_room_list = await get_db_streaming_room_info(user_id, room_id=roomId)
if len(streaming_room_list) == 1:
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
format_product_list = []
for db_product in streaming_room_list[0].product_list:
product_dict = dict(db_product)
# 将 start_video 改为服务器地址
if product_dict["start_video"] != "":
product_dict["start_video"] = API_CONFIG.REQUEST_FILES_URL + product_dict["start_video"]
format_product_list.append(product_dict)
streaming_room_list = dict(streaming_room_list[0])
streaming_room_list["product_list"] = format_product_list
else:
streaming_room_list = []
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list)
@router.get("/product-edit-list/{roomId}", summary="获取直播间商品编辑列表,含有已选中的标识")
async def get_streaming_room_product_list_api(
roomId: int, currentPage: int = 1, pageSize: int = 0, user_id: int = Depends(get_current_user_info)
):
"""获取直播间商品编辑列表,含有已选中的标识"""
# 获取目前直播间商品列表
if roomId == 0:
# 新的直播间
merge_list = []
exclude_list = []
else:
streaming_room_info = await get_db_streaming_room_info(user_id, roomId)
if len(streaming_room_info) == 0:
raise "401"
streaming_room_info = streaming_room_info[0]
# 获取未被选中的商品
exclude_list = [product.product_id for product in streaming_room_info.product_list]
merge_list = deepcopy(streaming_room_info.product_list)
# 获取未选中的商品信息
not_select_product_list, db_product_size = await get_db_product_info(user_id=user_id, exclude_list=exclude_list)
# 合并商品信息
for not_select_product in not_select_product_list:
merge_list.append(
SalesDocAndVideoInfo(
product_id=not_select_product.product_id,
product_info=ProductInfo(**dict(not_select_product)),
selected=False,
)
)
# TODO 懒加载分页
# 格式化
format_merge_list = []
for product in merge_list:
# 直接返回会导致字段丢失,需要转 dict 确保返回值里面有该字段
dict_info = dict(product)
if "stream_room" in dict_info:
dict_info.pop("stream_room")
format_merge_list.append(dict_info)
page_info = dict(
product_list=format_merge_list,
current=currentPage,
pageSize=db_product_size,
totalSize=db_product_size,
)
logger.info(page_info)
return make_return_data(True, ResultCode.SUCCESS, "成功", page_info)
@router.post("/create", summary="新增直播间接口")
async def streaming_room_edit_api(edit_item: dict, user_id: int = Depends(get_current_user_info)):
product_list = edit_item.pop("product_list")
status = edit_item.pop("status")
edit_item.pop("streamer_info")
edit_item.pop("room_id")
if "status_id" in edit_item:
edit_item.pop("status_id")
formate_product_list = []
for product in product_list:
if not product["selected"]:
continue
product.pop("product_info")
product_item = SalesDocAndVideoInfo(**product)
formate_product_list.append(product_item)
edit_item["user_id"] = user_id
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
room_id = create_or_update_db_room_by_id(0, formate_info, user_id)
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
@router.put("/edit/{room_id}", summary="编辑直播间接口")
async def streaming_room_edit_api(room_id: int, edit_item: dict, user_id: int = Depends(get_current_user_info)):
"""编辑直播间接口
Args:
edit_item (StreamRoomInfo): _description_
"""
product_list = edit_item.pop("product_list")
status = edit_item.pop("status")
edit_item.pop("streamer_info")
formate_product_list = []
for product in product_list:
if not product["selected"]:
continue
product.pop("product_info")
product_item = SalesDocAndVideoInfo(**product)
formate_product_list.append(product_item)
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status))
create_or_update_db_room_by_id(room_id, formate_info, user_id)
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id)
@router.delete("/delete/{roomId}", summary="删除直播间接口")
async def delete_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
process_success_flag = await delete_room_id(roomId, user_id)
if not process_success_flag:
return make_return_data(False, ResultCode.FAIL, "失败", "")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
# ============================================================
# 开播接口
# ============================================================
@router.post("/online/{roomId}", summary="直播间开播接口")
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
update_db_room_status(roomId, user_id, "online")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.put("/offline/{roomId}", summary="直播间下播接口")
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)):
update_db_room_status(roomId, user_id, "offline")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.post("/next-product/{roomId}", summary="直播间进行下一个商品讲解接口")
async def on_air_live_room_next_product_api(roomId: int, user_id: int = Depends(get_current_user_info)):
"""直播间进行下一个商品讲解
Args:
roomId (int): 直播间 ID
"""
update_db_room_status(roomId, user_id, "next-product")
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.get("/live-info/{roomId}", summary="获取正在直播的直播间信息接口")
async def get_on_air_live_room_api(roomId: int, user_id: int = Depends(get_current_user_info)):
"""获取正在直播的直播间信息
1. 主播视频地址
2. 商品信息,显示在右下角的商品缩略图
3. 对话记录 conversation_list
Args:
roomId (int): 直播间 ID
"""
res_data = await get_live_room_info(user_id, roomId)
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data)
@router.put("/chat", summary="直播间对话接口")
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
# 根据直播间 ID 获取信息
streaming_room_info = await get_db_streaming_room_info(user_id, room_chat.roomId)
streaming_room_info = streaming_room_info[0]
# 商品索引
product_detail = streaming_room_info.product_list[streaming_room_info.status.current_product_index].product_info
# 销售 ID
sales_info_id = streaming_room_info.product_list[streaming_room_info.status.current_product_index].sales_info_id
# 更新对话记录
update_message_info(sales_info_id, user_id, role="user", message=room_chat.message)
# 获取最新的对话记录
conversation_list = get_message_list(sales_info_id)
# 根据对话记录生成 prompt
prompt = await gen_poduct_base_prompt(
user_id,
streamer_info=streaming_room_info.streamer_info,
product_info=product_detail,
) # system + 获取商品文案prompt
prompt = combine_history(prompt, conversation_list)
# ====================== Agent ======================
# 调取 Agent
agent_response = await get_agent_res(prompt, product_detail.departure_place, product_detail.delivery_company)
if agent_response != "":
logger.info("Agent 执行成功,不执行 RAG")
prompt[-1]["content"] = agent_response
# ====================== RAG ======================
# 调取 rag
elif SERVER_PLUGINS_INFO.rag_enabled:
logger.info("Agent 未执行 or 未开启,调用 RAG")
# agent 失败,调取 rag, chat_item.plugins.rag 为 True,则使用 RAG 查询数据库
rag_res = build_rag_prompt(RAG_RETRIEVER, product_detail.product_name, prompt[-1]["content"])
if rag_res != "":
prompt[-1]["content"] = rag_res
# 调取 LLM
streamer_res = await get_llm_res(prompt)
# 生成数字人视频
server_video_path = await gen_tts_and_digital_human_video_app(streaming_room_info.streamer_info.streamer_id, streamer_res)
# 更新直播间数字人视频信息
update_room_video_path(streaming_room_info.status_id, server_video_path)
# 更新对话记录
update_message_info(sales_info_id, streaming_room_info.streamer_info.streamer_id, role="streamer", message=streamer_res)
return make_return_data(True, ResultCode.SUCCESS, "成功", "")
@router.post("/asr", summary="直播间调取 ASR 语音转文字 接口")
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)):
# room_chat.asr_file 是 服务器地址,需要进行转换
asr_local_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(WEB_CONFIGS.ASR_FILE_DIR, Path(room_chat.asrFileUrl).name)
# 获取 ASR 结果
req_data = {
"user_id": user_id,
"request_id": str(uuid.uuid1()),
"wav_path": str(asr_local_path),
}
logger.info(req_data)
res = requests.post(API_CONFIG.ASR_URL, json=req_data).json()
asr_str = res["result"]
logger.info(f"ASR res = {asr_str}")
# 删除过程文件
asr_local_path.unlink()
return make_return_data(True, ResultCode.SUCCESS, "成功", asr_str)