""" |
@File : streamer_room_db.py |
@Time : 2024/08/31 |
@Project : https://github.com/PeterH0323/Streamer-Sales |
@Author : HinGwenWong |
@Version : 1.0 |
@Desc : 直播间信息数据库操作 |
""" |
from datetime import datetime |
from typing import List |
from loguru import logger |
from sqlmodel import Session, and_, not_, select |
from ...web_configs import API_CONFIG |
from ..models.streamer_room_model import ChatMessageInfo, OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo |
from .init_db import DB_ENGINE |
async def get_db_streaming_room_info(user_id: int, room_id: int | None = None) -> List[StreamRoomInfo] | None: |
"""查询数据库中的商品信息 |
Args: |
user_id (int): 用户 ID |
streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None. |
Returns: |
List[StreamRoomInfo] | None: 直播间信息 |
""" |
query_condiction = and_(StreamRoomInfo.user_id == user_id, StreamRoomInfo.delete == False) |
with Session(DB_ENGINE) as session: |
if room_id is not None: |
query_condiction = and_( |
StreamRoomInfo.user_id == user_id, StreamRoomInfo.delete == False, StreamRoomInfo.room_id == room_id |
) |
stream_room_list = session.exec(select(StreamRoomInfo).where(query_condiction).order_by(StreamRoomInfo.room_id)).all() |
if stream_room_list is None: |
logger.warning("nothing to find in db...") |
stream_room_list = [] |
for stream_room in stream_room_list: |
stream_room.streamer_info.avatar = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.avatar |
stream_room.streamer_info.tts_reference_audio = ( |
API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.tts_reference_audio |
) |
stream_room.streamer_info.poster_image = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.poster_image |
stream_room.streamer_info.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + stream_room.streamer_info.base_mp4_path |
for idx, product in enumerate(stream_room.product_list): |
stream_room.product_list[idx].product_info.image_path = API_CONFIG.REQUEST_FILES_URL + product.product_info.image_path |
stream_room.product_list[idx].product_info.instruction = ( |
API_CONFIG.REQUEST_FILES_URL + product.product_info.instruction |
) |
logger.info(stream_room_list) |
logger.info(f"len {len(stream_room_list)}") |
return stream_room_list |
async def delete_room_id(room_id: int, user_id: int) -> bool: |
"""删除特定的主播间 ID |
Args: |
room_id (int): 直播间 ID |
user_id (int): 用户 ID,用于防止其他用户恶意删除 |
Returns: |
bool: 是否删除成功 |
""" |
delete_success = True |
try: |
with Session(DB_ENGINE) as session: |
room_info = session.exec( |
select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id)) |
).one() |
if room_info is None: |
logger.error("Delete by other ID !!!") |
return False |
room_info.delete = True |
session.add(room_info) |
session.commit() |
except Exception: |
delete_success = False |
return delete_success |
def create_or_update_db_room_by_id(room_id: int, new_info: StreamRoomInfo, user_id: int): |
"""新增 or 编辑直播间信息 |
Args: |
room_id (int): 直播间 ID |
new_info (StreamRoomInfo): 新的信息 |
user_id (int): 用户 ID,用于防止其他用户恶意修改 |
""" |
with Session(DB_ENGINE) as session: |
if new_info.status_id is not None: |
status_info = session.exec( |
select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == new_info.status_id) |
).one() |
else: |
status_info = OnAirRoomStatusItem() |
status_info.streaming_video_path = new_info.status.streaming_video_path.replace(API_CONFIG.REQUEST_FILES_URL, "") |
status_info.live_status = new_info.status.live_status |
session.add(status_info) |
session.commit() |
session.refresh(status_info) |
if room_id > 0: |
room_info = session.exec( |
select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id)) |
).one() |
if room_info is None: |
logger.error("Edit by other ID !!!") |
return |
else: |
room_info = StreamRoomInfo(status_id=status_info.status_id, user_id=user_id) |
room_info.name = new_info.name |
room_info.prohibited_words_id = new_info.prohibited_words_id |
room_info.room_poster = new_info.room_poster.replace(API_CONFIG.REQUEST_FILES_URL, "") |
room_info.background_image = new_info.background_image.replace(API_CONFIG.REQUEST_FILES_URL, "") |
room_info.streamer_id = new_info.streamer_id |
session.add(room_info) |
session.commit() |
session.refresh(room_info) |
if len(new_info.product_list) > 0: |
selected_id_list = [product.product_id for product in new_info.product_list] |
for product in new_info.product_list: |
if product.sales_info_id is not None: |
sales_info = session.exec( |
select(SalesDocAndVideoInfo).where( |
and_( |
SalesDocAndVideoInfo.room_id == room_info.room_id, |
SalesDocAndVideoInfo.product_id == product.product_id, |
SalesDocAndVideoInfo.sales_info_id == product.sales_info_id, |
) |
) |
).one() |
else: |
sales_info = SalesDocAndVideoInfo() |
sales_info.product_id = product.product_id |
sales_info.sales_doc = product.sales_doc |
sales_info.start_time = product.start_time |
sales_info.start_video = product.start_video.replace(API_CONFIG.REQUEST_FILES_URL, "") |
sales_info.selected = True |
sales_info.room_id = room_info.room_id |
session.add(sales_info) |
session.commit() |
if len(selected_id_list) > 0: |
cancel_select_sales_info = session.exec( |
select(SalesDocAndVideoInfo).where( |
and_( |
SalesDocAndVideoInfo.room_id == room_info.room_id, |
not_(SalesDocAndVideoInfo.product_id.in_(selected_id_list)), |
) |
) |
).all() |
if cancel_select_sales_info is not None: |
for cancel_select in cancel_select_sales_info: |
session.delete(cancel_select) |
session.commit() |
return room_info.room_id |
def init_conversation(db_session, sales_info_id: int, streamer_id: int, sales_doc: str): |
"""新建直播间对话,一般触发于点击 开始直播 or 下一个商品 |
Args: |
db_session (it): 数据库句柄 |
sales_info_id (int): 销售 ID |
streamer_id (int): 主播 ID |
sales_doc (str): 主播文案 |
""" |
message_info = ChatMessageInfo( |
sales_info_id=sales_info_id, streamer_id=streamer_id, role="streamer", message=sales_doc, send_time=datetime.now() |
) |
db_session.add(message_info) |
def update_message_info(sales_info_id: int, role_id: int, role: str, message: str): |
"""新增对话记录 |
Args: |
sales_info_id (int): 销售 ID |
role_id (int): 角色 ID |
role (str): 角色类型:"streamer", "user" |
message (str): 插入的消息 |
""" |
assert role in ["streamer", "user"] |
with Session(DB_ENGINE) as session: |
role_key = "streamer_id" if role == "streamer" else "user_id" |
role_id_info = {role_key: role_id} |
message_info = ChatMessageInfo( |
**role_id_info, sales_info_id=sales_info_id, role=role, message=message, send_time=datetime.now() |
) |
session.add(message_info) |
session.commit() |
def update_db_room_status(room_id: int, user_id: int, process_type: str): |
"""编辑直播间状态信息 |
Args: |
room_id (int): 直播间 ID |
new_status_info (OnAirRoomStatusItem): 新的信息 |
user_id (int): 用户 ID,用于防止其他用户恶意修改 |
""" |
with Session(DB_ENGINE) as session: |
room_info = session.exec( |
select(StreamRoomInfo).where(and_(StreamRoomInfo.room_id == room_id, StreamRoomInfo.user_id == user_id)) |
).one() |
if room_info is None: |
logger.error("Edit by other ID !!!") |
return |
if room_info.status_id is not None: |
status_info = session.exec( |
select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == room_info.status_id) |
).one() |
if status_info is None: |
logger.error("status_info is None !!!") |
return |
if process_type in ["online", "next-product"]: |
if process_type == "online": |
status_info.live_status = 1 |
status_info.start_time = datetime.now() |
status_info.end_time = None |
status_info.current_product_index = 0 |
elif process_type == "next-product": |
status_info.current_product_index += 1 |
current_idx = status_info.current_product_index |
status_info.streaming_video_path = room_info.product_list[current_idx].start_video |
status_info.sales_info_id = room_info.product_list[current_idx].sales_info_id |
sales_info = session.exec( |
select(SalesDocAndVideoInfo).where( |
SalesDocAndVideoInfo.sales_info_id == room_info.product_list[current_idx].sales_info_id |
) |
).one() |
sales_info.start_time = datetime.now() |
session.add(sales_info) |
init_conversation( |
session, status_info.sales_info_id, room_info.streamer_id, room_info.product_list[current_idx].sales_doc |
) |
elif process_type == "offline": |
status_info.streaming_video_path = "" |
status_info.live_status = 2 |
status_info.end_time = datetime.now() |
else: |
raise NotImplemented("process type error !!") |
session.add(status_info) |
session.commit() |
def get_message_list(sales_info_id: int) -> List[ChatMessageInfo]: |
"""根据销售 ID 获取全部对话 |
Args: |
sales_info_id (int): 销售 ID |
Returns: |
List[ChatMessageInfo]: 对话列表 |
""" |
with Session(DB_ENGINE) as session: |
message_info = session.exec( |
select(ChatMessageInfo) |
.where(and_(ChatMessageInfo.sales_info_id == sales_info_id)) |
.order_by(ChatMessageInfo.message_id) |
).all() |
if message_info is None: |
return [] |
formate_message_list = [] |
for message_ in message_info: |
chat_item = { |
"role": message_.role, |
"avatar": message_.user_info.avatar if message_.role == "user" else message_.streamer_info.avatar, |
"userName": message_.user_info.username if message_.role == "user" else message_.streamer_info.name, |
"message": message_.message, |
"datetime": message_.send_time, |
} |
chat_item["avatar"] = API_CONFIG.REQUEST_FILES_URL + chat_item["avatar"] |
formate_message_list.append(chat_item) |
return formate_message_list |
def update_room_video_path(status_id: int, news_video_server_path: str): |
"""数据库更新 status 主播视频 |
Args: |
status_id (int): 主播间 status ID |
news_video_server_path (str): 需要更新的主播视频 服务器地址 |
""" |
with Session(DB_ENGINE) as session: |
status_info = session.exec(select(OnAirRoomStatusItem).where(OnAirRoomStatusItem.status_id == status_id)).one() |
status_info.streaming_video_path = news_video_server_path.replace(API_CONFIG.REQUEST_FILES_URL, "") |
session.add(status_info) |
session.commit() |
async def get_live_room_info(user_id: int, room_id: int): |
"""获取直播间的开播实时信息 |
Args: |
user_id (int): 用户 ID |
room_id (int): 直播间 ID |
Returns: |
dict: 直播间实时信息 |
""" |
streaming_room_info = await get_db_streaming_room_info(user_id, room_id) |
streaming_room_info = streaming_room_info[0] |
streamer_info = streaming_room_info.streamer_info |
prodcut_index = streaming_room_info.status.current_product_index |
final_procut = True if len(streaming_room_info.product_list) - 1 == prodcut_index else False |
conversation_list = get_message_list(streaming_room_info.status.sales_info_id) |
video_path = API_CONFIG.REQUEST_FILES_URL + streaming_room_info.status.streaming_video_path |
res_data = { |
"streamerInfo": streamer_info, |
"conversation": conversation_list, |
"currentProductInfo": streaming_room_info.product_list[prodcut_index].product_info, |
"currentStreamerVideo": video_path, |
"currentProductIndex": streaming_room_info.status.current_product_index, |
"startTime": streaming_room_info.status.start_time, |
"currentPoductStartTime": streaming_room_info.product_list[prodcut_index].start_time, |
"finalProduct": final_procut, |
} |
logger.info(res_data) |
return res_data |