FAYO
model
1ef9436
raw
history blame
5.31 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : streamer_info_db.py
@Time : 2024/08/30
@Project : https://github.com/PeterH0323/Streamer-Sales
@Author : HinGwenWong
@Version : 1.0
@Desc : 主播信息数据库操作
"""
from typing import List
from loguru import logger
from sqlmodel import Session, and_, select
from ...web_configs import API_CONFIG
from ..models.streamer_info_model import StreamerInfo
from .init_db import DB_ENGINE
async def get_db_streamer_info(user_id: int, streamer_id: int | None = None) -> List[StreamerInfo] | None:
"""查询数据库中的主播信息
Args:
user_id (int): 用户 ID
streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None.
Returns:
List[StreamerInfo] | StreamerInfo | None: 主播信息,如果获取全部则返回 list,如果获取单个则返回单个,如果查不到返回 None
"""
# 查询条件
query_condiction = and_(StreamerInfo.user_id == user_id, StreamerInfo.delete == False)
# 获取总数
with Session(DB_ENGINE) as session:
# 获得该用户所有主播的总数
# total_product_num = session.scalar(select(func.count(StreamerInfo.product_id)).where(query_condiction))
if streamer_id is not None:
# 查询条件更改为查找特定 ID
query_condiction = and_(
StreamerInfo.user_id == user_id, StreamerInfo.delete == False, StreamerInfo.streamer_id == streamer_id
)
# 查询主播商品,并根据 ID 进行排序
try:
streamer_list = session.exec(select(StreamerInfo).where(query_condiction).order_by(StreamerInfo.streamer_id)).all()
except Exception as e:
streamer_list = None
if streamer_list is None:
logger.warning("nothing to find in db...")
streamer_list = []
# 将路径换成服务器路径
for streamer in streamer_list:
streamer.avatar = API_CONFIG.REQUEST_FILES_URL + streamer.avatar
streamer.tts_reference_audio = API_CONFIG.REQUEST_FILES_URL + streamer.tts_reference_audio
streamer.poster_image = API_CONFIG.REQUEST_FILES_URL + streamer.poster_image
streamer.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + streamer.base_mp4_path
logger.info(streamer_list)
logger.info(f"len {len(streamer_list)}")
return streamer_list
async def delete_streamer_id(streamer_id: int, user_id: int) -> bool:
"""删除特定的主播 ID
Args:
streamer_id (int): 主播 ID
user_id (int): 用户 ID,用于防止其他用户恶意删除
Returns:
bool: 是否删除成功
"""
delete_success = True
try:
# 获取总数
with Session(DB_ENGINE) as session:
# 查找特定 ID
streamer_info = session.exec(
select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
).one()
if streamer_info is None:
logger.error("Delete by other ID !!!")
return False
streamer_info.delete = True # 设置为删除
session.add(streamer_info)
session.commit() # 提交
except Exception:
delete_success = False
return delete_success
def create_or_update_db_streamer_by_id(streamer_id: int, new_info: StreamerInfo, user_id: int) -> int:
"""新增 or 编辑主播信息
Args:
product_id (int): 商品 ID
new_info (ProductInfo): 新的信息
user_id (int): 用户 ID,用于防止其他用户恶意修改
Returns:
int: 主播 ID
"""
# 去掉服务器地址
new_info.avatar = new_info.avatar.replace(API_CONFIG.REQUEST_FILES_URL, "")
new_info.tts_reference_audio = new_info.tts_reference_audio.replace(API_CONFIG.REQUEST_FILES_URL, "")
new_info.poster_image = new_info.poster_image.replace(API_CONFIG.REQUEST_FILES_URL, "")
new_info.base_mp4_path = new_info.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
with Session(DB_ENGINE) as session:
if streamer_id > 0:
# 更新特定 ID
streamer_info = session.exec(
select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
).one()
if streamer_info is None:
logger.error("Edit by other ID !!!")
return -1
else:
# 新增,直接添加即可
streamer_info = StreamerInfo(user_id=user_id)
# 更新对应的值
streamer_info.name = new_info.name
streamer_info.character = new_info.character
streamer_info.avatar = new_info.avatar
streamer_info.tts_weight_tag = new_info.tts_weight_tag
streamer_info.tts_reference_sentence = new_info.tts_reference_sentence
streamer_info.tts_reference_audio = new_info.tts_reference_audio
streamer_info.poster_image = new_info.poster_image
streamer_info.base_mp4_path = new_info.base_mp4_path
session.add(streamer_info)
session.commit() # 提交
session.refresh(streamer_info)
return int(streamer_info.streamer_id)