FaYo commited on
Commit
9f68218
·
1 Parent(s): 5661e58
finetune_configs/internlm_chat_7b_qlora_alpace_e3.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from datasets import load_dataset
4
+ from mmengine.dataset import DefaultSampler
5
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6
+ LoggerHook, ParamSchedulerHook)
7
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8
+ from peft import LoraConfig
9
+ from torch.optim import AdamW
10
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
11
+ BitsAndBytesConfig)
12
+
13
+ from xtuner.dataset import process_hf_dataset
14
+ from xtuner.dataset.collate_fns import default_collate_fn
15
+ from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
16
+ from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
17
+ VarlenAttnArgsToMessageHubHook)
18
+ from xtuner.engine.runner import TrainLoop
19
+ from xtuner.model import SupervisedFinetune
20
+ from xtuner.parallel.sequence import SequenceParallelSampler
21
+ from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
22
+
23
+ #######################################################################
24
+ # PART 1 Settings #
25
+ #######################################################################
26
+ # Model
27
+ pretrained_model_name_or_path = '/group_share/lntelligent-Medical-Guidance-Large-Model/InternLM/XTuner/model/internlm2-chat-7b'
28
+ use_varlen_attn = False
29
+
30
+ # Data
31
+ alpaca_en_path = 'dataset/gen_dataset/train_dataset/90_train.jsonl'
32
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
33
+ max_length = 2048
34
+ pack_to_max_length = True
35
+
36
+ # parallel
37
+ sequence_parallel_size = 1
38
+
39
+ # Scheduler & Optimizer
40
+ batch_size = 1 # per_device
41
+ accumulative_counts = 16
42
+ accumulative_counts *= sequence_parallel_size
43
+ dataloader_num_workers = 0
44
+ max_epochs = 3
45
+ optim_type = AdamW
46
+ lr = 2e-4
47
+ betas = (0.9, 0.999)
48
+ weight_decay = 0
49
+ max_norm = 1 # grad clip
50
+ warmup_ratio = 0.03
51
+
52
+ # Save
53
+ save_steps = 500
54
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
55
+
56
+ # Evaluate the generation performance during the training
57
+ evaluation_freq = 500
58
+ SYSTEM = SYSTEM_TEMPLATE.alpaca
59
+ evaluation_inputs = [
60
+ '请介绍一下你自己', 'Please introduce yourself'
61
+ ]
62
+
63
+ #######################################################################
64
+ # PART 2 Model & Tokenizer #
65
+ #######################################################################
66
+ tokenizer = dict(
67
+ type=AutoTokenizer.from_pretrained,
68
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
69
+ trust_remote_code=True,
70
+ padding_side='right')
71
+
72
+ model = dict(
73
+ type=SupervisedFinetune,
74
+ use_varlen_attn=use_varlen_attn,
75
+ llm=dict(
76
+ type=AutoModelForCausalLM.from_pretrained,
77
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
78
+ trust_remote_code=True,
79
+ torch_dtype=torch.float16,
80
+ quantization_config=dict(
81
+ type=BitsAndBytesConfig,
82
+ load_in_4bit=True,
83
+ load_in_8bit=False,
84
+ llm_int8_threshold=6.0,
85
+ llm_int8_has_fp16_weight=False,
86
+ bnb_4bit_compute_dtype=torch.float16,
87
+ bnb_4bit_use_double_quant=True,
88
+ bnb_4bit_quant_type='nf4')),
89
+ lora=dict(
90
+ type=LoraConfig,
91
+ r=64,
92
+ lora_alpha=16,
93
+ lora_dropout=0.1,
94
+ bias='none',
95
+ task_type='CAUSAL_LM'))
96
+
97
+ #######################################################################
98
+ # PART 3 Dataset & Dataloader #
99
+ #######################################################################
100
+ alpaca_en = dict(
101
+ type=process_hf_dataset,
102
+ dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
103
+ tokenizer=tokenizer,
104
+ max_length=max_length,
105
+ dataset_map_fn=None,
106
+ template_map_fn=dict(
107
+ type=template_map_fn_factory, template=prompt_template),
108
+ remove_unused_columns=True,
109
+ shuffle_before_pack=True,
110
+ pack_to_max_length=pack_to_max_length,
111
+ use_varlen_attn=use_varlen_attn)
112
+
113
+ sampler = SequenceParallelSampler \
114
+ if sequence_parallel_size > 1 else DefaultSampler
115
+ train_dataloader = dict(
116
+ batch_size=batch_size,
117
+ num_workers=dataloader_num_workers,
118
+ dataset=alpaca_en,
119
+ sampler=dict(type=sampler, shuffle=True),
120
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
121
+
122
+ #######################################################################
123
+ # PART 4 Scheduler & Optimizer #
124
+ #######################################################################
125
+ # optimizer
126
+ optim_wrapper = dict(
127
+ type=AmpOptimWrapper,
128
+ optimizer=dict(
129
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
130
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
131
+ accumulative_counts=accumulative_counts,
132
+ loss_scale='dynamic',
133
+ dtype='float16')
134
+
135
+ # learning policy
136
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
137
+ param_scheduler = [
138
+ dict(
139
+ type=LinearLR,
140
+ start_factor=1e-5,
141
+ by_epoch=True,
142
+ begin=0,
143
+ end=warmup_ratio * max_epochs,
144
+ convert_to_iter_based=True),
145
+ dict(
146
+ type=CosineAnnealingLR,
147
+ eta_min=0.0,
148
+ by_epoch=True,
149
+ begin=warmup_ratio * max_epochs,
150
+ end=max_epochs,
151
+ convert_to_iter_based=True)
152
+ ]
153
+
154
+ # train, val, test setting
155
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
156
+
157
+ #######################################################################
158
+ # PART 5 Runtime #
159
+ #######################################################################
160
+ # Log the dialogue periodically during the training process, optional
161
+ custom_hooks = [
162
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
163
+ dict(
164
+ type=EvaluateChatHook,
165
+ tokenizer=tokenizer,
166
+ every_n_iters=evaluation_freq,
167
+ evaluation_inputs=evaluation_inputs,
168
+ system=SYSTEM,
169
+ prompt_template=prompt_template)
170
+ ]
171
+
172
+ if use_varlen_attn:
173
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
174
+
175
+ # configure default hooks
176
+ default_hooks = dict(
177
+ # record the time of every iteration.
178
+ timer=dict(type=IterTimerHook),
179
+ # print log every 10 iterations.
180
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
181
+ # enable the parameter scheduler.
182
+ param_scheduler=dict(type=ParamSchedulerHook),
183
+ # save checkpoint per `save_steps`.
184
+ checkpoint=dict(
185
+ type=CheckpointHook,
186
+ by_epoch=False,
187
+ interval=save_steps,
188
+ max_keep_ckpts=save_total_limit),
189
+ # set sampler seed in distributed evrionment.
190
+ sampler_seed=dict(type=DistSamplerSeedHook),
191
+ )
192
+
193
+ # configure environment
194
+ env_cfg = dict(
195
+ # whether to enable cudnn benchmark
196
+ cudnn_benchmark=False,
197
+ # set multi process parameters
198
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
199
+ # set distributed parameters
200
+ dist_cfg=dict(backend='nccl'),
201
+ )
202
+
203
+ # set visualizer
204
+ visualizer = None
205
+
206
+ # set log level
207
+ log_level = 'INFO'
208
+
209
+ # load from which checkpoint
210
+ load_from = None
211
+
212
+ # whether to resume training from the loaded checkpoint
213
+ resume = False
214
+
215
+ # Defaults to use random seed and disable `deterministic`
216
+ randomness = dict(seed=None, deterministic=False)
217
+
218
+ # set log processor
219
+ log_processor = dict(by_epoch=False)
pages/selling_page.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 2024.4.16
4
+ # @Author : HinGwenWong
5
+
6
+ import random
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+
10
+ import streamlit as st
11
+
12
+ from utils.web_configs import WEB_CONFIGS
13
+
14
+ # 设置页面配置,包括标题、图标、布局和菜单项
15
+ st.set_page_config(
16
+ page_title="智能医导大模型",
17
+ page_icon="🛒",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded",
20
+ menu_items={
21
+ "Get Help": "https://github.com/nhbdgtgefr/Intelligent-Medical-Guidance-Large-Model/tree/main",
22
+ "About": "# 智能医导大模型",
23
+ },
24
+ )
25
+
26
+ from audiorecorder import audiorecorder
27
+
28
+ from utils.asr.asr_worker import process_asr
29
+ from utils.digital_human.digital_human_worker import show_video
30
+ from utils.infer.lmdeploy_infer import get_turbomind_response
31
+ from utils.model_loader import ASR_HANDLER, LLM_MODEL, RAG_RETRIEVER
32
+ from utils.tools import resize_image
33
+
34
+
35
+ def on_btn_click(*args, **kwargs):
36
+ """
37
+ 处理按钮点击事件的函数。
38
+ """
39
+ if kwargs["info"] == "清除对话历史":
40
+ st.session_state.messages = []
41
+ elif kwargs["info"] == "返回科室页":
42
+ st.session_state.page_switch = "app.py"
43
+ else:
44
+ st.session_state.button_msg = kwargs["info"]
45
+
46
+
47
+ def init_sidebar():
48
+ """
49
+ 初始化侧边栏界面,展示商品信息,并提供操作按钮。
50
+ """
51
+ asr_text = ""
52
+ with st.sidebar:
53
+ # 标题
54
+ st.markdown("## 智能医导大模型")
55
+ st.markdown("[智能医导大模型](https://github.com/nhbdgtgefr/Intelligent-Medical-Guidance-Large-Model)")
56
+ st.subheader("功能点:", divider="grey")
57
+ # st.markdown(
58
+ # "1. 📜 **主播文案一键生成**\n2. 🚀 KV cache + Turbomind **推理加速**\n3. 📚 RAG **检索增强生成**\n4. 🔊 TTS **文字转语音**\n5. 🦸 **数字人生成**\n6. 🌐 **Agent 网络查询**\n7. 🎙️ **ASR 语音转文字**"
59
+ # )
60
+
61
+ st.subheader("目前讲解")
62
+ with st.container(height=400, border=True):
63
+ st.subheader(st.session_state.product_name)
64
+
65
+ image = resize_image(st.session_state.image_path, max_height=100)
66
+ st.image(image, channels="bgr")
67
+
68
+ st.subheader("科室特点", divider="grey")
69
+ st.markdown(st.session_state.hightlight)
70
+
71
+ want_to_buy_list = [
72
+ "我打算买了。",
73
+ "我准备入手了。",
74
+ "我决定要买了。",
75
+ "我准备下单了。",
76
+ "我将要购买这款产品。",
77
+ "我准备买下来了。",
78
+ "我准备将这个买下。",
79
+ "我准备要购买了。",
80
+ "我决定买下它。",
81
+ "我准备将其买下。",
82
+ ]
83
+ buy_flag = st.button("加入信息🛒", on_click=on_btn_click, kwargs={"info": random.choice(want_to_buy_list)})
84
+
85
+ # TODO 加入卖货信息
86
+ # 卖出 xxx 个
87
+ # 成交额
88
+
89
+ if WEB_CONFIGS.ENABLE_ASR:
90
+ Path(WEB_CONFIGS.ASR_WAV_SAVE_PATH).mkdir(parents=True, exist_ok=True)
91
+
92
+ st.subheader(f"语音输入", divider="grey")
93
+ audio = audiorecorder(
94
+ start_prompt="开始录音", stop_prompt="停止录音", pause_prompt="", show_visualizer=True, key=None
95
+ )
96
+
97
+ if len(audio) > 0:
98
+
99
+ # 将录音保存 wav 文件
100
+ save_tag = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + ".wav"
101
+ wav_path = str(Path(WEB_CONFIGS.ASR_WAV_SAVE_PATH).joinpath(save_tag).absolute())
102
+
103
+ # st.audio(audio.export().read()) # 前端显示
104
+ audio.export(wav_path, format="wav") # 使用 pydub 保存到 wav 文件
105
+
106
+ # To get audio properties, use pydub AudioSegment properties:
107
+ # st.write(
108
+ # f"Frame rate: {audio.frame_rate}, Frame width: {audio.frame_width}, Duration: {audio.duration_seconds} seconds"
109
+ # )
110
+
111
+ # 语音识别
112
+ asr_text = process_asr(ASR_HANDLER, wav_path)
113
+
114
+ # 删除过程文件
115
+ # Path(wav_path).unlink()
116
+
117
+ # 是否生成 TTS
118
+ if WEB_CONFIGS.ENABLE_TTS:
119
+ st.subheader("TTS 配置", divider="grey")
120
+ st.session_state.gen_tts_checkbox = st.toggle("生成语音", value=st.session_state.gen_tts_checkbox)
121
+
122
+ if WEB_CONFIGS.ENABLE_DIGITAL_HUMAN:
123
+ # 是否生成 数字人
124
+ st.subheader(f"数字人 配置", divider="grey")
125
+ st.session_state.gen_digital_human_checkbox = st.toggle(
126
+ "生成数字人视频", value=st.session_state.gen_digital_human_checkbox
127
+ )
128
+
129
+ if WEB_CONFIGS.ENABLE_AGENT:
130
+ # 是否使用 agent
131
+ st.subheader(f"Agent 配置", divider="grey")
132
+ with st.container(border=True):
133
+ st.markdown("**插件列表**")
134
+ st.button("结合天气查询到货时间", type="primary")
135
+ st.session_state.enable_agent_checkbox = st.toggle("使用 Agent 能力", value=st.session_state.enable_agent_checkbox)
136
+
137
+ st.subheader("页面切换", divider="grey")
138
+ st.button("返回科室页", on_click=on_btn_click, kwargs={"info": "返回科室页"})
139
+
140
+ st.subheader("对话设置", divider="grey")
141
+ st.button("清除对话历史", on_click=on_btn_click, kwargs={"info": "清除对话历史"})
142
+
143
+ # 模型配置
144
+ # st.markdown("## 模型配置")
145
+ # max_length = st.slider("Max Length", min_value=8, max_value=32768, value=32768)
146
+ # top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
147
+ # temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
148
+
149
+ return asr_text
150
+
151
+
152
+ def init_message_block(meta_instruction, user_avator, robot_avator):
153
+
154
+ # 在应用重新运行时显示聊天历史消息
155
+ for message in st.session_state.messages:
156
+ with st.chat_message(message["role"], avatar=message.get("avatar")):
157
+ st.markdown(message["content"])
158
+
159
+ if message.get("wav") is not None:
160
+ # 展示语音
161
+ print(f"Load wav {message['wav']}")
162
+ with open(message["wav"], "rb") as f_wav:
163
+ audio_bytes = f_wav.read()
164
+ st.audio(audio_bytes, format="audio/wav")
165
+
166
+ # 如果聊天历史为空,则显示产品介绍
167
+ if len(st.session_state.messages) == 0:
168
+ # 直接产品介绍
169
+ get_turbomind_response(
170
+ st.session_state.first_input,
171
+ meta_instruction,
172
+ user_avator,
173
+ robot_avator,
174
+ LLM_MODEL,
175
+ session_messages=st.session_state.messages,
176
+ add_session_msg=False,
177
+ first_input_str="",
178
+ enable_agent=False,
179
+ )
180
+
181
+ # 初始化按钮消息状态
182
+ if "button_msg" not in st.session_state:
183
+ st.session_state.button_msg = "x-x"
184
+
185
+
186
+ def process_message(user_avator, prompt, meta_instruction, robot_avator):
187
+ # Display user message in chat message container
188
+ with st.chat_message("user", avatar=user_avator):
189
+ st.markdown(prompt)
190
+
191
+ get_turbomind_response(
192
+ prompt,
193
+ meta_instruction,
194
+ user_avator,
195
+ robot_avator,
196
+ LLM_MODEL,
197
+ session_messages=st.session_state.messages,
198
+ add_session_msg=True,
199
+ first_input_str=st.session_state.first_input,
200
+ rag_retriever=RAG_RETRIEVER,
201
+ product_name=st.session_state.product_name,
202
+ enable_agent=st.session_state.enable_agent_checkbox,
203
+ # departure_place=st.session_state.departure_place,
204
+ # delivery_company_name=st.session_state.delivery_company_name,
205
+ )
206
+
207
+
208
+ def main(meta_instruction):
209
+
210
+ # 检查页面切换状态并进行切换
211
+ if st.session_state.page_switch != st.session_state.current_page:
212
+ st.switch_page(st.session_state.page_switch)
213
+
214
+ # 页面标题
215
+ st.title("智能医导大模型")
216
+
217
+ # 说明
218
+ st.info(
219
+ "本项目是基于人工智能的文字、语音、视频生成领域搭建的智能医导大模型。用户被授予使用此工具创建文字、语音、视频的自由,但用户在使用过程中应该遵守当地法律,并负责任地使用。开发人员不对用户可能的不当使用承担任何责任。",
220
+ icon="❗",
221
+ )
222
+
223
+ # 初始化侧边栏
224
+ asr_text = init_sidebar()
225
+
226
+ # 初始化聊天历史记录
227
+ if "messages" not in st.session_state:
228
+ st.session_state.messages = []
229
+
230
+ message_col = None
231
+ if st.session_state.gen_digital_human_checkbox and WEB_CONFIGS.ENABLE_DIGITAL_HUMAN:
232
+
233
+ with st.container():
234
+ message_col, video_col = st.columns([0.6, 0.4])
235
+
236
+ with video_col:
237
+ # 创建 empty 控件
238
+ st.session_state.video_placeholder = st.empty()
239
+ with st.session_state.video_placeholder.container():
240
+ show_video(st.session_state.digital_human_video_path, autoplay=True, loop=True, muted=True)
241
+
242
+ with message_col:
243
+ init_message_block(meta_instruction, WEB_CONFIGS.USER_AVATOR, WEB_CONFIGS.ROBOT_AVATOR)
244
+ else:
245
+ init_message_block(meta_instruction, WEB_CONFIGS.USER_AVATOR, WEB_CONFIGS.ROBOT_AVATOR)
246
+
247
+ # 输入框显示提示信息
248
+ hint_msg = "你好,你可以向我提出任何关于就诊的问题,我将竭诚为您服务"
249
+ if st.session_state.button_msg != "x-x":
250
+ prompt = st.session_state.button_msg
251
+ st.session_state.button_msg = "x-x"
252
+ st.chat_input(hint_msg)
253
+ elif asr_text != "" and st.session_state.asr_text_cache != asr_text:
254
+ prompt = asr_text
255
+ st.chat_input(hint_msg)
256
+ st.session_state.asr_text_cache = asr_text
257
+ else:
258
+ prompt = st.chat_input(hint_msg)
259
+
260
+ # 接收用户输入
261
+ if prompt:
262
+
263
+ if message_col is None:
264
+ process_message(WEB_CONFIGS.USER_AVATOR, prompt, meta_instruction, WEB_CONFIGS.ROBOT_AVATOR)
265
+ else:
266
+ # 数字人启动,页面会分块,放入信息块中
267
+ with message_col:
268
+ process_message(WEB_CONFIGS.USER_AVATOR, prompt, meta_instruction, WEB_CONFIGS.ROBOT_AVATOR)
269
+
270
+
271
+ # st.sidebar.page_link("app.py", label="商品页")
272
+ # st.sidebar.page_link("./pages/selling_page.py", label="主播卖货", disabled=True)
273
+
274
+ # META_INSTRUCTION
275
+ print("into sales page")
276
+ st.session_state.current_page = "pages/selling_page.py"
277
+
278
+ if "sales_info" not in st.session_state or st.session_state.sales_info == "":
279
+ st.session_state.page_switch = "app.py"
280
+ st.switch_page("app.py")
281
+
282
+ main((st.session_state.sales_info))