import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer # 定义支持的模型列表 MODELS = { "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1", "Llama 2 (7B)": "meta-llama/Llama-2-7b-chat-hf", "DBRX Instruct": "databricks/dbrx-instruct", "XVERSE-MoE-A4.2B": "xverse/XVERSE-MoE-A4.2B", "Gemma (7B)": "google/gemma-7b", "CPM-Bee 10B": "openbmb/cpm-bee-10b", "ChatGLM3-6B": "THUDM/chatglm3-6b", "Yi-34B-Chat": "01-ai/Yi-34B-Chat", "Mistral-7B": "mistralai/Mistral-7B-v0.1", "Phi-3": "microsoft/phi-3" } # 默认回退模型 DEFAULT_MODEL = "Mistral-7B" # 加载模型和分词器 def load_model(model_name): try: model_path = MODELS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) return tokenizer, model, model_name except Exception as e: print(f"加载模型 {model_name} 失败: {e}") # 回退到默认模型 model_path = MODELS[DEFAULT_MODEL] tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) return tokenizer, model, DEFAULT_MODEL # 定义生成文本的函数 def generate_text(model_name, system_prompt, user_input): # 加载模型和分词器 tokenizer, model, loaded_model_name = load_model(model_name) # 如果回退到默认模型,提示用户 if loaded_model_name != model_name: gr.Warning(f"模型 {model_name} 加载失败,已回退到默认模型 {DEFAULT_MODEL}。") # 将系统提示词和用户输入结合 full_prompt = f"{system_prompt}\n\n{user_input}" # 使用模型生成文本 inputs = tokenizer(full_prompt, return_tensors="pt") outputs = model.generate(**inputs, max_length=200) return tokenizer.decode(outputs[0], skip_special_tokens=True) # 创建 Gradio 界面 def create_interface(): with gr.Blocks() as demo: gr.Markdown("# 多模型文本生成器") with gr.Row(): model_selector = gr.Dropdown( label="选择模型", choices=list(MODELS.keys()), value="Mixtral 8x7B" ) with gr.Row(): system_prompt = gr.Textbox(label="系统提示词", placeholder="Enter the system prompt here...") with gr.Row(): user_input = gr.Textbox(label="用户输入", placeholder="Enter your input here...") with gr.Row(): output = gr.Textbox(label="生成结果") with gr.Row(): submit_button = gr.Button("生成") # 绑定事件 submit_button.click( fn=generate_text, inputs=[model_selector, system_prompt, user_input], outputs=output ) return demo # 启动应用 if __name__ == "__main__": demo = create_interface() demo.launch()