import random
import re
import time
import numpy as np
import streamlit as st
import torch
st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed")
# 在文件开头的 CSS 样式中修改按钮样式
st.markdown("""
""", unsafe_allow_html=True)
system_prompt = []
device = "cuda" if torch.cuda.is_available() else "cpu"
def process_assistant_content(content):
if 'R1' not in MODEL_PATHS[selected_model][1]:
return content
if '' in content and '' in content:
content = re.sub(r'()(.*?)()',
r'推理内容(展开)
\2 ',
content,
flags=re.DOTALL)
if '' in content and '' not in content:
content = re.sub(r'(.*?)$',
r'推理中...
\1 ',
content,
flags=re.DOTALL)
if '' not in content and '' in content:
content = re.sub(r'(.*?)',
r'推理内容(展开)
\1 ',
content,
flags=re.DOTALL)
return content
@st.cache_resource
def load_model_tokenizer(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
trust_remote_code=True
)
model = model.eval().to(device)
return model, tokenizer
def clear_chat_messages():
del st.session_state.messages
del st.session_state.chat_messages
def init_chat_messages():
if "messages" in st.session_state:
for i, message in enumerate(st.session_state.messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=image_url):
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
# 在消息内容下方添加按钮
if st.button("🗑", key=f"delete_{i}"):
st.session_state.messages.pop(i)
st.session_state.messages.pop(i - 1)
st.session_state.chat_messages.pop(i)
st.session_state.chat_messages.pop(i - 1)
st.rerun()
else:
st.markdown(
f'
',
unsafe_allow_html=True)
else:
st.session_state.messages = []
st.session_state.chat_messages = []
return st.session_state.messages
# 添加这两个辅助函数
def regenerate_answer(index):
st.session_state.messages.pop()
st.session_state.chat_messages.pop()
st.rerun()
def delete_conversation(index):
st.session_state.messages.pop(index)
st.session_state.messages.pop(index - 1)
st.session_state.chat_messages.pop(index)
st.session_state.chat_messages.pop(index - 1)
st.rerun()
# 侧边栏模型选择
st.sidebar.title("模型设定调整")
st.sidebar.text("【注】训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减")
st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2)
# st.session_state.history_chat_num = 0
st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1)
st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01)
st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01)
# 模型路径映射
MODEL_PATHS = {
"MiniMind2-R1 (0.1B)": ["./MiniMind2-R1", "MiniMind2-R1"],
"MiniMind2 (0.1B)": ["./MiniMind2", "MiniMind2"],
}
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=0) # 默认选择 MiniMind2
model_path = MODEL_PATHS[selected_model][0]
slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}"
image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true"
st.markdown(
f''
'
'
f'
![]({image_url})
'
f'
{slogan}'
'
'
'
内容完全由AI生成,请务必仔细甄别
Content AI-generated, please discern with care'
'
',
unsafe_allow_html=True
)
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
model, tokenizer = load_model_tokenizer(model_path)
# 初始化消息列表
if "messages" not in st.session_state:
st.session_state.messages = []
st.session_state.chat_messages = []
# Use session state messages
messages = st.session_state.messages
# 在显示历史消息的循环中
for i, message in enumerate(messages):
if message["role"] == "assistant":
with st.chat_message("assistant", avatar=image_url):
st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)
if st.button("×", key=f"delete_{i}"):
# 删除当前消息及其之后的所有消息
st.session_state.messages = st.session_state.messages[:i - 1]
st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]
st.rerun()
else:
st.markdown(
f'',
unsafe_allow_html=True)
# 处理新的输入或重新生成
prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")
# 检查是否需要重新生成
if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:
prompt = st.session_state.last_user_message
regenerate_index = st.session_state.regenerate_index # 获取重新生成的位置
# 清除所有重新生成相关的状态
delattr(st.session_state, 'regenerate')
delattr(st.session_state, 'last_user_message')
delattr(st.session_state, 'regenerate_index')
if prompt:
st.markdown(
f'',
unsafe_allow_html=True)
messages.append({"role": "user", "content": prompt})
st.session_state.chat_messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant", avatar=image_url):
placeholder = st.empty()
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[
-(st.session_state.history_chat_num + 1):]
new_prompt = tokenizer.apply_chat_template(
st.session_state.chat_messages,
tokenize=False,
add_generation_prompt=True
)[-(st.session_state.max_new_tokens - 1):]
x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0)
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
temperature=st.session_state.temperature,
top_p=st.session_state.top_p, stream=True)
try:
for y in res_y:
answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
if (answer and answer[-1] == '�') or not answer:
continue
placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)
except StopIteration:
print("No answer")
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})
st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
with st.empty():
if st.button("×", key=f"delete_{len(messages) - 1}"):
st.session_state.messages = st.session_state.messages[:-2]
st.session_state.chat_messages = st.session_state.chat_messages[:-2]
st.rerun()
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer
main()