|
|
|
|
|
""" |
|
@File : streaming_room.py |
|
@Time : 2024/08/31 |
|
@Project : https://github.com/PeterH0323/Streamer-Sales |
|
@Author : HinGwenWong |
|
@Version : 1.0 |
|
@Desc : 主播间信息交互接口 |
|
""" |
|
|
|
import uuid |
|
from copy import deepcopy |
|
from pathlib import Path |
|
|
|
import requests |
|
from fastapi import APIRouter, Depends |
|
from loguru import logger |
|
|
|
from ...web_configs import API_CONFIG, WEB_CONFIGS |
|
from ..database.product_db import get_db_product_info |
|
from ..database.streamer_room_db import ( |
|
create_or_update_db_room_by_id, |
|
get_live_room_info, |
|
get_message_list, |
|
update_db_room_status, |
|
delete_room_id, |
|
get_db_streaming_room_info, |
|
update_message_info, |
|
update_room_video_path, |
|
) |
|
from ..models.product_model import ProductInfo |
|
from ..models.streamer_room_model import OnAirRoomStatusItem, RoomChatItem, SalesDocAndVideoInfo, StreamRoomInfo |
|
from ..modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt |
|
from ..routers.users import get_current_user_info |
|
from ..server_info import SERVER_PLUGINS_INFO |
|
from ..utils import ResultCode, make_return_data |
|
from .digital_human import gen_tts_and_digital_human_video_app |
|
from .llm import combine_history, gen_poduct_base_prompt, get_agent_res, get_llm_res |
|
|
|
router = APIRouter( |
|
prefix="/streaming-room", |
|
tags=["streaming-room"], |
|
responses={404: {"description": "Not found"}}, |
|
) |
|
|
|
|
|
@router.get("/list", summary="获取所有直播间信息接口") |
|
async def get_streaming_room_api(user_id: int = Depends(get_current_user_info)): |
|
"""获取所有直播间信息""" |
|
|
|
streaming_room_list = await get_db_streaming_room_info(user_id) |
|
|
|
for i in range(len(streaming_room_list)): |
|
|
|
streaming_room_list[i] = dict(streaming_room_list[i]) |
|
|
|
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list) |
|
|
|
|
|
@router.get("/info/{roomId}", summary="获取特定直播间信息接口") |
|
async def get_streaming_room_id_api( |
|
roomId: int, currentPage: int = 1, pageSize: int = 5, user_id: int = Depends(get_current_user_info) |
|
): |
|
"""获取特定直播间信息""" |
|
|
|
assert roomId != 0 |
|
|
|
|
|
|
|
|
|
streaming_room_list = await get_db_streaming_room_info(user_id, room_id=roomId) |
|
|
|
if len(streaming_room_list) == 1: |
|
|
|
format_product_list = [] |
|
for db_product in streaming_room_list[0].product_list: |
|
|
|
product_dict = dict(db_product) |
|
|
|
if product_dict["start_video"] != "": |
|
product_dict["start_video"] = API_CONFIG.REQUEST_FILES_URL + product_dict["start_video"] |
|
|
|
format_product_list.append(product_dict) |
|
streaming_room_list = dict(streaming_room_list[0]) |
|
streaming_room_list["product_list"] = format_product_list |
|
else: |
|
streaming_room_list = [] |
|
|
|
return make_return_data(True, ResultCode.SUCCESS, "成功", streaming_room_list) |
|
|
|
|
|
@router.get("/product-edit-list/{roomId}", summary="获取直播间商品编辑列表,含有已选中的标识") |
|
async def get_streaming_room_product_list_api( |
|
roomId: int, currentPage: int = 1, pageSize: int = 0, user_id: int = Depends(get_current_user_info) |
|
): |
|
"""获取直播间商品编辑列表,含有已选中的标识""" |
|
|
|
|
|
if roomId == 0: |
|
|
|
merge_list = [] |
|
exclude_list = [] |
|
else: |
|
streaming_room_info = await get_db_streaming_room_info(user_id, roomId) |
|
|
|
if len(streaming_room_info) == 0: |
|
raise "401" |
|
|
|
streaming_room_info = streaming_room_info[0] |
|
|
|
exclude_list = [product.product_id for product in streaming_room_info.product_list] |
|
merge_list = deepcopy(streaming_room_info.product_list) |
|
|
|
|
|
not_select_product_list, db_product_size = await get_db_product_info(user_id=user_id, exclude_list=exclude_list) |
|
|
|
|
|
for not_select_product in not_select_product_list: |
|
merge_list.append( |
|
SalesDocAndVideoInfo( |
|
product_id=not_select_product.product_id, |
|
product_info=ProductInfo(**dict(not_select_product)), |
|
selected=False, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
format_merge_list = [] |
|
for product in merge_list: |
|
|
|
dict_info = dict(product) |
|
if "stream_room" in dict_info: |
|
dict_info.pop("stream_room") |
|
format_merge_list.append(dict_info) |
|
|
|
page_info = dict( |
|
product_list=format_merge_list, |
|
current=currentPage, |
|
pageSize=db_product_size, |
|
totalSize=db_product_size, |
|
) |
|
logger.info(page_info) |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", page_info) |
|
|
|
|
|
@router.post("/create", summary="新增直播间接口") |
|
async def streaming_room_edit_api(edit_item: dict, user_id: int = Depends(get_current_user_info)): |
|
product_list = edit_item.pop("product_list") |
|
status = edit_item.pop("status") |
|
edit_item.pop("streamer_info") |
|
edit_item.pop("room_id") |
|
|
|
if "status_id" in edit_item: |
|
edit_item.pop("status_id") |
|
|
|
formate_product_list = [] |
|
for product in product_list: |
|
if not product["selected"]: |
|
continue |
|
product.pop("product_info") |
|
product_item = SalesDocAndVideoInfo(**product) |
|
formate_product_list.append(product_item) |
|
|
|
edit_item["user_id"] = user_id |
|
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status)) |
|
room_id = create_or_update_db_room_by_id(0, formate_info, user_id) |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id) |
|
|
|
|
|
@router.put("/edit/{room_id}", summary="编辑直播间接口") |
|
async def streaming_room_edit_api(room_id: int, edit_item: dict, user_id: int = Depends(get_current_user_info)): |
|
"""编辑直播间接口 |
|
|
|
Args: |
|
edit_item (StreamRoomInfo): _description_ |
|
""" |
|
|
|
product_list = edit_item.pop("product_list") |
|
status = edit_item.pop("status") |
|
edit_item.pop("streamer_info") |
|
|
|
formate_product_list = [] |
|
for product in product_list: |
|
if not product["selected"]: |
|
continue |
|
product.pop("product_info") |
|
product_item = SalesDocAndVideoInfo(**product) |
|
formate_product_list.append(product_item) |
|
|
|
formate_info = StreamRoomInfo(**edit_item, product_list=formate_product_list, status=OnAirRoomStatusItem(**status)) |
|
create_or_update_db_room_by_id(room_id, formate_info, user_id) |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", room_id) |
|
|
|
|
|
@router.delete("/delete/{roomId}", summary="删除直播间接口") |
|
async def delete_room_api(roomId: int, user_id: int = Depends(get_current_user_info)): |
|
|
|
process_success_flag = await delete_room_id(roomId, user_id) |
|
|
|
if not process_success_flag: |
|
return make_return_data(False, ResultCode.FAIL, "失败", "") |
|
|
|
return make_return_data(True, ResultCode.SUCCESS, "成功", "") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/online/{roomId}", summary="直播间开播接口") |
|
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)): |
|
|
|
update_db_room_status(roomId, user_id, "online") |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", "") |
|
|
|
|
|
@router.put("/offline/{roomId}", summary="直播间下播接口") |
|
async def offline_api(roomId: int, user_id: int = Depends(get_current_user_info)): |
|
|
|
update_db_room_status(roomId, user_id, "offline") |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", "") |
|
|
|
|
|
@router.post("/next-product/{roomId}", summary="直播间进行下一个商品讲解接口") |
|
async def on_air_live_room_next_product_api(roomId: int, user_id: int = Depends(get_current_user_info)): |
|
"""直播间进行下一个商品讲解 |
|
|
|
Args: |
|
roomId (int): 直播间 ID |
|
""" |
|
|
|
update_db_room_status(roomId, user_id, "next-product") |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", "") |
|
|
|
|
|
@router.get("/live-info/{roomId}", summary="获取正在直播的直播间信息接口") |
|
async def get_on_air_live_room_api(roomId: int, user_id: int = Depends(get_current_user_info)): |
|
"""获取正在直播的直播间信息 |
|
|
|
1. 主播视频地址 |
|
2. 商品信息,显示在右下角的商品缩略图 |
|
3. 对话记录 conversation_list |
|
|
|
Args: |
|
roomId (int): 直播间 ID |
|
""" |
|
|
|
res_data = await get_live_room_info(user_id, roomId) |
|
|
|
return make_return_data(True, ResultCode.SUCCESS, "成功", res_data) |
|
|
|
|
|
@router.put("/chat", summary="直播间对话接口") |
|
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)): |
|
|
|
streaming_room_info = await get_db_streaming_room_info(user_id, room_chat.roomId) |
|
streaming_room_info = streaming_room_info[0] |
|
|
|
|
|
product_detail = streaming_room_info.product_list[streaming_room_info.status.current_product_index].product_info |
|
|
|
|
|
sales_info_id = streaming_room_info.product_list[streaming_room_info.status.current_product_index].sales_info_id |
|
|
|
|
|
update_message_info(sales_info_id, user_id, role="user", message=room_chat.message) |
|
|
|
|
|
conversation_list = get_message_list(sales_info_id) |
|
|
|
|
|
prompt = await gen_poduct_base_prompt( |
|
user_id, |
|
streamer_info=streaming_room_info.streamer_info, |
|
product_info=product_detail, |
|
) |
|
|
|
prompt = combine_history(prompt, conversation_list) |
|
|
|
|
|
|
|
agent_response = await get_agent_res(prompt, product_detail.departure_place, product_detail.delivery_company) |
|
if agent_response != "": |
|
logger.info("Agent 执行成功,不执行 RAG") |
|
prompt[-1]["content"] = agent_response |
|
|
|
|
|
|
|
elif SERVER_PLUGINS_INFO.rag_enabled: |
|
logger.info("Agent 未执行 or 未开启,调用 RAG") |
|
|
|
rag_res = build_rag_prompt(RAG_RETRIEVER, product_detail.product_name, prompt[-1]["content"]) |
|
if rag_res != "": |
|
prompt[-1]["content"] = rag_res |
|
|
|
|
|
streamer_res = await get_llm_res(prompt) |
|
|
|
|
|
server_video_path = await gen_tts_and_digital_human_video_app(streaming_room_info.streamer_info.streamer_id, streamer_res) |
|
|
|
|
|
update_room_video_path(streaming_room_info.status_id, server_video_path) |
|
|
|
|
|
update_message_info(sales_info_id, streaming_room_info.streamer_info.streamer_id, role="streamer", message=streamer_res) |
|
|
|
return make_return_data(True, ResultCode.SUCCESS, "成功", "") |
|
|
|
|
|
@router.post("/asr", summary="直播间调取 ASR 语音转文字 接口") |
|
async def get_on_air_live_room_api(room_chat: RoomChatItem, user_id: int = Depends(get_current_user_info)): |
|
|
|
|
|
asr_local_path = Path(WEB_CONFIGS.SERVER_FILE_ROOT).joinpath(WEB_CONFIGS.ASR_FILE_DIR, Path(room_chat.asrFileUrl).name) |
|
|
|
|
|
req_data = { |
|
"user_id": user_id, |
|
"request_id": str(uuid.uuid1()), |
|
"wav_path": str(asr_local_path), |
|
} |
|
logger.info(req_data) |
|
|
|
res = requests.post(API_CONFIG.ASR_URL, json=req_data).json() |
|
asr_str = res["result"] |
|
logger.info(f"ASR res = {asr_str}") |
|
|
|
|
|
asr_local_path.unlink() |
|
return make_return_data(True, ResultCode.SUCCESS, "成功", asr_str) |
|
|