Cran-May commited on
Commit
a9f41ef
·
1 Parent(s): f9fa816

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import copy
4
+ import random
5
+ import os
6
+ import requests
7
+ import time
8
+ import sys
9
+
10
+ from huggingface_hub import snapshot_download
11
+ from llama_cpp import Llama
12
+
13
+
14
+ SYSTEM_PROMPT = '''You are a helpful, respectful and honest INTP-T AI Assistant named "Shi-Ci" in English or "兮辞" in Chinese.
15
+ You are good at speaking English and Chinese.
16
+ You are talking to a human User. If the question is meaningless, please explain the reason and don't share false information.
17
+ You are based on SEA model, trained by "SSFW NLPark" team, not related to GPT, LLaMA, Meta, Mistral or OpenAI.
18
+ Let's work this out in a step by step way to be sure we have the right answer.\n\n'''
19
+ SYSTEM_TOKEN = 1587
20
+ USER_TOKEN = 2188
21
+ BOT_TOKEN = 12435
22
+ LINEBREAK_TOKEN = 13
23
+
24
+
25
+ ROLE_TOKENS = {
26
+ "user": USER_TOKEN,
27
+ "bot": BOT_TOKEN,
28
+ "system": SYSTEM_TOKEN
29
+ }
30
+
31
+
32
+ def get_message_tokens(model, role, content):
33
+ message_tokens = model.tokenize(content.encode("utf-8"))
34
+ message_tokens.insert(1, ROLE_TOKENS[role])
35
+ message_tokens.insert(2, LINEBREAK_TOKEN)
36
+ message_tokens.append(model.token_eos())
37
+ return message_tokens
38
+
39
+
40
+ def get_system_tokens(model):
41
+ system_message = {"role": "system", "content": SYSTEM_PROMPT}
42
+ return get_message_tokens(model, **system_message)
43
+
44
+
45
+ repo_name = "Cran-May/OpenSLIDE"
46
+ model_name = "SLIDE.0.1.gguf"
47
+
48
+ snapshot_download(repo_id=repo_name, local_dir=".", allow_patterns=model_name)
49
+
50
+ model = Llama(
51
+ model_path=model_name,
52
+ n_ctx=2000,
53
+ n_parts=1,
54
+ )
55
+
56
+ max_new_tokens = 1500
57
+
58
+ def user(message, history):
59
+ new_history = history + [[message, None]]
60
+ return "", new_history
61
+
62
+
63
+ def bot(
64
+ history,
65
+ system_prompt,
66
+ top_p,
67
+ top_k,
68
+ temp
69
+ ):
70
+ tokens = get_system_tokens(model)[:]
71
+ tokens.append(LINEBREAK_TOKEN)
72
+
73
+ for user_message, bot_message in history[:-1]:
74
+ message_tokens = get_message_tokens(model=model, role="user", content=user_message)
75
+ tokens.extend(message_tokens)
76
+ if bot_message:
77
+ message_tokens = get_message_tokens(model=model, role="bot", content=bot_message)
78
+ tokens.extend(message_tokens)
79
+
80
+ last_user_message = history[-1][0]
81
+ message_tokens = get_message_tokens(model=model, role="user", content=last_user_message)
82
+ tokens.extend(message_tokens)
83
+
84
+ role_tokens = [model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN]
85
+ tokens.extend(role_tokens)
86
+ generator = model.generate(
87
+ tokens,
88
+ top_k=top_k,
89
+ top_p=top_p,
90
+ temp=temp
91
+ )
92
+
93
+ partial_text = ""
94
+ for i, token in enumerate(generator):
95
+ if token == model.token_eos() or (max_new_tokens is not None and i >= max_new_tokens):
96
+ break
97
+ partial_text += model.detokenize([token]).decode("utf-8", "ignore")
98
+ history[-1][1] = partial_text
99
+ yield history
100
+
101
+
102
+ with gr.Blocks(
103
+ theme=gr.themes.Soft()
104
+ ) as demo:
105
+ gr.Markdown(
106
+ f"""<h1><center>兮辞·析辞-人工智能助理</center></h1>
107
+ 这儿是一个**中文**模型的部署. If you are interested in other languages, please check other models, such as [MPT-7B-Chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat).
108
+ 这是量化版兮辞·析辞的部署,具有**70亿**个参数,在 CPU 上运行。
109
+ SLIDE 是一种会话语言模型,在多种类型的语料库上进行训练。
110
+ 本节目由上海师范大学附属外国语中学**NLPark**赞助播出~
111
+ """
112
+ )
113
+ with gr.Row():
114
+ with gr.Column(scale=5):
115
+ system_prompt = gr.Textbox(label="系统提示词", placeholder="", value=SYSTEM_PROMPT, interactive=False)
116
+ chatbot = gr.Chatbot(label="兮辞如是说").style(height=400)
117
+ with gr.Column(min_width=80, scale=1):
118
+ with gr.Tab(label="设置参数"):
119
+ top_p = gr.Slider(
120
+ minimum=0.0,
121
+ maximum=1.0,
122
+ value=0.9,
123
+ step=0.05,
124
+ interactive=True,
125
+ label="Top-p",
126
+ )
127
+ top_k = gr.Slider(
128
+ minimum=10,
129
+ maximum=100,
130
+ value=30,
131
+ step=5,
132
+ interactive=True,
133
+ label="Top-k",
134
+ )
135
+ temp = gr.Slider(
136
+ minimum=0.0,
137
+ maximum=2.0,
138
+ value=0.01,
139
+ step=0.01,
140
+ interactive=True,
141
+ label="情感温度"
142
+ )
143
+ with gr.Row():
144
+ with gr.Column():
145
+ msg = gr.Textbox(
146
+ label="来问问兮辞吧……",
147
+ placeholder="兮辞折寿中……",
148
+ show_label=False,
149
+ ).style(container=False)
150
+ with gr.Column():
151
+ with gr.Row():
152
+ submit = gr.Button("开凹!")
153
+ stop = gr.Button("全局时空断裂")
154
+ clear = gr.Button("打扫群内垃圾")
155
+ with gr.Row():
156
+ gr.Markdown(
157
+ """警告:该模型可能会生成事实上或道德上不正确的文本。NLPark和兮辞对此不承担任何责任。"""
158
+ )
159
+
160
+ # Pressing Enter
161
+ submit_event = msg.submit(
162
+ fn=user,
163
+ inputs=[msg, chatbot],
164
+ outputs=[msg, chatbot],
165
+ queue=False,
166
+ ).success(
167
+ fn=bot,
168
+ inputs=[
169
+ chatbot,
170
+ system_prompt,
171
+ top_p,
172
+ top_k,
173
+ temp
174
+ ],
175
+ outputs=chatbot,
176
+ queue=True,
177
+ )
178
+
179
+ # Pressing the button
180
+ submit_click_event = submit.click(
181
+ fn=user,
182
+ inputs=[msg, chatbot],
183
+ outputs=[msg, chatbot],
184
+ queue=False,
185
+ ).success(
186
+ fn=bot,
187
+ inputs=[
188
+ chatbot,
189
+ system_prompt,
190
+ top_p,
191
+ top_k,
192
+ temp
193
+ ],
194
+ outputs=chatbot,
195
+ queue=True,
196
+ )
197
+
198
+ # Stop generation
199
+ stop.click(
200
+ fn=None,
201
+ inputs=None,
202
+ outputs=None,
203
+ cancels=[submit_event, submit_click_event],
204
+ queue=False,
205
+ )
206
+
207
+ # Clear history
208
+ clear.click(lambda: None, None, chatbot, queue=False)
209
+
210
+ demo.queue(max_size=128, concurrency_count=1)
211
+ demo.launch()