FaYo
commited on
Commit
·
9f68218
1
Parent(s):
5661e58
model
Browse files
finetune_configs/internlm_chat_7b_qlora_alpace_e3.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from datasets import load_dataset
|
4 |
+
from mmengine.dataset import DefaultSampler
|
5 |
+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
6 |
+
LoggerHook, ParamSchedulerHook)
|
7 |
+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
8 |
+
from peft import LoraConfig
|
9 |
+
from torch.optim import AdamW
|
10 |
+
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
11 |
+
BitsAndBytesConfig)
|
12 |
+
|
13 |
+
from xtuner.dataset import process_hf_dataset
|
14 |
+
from xtuner.dataset.collate_fns import default_collate_fn
|
15 |
+
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
|
16 |
+
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
|
17 |
+
VarlenAttnArgsToMessageHubHook)
|
18 |
+
from xtuner.engine.runner import TrainLoop
|
19 |
+
from xtuner.model import SupervisedFinetune
|
20 |
+
from xtuner.parallel.sequence import SequenceParallelSampler
|
21 |
+
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
|
22 |
+
|
23 |
+
#######################################################################
|
24 |
+
# PART 1 Settings #
|
25 |
+
#######################################################################
|
26 |
+
# Model
|
27 |
+
pretrained_model_name_or_path = '/group_share/lntelligent-Medical-Guidance-Large-Model/InternLM/XTuner/model/internlm2-chat-7b'
|
28 |
+
use_varlen_attn = False
|
29 |
+
|
30 |
+
# Data
|
31 |
+
alpaca_en_path = 'dataset/gen_dataset/train_dataset/90_train.jsonl'
|
32 |
+
prompt_template = PROMPT_TEMPLATE.internlm2_chat
|
33 |
+
max_length = 2048
|
34 |
+
pack_to_max_length = True
|
35 |
+
|
36 |
+
# parallel
|
37 |
+
sequence_parallel_size = 1
|
38 |
+
|
39 |
+
# Scheduler & Optimizer
|
40 |
+
batch_size = 1 # per_device
|
41 |
+
accumulative_counts = 16
|
42 |
+
accumulative_counts *= sequence_parallel_size
|
43 |
+
dataloader_num_workers = 0
|
44 |
+
max_epochs = 3
|
45 |
+
optim_type = AdamW
|
46 |
+
lr = 2e-4
|
47 |
+
betas = (0.9, 0.999)
|
48 |
+
weight_decay = 0
|
49 |
+
max_norm = 1 # grad clip
|
50 |
+
warmup_ratio = 0.03
|
51 |
+
|
52 |
+
# Save
|
53 |
+
save_steps = 500
|
54 |
+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
55 |
+
|
56 |
+
# Evaluate the generation performance during the training
|
57 |
+
evaluation_freq = 500
|
58 |
+
SYSTEM = SYSTEM_TEMPLATE.alpaca
|
59 |
+
evaluation_inputs = [
|
60 |
+
'请介绍一下你自己', 'Please introduce yourself'
|
61 |
+
]
|
62 |
+
|
63 |
+
#######################################################################
|
64 |
+
# PART 2 Model & Tokenizer #
|
65 |
+
#######################################################################
|
66 |
+
tokenizer = dict(
|
67 |
+
type=AutoTokenizer.from_pretrained,
|
68 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
69 |
+
trust_remote_code=True,
|
70 |
+
padding_side='right')
|
71 |
+
|
72 |
+
model = dict(
|
73 |
+
type=SupervisedFinetune,
|
74 |
+
use_varlen_attn=use_varlen_attn,
|
75 |
+
llm=dict(
|
76 |
+
type=AutoModelForCausalLM.from_pretrained,
|
77 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
78 |
+
trust_remote_code=True,
|
79 |
+
torch_dtype=torch.float16,
|
80 |
+
quantization_config=dict(
|
81 |
+
type=BitsAndBytesConfig,
|
82 |
+
load_in_4bit=True,
|
83 |
+
load_in_8bit=False,
|
84 |
+
llm_int8_threshold=6.0,
|
85 |
+
llm_int8_has_fp16_weight=False,
|
86 |
+
bnb_4bit_compute_dtype=torch.float16,
|
87 |
+
bnb_4bit_use_double_quant=True,
|
88 |
+
bnb_4bit_quant_type='nf4')),
|
89 |
+
lora=dict(
|
90 |
+
type=LoraConfig,
|
91 |
+
r=64,
|
92 |
+
lora_alpha=16,
|
93 |
+
lora_dropout=0.1,
|
94 |
+
bias='none',
|
95 |
+
task_type='CAUSAL_LM'))
|
96 |
+
|
97 |
+
#######################################################################
|
98 |
+
# PART 3 Dataset & Dataloader #
|
99 |
+
#######################################################################
|
100 |
+
alpaca_en = dict(
|
101 |
+
type=process_hf_dataset,
|
102 |
+
dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
|
103 |
+
tokenizer=tokenizer,
|
104 |
+
max_length=max_length,
|
105 |
+
dataset_map_fn=None,
|
106 |
+
template_map_fn=dict(
|
107 |
+
type=template_map_fn_factory, template=prompt_template),
|
108 |
+
remove_unused_columns=True,
|
109 |
+
shuffle_before_pack=True,
|
110 |
+
pack_to_max_length=pack_to_max_length,
|
111 |
+
use_varlen_attn=use_varlen_attn)
|
112 |
+
|
113 |
+
sampler = SequenceParallelSampler \
|
114 |
+
if sequence_parallel_size > 1 else DefaultSampler
|
115 |
+
train_dataloader = dict(
|
116 |
+
batch_size=batch_size,
|
117 |
+
num_workers=dataloader_num_workers,
|
118 |
+
dataset=alpaca_en,
|
119 |
+
sampler=dict(type=sampler, shuffle=True),
|
120 |
+
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
|
121 |
+
|
122 |
+
#######################################################################
|
123 |
+
# PART 4 Scheduler & Optimizer #
|
124 |
+
#######################################################################
|
125 |
+
# optimizer
|
126 |
+
optim_wrapper = dict(
|
127 |
+
type=AmpOptimWrapper,
|
128 |
+
optimizer=dict(
|
129 |
+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
130 |
+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
131 |
+
accumulative_counts=accumulative_counts,
|
132 |
+
loss_scale='dynamic',
|
133 |
+
dtype='float16')
|
134 |
+
|
135 |
+
# learning policy
|
136 |
+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
137 |
+
param_scheduler = [
|
138 |
+
dict(
|
139 |
+
type=LinearLR,
|
140 |
+
start_factor=1e-5,
|
141 |
+
by_epoch=True,
|
142 |
+
begin=0,
|
143 |
+
end=warmup_ratio * max_epochs,
|
144 |
+
convert_to_iter_based=True),
|
145 |
+
dict(
|
146 |
+
type=CosineAnnealingLR,
|
147 |
+
eta_min=0.0,
|
148 |
+
by_epoch=True,
|
149 |
+
begin=warmup_ratio * max_epochs,
|
150 |
+
end=max_epochs,
|
151 |
+
convert_to_iter_based=True)
|
152 |
+
]
|
153 |
+
|
154 |
+
# train, val, test setting
|
155 |
+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
156 |
+
|
157 |
+
#######################################################################
|
158 |
+
# PART 5 Runtime #
|
159 |
+
#######################################################################
|
160 |
+
# Log the dialogue periodically during the training process, optional
|
161 |
+
custom_hooks = [
|
162 |
+
dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
163 |
+
dict(
|
164 |
+
type=EvaluateChatHook,
|
165 |
+
tokenizer=tokenizer,
|
166 |
+
every_n_iters=evaluation_freq,
|
167 |
+
evaluation_inputs=evaluation_inputs,
|
168 |
+
system=SYSTEM,
|
169 |
+
prompt_template=prompt_template)
|
170 |
+
]
|
171 |
+
|
172 |
+
if use_varlen_attn:
|
173 |
+
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
|
174 |
+
|
175 |
+
# configure default hooks
|
176 |
+
default_hooks = dict(
|
177 |
+
# record the time of every iteration.
|
178 |
+
timer=dict(type=IterTimerHook),
|
179 |
+
# print log every 10 iterations.
|
180 |
+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
181 |
+
# enable the parameter scheduler.
|
182 |
+
param_scheduler=dict(type=ParamSchedulerHook),
|
183 |
+
# save checkpoint per `save_steps`.
|
184 |
+
checkpoint=dict(
|
185 |
+
type=CheckpointHook,
|
186 |
+
by_epoch=False,
|
187 |
+
interval=save_steps,
|
188 |
+
max_keep_ckpts=save_total_limit),
|
189 |
+
# set sampler seed in distributed evrionment.
|
190 |
+
sampler_seed=dict(type=DistSamplerSeedHook),
|
191 |
+
)
|
192 |
+
|
193 |
+
# configure environment
|
194 |
+
env_cfg = dict(
|
195 |
+
# whether to enable cudnn benchmark
|
196 |
+
cudnn_benchmark=False,
|
197 |
+
# set multi process parameters
|
198 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
199 |
+
# set distributed parameters
|
200 |
+
dist_cfg=dict(backend='nccl'),
|
201 |
+
)
|
202 |
+
|
203 |
+
# set visualizer
|
204 |
+
visualizer = None
|
205 |
+
|
206 |
+
# set log level
|
207 |
+
log_level = 'INFO'
|
208 |
+
|
209 |
+
# load from which checkpoint
|
210 |
+
load_from = None
|
211 |
+
|
212 |
+
# whether to resume training from the loaded checkpoint
|
213 |
+
resume = False
|
214 |
+
|
215 |
+
# Defaults to use random seed and disable `deterministic`
|
216 |
+
randomness = dict(seed=None, deterministic=False)
|
217 |
+
|
218 |
+
# set log processor
|
219 |
+
log_processor = dict(by_epoch=False)
|
pages/selling_page.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# @Time : 2024.4.16
|
4 |
+
# @Author : HinGwenWong
|
5 |
+
|
6 |
+
import random
|
7 |
+
from datetime import datetime
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import streamlit as st
|
11 |
+
|
12 |
+
from utils.web_configs import WEB_CONFIGS
|
13 |
+
|
14 |
+
# 设置页面配置,包括标题、图标、布局和菜单项
|
15 |
+
st.set_page_config(
|
16 |
+
page_title="智能医导大模型",
|
17 |
+
page_icon="🛒",
|
18 |
+
layout="wide",
|
19 |
+
initial_sidebar_state="expanded",
|
20 |
+
menu_items={
|
21 |
+
"Get Help": "https://github.com/nhbdgtgefr/Intelligent-Medical-Guidance-Large-Model/tree/main",
|
22 |
+
"About": "# 智能医导大模型",
|
23 |
+
},
|
24 |
+
)
|
25 |
+
|
26 |
+
from audiorecorder import audiorecorder
|
27 |
+
|
28 |
+
from utils.asr.asr_worker import process_asr
|
29 |
+
from utils.digital_human.digital_human_worker import show_video
|
30 |
+
from utils.infer.lmdeploy_infer import get_turbomind_response
|
31 |
+
from utils.model_loader import ASR_HANDLER, LLM_MODEL, RAG_RETRIEVER
|
32 |
+
from utils.tools import resize_image
|
33 |
+
|
34 |
+
|
35 |
+
def on_btn_click(*args, **kwargs):
|
36 |
+
"""
|
37 |
+
处理按钮点击事件的函数。
|
38 |
+
"""
|
39 |
+
if kwargs["info"] == "清除对话历史":
|
40 |
+
st.session_state.messages = []
|
41 |
+
elif kwargs["info"] == "返回科室页":
|
42 |
+
st.session_state.page_switch = "app.py"
|
43 |
+
else:
|
44 |
+
st.session_state.button_msg = kwargs["info"]
|
45 |
+
|
46 |
+
|
47 |
+
def init_sidebar():
|
48 |
+
"""
|
49 |
+
初始化侧边栏界面,展示商品信息,并提供操作按钮。
|
50 |
+
"""
|
51 |
+
asr_text = ""
|
52 |
+
with st.sidebar:
|
53 |
+
# 标题
|
54 |
+
st.markdown("## 智能医导大模型")
|
55 |
+
st.markdown("[智能医导大模型](https://github.com/nhbdgtgefr/Intelligent-Medical-Guidance-Large-Model)")
|
56 |
+
st.subheader("功能点:", divider="grey")
|
57 |
+
# st.markdown(
|
58 |
+
# "1. 📜 **主播文案一键生成**\n2. 🚀 KV cache + Turbomind **推理加速**\n3. 📚 RAG **检索增强生成**\n4. 🔊 TTS **文字转语音**\n5. 🦸 **数字人生成**\n6. 🌐 **Agent 网络查询**\n7. 🎙️ **ASR 语音转文字**"
|
59 |
+
# )
|
60 |
+
|
61 |
+
st.subheader("目前讲解")
|
62 |
+
with st.container(height=400, border=True):
|
63 |
+
st.subheader(st.session_state.product_name)
|
64 |
+
|
65 |
+
image = resize_image(st.session_state.image_path, max_height=100)
|
66 |
+
st.image(image, channels="bgr")
|
67 |
+
|
68 |
+
st.subheader("科室特点", divider="grey")
|
69 |
+
st.markdown(st.session_state.hightlight)
|
70 |
+
|
71 |
+
want_to_buy_list = [
|
72 |
+
"我打算买了。",
|
73 |
+
"我准备入手了。",
|
74 |
+
"我决定要买了。",
|
75 |
+
"我准备下单了。",
|
76 |
+
"我将要购买这款产品。",
|
77 |
+
"我准备买下来了。",
|
78 |
+
"我准备将这个买下。",
|
79 |
+
"我准备要购买了。",
|
80 |
+
"我决定买下它。",
|
81 |
+
"我准备将其买下。",
|
82 |
+
]
|
83 |
+
buy_flag = st.button("加入信息🛒", on_click=on_btn_click, kwargs={"info": random.choice(want_to_buy_list)})
|
84 |
+
|
85 |
+
# TODO 加入卖货信息
|
86 |
+
# 卖出 xxx 个
|
87 |
+
# 成交额
|
88 |
+
|
89 |
+
if WEB_CONFIGS.ENABLE_ASR:
|
90 |
+
Path(WEB_CONFIGS.ASR_WAV_SAVE_PATH).mkdir(parents=True, exist_ok=True)
|
91 |
+
|
92 |
+
st.subheader(f"语音输入", divider="grey")
|
93 |
+
audio = audiorecorder(
|
94 |
+
start_prompt="开始录音", stop_prompt="停止录音", pause_prompt="", show_visualizer=True, key=None
|
95 |
+
)
|
96 |
+
|
97 |
+
if len(audio) > 0:
|
98 |
+
|
99 |
+
# 将录音保存 wav 文件
|
100 |
+
save_tag = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".wav"
|
101 |
+
wav_path = str(Path(WEB_CONFIGS.ASR_WAV_SAVE_PATH).joinpath(save_tag).absolute())
|
102 |
+
|
103 |
+
# st.audio(audio.export().read()) # 前端显示
|
104 |
+
audio.export(wav_path, format="wav") # 使用 pydub 保存到 wav 文件
|
105 |
+
|
106 |
+
# To get audio properties, use pydub AudioSegment properties:
|
107 |
+
# st.write(
|
108 |
+
# f"Frame rate: {audio.frame_rate}, Frame width: {audio.frame_width}, Duration: {audio.duration_seconds} seconds"
|
109 |
+
# )
|
110 |
+
|
111 |
+
# 语音识别
|
112 |
+
asr_text = process_asr(ASR_HANDLER, wav_path)
|
113 |
+
|
114 |
+
# 删除过程文件
|
115 |
+
# Path(wav_path).unlink()
|
116 |
+
|
117 |
+
# 是否生成 TTS
|
118 |
+
if WEB_CONFIGS.ENABLE_TTS:
|
119 |
+
st.subheader("TTS 配置", divider="grey")
|
120 |
+
st.session_state.gen_tts_checkbox = st.toggle("生成语音", value=st.session_state.gen_tts_checkbox)
|
121 |
+
|
122 |
+
if WEB_CONFIGS.ENABLE_DIGITAL_HUMAN:
|
123 |
+
# 是否生成 数字人
|
124 |
+
st.subheader(f"数字人 配置", divider="grey")
|
125 |
+
st.session_state.gen_digital_human_checkbox = st.toggle(
|
126 |
+
"生成数字人视频", value=st.session_state.gen_digital_human_checkbox
|
127 |
+
)
|
128 |
+
|
129 |
+
if WEB_CONFIGS.ENABLE_AGENT:
|
130 |
+
# 是否使用 agent
|
131 |
+
st.subheader(f"Agent 配置", divider="grey")
|
132 |
+
with st.container(border=True):
|
133 |
+
st.markdown("**插件列表**")
|
134 |
+
st.button("结合天气查询到货时间", type="primary")
|
135 |
+
st.session_state.enable_agent_checkbox = st.toggle("使用 Agent 能力", value=st.session_state.enable_agent_checkbox)
|
136 |
+
|
137 |
+
st.subheader("页面切换", divider="grey")
|
138 |
+
st.button("返回科室页", on_click=on_btn_click, kwargs={"info": "返回科室页"})
|
139 |
+
|
140 |
+
st.subheader("对话设置", divider="grey")
|
141 |
+
st.button("清除对话历史", on_click=on_btn_click, kwargs={"info": "清除对话历史"})
|
142 |
+
|
143 |
+
# 模型配置
|
144 |
+
# st.markdown("## 模型配置")
|
145 |
+
# max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
|
146 |
+
# top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
|
147 |
+
# temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
|
148 |
+
|
149 |
+
return asr_text
|
150 |
+
|
151 |
+
|
152 |
+
def init_message_block(meta_instruction, user_avator, robot_avator):
|
153 |
+
|
154 |
+
# 在应用重新运行时显示聊天历史消息
|
155 |
+
for message in st.session_state.messages:
|
156 |
+
with st.chat_message(message["role"], avatar=message.get("avatar")):
|
157 |
+
st.markdown(message["content"])
|
158 |
+
|
159 |
+
if message.get("wav") is not None:
|
160 |
+
# 展示语音
|
161 |
+
print(f"Load wav {message['wav']}")
|
162 |
+
with open(message["wav"], "rb") as f_wav:
|
163 |
+
audio_bytes = f_wav.read()
|
164 |
+
st.audio(audio_bytes, format="audio/wav")
|
165 |
+
|
166 |
+
# 如果聊天历史为空,则显示产品介绍
|
167 |
+
if len(st.session_state.messages) == 0:
|
168 |
+
# 直接产品介绍
|
169 |
+
get_turbomind_response(
|
170 |
+
st.session_state.first_input,
|
171 |
+
meta_instruction,
|
172 |
+
user_avator,
|
173 |
+
robot_avator,
|
174 |
+
LLM_MODEL,
|
175 |
+
session_messages=st.session_state.messages,
|
176 |
+
add_session_msg=False,
|
177 |
+
first_input_str="",
|
178 |
+
enable_agent=False,
|
179 |
+
)
|
180 |
+
|
181 |
+
# 初始化按钮消息状态
|
182 |
+
if "button_msg" not in st.session_state:
|
183 |
+
st.session_state.button_msg = "x-x"
|
184 |
+
|
185 |
+
|
186 |
+
def process_message(user_avator, prompt, meta_instruction, robot_avator):
|
187 |
+
# Display user message in chat message container
|
188 |
+
with st.chat_message("user", avatar=user_avator):
|
189 |
+
st.markdown(prompt)
|
190 |
+
|
191 |
+
get_turbomind_response(
|
192 |
+
prompt,
|
193 |
+
meta_instruction,
|
194 |
+
user_avator,
|
195 |
+
robot_avator,
|
196 |
+
LLM_MODEL,
|
197 |
+
session_messages=st.session_state.messages,
|
198 |
+
add_session_msg=True,
|
199 |
+
first_input_str=st.session_state.first_input,
|
200 |
+
rag_retriever=RAG_RETRIEVER,
|
201 |
+
product_name=st.session_state.product_name,
|
202 |
+
enable_agent=st.session_state.enable_agent_checkbox,
|
203 |
+
# departure_place=st.session_state.departure_place,
|
204 |
+
# delivery_company_name=st.session_state.delivery_company_name,
|
205 |
+
)
|
206 |
+
|
207 |
+
|
208 |
+
def main(meta_instruction):
|
209 |
+
|
210 |
+
# 检查页面切换状态并进行切换
|
211 |
+
if st.session_state.page_switch != st.session_state.current_page:
|
212 |
+
st.switch_page(st.session_state.page_switch)
|
213 |
+
|
214 |
+
# 页面标题
|
215 |
+
st.title("智能医导大模型")
|
216 |
+
|
217 |
+
# 说明
|
218 |
+
st.info(
|
219 |
+
"本项目是基于人工智能的文字、语音、视频生成领域搭建的智能医导大模型。用户被授予使用此工具创建文字、语音、视频的自由,但用户在使用过程中应该遵守当地法律,并负责任地使用。开发人员不对用户可能的不当使用承担任何责任。",
|
220 |
+
icon="❗",
|
221 |
+
)
|
222 |
+
|
223 |
+
# 初始化侧边栏
|
224 |
+
asr_text = init_sidebar()
|
225 |
+
|
226 |
+
# 初始化聊天历史记录
|
227 |
+
if "messages" not in st.session_state:
|
228 |
+
st.session_state.messages = []
|
229 |
+
|
230 |
+
message_col = None
|
231 |
+
if st.session_state.gen_digital_human_checkbox and WEB_CONFIGS.ENABLE_DIGITAL_HUMAN:
|
232 |
+
|
233 |
+
with st.container():
|
234 |
+
message_col, video_col = st.columns([0.6, 0.4])
|
235 |
+
|
236 |
+
with video_col:
|
237 |
+
# 创建 empty 控件
|
238 |
+
st.session_state.video_placeholder = st.empty()
|
239 |
+
with st.session_state.video_placeholder.container():
|
240 |
+
show_video(st.session_state.digital_human_video_path, autoplay=True, loop=True, muted=True)
|
241 |
+
|
242 |
+
with message_col:
|
243 |
+
init_message_block(meta_instruction, WEB_CONFIGS.USER_AVATOR, WEB_CONFIGS.ROBOT_AVATOR)
|
244 |
+
else:
|
245 |
+
init_message_block(meta_instruction, WEB_CONFIGS.USER_AVATOR, WEB_CONFIGS.ROBOT_AVATOR)
|
246 |
+
|
247 |
+
# 输入框显示提示信息
|
248 |
+
hint_msg = "你好,你可以向我提出任何关于就诊的问题,我将竭诚为您服务"
|
249 |
+
if st.session_state.button_msg != "x-x":
|
250 |
+
prompt = st.session_state.button_msg
|
251 |
+
st.session_state.button_msg = "x-x"
|
252 |
+
st.chat_input(hint_msg)
|
253 |
+
elif asr_text != "" and st.session_state.asr_text_cache != asr_text:
|
254 |
+
prompt = asr_text
|
255 |
+
st.chat_input(hint_msg)
|
256 |
+
st.session_state.asr_text_cache = asr_text
|
257 |
+
else:
|
258 |
+
prompt = st.chat_input(hint_msg)
|
259 |
+
|
260 |
+
# 接收用户输入
|
261 |
+
if prompt:
|
262 |
+
|
263 |
+
if message_col is None:
|
264 |
+
process_message(WEB_CONFIGS.USER_AVATOR, prompt, meta_instruction, WEB_CONFIGS.ROBOT_AVATOR)
|
265 |
+
else:
|
266 |
+
# 数字人启动,页面会分块,放入信息块中
|
267 |
+
with message_col:
|
268 |
+
process_message(WEB_CONFIGS.USER_AVATOR, prompt, meta_instruction, WEB_CONFIGS.ROBOT_AVATOR)
|
269 |
+
|
270 |
+
|
271 |
+
# st.sidebar.page_link("app.py", label="商品页")
|
272 |
+
# st.sidebar.page_link("./pages/selling_page.py", label="主播卖货", disabled=True)
|
273 |
+
|
274 |
+
# META_INSTRUCTION
|
275 |
+
print("into sales page")
|
276 |
+
st.session_state.current_page = "pages/selling_page.py"
|
277 |
+
|
278 |
+
if "sales_info" not in st.session_state or st.session_state.sales_info == "":
|
279 |
+
st.session_state.page_switch = "app.py"
|
280 |
+
st.switch_page("app.py")
|
281 |
+
|
282 |
+
main((st.session_state.sales_info))
|