iKING-ROC commited on
Commit
57f2485
1 Parent(s): f8c5db3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import gradio as gr
5
+ import mdtex2html
6
+ from gradio.themes.utils import colors, fonts, sizes
7
+ import torch
8
+ from peft import PeftModel
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ AutoModelForCausalLM,
13
+ BloomForCausalLM,
14
+ BloomTokenizerFast,
15
+ LlamaTokenizer,
16
+ LlamaForCausalLM,
17
+ GenerationConfig,
18
+ )
19
+
20
+ MODEL_CLASSES = {
21
+ "bloom": (BloomForCausalLM, BloomTokenizerFast),
22
+ "chatglm": (AutoModel, AutoTokenizer),
23
+ "llama": (LlamaForCausalLM, LlamaTokenizer),
24
+ "auto": (AutoModelForCausalLM, AutoTokenizer),
25
+ }
26
+
27
+ class OpenGVLab(gr.themes.base.Base):
28
+ def __init__(
29
+ self,
30
+ *,
31
+ primary_hue=colors.blue,
32
+ secondary_hue=colors.sky,
33
+ neutral_hue=colors.gray,
34
+ spacing_size=sizes.spacing_md,
35
+ radius_size=sizes.radius_sm,
36
+ text_size=sizes.text_md,
37
+ font=(
38
+ fonts.GoogleFont("Noto Sans"),
39
+ "ui-sans-serif",
40
+ "sans-serif",
41
+ ),
42
+ font_mono=(
43
+ fonts.GoogleFont("IBM Plex Mono"),
44
+ "ui-monospace",
45
+ "monospace",
46
+ ),
47
+ ):
48
+ super().__init__(
49
+ primary_hue=primary_hue,
50
+ secondary_hue=secondary_hue,
51
+ neutral_hue=neutral_hue,
52
+ spacing_size=spacing_size,
53
+ radius_size=radius_size,
54
+ text_size=text_size,
55
+ font=font,
56
+ font_mono=font_mono,
57
+ )
58
+ super().set(
59
+ body_background_fill="*neutral_50",
60
+ )
61
+
62
+
63
+ gvlabtheme = OpenGVLab(primary_hue=colors.blue,
64
+ secondary_hue=colors.sky,
65
+ neutral_hue=colors.gray,
66
+ spacing_size=sizes.spacing_md,
67
+ radius_size=sizes.radius_sm,
68
+ text_size=sizes.text_md,
69
+ )
70
+
71
+ def main():
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument('--model_type', default="llama", type=str)
74
+ parser.add_argument('--base_model', default=r"/data/wangpeng/JiaotongGPT-main/merged-sft-no-1ep", type=str)
75
+ parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
76
+ parser.add_argument('--tokenizer_path', default=None, type=str)
77
+ parser.add_argument('--gpus', default="0", type=str)
78
+ parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
79
+ parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
80
+ args = parser.parse_args()
81
+ if args.only_cpu is True:
82
+ args.gpus = ""
83
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
84
+
85
+ def postprocess(self, y):
86
+ if y is None:
87
+ return []
88
+ for i, (message, response) in enumerate(y):
89
+ y[i] = (
90
+ None if message is None else mdtex2html.convert((message)),
91
+ None if response is None else mdtex2html.convert(response),
92
+ )
93
+ return y
94
+
95
+ gr.Chatbot.postprocess = postprocess
96
+
97
+ generation_config = dict(
98
+ temperature=0.2,
99
+ top_k=40,
100
+ top_p=0.9,
101
+ do_sample=True,
102
+ num_beams=1,
103
+ repetition_penalty=1.1,
104
+ max_new_tokens=400
105
+ )
106
+ load_type = torch.float16
107
+ if torch.cuda.is_available():
108
+ device = torch.device(0)
109
+ else:
110
+ device = torch.device('cpu')
111
+
112
+ if args.tokenizer_path is None:
113
+ args.tokenizer_path = args.base_model
114
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
115
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
116
+ base_model = model_class.from_pretrained(
117
+ args.base_model,
118
+ load_in_8bit=False,
119
+ torch_dtype=load_type,
120
+ low_cpu_mem_usage=True,
121
+ device_map='auto',
122
+ trust_remote_code=True,
123
+ )
124
+ if args.resize_emb:
125
+ model_vocab_size = base_model.get_input_embeddings().weight.size(0)
126
+ tokenzier_vocab_size = len(tokenizer)
127
+ print(f"Vocab of the base model: {model_vocab_size}")
128
+ print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
129
+ if model_vocab_size != tokenzier_vocab_size:
130
+ print("Resize model embeddings to fit tokenizer")
131
+ base_model.resize_token_embeddings(tokenzier_vocab_size)
132
+ if args.lora_model:
133
+ model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
134
+ print("loaded lora model")
135
+ else:
136
+ model = base_model
137
+
138
+ if device == torch.device('cpu'):
139
+ model.float()
140
+
141
+ model.eval()
142
+
143
+ def reset_user_input():
144
+ return gr.update(value='')
145
+
146
+ def reset_state():
147
+ return [], []
148
+
149
+ def generate_prompt(instruction):
150
+ return f"""You are TransGPT, a specialist in the field of transportation.Below is an instruction that describes a task. Write a response that appropriately completes the request.
151
+
152
+ ### Instruction:
153
+ {instruction}
154
+
155
+ ### Response: """
156
+
157
+ def predict(
158
+ input,
159
+ chatbot,
160
+ history,
161
+ max_new_tokens=128,
162
+ top_p=0.75,
163
+ temperature=0.1,
164
+ top_k=40,
165
+ num_beams=4,
166
+ repetition_penalty=1.0,
167
+ max_memory=256,
168
+ **kwargs,
169
+ ):
170
+ now_input = input
171
+ chatbot.append((input, ""))
172
+ history = history or []
173
+ if len(history) != 0:
174
+ input = "".join(
175
+ ["### Instruction:\n" + i[0] + "\n\n" + "### Response: " + i[1] + "\n\n" for i in history]) + \
176
+ "### Instruction:\n" + input
177
+ input = input[len("### Instruction:\n"):]
178
+ if len(input) > max_memory:
179
+ input = input[-max_memory:]
180
+ prompt = generate_prompt(input)
181
+ inputs = tokenizer(prompt, return_tensors="pt")
182
+ input_ids = inputs["input_ids"].to(device)
183
+ generation_config = GenerationConfig(
184
+ temperature=temperature,
185
+ top_p=top_p,
186
+ top_k=top_k,
187
+ num_beams=num_beams,
188
+ **kwargs,
189
+ )
190
+ with torch.no_grad():
191
+ generation_output = model.generate(
192
+ input_ids=input_ids,
193
+ generation_config=generation_config,
194
+ return_dict_in_generate=True,
195
+ output_scores=False,
196
+ max_new_tokens=max_new_tokens,
197
+ repetition_penalty=float(repetition_penalty),
198
+ )
199
+ s = generation_output.sequences[0]
200
+ output = tokenizer.decode(s, skip_special_tokens=True)
201
+ output = output.split("### Response:")[-1].strip()
202
+ history.append((now_input, output))
203
+ chatbot[-1] = (now_input, output)
204
+ return chatbot, history
205
+
206
+ title = """<h1 align="center">Welcome to TransGPT!"""
207
+
208
+ with gr.Blocks(title="DUOMO TransGPT!", theme=gvlabtheme,
209
+ css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
210
+ gr.Markdown(title)
211
+ # with gr.Blocks() as demo:
212
+ # gr.HTML("""<h1 align="center">TransGPT</h1>""")
213
+ # # gr.Markdown(
214
+ # # "> 为了促进医疗行业大模型的开放研究,本项目开源了TransGPT医疗大模型")
215
+ chatbot = gr.Chatbot()
216
+ with gr.Row():
217
+ with gr.Column(scale=4):
218
+ with gr.Column(scale=12):
219
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
220
+ container=False)
221
+ with gr.Column(min_width=32, scale=1):
222
+ submitBtn = gr.Button("Submit", variant="primary")
223
+ with gr.Column(scale=1):
224
+ emptyBtn = gr.Button("Clear History")
225
+ max_length = gr.Slider(
226
+ 0, 4096, value=128, step=1.0, label="Maximum length", interactive=True)
227
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01,
228
+ label="Top P", interactive=True)
229
+ temperature = gr.Slider(
230
+ 0, 1, value=0.7, step=0.01, label="Temperature", interactive=True)
231
+
232
+ history = gr.State([]) # (message, bot_message)
233
+
234
+ submitBtn.click(predict, [user_input, chatbot, history, max_length, top_p, temperature], [chatbot, history],
235
+ show_progress=True)
236
+ submitBtn.click(reset_user_input, [], [user_input])
237
+
238
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
239
+ demo.queue().launch(share=True, inbrowser=True, server_name='0.0.0.0', server_port=8080)
240
+
241
+
242
+ if __name__ == '__main__':
243
+ main()