Spaces:
Running
Running
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import gradio as gr | |
# 加载指令模型 | |
model = AutoModelForCausalLM.from_pretrained( | |
"MediaTek-Research/Breeze-7B-Instruct-v1_0", | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
# 加载分词器 | |
tokenizer = AutoTokenizer.from_pretrained("MediaTek-Research/Breeze-7B-Instruct-v1_0") | |
# 定义SYS_PROMPT | |
SYS_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan." | |
# 定义对话历史记录变量 | |
chat_history = [] | |
def generate_response(user_input): | |
global chat_history | |
# 将用户输入添加到对话历史记录 | |
chat_history.append({"role": "user", "content": user_input}) | |
# 应用聊天模板 | |
prompt = tokenizer.apply_chat_template(chat_history, tokenize=False) | |
full_prompt = f"<s>{SYS_PROMPT} [INST] {prompt} [/INST]" | |
# 生成文本 | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_new_tokens=128, | |
top_p=0.95, | |
top_k=50, | |
repetition_penalty=1.1, | |
temperature=0.7, | |
do_sample=True, # 启用 sample-based 生成模式 | |
attention_mask=inputs["attention_mask"], # 设置 attention mask | |
) | |
# 解码输出 | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# 将生成的文本添加到对话历史记录 | |
chat_history.append({"role": "assistant", "content": generated_text}) | |
# 只保留最新的五个问答对话 | |
if len(chat_history) > 10: | |
chat_history = chat_history[-10:] | |
return generated_text | |
# 创建Gradio界面 | |
def chat_interface(user_input, history): | |
response = generate_response(user_input) | |
history.append({"role": "user", "content": user_input}) | |
history.append({"role": "assistant", "content": response}) | |
return "", history | |
iface = gr.Blocks() | |
with iface: | |
gr.Markdown("# 醫療問答助手\n這是一個基於 MediaTek-Research/Breeze-7B-Instruct-v1_0 模型的醫療問答助手。") | |
chatbot = gr.Chatbot(type="messages") # 使用OpenAI风格的字典格式 | |
with gr.Row(): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="請輸入你的問題...", | |
lines=1, | |
) | |
txt.submit(chat_interface, [txt, chatbot], [txt, chatbot]) | |
# 启动Gradio界面并共享链接 | |
iface.launch(share=True) | |