import random import re import time import numpy as np import streamlit as st import torch st.set_page_config(page_title="MiniMind", initial_sidebar_state="collapsed") # 在文件开头的 CSS 样式中修改按钮样式 st.markdown(""" """, unsafe_allow_html=True) system_prompt = [] device = "cuda" if torch.cuda.is_available() else "cpu" def process_assistant_content(content): if 'R1' not in MODEL_PATHS[selected_model][1]: return content if '' in content and '' in content: content = re.sub(r'()(.*?)()', r'
推理内容(展开)\2
', content, flags=re.DOTALL) if '' in content and '' not in content: content = re.sub(r'(.*?)$', r'
推理中...\1
', content, flags=re.DOTALL) if '' not in content and '' in content: content = re.sub(r'(.*?)
', r'
推理内容(展开)\1
', content, flags=re.DOTALL) return content @st.cache_resource def load_model_tokenizer(model_path): model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( model_path, use_fast=False, trust_remote_code=True ) model = model.eval().to(device) return model, tokenizer def clear_chat_messages(): del st.session_state.messages del st.session_state.chat_messages def init_chat_messages(): if "messages" in st.session_state: for i, message in enumerate(st.session_state.messages): if message["role"] == "assistant": with st.chat_message("assistant", avatar=image_url): st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True) # 在消息内容下方添加按钮 if st.button("🗑", key=f"delete_{i}"): st.session_state.messages.pop(i) st.session_state.messages.pop(i - 1) st.session_state.chat_messages.pop(i) st.session_state.chat_messages.pop(i - 1) st.rerun() else: st.markdown( f'
{message["content"]}
', unsafe_allow_html=True) else: st.session_state.messages = [] st.session_state.chat_messages = [] return st.session_state.messages # 添加这两个辅助函数 def regenerate_answer(index): st.session_state.messages.pop() st.session_state.chat_messages.pop() st.rerun() def delete_conversation(index): st.session_state.messages.pop(index) st.session_state.messages.pop(index - 1) st.session_state.chat_messages.pop(index) st.session_state.chat_messages.pop(index - 1) st.rerun() # 侧边栏模型选择 st.sidebar.title("模型设定调整") st.sidebar.text("【注】训练数据偏差,增加上下文记忆时\n多轮对话(较单轮)容易出现能力衰减") st.session_state.history_chat_num = st.sidebar.slider("Number of Historical Dialogues", 0, 6, 0, step=2) # st.session_state.history_chat_num = 0 st.session_state.max_new_tokens = st.sidebar.slider("Max Sequence Length", 256, 8192, 8192, step=1) st.session_state.top_p = st.sidebar.slider("Top-P", 0.8, 0.99, 0.85, step=0.01) st.session_state.temperature = st.sidebar.slider("Temperature", 0.6, 1.2, 0.85, step=0.01) # 模型路径映射 MODEL_PATHS = { "MiniMind2-R1 (0.1B)": ["./MiniMind2-R1", "MiniMind2-R1"], "MiniMind2 (0.1B)": ["./MiniMind2", "MiniMind2"], } selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=0) # 默认选择 MiniMind2 model_path = MODEL_PATHS[selected_model][0] slogan = f"Hi, I'm {MODEL_PATHS[selected_model][1]}" image_url = "https://www.modelscope.cn/api/v1/studio/gongjy/MiniMind/repo?Revision=master&FilePath=images%2Flogo2.png&View=true" st.markdown( f'
' '
' f' ' f'{slogan}' '
' '内容完全由AI生成,请务必仔细甄别
Content AI-generated, please discern with care
' '
', unsafe_allow_html=True ) def setup_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def main(): model, tokenizer = load_model_tokenizer(model_path) # 初始化消息列表 if "messages" not in st.session_state: st.session_state.messages = [] st.session_state.chat_messages = [] # Use session state messages messages = st.session_state.messages # 在显示历史消息的循环中 for i, message in enumerate(messages): if message["role"] == "assistant": with st.chat_message("assistant", avatar=image_url): st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True) if st.button("×", key=f"delete_{i}"): # 删除当前消息及其之后的所有消息 st.session_state.messages = st.session_state.messages[:i - 1] st.session_state.chat_messages = st.session_state.chat_messages[:i - 1] st.rerun() else: st.markdown( f'
{message["content"]}
', unsafe_allow_html=True) # 处理新的输入或重新生成 prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息") # 检查是否需要重新生成 if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate: prompt = st.session_state.last_user_message regenerate_index = st.session_state.regenerate_index # 获取重新生成的位置 # 清除所有重新生成相关的状态 delattr(st.session_state, 'regenerate') delattr(st.session_state, 'last_user_message') delattr(st.session_state, 'regenerate_index') if prompt: st.markdown( f'
{prompt}
', unsafe_allow_html=True) messages.append({"role": "user", "content": prompt}) st.session_state.chat_messages.append({"role": "user", "content": prompt}) with st.chat_message("assistant", avatar=image_url): placeholder = st.empty() random_seed = random.randint(0, 2 ** 32 - 1) setup_seed(random_seed) st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[ -(st.session_state.history_chat_num + 1):] new_prompt = tokenizer.apply_chat_template( st.session_state.chat_messages, tokenize=False, add_generation_prompt=True )[-(st.session_state.max_new_tokens - 1):] x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0) with torch.no_grad(): res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens, temperature=st.session_state.temperature, top_p=st.session_state.top_p, stream=True) try: for y in res_y: answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True) if (answer and answer[-1] == '�') or not answer: continue placeholder.markdown(process_assistant_content(answer), unsafe_allow_html=True) except StopIteration: print("No answer") assistant_answer = answer.replace(new_prompt, "") messages.append({"role": "assistant", "content": assistant_answer}) st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer}) with st.empty(): if st.button("×", key=f"delete_{len(messages) - 1}"): st.session_state.messages = st.session_state.messages[:-2] st.session_state.chat_messages = st.session_state.chat_messages[:-2] st.rerun() if __name__ == "__main__": from transformers import AutoModelForCausalLM, AutoTokenizer main()