Spaces:
Paused
Paused
fix chat msg
Browse files
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Baichuan2 13B Chat
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Baichuan2 13B Chat
|
3 |
+
emoji: 🔥
|
4 |
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -3,269 +3,270 @@ from typing import Iterator
|
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
|
6 |
-
from model import
|
7 |
|
8 |
-
|
9 |
-
# You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
|
10 |
-
# """
|
11 |
-
|
12 |
-
DEFAULT_SYSTEM_PROMPT = """
|
13 |
-
"""
|
14 |
MAX_MAX_NEW_TOKENS = 2048
|
15 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
17 |
-
|
18 |
DESCRIPTION = """
|
19 |
# Baichuan2-13B-Chat
|
20 |
Baichuan 2 is the new generation of open-source large language models launched by Baichuan Intelligent Technology. It was trained on a high-quality corpus with 2.6 trillion tokens.
|
21 |
"""
|
22 |
-
|
23 |
-
LICENSE = """
|
24 |
-
"""
|
25 |
|
26 |
if not torch.cuda.is_available():
|
27 |
-
|
28 |
|
29 |
|
30 |
def clear_and_save_textbox(message: str) -> tuple[str, str]:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
def display_input(message: str,
|
35 |
-
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
|
36 |
-
history.append((message, ''))
|
37 |
-
return history
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def delete_prev_fn(
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
|
49 |
def generate(
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
) -> Iterator[list[tuple[str, str]]]:
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
|
72 |
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
with gr.Blocks(css='style.css') as demo:
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
container=False,
|
95 |
-
show_label=False,
|
96 |
-
placeholder='Type a message...',
|
97 |
-
scale=10,
|
98 |
-
)
|
99 |
-
submit_button = gr.Button('Submit',
|
100 |
-
variant='primary',
|
101 |
-
scale=1,
|
102 |
-
min_width=0)
|
103 |
with gr.Row():
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
label='Top-p (nucleus sampling)',
|
130 |
-
minimum=0.05,
|
131 |
-
maximum=1.0,
|
132 |
-
step=0.05,
|
133 |
-
value=0.95,
|
134 |
-
)
|
135 |
-
top_k = gr.Slider(
|
136 |
-
label='Top-k',
|
137 |
-
minimum=1,
|
138 |
-
maximum=1000,
|
139 |
-
step=1,
|
140 |
-
value=50,
|
141 |
-
)
|
142 |
-
|
143 |
-
gr.Examples(
|
144 |
-
examples=[
|
145 |
-
'Hello there! How are you doing?',
|
146 |
-
'Can you explain briefly to me what is the Python programming language?',
|
147 |
-
'Explain the plot of Cinderella in a sentence.',
|
148 |
-
'How many hours does it take a man to eat a Helicopter?',
|
149 |
-
"Write a 100-word article on 'Benefits of Open-Source in AI research'",
|
150 |
-
],
|
151 |
-
inputs=textbox,
|
152 |
-
outputs=[textbox, chatbot],
|
153 |
-
fn=process_example,
|
154 |
-
cache_examples=True,
|
155 |
)
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
outputs=[textbox, saved_input],
|
163 |
-
api_name=False,
|
164 |
-
queue=False,
|
165 |
-
).then(
|
166 |
-
fn=display_input,
|
167 |
-
inputs=[saved_input, chatbot],
|
168 |
-
outputs=chatbot,
|
169 |
-
api_name=False,
|
170 |
-
queue=False,
|
171 |
-
).then(
|
172 |
-
fn=check_input_token_length,
|
173 |
-
inputs=[saved_input, chatbot, system_prompt],
|
174 |
-
api_name=False,
|
175 |
-
queue=False,
|
176 |
-
).success(
|
177 |
-
fn=generate,
|
178 |
-
inputs=[
|
179 |
-
saved_input,
|
180 |
-
chatbot,
|
181 |
-
system_prompt,
|
182 |
-
max_new_tokens,
|
183 |
-
temperature,
|
184 |
-
top_p,
|
185 |
-
top_k,
|
186 |
-
],
|
187 |
-
outputs=chatbot,
|
188 |
-
api_name=False,
|
189 |
)
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
queue=False,
|
197 |
-
).then(
|
198 |
-
fn=display_input,
|
199 |
-
inputs=[saved_input, chatbot],
|
200 |
-
outputs=chatbot,
|
201 |
-
api_name=False,
|
202 |
-
queue=False,
|
203 |
-
).then(
|
204 |
-
fn=check_input_token_length,
|
205 |
-
inputs=[saved_input, chatbot, system_prompt],
|
206 |
-
api_name=False,
|
207 |
-
queue=False,
|
208 |
-
).success(
|
209 |
-
fn=generate,
|
210 |
-
inputs=[
|
211 |
-
saved_input,
|
212 |
-
chatbot,
|
213 |
-
system_prompt,
|
214 |
-
max_new_tokens,
|
215 |
-
temperature,
|
216 |
-
top_p,
|
217 |
-
top_k,
|
218 |
-
],
|
219 |
-
outputs=chatbot,
|
220 |
-
api_name=False,
|
221 |
)
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
queue=False,
|
229 |
-
).then(
|
230 |
-
fn=display_input,
|
231 |
-
inputs=[saved_input, chatbot],
|
232 |
-
outputs=chatbot,
|
233 |
-
api_name=False,
|
234 |
-
queue=False,
|
235 |
-
).then(
|
236 |
-
fn=generate,
|
237 |
-
inputs=[
|
238 |
-
saved_input,
|
239 |
-
chatbot,
|
240 |
-
system_prompt,
|
241 |
-
max_new_tokens,
|
242 |
-
temperature,
|
243 |
-
top_p,
|
244 |
-
top_k,
|
245 |
-
],
|
246 |
-
outputs=chatbot,
|
247 |
-
api_name=False,
|
248 |
)
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
queue=False,
|
256 |
-
).then(
|
257 |
-
fn=lambda x: x,
|
258 |
-
inputs=[saved_input],
|
259 |
-
outputs=textbox,
|
260 |
-
api_name=False,
|
261 |
-
queue=False,
|
262 |
)
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
demo.queue(max_size=20).launch()
|
|
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
|
6 |
+
from model import run
|
7 |
|
8 |
+
DEFAULT_SYSTEM_PROMPT = ""
|
|
|
|
|
|
|
|
|
|
|
9 |
MAX_MAX_NEW_TOKENS = 2048
|
10 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
11 |
MAX_INPUT_TOKEN_LENGTH = 4000
|
|
|
12 |
DESCRIPTION = """
|
13 |
# Baichuan2-13B-Chat
|
14 |
Baichuan 2 is the new generation of open-source large language models launched by Baichuan Intelligent Technology. It was trained on a high-quality corpus with 2.6 trillion tokens.
|
15 |
"""
|
16 |
+
LICENSE = ""
|
|
|
|
|
17 |
|
18 |
if not torch.cuda.is_available():
|
19 |
+
DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
|
20 |
|
21 |
|
22 |
def clear_and_save_textbox(message: str) -> tuple[str, str]:
|
23 |
+
return '', message
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
def display_input(
|
26 |
+
message: str,
|
27 |
+
history: list[tuple[str, str]]
|
28 |
+
) -> list[tuple[str, str]]:
|
29 |
+
history.append((message, ''))
|
30 |
+
return history
|
31 |
|
32 |
def delete_prev_fn(
|
33 |
+
history: list[tuple[str, str]]
|
34 |
+
) -> tuple[list[tuple[str, str]], str]:
|
35 |
+
try:
|
36 |
+
message, _ = history.pop()
|
37 |
+
except IndexError:
|
38 |
+
message = ''
|
39 |
+
return history, message or ''
|
40 |
|
41 |
def generate(
|
42 |
+
message: str,
|
43 |
+
history_with_input: list[tuple[str, str]],
|
44 |
+
system_prompt: str,
|
45 |
+
max_new_tokens: int,
|
46 |
+
temperature: float,
|
47 |
+
top_p: float,
|
48 |
+
top_k: int,
|
49 |
) -> Iterator[list[tuple[str, str]]]:
|
50 |
+
if max_new_tokens > MAX_MAX_NEW_TOKENS:
|
51 |
+
raise ValueError
|
52 |
+
|
53 |
+
history = history_with_input[:-1]
|
54 |
+
generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
|
55 |
+
try:
|
56 |
+
first_response = next(generator)
|
57 |
+
yield history + [(message, first_response)]
|
58 |
+
except StopIteration:
|
59 |
+
yield history + [(message, '')]
|
60 |
+
for response in generator:
|
61 |
+
yield history + [(message, response)]
|
|
|
62 |
|
63 |
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
|
64 |
+
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 5)
|
65 |
+
for x in generator:
|
66 |
+
pass
|
67 |
+
return '', x
|
68 |
+
|
69 |
+
def check_input_token_length(
|
70 |
+
message: str,
|
71 |
+
chat_history: list[tuple[str, str]],
|
72 |
+
system_prompt: str
|
73 |
+
) -> None:
|
74 |
+
a = 1
|
75 |
+
# input_token_length = get_input_token_length(message, chat_history, system_prompt)
|
76 |
+
# if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
77 |
+
# raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
|
78 |
|
79 |
|
80 |
with gr.Blocks(css='style.css') as demo:
|
81 |
+
gr.Markdown(DESCRIPTION)
|
82 |
+
gr.DuplicateButton(
|
83 |
+
value='Duplicate Space for private use',
|
84 |
+
elem_id='duplicate-button'
|
85 |
+
)
|
86 |
+
|
87 |
+
with gr.Group():
|
88 |
+
chatbot = gr.Chatbot(label='Chatbot')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
with gr.Row():
|
90 |
+
textbox = gr.Textbox(
|
91 |
+
container=False,
|
92 |
+
show_label=False,
|
93 |
+
placeholder='Type a message...',
|
94 |
+
scale=10,
|
95 |
+
)
|
96 |
+
submit_button = gr.Button(
|
97 |
+
'Submit',
|
98 |
+
variant='primary',
|
99 |
+
scale=1,
|
100 |
+
min_width=0
|
101 |
+
)
|
102 |
+
|
103 |
+
with gr.Row():
|
104 |
+
retry_button = gr.Button('🔄 Retry', variant='secondary')
|
105 |
+
undo_button = gr.Button('↩️ Undo', variant='secondary')
|
106 |
+
clear_button = gr.Button('🗑️ Clear', variant='secondary')
|
107 |
+
|
108 |
+
saved_input = gr.State()
|
109 |
+
|
110 |
+
with gr.Accordion(label='Advanced options', open=False):
|
111 |
+
system_prompt = gr.Textbox(
|
112 |
+
label='System prompt',
|
113 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
114 |
+
lines=6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
)
|
116 |
+
max_new_tokens = gr.Slider(
|
117 |
+
label='Max new tokens',
|
118 |
+
minimum=1,
|
119 |
+
maximum=MAX_MAX_NEW_TOKENS,
|
120 |
+
step=1,
|
121 |
+
value=DEFAULT_MAX_NEW_TOKENS,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
123 |
+
temperature = gr.Slider(
|
124 |
+
label='Temperature',
|
125 |
+
minimum=0.1,
|
126 |
+
maximum=4.0,
|
127 |
+
step=0.1,
|
128 |
+
value=1.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
)
|
130 |
+
top_p = gr.Slider(
|
131 |
+
label='Top-p (nucleus sampling)',
|
132 |
+
minimum=0.05,
|
133 |
+
maximum=1.0,
|
134 |
+
step=0.05,
|
135 |
+
value=0.95,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
+
top_k = gr.Slider(
|
138 |
+
label='Top-k',
|
139 |
+
minimum=1,
|
140 |
+
maximum=1000,
|
141 |
+
step=1,
|
142 |
+
value=5,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
)
|
144 |
|
145 |
+
gr.Examples(
|
146 |
+
examples=[
|
147 |
+
'介绍下你自己',
|
148 |
+
'找到下列数组的中位数[3.1,6.2,1.3,8.4,10.5,11.6,2.1],请用python代码完成以上功能',
|
149 |
+
'鸡和兔在一个笼子里,共有26个头,68只脚,那么鸡有多少只,兔有多少只?',
|
150 |
+
'以下物理常识题目,哪一个是错误的?A.在自然环境下,声音在固体中传播速度最快。B.牛顿第一定律:一个物体如果不受力作用,将保持静止或匀速直线运动的状态。C.牛顿第三定律:对于每个作用力,都有一个相等而反向的反作用力。D.声音在空气中的传播速度为1000m/s。',
|
151 |
+
],
|
152 |
+
inputs=textbox,
|
153 |
+
outputs=[textbox, chatbot],
|
154 |
+
fn=process_example,
|
155 |
+
cache_examples=True,
|
156 |
+
)
|
157 |
+
|
158 |
+
gr.Markdown(LICENSE)
|
159 |
+
|
160 |
+
textbox.submit(
|
161 |
+
fn=clear_and_save_textbox,
|
162 |
+
inputs=textbox,
|
163 |
+
outputs=[textbox, saved_input],
|
164 |
+
api_name=False,
|
165 |
+
queue=False,
|
166 |
+
).then(
|
167 |
+
fn=display_input,
|
168 |
+
inputs=[saved_input, chatbot],
|
169 |
+
outputs=chatbot,
|
170 |
+
api_name=False,
|
171 |
+
queue=False,
|
172 |
+
).then(
|
173 |
+
fn=check_input_token_length,
|
174 |
+
inputs=[saved_input, chatbot, system_prompt],
|
175 |
+
api_name=False,
|
176 |
+
queue=False,
|
177 |
+
).success(
|
178 |
+
fn=generate,
|
179 |
+
inputs=[
|
180 |
+
saved_input,
|
181 |
+
chatbot,
|
182 |
+
system_prompt,
|
183 |
+
max_new_tokens,
|
184 |
+
temperature,
|
185 |
+
top_p,
|
186 |
+
top_k,
|
187 |
+
],
|
188 |
+
outputs=chatbot,
|
189 |
+
api_name=False,
|
190 |
+
)
|
191 |
+
|
192 |
+
button_event_preprocess = submit_button.click(
|
193 |
+
fn=clear_and_save_textbox,
|
194 |
+
inputs=textbox,
|
195 |
+
outputs=[textbox, saved_input],
|
196 |
+
api_name=False,
|
197 |
+
queue=False,
|
198 |
+
).then(
|
199 |
+
fn=display_input,
|
200 |
+
inputs=[saved_input, chatbot],
|
201 |
+
outputs=chatbot,
|
202 |
+
api_name=False,
|
203 |
+
queue=False,
|
204 |
+
).then(
|
205 |
+
fn=check_input_token_length,
|
206 |
+
inputs=[saved_input, chatbot, system_prompt],
|
207 |
+
api_name=False,
|
208 |
+
queue=False,
|
209 |
+
).success(
|
210 |
+
fn=generate,
|
211 |
+
inputs=[
|
212 |
+
saved_input,
|
213 |
+
chatbot,
|
214 |
+
system_prompt,
|
215 |
+
max_new_tokens,
|
216 |
+
temperature,
|
217 |
+
top_p,
|
218 |
+
top_k,
|
219 |
+
],
|
220 |
+
outputs=chatbot,
|
221 |
+
api_name=False,
|
222 |
+
)
|
223 |
+
|
224 |
+
retry_button.click(
|
225 |
+
fn=delete_prev_fn,
|
226 |
+
inputs=chatbot,
|
227 |
+
outputs=[chatbot, saved_input],
|
228 |
+
api_name=False,
|
229 |
+
queue=False,
|
230 |
+
).then(
|
231 |
+
fn=display_input,
|
232 |
+
inputs=[saved_input, chatbot],
|
233 |
+
outputs=chatbot,
|
234 |
+
api_name=False,
|
235 |
+
queue=False,
|
236 |
+
).then(
|
237 |
+
fn=generate,
|
238 |
+
inputs=[
|
239 |
+
saved_input,
|
240 |
+
chatbot,
|
241 |
+
system_prompt,
|
242 |
+
max_new_tokens,
|
243 |
+
temperature,
|
244 |
+
top_p,
|
245 |
+
top_k,
|
246 |
+
],
|
247 |
+
outputs=chatbot,
|
248 |
+
api_name=False,
|
249 |
+
)
|
250 |
+
|
251 |
+
undo_button.click(
|
252 |
+
fn=delete_prev_fn,
|
253 |
+
inputs=chatbot,
|
254 |
+
outputs=[chatbot, saved_input],
|
255 |
+
api_name=False,
|
256 |
+
queue=False,
|
257 |
+
).then(
|
258 |
+
fn=lambda x: x,
|
259 |
+
inputs=[saved_input],
|
260 |
+
outputs=textbox,
|
261 |
+
api_name=False,
|
262 |
+
queue=False,
|
263 |
+
)
|
264 |
+
|
265 |
+
clear_button.click(
|
266 |
+
fn=lambda: ([], ''),
|
267 |
+
outputs=[chatbot, saved_input],
|
268 |
+
queue=False,
|
269 |
+
api_name=False,
|
270 |
+
)
|
271 |
|
272 |
demo.queue(max_size=20).launch()
|
model.py
CHANGED
@@ -2,79 +2,77 @@ from threading import Thread
|
|
2 |
from typing import Iterator
|
3 |
|
4 |
import torch
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
from transformers.generation.utils import GenerationConfig
|
7 |
|
8 |
model_id = 'baichuan-inc/Baichuan2-13B-Chat'
|
9 |
|
10 |
if torch.cuda.is_available():
|
11 |
-
|
12 |
-
model_id,
|
13 |
-
# device_map='auto',
|
14 |
-
torch_dtype=torch.float16,
|
15 |
-
trust_remote_code=True
|
16 |
-
)
|
17 |
-
model = model.quantize(4).cuda()
|
18 |
-
model.generation_config = GenerationConfig.from_pretrained(model_id)
|
19 |
-
else:
|
20 |
-
model = None
|
21 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
22 |
model_id,
|
23 |
-
|
|
|
24 |
trust_remote_code=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
)
|
26 |
|
27 |
-
def get_prompt(
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
prompt = get_prompt(message, chat_history, system_prompt)
|
43 |
-
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
44 |
-
return input_ids.shape[-1]
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
def run(
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
timeout=10.,
|
60 |
-
skip_prompt=True,
|
61 |
-
skip_special_tokens=True
|
62 |
-
)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
top_p=top_p,
|
70 |
-
top_k=top_k,
|
71 |
-
temperature=temperature,
|
72 |
-
num_beams=1,
|
73 |
-
)
|
74 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
75 |
-
t.start()
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
2 |
from typing import Iterator
|
3 |
|
4 |
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
from transformers.generation.utils import GenerationConfig
|
7 |
|
8 |
model_id = 'baichuan-inc/Baichuan2-13B-Chat'
|
9 |
|
10 |
if torch.cuda.is_available():
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
model_id,
|
13 |
+
# device_map='auto',
|
14 |
+
torch_dtype=torch.float16,
|
15 |
trust_remote_code=True
|
16 |
+
)
|
17 |
+
model = model.quantize(4).cuda()
|
18 |
+
model.generation_config = GenerationConfig.from_pretrained(model_id)
|
19 |
+
else:
|
20 |
+
model = None
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
22 |
+
model_id,
|
23 |
+
use_fast=False,
|
24 |
+
trust_remote_code=True
|
25 |
)
|
26 |
|
27 |
+
def get_prompt(
|
28 |
+
message: str,
|
29 |
+
chat_history: list[tuple[str, str]],
|
30 |
+
system_prompt: str
|
31 |
+
) -> str:
|
32 |
+
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
33 |
+
# The first user input is _not_ stripped
|
34 |
+
do_strip = False
|
35 |
+
for user_input, response in chat_history:
|
36 |
+
user_input = user_input.strip() if do_strip else user_input
|
37 |
+
do_strip = True
|
38 |
+
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
39 |
+
message = message.strip() if do_strip else message
|
40 |
+
texts.append(f'{message} [/INST]')
|
41 |
+
return ''.join(texts)
|
|
|
|
|
|
|
42 |
|
43 |
+
def get_input_token_length(
|
44 |
+
message: str,
|
45 |
+
chat_history: list[tuple[str, str]],
|
46 |
+
system_prompt: str
|
47 |
+
) -> int:
|
48 |
+
prompt = get_prompt(message, chat_history, system_prompt)
|
49 |
+
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
50 |
+
return input_ids.shape[-1]
|
51 |
|
52 |
+
def run(
|
53 |
+
message: str,
|
54 |
+
chat_history: list[tuple[str, str]],
|
55 |
+
system_prompt: str,
|
56 |
+
max_new_tokens: int = 1024,
|
57 |
+
temperature: float = 1.0,
|
58 |
+
top_p: float = 0.95,
|
59 |
+
top_k: int = 5
|
60 |
+
) -> Iterator[str]:
|
61 |
+
print(chat_history)
|
62 |
|
63 |
+
history = []
|
64 |
+
result=""
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
for i in chat_history:
|
67 |
+
history.append({"role": "user", "content": i[0]})
|
68 |
+
history.append({"role": "assistant", "content": i[1]})
|
69 |
+
|
70 |
+
print(history)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
history.append({"role": "user", "content": message})
|
73 |
+
|
74 |
+
for response in model.chat(tokenizer, history, stream=True):
|
75 |
+
print(response)
|
76 |
+
if "content" in response["choices"][0]["delta"]:
|
77 |
+
result = result + response["choices"][0]["delta"]["content"]
|
78 |
+
yield result
|