File size: 3,263 Bytes
0d50143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from transformers import AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread
import gradio as gr
import transformers
import torch

# Run the entire app with `python run_mixtral.py`

""" The messages list should be of the following format:

messages =

[
    {"role": "user", "content": "User's first message"},
    {"role": "assistant", "content": "Assistant's first response"},
    {"role": "user", "content": "User's second message"},
    {"role": "assistant", "content": "Assistant's second response"},
    {"role": "user", "content": "User's third message"}
]

"""
""" The `format_chat_history` function below is designed to format the dialogue history into a prompt that can be fed into the Mixtral model. This will help understand the context of the conversation and generate appropriate responses by the Model.
The function takes a history of dialogues as input, which is a list of lists where each sublist represents a pair of user and assistant messages.
"""


def format_chat_history(history) -> str:
    messages = [{"role": ("user" if i % 2 == 0 else "assistant"), "content": dialog[i % 2]}
                for i, dialog in enumerate(history) for _ in (0, 1) if dialog[i % 2]]
    # The conditional `(if dialog[i % 2])` ensures that messages
    # that are None (like the latest assistant response in an ongoing
    # conversation) are not included.
    return pipeline.tokenizer.apply_chat_template(
        messages, tokenize=False,
        add_generation_prompt=True)


def model_loading_pipeline():
    model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)

    pipeline = transformers.pipeline(
        "text-generation",
        model=model_id,
        model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True,
                      "quantization_config": BitsAndBytesConfig(
                          load_in_4bit=True,
                          bnb_4bit_compute_dtype=torch.float16)},
        streamer=streamer
    )
    return pipeline, streamer


def launch_gradio_app(pipeline, streamer):
    with gr.Blocks() as demo:
        chatbot = gr.Chatbot()
        msg = gr.Textbox()
        clear = gr.Button("Clear")

        def user(user_message, history):
            return "", history + [[user_message, None]]

        def bot(history):
            prompt = format_chat_history(history)

            history[-1][1] = ""
            kwargs = dict(text_inputs=prompt, max_new_tokens=2048,
                          do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
            thread = Thread(target=pipeline, kwargs=kwargs)
            thread.start()

            for token in streamer:
                history[-1][1] += token
                yield history

        msg.submit(user, [msg, chatbot], [msg, chatbot],
                   queue=False).then(bot, chatbot, chatbot)
        clear.click(lambda: None, None, chatbot, queue=False)

    demo.queue()
    demo.launch(share=True, debug=True)


if __name__ == '__main__':
    pipeline, streamer = model_loading_pipeline()
    launch_gradio_app(pipeline, streamer)

# Run the entire app with `python run_mixtral.py`