Update app.py
Browse files
app.py
CHANGED
@@ -10,8 +10,8 @@ import sentencepiece
|
|
10 |
model = AutoModelForCausalLM.from_pretrained("01-ai/Yi-34B-200K", device_map="auto", torch_dtype="auto", trust_remote_code=True)
|
11 |
tokenizer = YiTokenizer(vocab_file="./tokenizer.model")
|
12 |
|
13 |
-
def run(message, chat_history,
|
14 |
-
prompt = get_prompt(message, chat_history
|
15 |
|
16 |
# Encode the prompt to tensor
|
17 |
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
@@ -32,16 +32,16 @@ def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0
|
|
32 |
response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
33 |
return response
|
34 |
|
35 |
-
def get_prompt(message, chat_history
|
36 |
-
texts = [
|
37 |
|
38 |
do_strip = False
|
39 |
for user_input, response in chat_history:
|
40 |
user_input = user_input.strip() if do_strip else user_input
|
41 |
do_strip = True
|
42 |
-
texts.append(f"
|
43 |
message = message.strip() if do_strip else message
|
44 |
-
texts.append(f"{message}
|
45 |
return ''.join(texts)
|
46 |
|
47 |
DESCRIPTION = """
|
@@ -51,14 +51,6 @@ You can also use 🧑🏻🚀YI-200K🚀 by cloning this space. 🧬🔬🔍
|
|
51 |
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
|
52 |
"""
|
53 |
|
54 |
-
DEFAULT_SYSTEM_PROMPT = """
|
55 |
-
You are Yi. You are an AI assistant, you are moderately-polite and give only true information.
|
56 |
-
You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
|
57 |
-
If you think there might not be a correct answer, you say so. Since you are autoregressive,
|
58 |
-
each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
|
59 |
-
assumptions, and step-by-step thinking BEFORE you try to answer a question.
|
60 |
-
"""
|
61 |
-
|
62 |
MAX_MAX_NEW_TOKENS = 200000
|
63 |
DEFAULT_MAX_NEW_TOKENS = 100000
|
64 |
MAX_INPUT_TOKEN_LENGTH = 100000
|
@@ -76,12 +68,12 @@ def delete_prev_fn(history=[]):
|
|
76 |
message = ''
|
77 |
return history, message or ''
|
78 |
|
79 |
-
def generate(message, history_with_input,
|
80 |
if max_new_tokens > MAX_MAX_NEW_TOKENS:
|
81 |
raise ValueError
|
82 |
|
83 |
history = history_with_input[:-1]
|
84 |
-
generator = run(message, history,
|
85 |
try:
|
86 |
first_response = next(generator)
|
87 |
yield history + [(message, first_response)]
|
@@ -91,12 +83,12 @@ def generate(message, history_with_input, system_prompt, max_new_tokens, tempera
|
|
91 |
yield history + [(message, response)]
|
92 |
|
93 |
def process_example(message):
|
94 |
-
generator = generate(message, [],
|
95 |
for x in generator:
|
96 |
pass
|
97 |
return '', x
|
98 |
|
99 |
-
def check_input_token_length(message, chat_history
|
100 |
input_token_length = len(message) + len(chat_history)
|
101 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
102 |
raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
|
@@ -125,7 +117,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
|
|
125 |
saved_input = gr.State()
|
126 |
|
127 |
with gr.Accordion(label='Advanced options', open=False):
|
128 |
-
|
129 |
max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
|
130 |
temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
|
131 |
top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
|
@@ -145,7 +137,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
|
|
145 |
queue=False,
|
146 |
).then(
|
147 |
fn=check_input_token_length,
|
148 |
-
inputs=[saved_input, chatbot
|
149 |
api_name=False,
|
150 |
queue=False,
|
151 |
).success(
|
@@ -153,7 +145,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
|
|
153 |
inputs=[
|
154 |
saved_input,
|
155 |
chatbot,
|
156 |
-
system_prompt,
|
157 |
max_new_tokens,
|
158 |
temperature,
|
159 |
top_p,
|
@@ -177,7 +168,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
|
|
177 |
queue=False,
|
178 |
).then(
|
179 |
fn=check_input_token_length,
|
180 |
-
inputs=[saved_input, chatbot
|
181 |
api_name=False,
|
182 |
queue=False,
|
183 |
).success(
|
@@ -185,7 +176,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
|
|
185 |
inputs=[
|
186 |
saved_input,
|
187 |
chatbot,
|
188 |
-
system_prompt,
|
189 |
max_new_tokens,
|
190 |
temperature,
|
191 |
top_p,
|
@@ -212,7 +202,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
|
|
212 |
inputs=[
|
213 |
saved_input,
|
214 |
chatbot,
|
215 |
-
system_prompt,
|
216 |
max_new_tokens,
|
217 |
temperature,
|
218 |
top_p,
|
|
|
10 |
model = AutoModelForCausalLM.from_pretrained("01-ai/Yi-34B-200K", device_map="auto", torch_dtype="auto", trust_remote_code=True)
|
11 |
tokenizer = YiTokenizer(vocab_file="./tokenizer.model")
|
12 |
|
13 |
+
def run(message, chat_history, max_new_tokens=100000, temperature=3.5, top_p=0.9, top_k=800):
|
14 |
+
prompt = get_prompt(message, chat_history)
|
15 |
|
16 |
# Encode the prompt to tensor
|
17 |
input_ids = tokenizer.encode(prompt, return_tensors='pt')
|
|
|
32 |
response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
33 |
return response
|
34 |
|
35 |
+
def get_prompt(message, chat_history):
|
36 |
+
texts = []
|
37 |
|
38 |
do_strip = False
|
39 |
for user_input, response in chat_history:
|
40 |
user_input = user_input.strip() if do_strip else user_input
|
41 |
do_strip = True
|
42 |
+
texts.append(f" {response.strip()} {user_input} ")
|
43 |
message = message.strip() if do_strip else message
|
44 |
+
texts.append(f"{message}")
|
45 |
return ''.join(texts)
|
46 |
|
47 |
DESCRIPTION = """
|
|
|
51 |
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
|
52 |
"""
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
MAX_MAX_NEW_TOKENS = 200000
|
55 |
DEFAULT_MAX_NEW_TOKENS = 100000
|
56 |
MAX_INPUT_TOKEN_LENGTH = 100000
|
|
|
68 |
message = ''
|
69 |
return history, message or ''
|
70 |
|
71 |
+
def generate(message, history_with_input, max_new_tokens, temperature, top_p, top_k):
|
72 |
if max_new_tokens > MAX_MAX_NEW_TOKENS:
|
73 |
raise ValueError
|
74 |
|
75 |
history = history_with_input[:-1]
|
76 |
+
generator = run(message, history, max_new_tokens, temperature, top_p, top_k)
|
77 |
try:
|
78 |
first_response = next(generator)
|
79 |
yield history + [(message, first_response)]
|
|
|
83 |
yield history + [(message, response)]
|
84 |
|
85 |
def process_example(message):
|
86 |
+
generator = generate(message, [], 1024, 2.5, 0.95, 900)
|
87 |
for x in generator:
|
88 |
pass
|
89 |
return '', x
|
90 |
|
91 |
+
def check_input_token_length(message, chat_history):
|
92 |
input_token_length = len(message) + len(chat_history)
|
93 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
94 |
raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
|
|
|
117 |
saved_input = gr.State()
|
118 |
|
119 |
with gr.Accordion(label='Advanced options', open=False):
|
120 |
+
# system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
|
121 |
max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
|
122 |
temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
|
123 |
top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
|
|
|
137 |
queue=False,
|
138 |
).then(
|
139 |
fn=check_input_token_length,
|
140 |
+
inputs=[saved_input, chatbot],
|
141 |
api_name=False,
|
142 |
queue=False,
|
143 |
).success(
|
|
|
145 |
inputs=[
|
146 |
saved_input,
|
147 |
chatbot,
|
|
|
148 |
max_new_tokens,
|
149 |
temperature,
|
150 |
top_p,
|
|
|
168 |
queue=False,
|
169 |
).then(
|
170 |
fn=check_input_token_length,
|
171 |
+
inputs=[saved_input, chatbot],
|
172 |
api_name=False,
|
173 |
queue=False,
|
174 |
).success(
|
|
|
176 |
inputs=[
|
177 |
saved_input,
|
178 |
chatbot,
|
|
|
179 |
max_new_tokens,
|
180 |
temperature,
|
181 |
top_p,
|
|
|
202 |
inputs=[
|
203 |
saved_input,
|
204 |
chatbot,
|
|
|
205 |
max_new_tokens,
|
206 |
temperature,
|
207 |
top_p,
|