sddwt commited on
Commit
d27a512
·
1 Parent(s): 7e89fb6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -0
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 MosaicML spaces authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional
4
+ import datetime
5
+ import os
6
+ from threading import Event, Thread
7
+ from uuid import uuid4
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from transformers import (
13
+ AutoModelForCausalLM,
14
+ AutoTokenizer,
15
+ StoppingCriteria,
16
+ StoppingCriteriaList,
17
+ TextIteratorStreamer,
18
+ )
19
+
20
+
21
+ model_name = "JosephusCheung/Guanaco"
22
+ max_new_tokens = 2048
23
+
24
+
25
+ print(f"Starting to load the model {model_name} into memory")
26
+
27
+ tok = AutoTokenizer.from_pretrained(model_name)
28
+ m = AutoModelForCausalLM.from_pretrained(model_name).eval()
29
+
30
+ # tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
31
+ stop_token_ids = [tok.eos_token_id]
32
+
33
+ print(f"Successfully loaded the model {model_name} into memory")
34
+
35
+
36
+
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ for stop_id in stop_token_ids:
40
+ if input_ids[0][-1] == stop_id:
41
+ return True
42
+ return False
43
+
44
+
45
+ PROMPT_DICT = {
46
+ "prompt_input": (
47
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
48
+ "Write a response that appropriately completes the request.\n\n"
49
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
50
+ ),
51
+ "prompt_no_input": (
52
+ "Below is an instruction that describes a task. "
53
+ "Write a response that appropriately completes the request.\n\n"
54
+ "### Instruction:\n{instruction}\n\n### Response:"
55
+ ),
56
+ }
57
+
58
+
59
+ def generate_input(instruction: Optional[str] = None, input_str: Optional[str] = None) -> str:
60
+ if input_str is None:
61
+ return PROMPT_DICT['prompt_no_input'].format_map({'instruction': instruction})
62
+ else:
63
+ return PROMPT_DICT['prompt_input'].format_map({'instruction': instruction, 'input': input_str})
64
+
65
+
66
+ def convert_history_to_text(history):
67
+
68
+ user_input = history[-1][0]
69
+
70
+ text = generate_input(user_input)
71
+ return text
72
+
73
+
74
+
75
+
76
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
77
+ logging_url = os.getenv("LOGGING_URL", None)
78
+ if logging_url is None:
79
+ return
80
+
81
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
82
+
83
+ data = {
84
+ "conversation_id": conversation_id,
85
+ "timestamp": timestamp,
86
+ "history": history,
87
+ "messages": messages,
88
+ "generate_kwargs": generate_kwargs,
89
+ }
90
+
91
+ try:
92
+ requests.post(logging_url, json=data)
93
+ except requests.exceptions.RequestException as e:
94
+ print(f"Error logging conversation: {e}")
95
+
96
+
97
+ def user(message, history):
98
+ # Append the user's message to the conversation history
99
+ return "", history + [[message, ""]]
100
+
101
+
102
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
103
+ print(f"history: {history}")
104
+ # Initialize a StopOnTokens object
105
+ stop = StopOnTokens()
106
+
107
+ # Construct the input message string for the model by concatenating the current system message and conversation history
108
+ messages = convert_history_to_text(history)
109
+
110
+ # Tokenize the messages string
111
+ input_ids = tok(messages, return_tensors="pt").input_ids
112
+ input_ids = input_ids.to(m.device)
113
+ streamer = TextIteratorStreamer(
114
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
115
+ generate_kwargs = dict(
116
+ input_ids=input_ids,
117
+ max_new_tokens=max_new_tokens,
118
+ temperature=temperature,
119
+ do_sample=temperature > 0.0,
120
+ top_p=top_p,
121
+ top_k=top_k,
122
+ repetition_penalty=repetition_penalty,
123
+ streamer=streamer,
124
+ stopping_criteria=StoppingCriteriaList([stop]),
125
+ )
126
+
127
+ stream_complete = Event()
128
+
129
+ def generate_and_signal_complete():
130
+ m.generate(**generate_kwargs)
131
+ stream_complete.set()
132
+
133
+ def log_after_stream_complete():
134
+ stream_complete.wait()
135
+ log_conversation(
136
+ conversation_id,
137
+ history,
138
+ messages,
139
+ {
140
+ "top_k": top_k,
141
+ "top_p": top_p,
142
+ "temperature": temperature,
143
+ "repetition_penalty": repetition_penalty,
144
+ },
145
+ )
146
+
147
+ t1 = Thread(target=generate_and_signal_complete)
148
+ t1.start()
149
+
150
+ t2 = Thread(target=log_after_stream_complete)
151
+ t2.start()
152
+
153
+ # Initialize an empty string to store the generated text
154
+ partial_text = ""
155
+ for new_text in streamer:
156
+ partial_text += new_text
157
+ history[-1][1] = partial_text
158
+ yield history
159
+
160
+
161
+ def get_uuid():
162
+ return str(uuid4())
163
+
164
+
165
+ with gr.Blocks(
166
+ theme=gr.themes.Soft(),
167
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
168
+ ) as demo:
169
+ conversation_id = gr.State(get_uuid)
170
+ gr.Markdown(
171
+ """
172
+ ## sambanovasystems/BLOOMChat-176B-v1 模型
173
+
174
+ """
175
+ )
176
+ chatbot = gr.Chatbot().style(height=500)
177
+ with gr.Row():
178
+ with gr.Column():
179
+ msg = gr.Textbox(
180
+ label="Chat Message Box",
181
+ placeholder="输入您的问题",
182
+ show_label=False,
183
+ ).style(container=False)
184
+ with gr.Column():
185
+ with gr.Row():
186
+ submit = gr.Button("Submit")
187
+ stop = gr.Button("停止")
188
+ clear = gr.Button("Clear")
189
+ with gr.Row():
190
+ with gr.Accordion("高级选项:", open=False):
191
+ with gr.Row():
192
+ with gr.Column():
193
+ with gr.Row():
194
+ temperature = gr.Slider(
195
+ label="Temperature",
196
+ value=0.1,
197
+ minimum=0.0,
198
+ maximum=1.0,
199
+ step=0.1,
200
+ interactive=True,
201
+ info="Higher values produce more diverse outputs",
202
+ )
203
+ with gr.Column():
204
+ with gr.Row():
205
+ top_p = gr.Slider(
206
+ label="Top-p (nucleus sampling)",
207
+ value=1.0,
208
+ minimum=0.0,
209
+ maximum=1,
210
+ step=0.01,
211
+ interactive=True,
212
+ info=(
213
+ "Sample from the smallest possible set of tokens whose cumulative probability "
214
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
215
+ ),
216
+ )
217
+ with gr.Column():
218
+ with gr.Row():
219
+ top_k = gr.Slider(
220
+ label="Top-k",
221
+ value=0,
222
+ minimum=0.0,
223
+ maximum=200,
224
+ step=1,
225
+ interactive=True,
226
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
227
+ )
228
+ with gr.Column():
229
+ with gr.Row():
230
+ repetition_penalty = gr.Slider(
231
+ label="Repetition Penalty",
232
+ value=1.1,
233
+ minimum=1.0,
234
+ maximum=2.0,
235
+ step=0.1,
236
+ interactive=True,
237
+ info="Penalize repetition — 1.0 to disable.",
238
+ )
239
+ # with gr.Row():
240
+ # gr.Markdown(
241
+ # "demo 2",
242
+ # elem_classes=["disclaimer"],
243
+ # )
244
+
245
+ submit_event = msg.submit(
246
+ fn=user,
247
+ inputs=[msg, chatbot],
248
+ outputs=[msg, chatbot],
249
+ queue=False,
250
+ ).then(
251
+ fn=bot,
252
+ inputs=[
253
+ chatbot,
254
+ temperature,
255
+ top_p,
256
+ top_k,
257
+ repetition_penalty,
258
+ conversation_id,
259
+ ],
260
+ outputs=chatbot,
261
+ queue=True,
262
+ )
263
+ submit_click_event = submit.click(
264
+ fn=user,
265
+ inputs=[msg, chatbot],
266
+ outputs=[msg, chatbot],
267
+ queue=False,
268
+ ).then(
269
+ fn=bot,
270
+ inputs=[
271
+ chatbot,
272
+ temperature,
273
+ top_p,
274
+ top_k,
275
+ repetition_penalty,
276
+ conversation_id,
277
+ ],
278
+ outputs=chatbot,
279
+ queue=True,
280
+ )
281
+ stop.click(
282
+ fn=None,
283
+ inputs=None,
284
+ outputs=None,
285
+ cancels=[submit_event, submit_click_event],
286
+ queue=False,
287
+ )
288
+ clear.click(lambda: None, None, chatbot, queue=False)
289
+
290
+ demo.queue(max_size=128, concurrency_count=2)
291
+ demo.launch()