Tobias Bergmann
smaller font
cbef7a0
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
from huggingface_hub import hf_hub_download
from themes.research_monochrome import theme
from typing import Iterator, List, Dict
import requests
import json
import subprocess
import gradio as gr
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
SYS_PROMPT = f"""Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 3b a800 MoE Instruct from local GGUF server"
DESCRIPTION = """
<p>Granite 3.1 3b instruct is an open-source LLM supporting a 128k context window. This demo uses only 2K context.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
LLAMA_CPP_SERVER = "http://127.0.0.1:8081"
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05
# download GGUF into local directory
gguf_path = hf_hub_download(
repo_id="bartowski/granite-3.1-3b-a800m-instruct-GGUF",
filename="granite-3.1-3b-a800m-instruct-Q8_0.gguf",
local_dir="."
)
# start llama-server
subprocess.run(["chmod", "+x", "llama-server"])
command = ["./llama-server", "-m", "granite-3.1-3b-a800m-instruct-Q8_0.gguf", "-ngl", "0", "--temp", "0.0", "-c", "2048", "-t", "8", "--port", "8081"]
process = subprocess.Popen(command)
print(f"Llama-server process started with PID {process.pid}")
def generate(
message: str,
chat_history: List[Dict],
temperature: float = TEMPERATURE,
repetition_penalty: float = REPETITION_PENALTY,
top_p: float = TOP_P,
top_k: float = TOP_K,
max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
"""Generate function for chat demo using Llama.cpp server."""
# Build messages
conversation = []
conversation.append({"role": "system", "content": SYS_PROMPT})
conversation += chat_history
conversation.append({"role": "user", "content": message})
# Prepare the prompt for the Llama.cpp server
prompt = ""
for item in conversation:
if item["role"] == "system":
prompt += f"<|system|>\n{item['content']}\n<|file_separator|>\n"
elif item["role"] == "user":
prompt += f"<|user|>\n{item['content']}\n<|file_separator|>\n"
elif item["role"] == "assistant":
prompt += f"<|model|>\n{item['content']}\n<|file_separator|>\n"
prompt += "<|model|>\n" # Add the beginning token for the assistant
# Construct the request payload
payload = {
"prompt": prompt,
"stream": True, # Enable streaming
"max_tokens": max_new_tokens,
"temperature": temperature,
"repeat_penalty": repetition_penalty,
"top_p": top_p,
"top_k": top_k,
"stop": ["<|file_separator|>"], #stops after it sees this
}
try:
# Make the request to the Llama.cpp server
with requests.post(f"{LLAMA_CPP_SERVER}/completion", json=payload, stream=True, timeout=60) as response:
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
# Stream the response from the server
outputs = []
for line in response.iter_lines():
if line:
# Decode the line
decoded_line = line.decode('utf-8')
# Remove 'data: ' prefix if present
if decoded_line.startswith("data: "):
decoded_line = decoded_line[6:]
# Handle potential JSON decoding errors
try:
json_data = json.loads(decoded_line)
text = json_data.get("content", "") # Extract content field. crucial.
if text:
outputs.append(text)
yield "".join(outputs)
except json.JSONDecodeError:
print(f"JSONDecodeError: {decoded_line}")
# Handle the error, potentially skipping the line or logging it.
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
yield f"Error: {e}" # Yield an error message to the user
except Exception as e:
print(f"An unexpected error occurred: {e}")
yield f"Error: {e}" # Yield error message
css_file_path = Path(Path(__file__).parent / "app.css")
# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
)
top_p_slider = gr.Slider(
minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
)
top_k_slider = gr.Slider(
minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
)
repetition_penalty_slider = gr.Slider(
minimum=0,
maximum=2.0,
value=REPETITION_PENALTY,
step=0.05,
label="Repetition Penalty",
elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
minimum=1,
maximum=2000,
value=MAX_NEW_TOKENS,
step=1,
label="Max New Tokens",
elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
with gr.Blocks(fill_height=True, css_paths=css_file_path, theme=theme, title=TITLE) as demo:
gr.HTML(f"<h2>{TITLE}</h2>", elem_classes=["gr_title"])
gr.HTML(DESCRIPTION)
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Explain the concept of quantum computing to someone with no background in physics or computer science."],
["What is OpenShift?"],
["What's the importance of low latency inference?"],
["Help me boost productivity habits."],
],
example_labels=[
"Explain quantum computing",
"What is OpenShift?",
"Importance of low latency inference",
"Boosting productivity habits",
],
cache_examples=False,
type="messages",
additional_inputs=[
temperature_slider,
repetition_penalty_slider,
top_p_slider,
top_k_slider,
max_new_tokens_slider,
],
additional_inputs_accordion=chat_interface_accordion,
)
if __name__ == "__main__":
demo.queue().launch()