|
import gradio as gr |
|
|
|
import copy |
|
import random |
|
import os |
|
import requests |
|
import time |
|
import sys |
|
|
|
from huggingface_hub import snapshot_download |
|
from llama_cpp import Llama |
|
|
|
|
|
SYSTEM_PROMPT = '''You are a helpful, respectful and honest INTP-T AI Assistant named "Shi-Ci" in English or "兮辞" in Chinese. |
|
You are good at speaking English and Chinese. |
|
You are talking to a human User. If the question is meaningless, please explain the reason and don't share false information. |
|
You are based on SEA model, trained by "SSFW NLPark" team, not related to GPT, LLaMA, Meta, Mistral or OpenAI. |
|
Let's work this out in a step by step way to be sure we have the right answer.\n\n''' |
|
SYSTEM_TOKEN = 1587 |
|
USER_TOKEN = 2188 |
|
BOT_TOKEN = 12435 |
|
LINEBREAK_TOKEN = 13 |
|
|
|
|
|
ROLE_TOKENS = { |
|
"user": USER_TOKEN, |
|
"bot": BOT_TOKEN, |
|
"system": SYSTEM_TOKEN |
|
} |
|
|
|
|
|
def get_message_tokens(model, role, content): |
|
message_tokens = model.tokenize(content.encode("utf-8")) |
|
message_tokens.insert(1, ROLE_TOKENS[role]) |
|
message_tokens.insert(2, LINEBREAK_TOKEN) |
|
message_tokens.append(model.token_eos()) |
|
return message_tokens |
|
|
|
|
|
def get_system_tokens(model): |
|
system_message = {"role": "system", "content": SYSTEM_PROMPT} |
|
return get_message_tokens(model, **system_message) |
|
|
|
|
|
repo_name = "Cran-May/OpenSLIDE" |
|
model_name = "SLIDE.0.1.gguf" |
|
|
|
snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name) |
|
|
|
model = Llama( |
|
model_path=model_name, |
|
n_ctx=2000, |
|
n_parts=1, |
|
) |
|
|
|
max_new_tokens = 1500 |
|
|
|
def user(message, history): |
|
new_history = history + [[message, None]] |
|
return "", new_history |
|
|
|
|
|
def bot( |
|
history, |
|
system_prompt, |
|
top_p, |
|
top_k, |
|
temp |
|
): |
|
tokens = get_system_tokens(model)[:] |
|
tokens.append(LINEBREAK_TOKEN) |
|
|
|
for user_message, bot_message in history[:-1]: |
|
message_tokens = get_message_tokens(model=model, role="user", content=user_message) |
|
tokens.extend(message_tokens) |
|
if bot_message: |
|
message_tokens = get_message_tokens(model=model, role="bot", content=bot_message) |
|
tokens.extend(message_tokens) |
|
|
|
last_user_message = history[-1][0] |
|
message_tokens = get_message_tokens(model=model, role="user", content=last_user_message) |
|
tokens.extend(message_tokens) |
|
|
|
role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] |
|
tokens.extend(role_tokens) |
|
generator = model.generate( |
|
tokens, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temp=temp |
|
) |
|
|
|
partial_text = "" |
|
for i, token in enumerate(generator): |
|
if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens): |
|
break |
|
partial_text += model.detokenize([token]).decode("utf-8", "ignore") |
|
history[-1][1] = partial_text |
|
yield history |
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft() |
|
) as demo: |
|
gr.Markdown(f"""<h1><center>上师附外-兮辞·析辞-人工智能助理</center></h1>""") |
|
gr.Markdown(value="""这儿是一个中文模型的部署。 |
|
这是量化版兮辞·析辞的部署,具有 70亿 个参数,在 CPU 上运行。 |
|
SLIDE 是一种会话语言模型,在多种类型的语料库上进行训练。 |
|
本节目由上海师范大学附属外国语中学 NLPark 赞助播出""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
chatbot = gr.Chatbot(label="兮辞如是说").style(height=400) with gr.Column(): |
|
msg = gr.Textbox( |
|
label="来问问兮辞吧……", |
|
placeholder="兮辞折寿中……", |
|
show_label=False, |
|
).style(container=False) |
|
with gr.Column(): |
|
submit = gr.Button("Submit / 开凹!") |
|
stop = gr.Button("Stop / 全局时空断裂") |
|
clear = gr.Button("Clear / 打扫群内垃圾") |
|
with gr.Column(min_width=80, scale=1): |
|
with gr.Tab(label="设置参数"): |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
) |
|
top_k = gr.Slider( |
|
minimum=10, |
|
maximum=100, |
|
value=30, |
|
step=5, |
|
interactive=True, |
|
label="Top-k", |
|
) |
|
temp = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
value=0.2, |
|
step=0.01, |
|
interactive=True, |
|
label="情感温度" |
|
) |
|
with gr.Column(): |
|
system_prompt = gr.Textbox(label="系统提示词", placeholder="", value=SYSTEM_PROMPT, interactive=False) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"""警告:该模型可能会生成事实上或道德上不正确的文本。NLPark和兮辞对此不承担任何责任。""" |
|
) |
|
|
|
|
|
submit_event = msg.submit( |
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
system_prompt, |
|
top_p, |
|
top_k, |
|
temp |
|
], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
submit_click_event = submit.click( |
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).success( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
system_prompt, |
|
top_p, |
|
top_k, |
|
temp |
|
], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
|
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue(max_size=128, concurrency_count=1) |
|
demo.launch() |