File size: 17,745 Bytes
1ef9436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File    :   utils.py
@Time    :   2024/09/02
@Project :   https://github.com/PeterH0323/Streamer-Sales
@Author  :   HinGwenWong
@Version :   1.0
@Desc    :   工具集合文件
"""


import asyncio
from ipaddress import IPv4Address
import json
import random
import wave
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List

import cv2
from lmdeploy.serve.openai.api_client import APIClient
from loguru import logger
from pydantic import BaseModel
from sqlmodel import Session, select
from tqdm import tqdm

from server.base.models.user_model import UserInfo

from ..tts.tools import SYMBOL_SPLITS, make_text_chunk
from ..web_configs import API_CONFIG, WEB_CONFIGS
from .database.init_db import DB_ENGINE
from .models.product_model import ProductInfo
from .models.streamer_info_model import StreamerInfo
from .models.streamer_room_model import OnAirRoomStatusItem, SalesDocAndVideoInfo, StreamRoomInfo

from .modules.agent.agent_worker import get_agent_result
from .modules.rag.rag_worker import RAG_RETRIEVER, build_rag_prompt
from .queue_thread import DIGITAL_HUMAN_QUENE, TTS_TEXT_QUENE
from .server_info import SERVER_PLUGINS_INFO


class ChatGenConfig(BaseModel):
    # LLM 推理配置
    top_p: float = 0.8
    temperature: float = 0.7
    repetition_penalty: float = 1.005


class ProductInfoItem(BaseModel):
    name: str
    heighlights: str
    introduce: str  # 生成商品文案 prompt

    image_path: str
    departure_place: str
    delivery_company_name: str


class PluginsInfo(BaseModel):
    rag: bool = True
    agent: bool = True
    tts: bool = True
    digital_human: bool = True


class ChatItem(BaseModel):
    user_id: str  # User 识别号,用于区分不用的用户调用
    request_id: str  # 请求 ID,用于生成 TTS & 数字人
    prompt: List[Dict[str, str]]  # 本次的 prompt
    product_info: ProductInfoItem  # 商品信息
    plugins: PluginsInfo = PluginsInfo()  # 插件信息
    chat_config: ChatGenConfig = ChatGenConfig()


# 加载 LLM 模型
LLM_MODEL_HANDLER = APIClient(API_CONFIG.LLM_URL)


async def streamer_sales_process(chat_item: ChatItem):

    # ====================== Agent ======================
    # 调取 Agent
    agent_response = ""
    if chat_item.plugins.agent and SERVER_PLUGINS_INFO.agent_enabled:
        GENERATE_AGENT_TEMPLATE = (
            "这是网上获取到的信息:“{}”\n 客户的问题:“{}” \n 请认真阅读信息并运用你的性格进行解答。"  # Agent prompt 模板
        )
        input_prompt = chat_item.prompt[-1]["content"]
        agent_response = get_agent_result(
            LLM_MODEL_HANDLER, input_prompt, chat_item.product_info.departure_place, chat_item.product_info.delivery_company_name
        )
        if agent_response != "":
            agent_response = GENERATE_AGENT_TEMPLATE.format(agent_response, input_prompt)
            print(f"Agent response: {agent_response}")
            chat_item.prompt[-1]["content"] = agent_response

    # ====================== RAG ======================
    # 调取 rag
    if chat_item.plugins.rag and agent_response == "":
        # 如果 Agent 没有执行,则使用 RAG 查询数据库
        rag_prompt = chat_item.prompt[-1]["content"]
        prompt_pro = build_rag_prompt(RAG_RETRIEVER, chat_item.product_info.name, rag_prompt)

        if prompt_pro != "":
            chat_item.prompt[-1]["content"] = prompt_pro

    # llm 推理流返回
    logger.info(chat_item.prompt)

    current_predict = ""
    idx = 0
    last_text_index = 0
    sentence_id = 0
    model_name = LLM_MODEL_HANDLER.available_models[0]
    for item in LLM_MODEL_HANDLER.chat_completions_v1(model=model_name, messages=chat_item.prompt, stream=True):
        logger.debug(f"LLM predict: {item}")
        if "content" not in item["choices"][0]["delta"]:
            continue
        current_res = item["choices"][0]["delta"]["content"]

        if "~" in current_res:
            current_res = current_res.replace("~", "。").replace("。。", "。")

        current_predict += current_res
        idx += 1

        if chat_item.plugins.tts and SERVER_PLUGINS_INFO.tts_server_enabled:
            # 切句子
            sentence = ""
            for symbol in SYMBOL_SPLITS:
                if symbol in current_res:
                    last_text_index, sentence = make_text_chunk(current_predict, last_text_index)
                    if len(sentence) <= 3:
                        # 文字太短的情况,不做生成
                        sentence = ""
                    break

            if sentence != "":
                sentence_id += 1
                logger.info(f"get sentence: {sentence}")
                tts_request_dict = {
                    "user_id": chat_item.user_id,
                    "request_id": chat_item.request_id,
                    "sentence": sentence,
                    "chunk_id": sentence_id,
                    # "wav_save_name": chat_item.request_id + f"{str(sentence_id).zfill(8)}.wav",
                }

                TTS_TEXT_QUENE.put(tts_request_dict)
                await asyncio.sleep(0.01)

        yield json.dumps(
            {
                "event": "message",
                "retry": 100,
                "id": idx,
                "data": current_predict,
                "step": "llm",
                "end_flag": False,
            },
            ensure_ascii=False,
        )
        await asyncio.sleep(0.01)  # 加个延时避免无法发出 event stream

    if chat_item.plugins.digital_human and SERVER_PLUGINS_INFO.digital_human_server_enabled:

        wav_list = [
            Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + f"-{str(i).zfill(8)}.wav")
            for i in range(1, sentence_id + 1)
        ]
        while True:
            # 等待 TTS 生成完成
            not_exist_count = 0
            for tts_wav in wav_list:
                if not tts_wav.exists():
                    not_exist_count += 1

            logger.info(f"still need to wait for {not_exist_count}/{sentence_id} wav generating...")
            if not_exist_count == 0:
                break

            yield json.dumps(
                {
                    "event": "message",
                    "retry": 100,
                    "id": idx,
                    "data": current_predict,
                    "step": "tts",
                    "end_flag": False,
                },
                ensure_ascii=False,
            )
            await asyncio.sleep(1)  # 加个延时避免无法发出 event stream

        # 合并 tts
        tts_save_path = Path(WEB_CONFIGS.TTS_WAV_GEN_PATH, chat_item.request_id + ".wav")
        all_tts_data = []

        for wav_file in tqdm(wav_list):
            logger.info(f"Reading wav file {wav_file}...")
            with wave.open(str(wav_file), "rb") as wf:
                all_tts_data.append([wf.getparams(), wf.readframes(wf.getnframes())])

        logger.info(f"Merging wav file to {tts_save_path}...")
        tts_params = max([tts_data[0] for tts_data in all_tts_data])
        with wave.open(str(tts_save_path), "wb") as wf:
            wf.setparams(tts_params)  # 使用第一个音频参数

            for wf_data in all_tts_data:
                wf.writeframes(wf_data[1])
        logger.info(f"Merged wav file to {tts_save_path} !")

        # 生成数字人视频
        tts_request_dict = {
            "user_id": chat_item.user_id,
            "request_id": chat_item.request_id,
            "chunk_id": 0,
            "tts_path": str(tts_save_path),
        }

        logger.info(f"Generating digital human...")
        DIGITAL_HUMAN_QUENE.put(tts_request_dict)
        while True:
            if (
                Path(WEB_CONFIGS.DIGITAL_HUMAN_VIDEO_OUTPUT_PATH)
                .joinpath(Path(tts_save_path).stem + ".mp4")
                .with_suffix(".txt")
                .exists()
            ):
                break
            yield json.dumps(
                {
                    "event": "message",
                    "retry": 100,
                    "id": idx,
                    "data": current_predict,
                    "step": "dg",
                    "end_flag": False,
                },
                ensure_ascii=False,
            )
            await asyncio.sleep(1)  # 加个延时避免无法发出 event stream

        # 删除过程文件
        for wav_file in wav_list:
            wav_file.unlink()

    yield json.dumps(
        {
            "event": "message",
            "retry": 100,
            "id": idx,
            "data": current_predict,
            "step": "all",
            "end_flag": True,
        },
        ensure_ascii=False,
    )


def make_poster_by_video_first_frame(video_path: str, image_output_name: str):
    """根据视频第一帧生成缩略图

    Args:
        video_path (str): 视频文件路径

    Returns:
        str: 第一帧保存的图片路径
    """

    # 打开视频文件
    cap = cv2.VideoCapture(video_path)

    # 读取第一帧
    ret, frame = cap.read()

    # 检查是否成功读取
    poster_save_path = str(Path(video_path).parent.joinpath(image_output_name))
    if ret:
        # 保存图像到文件
        cv2.imwrite(poster_save_path, frame)
        logger.info(f"第一帧已保存为 {poster_save_path}")
    else:
        logger.error("无法读取视频帧")

    # 释放视频捕获对象
    cap.release()

    return poster_save_path


@dataclass
class ResultCode:
    SUCCESS: int = 0000  # 成功
    FAIL: int = 1000  # 失败


def make_return_data(success_flag: bool, code: ResultCode, message: str, data: dict):
    return {
        "success": success_flag,
        "code": code,
        "message": message,
        "data": data,
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }


def gen_default_data():
    """生成默认数据,包括:
    - 商品数据
    - 主播数据
    - 直播间信息以及关联表
    """

    def create_default_user():
        """创建默认用户"""
        admin_user = UserInfo(
            username="hingwen.wong",
            ip_address=IPv4Address("127.0.0.1"),
            email="[email protected]",
            hashed_password="$2b$12$zXXveodjipHZMoSxJz5ODul7Z9YeRJd0GeSBjpwHdqEtBbAFvEdre",  # 123456 -> 用 get_password_hash 加密后的字符串
            avatar="/user/user-avatar.png",
        )

        with Session(DB_ENGINE) as session:
            session.add(admin_user)
            session.commit()

    def init_user() -> bool:
        """判断是否需要创建默认用户

        Returns:
            bool: 是否执行创建默认用户
        """
        with Session(DB_ENGINE) as session:
            results = session.exec(select(UserInfo).where(UserInfo.user_id == 1)).first()

        if results is None:
            # 如果数据库为空,创建初始用户
            create_default_user()
            logger.info("created default user info")
            return True

        return False

    def create_default_product_item():
        """生成商品默认数据库"""
        delivery_company_list = ["京东", "顺丰", "韵达", "圆通", "中通"]
        departure_place_list = ["广州", "北京", "武汉", "杭州", "上海", "深圳", "成都"]
        default_product_list = {
            "beef": {
                "product_name": "进口和牛羽下肉",
                "heighlights": "富含铁质;营养价值高;肌肉纤维好;红白相间纹理;适合烧烤炖煮;草食动物来源",
                "product_class": "食品",
            },
            "elec_toothblush": {
                "product_name": "声波电动牙刷",
                "heighlights": "高效清洁;减少手动压力;定时提醒;智能模式调节;无线充电;噪音低",
                "product_class": "电子",
            },
            "lip_stick": {
                "product_name": "唇膏",
                "heighlights": "丰富色号;滋润保湿;显色度高;持久不脱色;易于涂抹;便携包装",
                "product_class": "美妆",
            },
            "mask": {
                "product_name": "光感润颜面膜",
                "heighlights": "密集滋养;深层补水;急救修复;快速见效;定期护理;多种类型选择",
                "product_class": "美妆",
            },
            "oled_tv": {
                "product_name": "65英寸OLED电视",
                "heighlights": "色彩鲜艳;对比度极高;响应速度快;无背光眩光;厚度较薄;自发光无需额外照明",
                "product_class": "家电",
            },
            "pad": {
                "product_name": "14英寸平板电脑",
                "heighlights": "轻薄;触控操作;电池续航好;移动办公便利;娱乐性强;适合儿童学习",
                "product_class": "电子",
            },
            "pants": {
                "product_name": "速干运动裤",
                "heighlights": "快干;伸缩自如;吸湿排汗;防风保暖;高腰设计;多口袋实用",
                "product_class": "衣服",
            },
            "pen": {
                "product_name": "墨水钢笔",
                "heighlights": "耐用性;可书写性;不同颜色和类型;轻便设计;环保材料;易于携带",
                "product_class": "文具",
            },
            "perfume": {
                "product_name": "薰衣草淡香氛",
                "heighlights": "浪漫优雅;花香调为主;情感表达;适合各种年龄;瓶身设计精致;提升女性魅力",
                "product_class": "家居用品",
            },
            "shampoo": {
                "product_name": "本草精华洗发露",
                "heighlights": "温和配方;深层清洁;滋养头皮;丰富泡沫;易冲洗;适合各种发质",
                "product_class": "日用品",
            },
            "wok": {
                "product_name": "不粘煎炒锅",
                "heighlights": "不粘涂层;耐磨耐用;导热快;易清洗;多种烹饪方式;设计人性化",
                "product_class": "厨具",
            },
            "yoga_mat": {
                "product_name": "瑜伽垫",
                "heighlights": "防滑材质;吸湿排汗;厚度适中;耐用易清洁;各种瑜伽动作适用;轻巧便携",
                "product_class": "运动",
            },
        }

        with Session(DB_ENGINE) as session:
            for product_key, product_info in default_product_list.items():
                add_item = ProductInfo(
                    **product_info,
                    image_path=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.IMAGES_DIR}/{product_key}.png",
                    instruction=f"/{WEB_CONFIGS.PRODUCT_FILE_DIR}/{WEB_CONFIGS.INSTRUCTIONS_DIR}/{product_key}.md",
                    departure_place=random.choice(departure_place_list),
                    delivery_company=random.choice(delivery_company_list),
                    selling_price=round(random.uniform(66.6, 1999.9), 2),
                    amount=random.randint(999, 9999),
                    user_id=1,
                )
                session.add(add_item)
            session.commit()

        logger.info("created default product info done!")

    def create_default_streamer():

        with Session(DB_ENGINE) as session:
            streamer_item = StreamerInfo(
                name="乐乐喵",
                character="甜美;可爱;熟练使用各种网络热门梗造句;称呼客户为[家人们]",
                avatar=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.png",
                base_mp4_path=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.mp4",
                poster_image=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.png",
                tts_reference_audio=f"/{WEB_CONFIGS.STREAMER_FILE_DIR}/{WEB_CONFIGS.STREAMER_INFO_FILES_DIR}/lelemiao.wav",
                tts_reference_sentence="列车巡游银河,我不一定都能帮上忙,但只要是花钱能解决的事,尽管和我说吧。",
                tts_weight_tag="艾丝妲",
                user_id=1,
            )
            session.add(streamer_item)
            session.commit()

    def create_default_room():

        with Session(DB_ENGINE) as session:

            product_list = session.exec(
                select(ProductInfo).where(ProductInfo.user_id == 1).order_by(ProductInfo.product_id)
            ).all()

            on_air_status = OnAirRoomStatusItem(user_id=1)
            session.add(on_air_status)
            session.commit()
            session.refresh(on_air_status)

            stream_item = StreamRoomInfo(
                name="001",
                user_id=1,
                status_id=on_air_status.status_id,
                streamer_id=1,
            )
            session.add(stream_item)
            session.commit()
            session.refresh(stream_item)

            random_list = random.choices(product_list, k=3)
            for product_random in random_list:
                add_sales_info = SalesDocAndVideoInfo(product_id=product_random.product_id, room_id=stream_item.room_id)
                session.add(add_sales_info)
                session.commit()
                session.refresh(add_sales_info)

    # 主要逻辑
    created = init_user()
    if created:
        create_default_product_item()  # 商品信息
        create_default_streamer()  # 主播信息
        create_default_room()  # 直播间信息