|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
MODELS = { |
|
"Mixtral 8x7B": "mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
"Llama 2 (7B)": "meta-llama/Llama-2-7b-chat-hf", |
|
"DBRX Instruct": "databricks/dbrx-instruct", |
|
"XVERSE-MoE-A4.2B": "xverse/XVERSE-MoE-A4.2B", |
|
"Gemma (7B)": "google/gemma-7b", |
|
"CPM-Bee 10B": "openbmb/cpm-bee-10b", |
|
"ChatGLM3-6B": "THUDM/chatglm3-6b", |
|
"Yi-34B-Chat": "01-ai/Yi-34B-Chat", |
|
"Mistral-7B": "mistralai/Mistral-7B-v0.1", |
|
"Phi-3": "microsoft/phi-3" |
|
} |
|
|
|
|
|
DEFAULT_MODEL = "Mistral-7B" |
|
|
|
|
|
def load_model(model_name): |
|
try: |
|
model_path = MODELS[model_name] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
return tokenizer, model, model_name |
|
except Exception as e: |
|
print(f"加载模型 {model_name} 失败: {e}") |
|
|
|
model_path = MODELS[DEFAULT_MODEL] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
return tokenizer, model, DEFAULT_MODEL |
|
|
|
|
|
def generate_text(model_name, system_prompt, user_input): |
|
|
|
tokenizer, model, loaded_model_name = load_model(model_name) |
|
|
|
|
|
if loaded_model_name != model_name: |
|
gr.Warning(f"模型 {model_name} 加载失败,已回退到默认模型 {DEFAULT_MODEL}。") |
|
|
|
|
|
full_prompt = f"{system_prompt}\n\n{user_input}" |
|
|
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt") |
|
outputs = model.generate(**inputs, max_length=200) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 多模型文本生成器") |
|
with gr.Row(): |
|
model_selector = gr.Dropdown( |
|
label="选择模型", |
|
choices=list(MODELS.keys()), |
|
value="Mixtral 8x7B" |
|
) |
|
with gr.Row(): |
|
system_prompt = gr.Textbox(label="系统提示词", placeholder="Enter the system prompt here...") |
|
with gr.Row(): |
|
user_input = gr.Textbox(label="用户输入", placeholder="Enter your input here...") |
|
with gr.Row(): |
|
output = gr.Textbox(label="生成结果") |
|
with gr.Row(): |
|
submit_button = gr.Button("生成") |
|
|
|
|
|
submit_button.click( |
|
fn=generate_text, |
|
inputs=[model_selector, system_prompt, user_input], |
|
outputs=output |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |