File size: 5,311 Bytes
1ef9436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : streamer_info_db.py
@Time : 2024/08/30
@Project : https://github.com/PeterH0323/Streamer-Sales
@Author : HinGwenWong
@Version : 1.0
@Desc : 主播信息数据库操作
"""
from typing import List
from loguru import logger
from sqlmodel import Session, and_, select
from ...web_configs import API_CONFIG
from ..models.streamer_info_model import StreamerInfo
from .init_db import DB_ENGINE
async def get_db_streamer_info(user_id: int, streamer_id: int | None = None) -> List[StreamerInfo] | None:
"""查询数据库中的主播信息
Args:
user_id (int): 用户 ID
streamer_id (int | None, optional): 主播 ID,用户获取特定主播信息. Defaults to None.
Returns:
List[StreamerInfo] | StreamerInfo | None: 主播信息,如果获取全部则返回 list,如果获取单个则返回单个,如果查不到返回 None
"""
# 查询条件
query_condiction = and_(StreamerInfo.user_id == user_id, StreamerInfo.delete == False)
# 获取总数
with Session(DB_ENGINE) as session:
# 获得该用户所有主播的总数
# total_product_num = session.scalar(select(func.count(StreamerInfo.product_id)).where(query_condiction))
if streamer_id is not None:
# 查询条件更改为查找特定 ID
query_condiction = and_(
StreamerInfo.user_id == user_id, StreamerInfo.delete == False, StreamerInfo.streamer_id == streamer_id
)
# 查询主播商品,并根据 ID 进行排序
try:
streamer_list = session.exec(select(StreamerInfo).where(query_condiction).order_by(StreamerInfo.streamer_id)).all()
except Exception as e:
streamer_list = None
if streamer_list is None:
logger.warning("nothing to find in db...")
streamer_list = []
# 将路径换成服务器路径
for streamer in streamer_list:
streamer.avatar = API_CONFIG.REQUEST_FILES_URL + streamer.avatar
streamer.tts_reference_audio = API_CONFIG.REQUEST_FILES_URL + streamer.tts_reference_audio
streamer.poster_image = API_CONFIG.REQUEST_FILES_URL + streamer.poster_image
streamer.base_mp4_path = API_CONFIG.REQUEST_FILES_URL + streamer.base_mp4_path
logger.info(streamer_list)
logger.info(f"len {len(streamer_list)}")
return streamer_list
async def delete_streamer_id(streamer_id: int, user_id: int) -> bool:
"""删除特定的主播 ID
Args:
streamer_id (int): 主播 ID
user_id (int): 用户 ID,用于防止其他用户恶意删除
Returns:
bool: 是否删除成功
"""
delete_success = True
try:
# 获取总数
with Session(DB_ENGINE) as session:
# 查找特定 ID
streamer_info = session.exec(
select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
).one()
if streamer_info is None:
logger.error("Delete by other ID !!!")
return False
streamer_info.delete = True # 设置为删除
session.add(streamer_info)
session.commit() # 提交
except Exception:
delete_success = False
return delete_success
def create_or_update_db_streamer_by_id(streamer_id: int, new_info: StreamerInfo, user_id: int) -> int:
"""新增 or 编辑主播信息
Args:
product_id (int): 商品 ID
new_info (ProductInfo): 新的信息
user_id (int): 用户 ID,用于防止其他用户恶意修改
Returns:
int: 主播 ID
"""
# 去掉服务器地址
new_info.avatar = new_info.avatar.replace(API_CONFIG.REQUEST_FILES_URL, "")
new_info.tts_reference_audio = new_info.tts_reference_audio.replace(API_CONFIG.REQUEST_FILES_URL, "")
new_info.poster_image = new_info.poster_image.replace(API_CONFIG.REQUEST_FILES_URL, "")
new_info.base_mp4_path = new_info.base_mp4_path.replace(API_CONFIG.REQUEST_FILES_URL, "")
with Session(DB_ENGINE) as session:
if streamer_id > 0:
# 更新特定 ID
streamer_info = session.exec(
select(StreamerInfo).where(and_(StreamerInfo.streamer_id == streamer_id, StreamerInfo.user_id == user_id))
).one()
if streamer_info is None:
logger.error("Edit by other ID !!!")
return -1
else:
# 新增,直接添加即可
streamer_info = StreamerInfo(user_id=user_id)
# 更新对应的值
streamer_info.name = new_info.name
streamer_info.character = new_info.character
streamer_info.avatar = new_info.avatar
streamer_info.tts_weight_tag = new_info.tts_weight_tag
streamer_info.tts_reference_sentence = new_info.tts_reference_sentence
streamer_info.tts_reference_audio = new_info.tts_reference_audio
streamer_info.poster_image = new_info.poster_image
streamer_info.base_mp4_path = new_info.base_mp4_path
session.add(streamer_info)
session.commit() # 提交
session.refresh(streamer_info)
return int(streamer_info.streamer_id)
|