Islam YAHIAOUI
Update UI
96f677c
raw
history blame
8.69 kB
import json
import gradio as gr
from huggingface_hub import InferenceClient
import os
import requests
from rag import run_rag
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
TOKEN = os.getenv("HF_TOKEN")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta" , token=TOKEN)
system_message ="You are a capable and freindly assistant."
history = []
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
# ================================================================================================================================
# ================================================================================================================================
def chat(
state,
message,
# history: list[tuple[str, str]],
max_tokens,
temperature,
top_p,
):
print("Message: ", message)
print("History: ", history)
print("System Message: ", system_message)
print("Max Tokens: ", max_tokens)
print("Temperature: ", temperature)
print("Top P: ", top_p)
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# message =run_rag(message, history)
messages.append({"role": "user", "content": run_rag(message)})
response = ""
if state is None:
state = gr.State()
state.messages = [[("assistant", "")]]
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += str(token)
state.messages[-1][-1] = str(token)
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
# ================================================================================================================================
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
EXAMPLES = [
[ "Tell me about the latest news in the world ?"],
[ "Tell me about the increase in the price of Bitcoin ?"],
[ "Tell me about the actual situation in Ukraine ?"],
[ "Tell me about current situation in palestine ?"],
]
max_new_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=1024,
step=1,
interactive=True,
label="Max new tokens",
)
temperature = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.6,
step=0.1,
visible=True,
interactive=True,
label="Temperature",
info="Higher values will produce more diverse outputs.",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1,
value=0.9,
step=0.05,
visible=True,
interactive=True,
label="Top-p (nucleus sampling)",
info="Higher values is equivalent to sampling more low-probability tokens.",
)
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
# ================================================================================================================================
# with gr.Blocks(
# fill_height=True,
# css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
# ) as main:
# gr.ChatInterface(
# chat,
# chatbot=chatbot,
# title="Retrieval Augmented Generation (RAG) Chatbot",
# examples=EXAMPLES,
# theme=theme,
# fill_height=True,
# additional_inputs=[
# max_new_tokens,
# temperature,
# top_p,
# ],
# )
# with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="RAG") as demo:
# gr.TabbedInterface([main ] , tab_names=["Chatbot"] )
# demo.launch()
def upvote_last_response(state):
return ("",) + (disable_btn,) * 3
def downvote_last_response(state):
return ("",) + (disable_btn,) * 3
def flag_last_response(state):
return ("",) + (disable_btn,) * 3
def add_text(state ,textbox ):
print("textbox: ", textbox)
if state is None:
state = gr.State()
state.messages = [[("assistant", "")]]
state.text = textbox
history=""
state.append_message(state.roles[0], textbox)#
state.append_message(state.roles[1], "")
yield (state, None, history) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
# ================================================================================================================================
with gr.Blocks(title="CuMo", theme=theme, css=block_css) as demo:
state = gr.State()
gr.Markdown("Retrieval Augmented Generation (RAG) Chatbot" )
with gr.Row():
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Retrieval Augmented Generation (RAG) Chatbot",
height=400,
layout="bubble",
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=100):
submit_btn = gr.Button(value="Submit", variant="primary" )
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
with gr.Column(scale=3):
gr.Examples(examples=[
[f"Tell me about the latest news in the world ?"],
[f"Tell me about the increase in the price of Bitcoin ?"],
[f"Tell me about the actual situation in Ukraine ?"],
[f"Tell me about current situation in palestinian ?"],
],inputs=[textbox], label="Examples")
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
# ================================================================================================================================
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
downvote_btn.click(
downvote_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
flag_btn.click(
flag_last_response,
[state],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
textbox.submit(
add_text,
[state, textbox],
[state, chatbot, textbox] + btn_list,
).then(
chat,
[state, textbox,max_output_tokens, temperature, top_p],
[state, chatbot, textbox] + btn_list,
)
submit_btn.click(
add_text,
[state , textbox],
[state,chatbot, textbox] + btn_list,
).then(
chat,
[state, textbox, max_output_tokens , temperature, top_p ],
[state,chatbot, textbox] + btn_list,
)
# ================================================================================================================================
demo.launch()
# ================================================================================================================================