weihuang11 commited on
Commit
0d50143
·
1 Parent(s): e764a5d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
2
+ from threading import Thread
3
+ import gradio as gr
4
+ import transformers
5
+ import torch
6
+
7
+ # Run the entire app with `python run_mixtral.py`
8
+
9
+ """ The messages list should be of the following format:
10
+
11
+ messages =
12
+
13
+ [
14
+ {"role": "user", "content": "User's first message"},
15
+ {"role": "assistant", "content": "Assistant's first response"},
16
+ {"role": "user", "content": "User's second message"},
17
+ {"role": "assistant", "content": "Assistant's second response"},
18
+ {"role": "user", "content": "User's third message"}
19
+ ]
20
+
21
+ """
22
+ """ The `format_chat_history` function below is designed to format the dialogue history into a prompt that can be fed into the Mixtral model. This will help understand the context of the conversation and generate appropriate responses by the Model.
23
+ The function takes a history of dialogues as input, which is a list of lists where each sublist represents a pair of user and assistant messages.
24
+ """
25
+
26
+
27
+ def format_chat_history(history) -> str:
28
+ messages = [{"role": ("user" if i % 2 == 0 else "assistant"), "content": dialog[i % 2]}
29
+ for i, dialog in enumerate(history) for _ in (0, 1) if dialog[i % 2]]
30
+ # The conditional `(if dialog[i % 2])` ensures that messages
31
+ # that are None (like the latest assistant response in an ongoing
32
+ # conversation) are not included.
33
+ return pipeline.tokenizer.apply_chat_template(
34
+ messages, tokenize=False,
35
+ add_generation_prompt=True)
36
+
37
+
38
+ def model_loading_pipeline():
39
+ model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
41
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
42
+
43
+ pipeline = transformers.pipeline(
44
+ "text-generation",
45
+ model=model_id,
46
+ model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True,
47
+ "quantization_config": BitsAndBytesConfig(
48
+ load_in_4bit=True,
49
+ bnb_4bit_compute_dtype=torch.float16)},
50
+ streamer=streamer
51
+ )
52
+ return pipeline, streamer
53
+
54
+
55
+ def launch_gradio_app(pipeline, streamer):
56
+ with gr.Blocks() as demo:
57
+ chatbot = gr.Chatbot()
58
+ msg = gr.Textbox()
59
+ clear = gr.Button("Clear")
60
+
61
+ def user(user_message, history):
62
+ return "", history + [[user_message, None]]
63
+
64
+ def bot(history):
65
+ prompt = format_chat_history(history)
66
+
67
+ history[-1][1] = ""
68
+ kwargs = dict(text_inputs=prompt, max_new_tokens=2048,
69
+ do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
70
+ thread = Thread(target=pipeline, kwargs=kwargs)
71
+ thread.start()
72
+
73
+ for token in streamer:
74
+ history[-1][1] += token
75
+ yield history
76
+
77
+ msg.submit(user, [msg, chatbot], [msg, chatbot],
78
+ queue=False).then(bot, chatbot, chatbot)
79
+ clear.click(lambda: None, None, chatbot, queue=False)
80
+
81
+ demo.queue()
82
+ demo.launch(share=True, debug=True)
83
+
84
+
85
+ if __name__ == '__main__':
86
+ pipeline, streamer = model_loading_pipeline()
87
+ launch_gradio_app(pipeline, streamer)
88
+
89
+ # Run the entire app with `python run_mixtral.py`