File size: 4,659 Bytes
b376f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8d64ca
 
 
 
 
 
 
 
 
 
da09cca
b376f12
 
 
 
 
 
 
 
 
da09cca
 
 
 
b376f12
 
 
 
03c2ae6
 
 
 
 
 
 
 
 
b376f12
 
 
 
 
 
 
 
da09cca
 
 
 
 
03c2ae6
da09cca
b376f12
 
 
 
 
 
03c2ae6
b376f12
03c2ae6
 
 
b376f12
03c2ae6
b376f12
 
 
 
 
 
 
 
 
 
 
 
 
 
03c2ae6
 
 
 
 
 
 
 
 
da09cca
b376f12
 
 
da09cca
0a8d079
da09cca
 
 
 
 
 
 
 
 
8d34f84
da09cca
 
 
03c2ae6
 
 
 
 
 
 
 
da09cca
 
b376f12
da09cca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Template Demo for IBM Granite Hugging Face spaces."""

from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from themes.carbon import carbon_theme

today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002

SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 8b Instruct"
DESCRIPTION = """
<p>Granite 3.1 is a general purpose large language model released in the open under an Apache 2.0 license. Granite
models support a 128k context length.</p>

<p>Try one of the sample prompts below or write your own. Remember, AI models can make mistakes.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation</a> <i class="fa fa-external-link"></i>
</span>
</p>
"""
MAX_INPUT_TOKEN_LENGTH = 128_000
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05

if not torch.cuda.is_available():
    DESCRIPTION += "\nThis demo does not work on CPU."

model = AutoModelForCausalLM.from_pretrained(
    "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
tokenizer.use_default_system_prompt = False


@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    temperature: float = TEMPERATURE,
    top_p: float = TOP_P,
    top_k: float = TOP_K,
    repetition_penalty: float = REPETITION_PENALTY,
    max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
    """Generate function for chat demo."""
    # Build messages
    conversation = []
    conversation.append({"role": "system", "content": SYS_PROMPT})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    # Convert messages to prompt format
    input_ids = tokenizer.apply_chat_template(
        conversation,
        return_tensors="pt",
        add_generation_prompt=True,
        truncation=True,
        max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
    )

    input_ids = input_ids.to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")

# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature")
top_p_slider = gr.Slider(minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P")
top_k_slider = gr.Slider(minimum=0, maximum=100, value=TOP_K, step=1, label="Top K")
repetition_penalty_slider = gr.Slider(
    minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.1, label="Repetition Penalty"
)
max_new_tokens_slider = gr.Slider(minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens")
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)

with gr.Blocks(
    fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
) as demo:
    gr.HTML(
        f"<img src='https://www.ibm.com/granite/docs/images/granite-pictogram.svg'/><h1>{TITLE}</h1>",
        elem_classes=["gr_title"],
    )
    gr.HTML(DESCRIPTION)
    chat_interface = gr.ChatInterface(
        fn=generate,
        examples=[
            ["Explain quantum computing"],
            ["What is OpenShift?"],
            ["Importance of low latency inference"],
            ["Write a binary search in Python"],
        ],
        cache_examples=False,
        type="messages",
        additional_inputs=[
            temperature_slider,
            top_p_slider,
            top_k_slider,
            repetition_penalty_slider,
            max_new_tokens_slider,
        ],
        additional_inputs_accordion=chat_interface_accordion,
    )

if __name__ == "__main__":
    demo.queue().launch()