Tonic commited on
Commit
f8c306d
·
1 Parent(s): 40afce0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -10,8 +10,8 @@ import sentencepiece
10
  model = AutoModelForCausalLM.from_pretrained("01-ai/Yi-34B-200K", device_map="auto", torch_dtype="auto", trust_remote_code=True)
11
  tokenizer = YiTokenizer(vocab_file="./tokenizer.model")
12
 
13
- def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.3, top_p=0.9, top_k=50):
14
- prompt = get_prompt(message, chat_history, system_prompt)
15
 
16
  # Encode the prompt to tensor
17
  input_ids = tokenizer.encode(prompt, return_tensors='pt')
@@ -32,16 +32,16 @@ def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0
32
  response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
33
  return response
34
 
35
- def get_prompt(message, chat_history, system_prompt):
36
- texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
37
 
38
  do_strip = False
39
  for user_input, response in chat_history:
40
  user_input = user_input.strip() if do_strip else user_input
41
  do_strip = True
42
- texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
43
  message = message.strip() if do_strip else message
44
- texts.append(f"{message} [/INST]")
45
  return ''.join(texts)
46
 
47
  DESCRIPTION = """
@@ -51,14 +51,6 @@ You can also use 🧑🏻‍🚀YI-200K🚀 by cloning this space. 🧬🔬🔍
51
  Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
52
  """
53
 
54
- DEFAULT_SYSTEM_PROMPT = """
55
- You are Yi. You are an AI assistant, you are moderately-polite and give only true information.
56
- You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
57
- If you think there might not be a correct answer, you say so. Since you are autoregressive,
58
- each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
59
- assumptions, and step-by-step thinking BEFORE you try to answer a question.
60
- """
61
-
62
  MAX_MAX_NEW_TOKENS = 200000
63
  DEFAULT_MAX_NEW_TOKENS = 100000
64
  MAX_INPUT_TOKEN_LENGTH = 100000
@@ -76,12 +68,12 @@ def delete_prev_fn(history=[]):
76
  message = ''
77
  return history, message or ''
78
 
79
- def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
80
  if max_new_tokens > MAX_MAX_NEW_TOKENS:
81
  raise ValueError
82
 
83
  history = history_with_input[:-1]
84
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
85
  try:
86
  first_response = next(generator)
87
  yield history + [(message, first_response)]
@@ -91,12 +83,12 @@ def generate(message, history_with_input, system_prompt, max_new_tokens, tempera
91
  yield history + [(message, response)]
92
 
93
  def process_example(message):
94
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 2.5, 0.95, 900)
95
  for x in generator:
96
  pass
97
  return '', x
98
 
99
- def check_input_token_length(message, chat_history, system_prompt):
100
  input_token_length = len(message) + len(chat_history)
101
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
102
  raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
@@ -125,7 +117,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
125
  saved_input = gr.State()
126
 
127
  with gr.Accordion(label='Advanced options', open=False):
128
- system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
129
  max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
130
  temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
131
  top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
@@ -145,7 +137,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
145
  queue=False,
146
  ).then(
147
  fn=check_input_token_length,
148
- inputs=[saved_input, chatbot, system_prompt],
149
  api_name=False,
150
  queue=False,
151
  ).success(
@@ -153,7 +145,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
153
  inputs=[
154
  saved_input,
155
  chatbot,
156
- system_prompt,
157
  max_new_tokens,
158
  temperature,
159
  top_p,
@@ -177,7 +168,7 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
177
  queue=False,
178
  ).then(
179
  fn=check_input_token_length,
180
- inputs=[saved_input, chatbot, system_prompt],
181
  api_name=False,
182
  queue=False,
183
  ).success(
@@ -185,7 +176,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
185
  inputs=[
186
  saved_input,
187
  chatbot,
188
- system_prompt,
189
  max_new_tokens,
190
  temperature,
191
  top_p,
@@ -212,7 +202,6 @@ with gr.Blocks(theme='ParityError/Anime') as demo:
212
  inputs=[
213
  saved_input,
214
  chatbot,
215
- system_prompt,
216
  max_new_tokens,
217
  temperature,
218
  top_p,
 
10
  model = AutoModelForCausalLM.from_pretrained("01-ai/Yi-34B-200K", device_map="auto", torch_dtype="auto", trust_remote_code=True)
11
  tokenizer = YiTokenizer(vocab_file="./tokenizer.model")
12
 
13
+ def run(message, chat_history, max_new_tokens=100000, temperature=3.5, top_p=0.9, top_k=800):
14
+ prompt = get_prompt(message, chat_history)
15
 
16
  # Encode the prompt to tensor
17
  input_ids = tokenizer.encode(prompt, return_tensors='pt')
 
32
  response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
33
  return response
34
 
35
+ def get_prompt(message, chat_history):
36
+ texts = []
37
 
38
  do_strip = False
39
  for user_input, response in chat_history:
40
  user_input = user_input.strip() if do_strip else user_input
41
  do_strip = True
42
+ texts.append(f" {response.strip()} {user_input} ")
43
  message = message.strip() if do_strip else message
44
+ texts.append(f"{message}")
45
  return ''.join(texts)
46
 
47
  DESCRIPTION = """
 
51
  Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
52
  """
53
 
 
 
 
 
 
 
 
 
54
  MAX_MAX_NEW_TOKENS = 200000
55
  DEFAULT_MAX_NEW_TOKENS = 100000
56
  MAX_INPUT_TOKEN_LENGTH = 100000
 
68
  message = ''
69
  return history, message or ''
70
 
71
+ def generate(message, history_with_input, max_new_tokens, temperature, top_p, top_k):
72
  if max_new_tokens > MAX_MAX_NEW_TOKENS:
73
  raise ValueError
74
 
75
  history = history_with_input[:-1]
76
+ generator = run(message, history, max_new_tokens, temperature, top_p, top_k)
77
  try:
78
  first_response = next(generator)
79
  yield history + [(message, first_response)]
 
83
  yield history + [(message, response)]
84
 
85
  def process_example(message):
86
+ generator = generate(message, [], 1024, 2.5, 0.95, 900)
87
  for x in generator:
88
  pass
89
  return '', x
90
 
91
+ def check_input_token_length(message, chat_history):
92
  input_token_length = len(message) + len(chat_history)
93
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
94
  raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
 
117
  saved_input = gr.State()
118
 
119
  with gr.Accordion(label='Advanced options', open=False):
120
+ # system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
121
  max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
122
  temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
123
  top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
 
137
  queue=False,
138
  ).then(
139
  fn=check_input_token_length,
140
+ inputs=[saved_input, chatbot],
141
  api_name=False,
142
  queue=False,
143
  ).success(
 
145
  inputs=[
146
  saved_input,
147
  chatbot,
 
148
  max_new_tokens,
149
  temperature,
150
  top_p,
 
168
  queue=False,
169
  ).then(
170
  fn=check_input_token_length,
171
+ inputs=[saved_input, chatbot],
172
  api_name=False,
173
  queue=False,
174
  ).success(
 
176
  inputs=[
177
  saved_input,
178
  chatbot,
 
179
  max_new_tokens,
180
  temperature,
181
  top_p,
 
202
  inputs=[
203
  saved_input,
204
  chatbot,
 
205
  max_new_tokens,
206
  temperature,
207
  top_p,