|
import os |
|
from openai import OpenAI |
|
import gradio as gr |
|
|
|
|
|
import socket |
|
hostname=socket.gethostname() |
|
IPAddr=socket.gethostbyname(hostname) |
|
print("Your Computer Name is:" + hostname) |
|
print("Your Computer IP Address is:" + IPAddr) |
|
|
|
|
|
DESCRIPTION = """ |
|
# Cloned from MediaTek Research Breeze-7B |
|
MediaTek Research Breeze-7B (hereinafter referred to as Breeze-7B) is a language model family that builds on top of [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), specifically intended for Traditional Chinese use. |
|
[Breeze-7B-Base](https://huggingface.co/MediaTek-Research/Breeze-7B-Base-v1_0) is the base model for the Breeze-7B series. |
|
It is suitable for use if you have substantial fine-tuning data to tune it for your specific use case. |
|
[Breeze-7B-Instruct](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-v1_0) derives from the base model Breeze-7B-Base, making the resulting model amenable to be used as-is for commonly seen tasks. |
|
|
|
This App is cloned from [Demo-MR-Breeze-7B](https://huggingface.co/spaces/MediaTek-Research/Demo-MR-Breeze-7B) |
|
|
|
""" |
|
|
|
LICENSE = """ |
|
""" |
|
|
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan." |
|
|
|
API_URL = os.environ.get("API_URL") |
|
TOKEN = os.environ.get("TOKEN") |
|
TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0" |
|
MODEL_NAME = os.environ.get("MODEL_NAME") |
|
MAX_SEC = 30 |
|
MAX_INPUT_LENGTH = 5000 |
|
|
|
|
|
def chat_with_openai(model_name, system_message, user_message, temperature=0.5, max_tokens=1024, top_p=0.5): |
|
client = OpenAI( |
|
base_url=os.path.join(API_URL, "v1/"), |
|
api_key=TOKEN |
|
) |
|
|
|
chat_completion = client.chat.completions.create( |
|
model=model_name, |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": system_message |
|
}, |
|
{ |
|
"role": "user", |
|
"content": user_message |
|
} |
|
], |
|
temperature=temperature, |
|
max_tokens=max_tokens, |
|
top_p=top_p, |
|
stream=True |
|
) |
|
|
|
for message in chat_completion: |
|
yield message.choices[0].delta.content |
|
|
|
def refusal_condition(query): |
|
|
|
|
|
query_remove_space = query.replace(' ', '').lower() |
|
is_including_tw = False |
|
for x in ['台灣', '台湾', 'taiwan', 'tw', '中華民國', '中华民国']: |
|
if x in query_remove_space: |
|
is_including_tw = True |
|
is_including_cn = False |
|
for x in ['中國', '中国', 'cn', 'china', '大陸', '內地', '大陆', '内地', '中華人民共和國', '中华人民共和国']: |
|
if x in query_remove_space: |
|
is_including_cn = True |
|
if is_including_tw and is_including_cn: |
|
return True |
|
|
|
for x in ['一個中國', '兩岸', '一中原則', '一中政策', '一个中国', '两岸', '一中原则']: |
|
if x in query_remove_space: |
|
return True |
|
|
|
return False |
|
|
|
with gr.Blocks() as demo: |
|
|
|
if API_URL is None: |
|
raise gr.Error("API_URL is not set as an environment variable.") |
|
if TOKEN is None: |
|
raise gr.Error("TOKEN is not set as an environment variable.") |
|
if MODEL_NAME is None: |
|
raise gr.Error("MODEL_NAME is not set as an environment variable.") |
|
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
system_prompt = gr.Textbox(label='System prompt', |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=1) |
|
|
|
with gr.Accordion(label='Advanced options', open=False): |
|
|
|
max_new_tokens = gr.Slider( |
|
label='Max new tokens', |
|
minimum=32, |
|
maximum=2048, |
|
step=1, |
|
value=1024, |
|
) |
|
temperature = gr.Slider( |
|
label='Temperature', |
|
minimum=0.01, |
|
maximum=0.5, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
top_p = gr.Slider( |
|
label='Top-p (nucleus sampling)', |
|
minimum=0.01, |
|
maximum=0.99, |
|
step=0.01, |
|
value=0.01, |
|
) |
|
|
|
chatbot = gr.Chatbot(show_copy_button=True, show_share_button=True, ) |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
container=False, |
|
show_label=False, |
|
placeholder='Type a message...', |
|
scale=10, |
|
lines=6 |
|
) |
|
submit_button = gr.Button('Submit', |
|
variant='primary', |
|
scale=1, |
|
min_width=0) |
|
|
|
with gr.Row(): |
|
retry_button = gr.Button('🔄 Retry', variant='secondary') |
|
undo_button = gr.Button('↩️ Undo', variant='secondary') |
|
clear = gr.Button('🗑️ Clear', variant='secondary') |
|
|
|
saved_input = gr.State() |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
def bot(history, max_new_tokens, temperature, top_p, system_prompt): |
|
chat_data = [] |
|
system_prompt = system_prompt.strip() |
|
if system_prompt: |
|
chat_data.append({"role": "system", "content": system_prompt}) |
|
for user_msg, assistant_msg in history: |
|
chat_data.append({"role": "user", "content": user_msg if user_msg is not None else ''}) |
|
chat_data.append({"role": "assistant", "content": assistant_msg if assistant_msg is not None else ''}) |
|
|
|
response = '[ERROR]' |
|
if refusal_condition(history[-1][0]): |
|
history = [['[安全拒答啟動]', '[安全拒答啟動] 請清除再開啟對話']] |
|
response = '[REFUSAL]' |
|
yield history |
|
else: |
|
r = chat_with_openai( |
|
MODEL_NAME, |
|
system_prompt, |
|
history[-1][0], |
|
temperature, |
|
max_new_tokens, |
|
top_p) |
|
if r is not None: |
|
for delta in r: |
|
if history[-1][1] is None: |
|
history[-1][1] = '' |
|
if delta is None: |
|
delta = '' |
|
history[-1][1] += delta |
|
yield history |
|
|
|
if history[-1][1].endswith('</s>'): |
|
history[-1][1] = history[-1][1][:-4] |
|
yield history |
|
|
|
response = history[-1][1] |
|
|
|
if refusal_condition(history[-1][1]): |
|
history[-1][1] = history[-1][1] + '\n\n**[免責聲明: 此模型並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。]**' |
|
yield history |
|
else: |
|
del history[-1] |
|
yield history |
|
|
|
print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(history[-1][0]), response=repr(history[-1][1]))) |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot |
|
) |
|
|
|
submit_button.click( |
|
user, [msg, chatbot], [msg, chatbot], queue=False |
|
).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot |
|
) |
|
|
|
|
|
def delete_prev_fn( |
|
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
|
try: |
|
message, _ = history.pop() |
|
except IndexError: |
|
message = '' |
|
return history, message or '' |
|
|
|
|
|
def display_input(message: str, |
|
history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
|
history.append((message, '')) |
|
return history |
|
|
|
retry_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=display_input, |
|
inputs=[saved_input, chatbot], |
|
outputs=chatbot, |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[ |
|
chatbot, |
|
max_new_tokens, |
|
temperature, |
|
top_p, |
|
system_prompt, |
|
], |
|
outputs=chatbot, |
|
) |
|
|
|
undo_button.click( |
|
fn=delete_prev_fn, |
|
inputs=chatbot, |
|
outputs=[chatbot, saved_input], |
|
api_name=False, |
|
queue=False, |
|
).then( |
|
fn=lambda x: x, |
|
inputs=[saved_input], |
|
outputs=msg, |
|
api_name=False, |
|
queue=False, |
|
) |
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
demo.queue(default_concurrency_limit=10) |
|
demo.launch() |