File size: 2,966 Bytes
a43a412
 
 
1bd5c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a43a412
1bd5c96
 
 
 
 
 
 
 
 
 
 
 
 
a43a412
 
1bd5c96
 
 
 
 
 
 
 
a43a412
 
 
 
 
 
 
 
 
1bd5c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a43a412
 
1bd5c96
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# 定义支持的模型列表
MODELS = {
    "Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "Llama 2 (7B)": "meta-llama/Llama-2-7b-chat-hf",
    "DBRX Instruct": "databricks/dbrx-instruct",
    "XVERSE-MoE-A4.2B": "xverse/XVERSE-MoE-A4.2B",
    "Gemma (7B)": "google/gemma-7b",
    "CPM-Bee 10B": "openbmb/cpm-bee-10b",
    "ChatGLM3-6B": "THUDM/chatglm3-6b",
    "Yi-34B-Chat": "01-ai/Yi-34B-Chat",
    "Mistral-7B": "mistralai/Mistral-7B-v0.1",
    "Phi-3": "microsoft/phi-3"
}

# 默认回退模型
DEFAULT_MODEL = "Mistral-7B"

# 加载模型和分词器
def load_model(model_name):
    try:
        model_path = MODELS[model_name]
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(model_path)
        return tokenizer, model, model_name
    except Exception as e:
        print(f"加载模型 {model_name} 失败: {e}")
        # 回退到默认模型
        model_path = MODELS[DEFAULT_MODEL]
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(model_path)
        return tokenizer, model, DEFAULT_MODEL

# 定义生成文本的函数
def generate_text(model_name, system_prompt, user_input):
    # 加载模型和分词器
    tokenizer, model, loaded_model_name = load_model(model_name)
    
    # 如果回退到默认模型,提示用户
    if loaded_model_name != model_name:
        gr.Warning(f"模型 {model_name} 加载失败,已回退到默认模型 {DEFAULT_MODEL}。")
    
    # 将系统提示词和用户输入结合
    full_prompt = f"{system_prompt}\n\n{user_input}"
    
    # 使用模型生成文本
    inputs = tokenizer(full_prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=200)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# 创建 Gradio 界面
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# 多模型文本生成器")
        with gr.Row():
            model_selector = gr.Dropdown(
                label="选择模型",
                choices=list(MODELS.keys()),
                value="Mixtral 8x7B"
            )
        with gr.Row():
            system_prompt = gr.Textbox(label="系统提示词", placeholder="Enter the system prompt here...")
        with gr.Row():
            user_input = gr.Textbox(label="用户输入", placeholder="Enter your input here...")
        with gr.Row():
            output = gr.Textbox(label="生成结果")
        with gr.Row():
            submit_button = gr.Button("生成")
        
        # 绑定事件
        submit_button.click(
            fn=generate_text,
            inputs=[model_selector, system_prompt, user_input],
            outputs=output
        )
    
    return demo

# 启动应用
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()