File size: 2,966 Bytes
a43a412 1bd5c96 a43a412 1bd5c96 a43a412 1bd5c96 a43a412 1bd5c96 a43a412 1bd5c96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# 定义支持的模型列表
MODELS = {
"Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"Llama 2 (7B)": "meta-llama/Llama-2-7b-chat-hf",
"DBRX Instruct": "databricks/dbrx-instruct",
"XVERSE-MoE-A4.2B": "xverse/XVERSE-MoE-A4.2B",
"Gemma (7B)": "google/gemma-7b",
"CPM-Bee 10B": "openbmb/cpm-bee-10b",
"ChatGLM3-6B": "THUDM/chatglm3-6b",
"Yi-34B-Chat": "01-ai/Yi-34B-Chat",
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
"Phi-3": "microsoft/phi-3"
}
# 默认回退模型
DEFAULT_MODEL = "Mistral-7B"
# 加载模型和分词器
def load_model(model_name):
try:
model_path = MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
return tokenizer, model, model_name
except Exception as e:
print(f"加载模型 {model_name} 失败: {e}")
# 回退到默认模型
model_path = MODELS[DEFAULT_MODEL]
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
return tokenizer, model, DEFAULT_MODEL
# 定义生成文本的函数
def generate_text(model_name, system_prompt, user_input):
# 加载模型和分词器
tokenizer, model, loaded_model_name = load_model(model_name)
# 如果回退到默认模型,提示用户
if loaded_model_name != model_name:
gr.Warning(f"模型 {model_name} 加载失败,已回退到默认模型 {DEFAULT_MODEL}。")
# 将系统提示词和用户输入结合
full_prompt = f"{system_prompt}\n\n{user_input}"
# 使用模型生成文本
inputs = tokenizer(full_prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=200)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 创建 Gradio 界面
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# 多模型文本生成器")
with gr.Row():
model_selector = gr.Dropdown(
label="选择模型",
choices=list(MODELS.keys()),
value="Mixtral 8x7B"
)
with gr.Row():
system_prompt = gr.Textbox(label="系统提示词", placeholder="Enter the system prompt here...")
with gr.Row():
user_input = gr.Textbox(label="用户输入", placeholder="Enter your input here...")
with gr.Row():
output = gr.Textbox(label="生成结果")
with gr.Row():
submit_button = gr.Button("生成")
# 绑定事件
submit_button.click(
fn=generate_text,
inputs=[model_selector, system_prompt, user_input],
outputs=output
)
return demo
# 启动应用
if __name__ == "__main__":
demo = create_interface()
demo.launch() |