grahamwhiteuk's picture
fix: feedback from latest design review
6997fc5
raw
history blame
4.93 kB
"""Template Demo for IBM Granite Hugging Face spaces."""
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from themes.research_monochrome import theme
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 8b Instruct"
DESCRIPTION = """
<p>Granite 3.1 is a general purpose large language model released in the open under an Apache 2.0 license. Granite
models support a 128k context length. Try one of the sample prompts below or write your own. Remember, AI models can
make mistakes. Note: This demo does not work on CPU.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
MAX_INPUT_TOKEN_LENGTH = 128_000
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05
if not torch.cuda.is_available():
DESCRIPTION += "\nThis demo does not work on CPU."
model = AutoModelForCausalLM.from_pretrained(
"ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
temperature: float = TEMPERATURE,
top_p: float = TOP_P,
top_k: float = TOP_K,
repetition_penalty: float = REPETITION_PENALTY,
max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
"""Generate function for chat demo."""
# Build messages
conversation = []
conversation.append({"role": "system", "content": SYS_PROMPT})
conversation += chat_history
conversation.append({"role": "user", "content": message})
# Convert messages to prompt format
input_ids = tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True,
truncation=True,
max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
)
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
)
top_p_slider = gr.Slider(
minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
)
top_k_slider = gr.Slider(
minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
)
repetition_penalty_slider = gr.Slider(
minimum=0,
maximum=2.0,
value=REPETITION_PENALTY,
step=0.1,
label="Repetition Penalty",
elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
minimum=1,
maximum=2000,
value=MAX_NEW_TOKENS,
step=1,
label="Max New Tokens",
elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
gr.HTML(
f"<img src='https://www.ibm.com/granite/docs/images/granite-pictogram.svg'/><h1>{TITLE}</h1>",
elem_classes=["gr_title"],
)
gr.HTML(DESCRIPTION)
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Explain quantum computing"],
["What is OpenShift?"],
["Importance of low latency inference"],
["Write a binary search in Python"],
],
cache_examples=False,
type="messages",
additional_inputs=[
temperature_slider,
repetition_penalty_slider,
top_p_slider,
top_k_slider,
max_new_tokens_slider,
],
additional_inputs_accordion=chat_interface_accordion,
)
if __name__ == "__main__":
demo.queue().launch()