Islam YAHIAOUI
Update UI
c41e5ec
raw
history blame
9.16 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
from rag import run_rag
# ================================================================================================================================
TOKEN = os.getenv("HF_TOKEN")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta" )
system_message ="You are a capable and freindly assistant."
history = []
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
# ================================================================================================================================
class Int_State:
def __init__(self):
# initialise history of type list[tuple[str, str]]
self.history = []
self.current_query = ""
self.current_response = ""
self.roles = ["user", "system"]
print("State has been initialise")
def save_question(self, question):
self.current_query = question
self.current_response = ""
self.history.append({"role": "user", "content": question})
print("Question added ")
def save_response(self, assistant_message):
current_question = self.current_query
self.current_response = assistant_message
self.history.append({"role": "system", "content": assistant_message})
print("Response saved ")
def get_history(self):
return self.history
# ================================================================================================================================
state = Int_State()
# ================================================================================================================================
# def clear_chat(state):
# state.history = []
# return ("",) + (disable_btn,) * 3
# ================================================================================================================================
# def enable_buttons(btn_list ,upvote=False, downvote=False, flag=False, regenerate=False, clear=False):
# return [gr.Button(interactive=upvote), gr.Button(interactive=downvote), gr.Button(interactive=flag), gr.Button(interactive=regenerate), gr.Button(interactive=clear)]
# def disable_buttons(btn_list, upvote=True, downvote=True, flag=True, regenerate=False, clear=False):
# if upvote:
# btn_list[0] = disable_btn
# if downvote:
# btn_list[1] = disable_btn
# if flag:
# btn_list[2] = disable_btn
# if regenerate:
# btn_list[3] = disable_btn
# if clear:
# btn_list[4] = disable_btn
# return btn_list
# def upvote_last_response(btn_list):
# # upvote the last response
# print("Upvoted")
# return disable_buttons(btn_list)
# def downvote_last_response(btn_list):
# # downvote the last response
# print("Downvoted")
# return disable_buttons(btn_list)
# def flag_last_response(btn_list):
# # flag the last response
# print("Flagged")
# return disable_buttons(btn_list)
def chat(
chatbot,
message,
max_tokens,
temperature,
top_p,
):
print("Message: ", message)
print("System Message: ", system_message)
print("Max Tokens: ", max_tokens)
print("Temperature: ", temperature)
print("Top P: ", top_p)
question= message
messages = [{"role": "system", "content": system_message}]
history= state.get_history()
state.save_question(message)
for val in history:
messages.append(val)
messages.append({"role": "user", "content": message})
response = ""
for msg in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = msg.choices[0].delta.content
response += str(token)
# yield "" , chatbot
chatbot.append((question , response))
state.save_response(response)
# btn_list =enable_buttons(btn_list)
# yield "" , chatbot + (enable_btn,) * 5
yield "" , chatbot
# ================================================================================================================================
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
EXAMPLES = [
[ "Tell me about the latest news in the world ?"],
[ "Tell me about the increase in the price of Bitcoin ?"],
[ "Tell me about the actual situation in Ukraine ?"],
[ "Tell me about current situation in palestine ?"],
]
# max_new_tokens = gr.Slider(
# minimum=1,
# maximum=2048,
# value=1024,
# step=1,
# interactive=True,
# label="Max new tokens",
# )
# temperature = gr.Slider(
# minimum=0.1,
# maximum=0.9,
# value=0.6,
# step=0.1,
# visible=True,
# interactive=True,
# label="Temperature",
# info="Higher values will produce more diverse outputs.",
# )
# top_p = gr.Slider(
# minimum=0.1,
# maximum=1,
# value=0.9,
# step=0.05,
# visible=True,
# interactive=True,
# label="Top-p (nucleus sampling)",
# info="Higher values is equivalent to sampling more low-probability tokens.",
# )
# ================================================================================================================================
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
# ================================================================================================================================
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press ENTER",
container=False,
show_copy_button=True
)
with gr.Blocks(title="RAG", theme=theme, css=block_css) as demo:
gr.Markdown("Retrieval Augmented Generation (RAG) Chatbot" )
with gr.Row():
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Retrieval Augmented Generation (RAG) Chatbot",
height=400,
layout="bubble",
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=100):
submit_btn = gr.Button(value="Submit", variant="primary", interactive=True)
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
with gr.Column(scale=3):
gr.Examples(examples=[
[f"Tell me about the latest news in the world ?"],
[f"Tell me about the increase in the price of Bitcoin ?"],
[f"Tell me about the actual situation in Ukraine ?"],
[f"Tell me about current situation in palestinian ?"],
],inputs=[textbox], label="Examples")
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
# ================================================================================================================================
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
# upvote_btn.click(
# upvote_last_response,
# [btn_list],
# [btn_list]
# )
# downvote_btn.click(
# downvote_last_response,
# [btn_list],
# [btn_list]
# )
# flag_btn.click(
# flag_last_response,
# [btn_list],
# [btn_list]
# )
# regenerate_btn.click(
# chat,
# [btn_list, chatbot, textbox, max_output_tokens, temperature, top_p],
# [textbox, chatbot , btn_list]
# )
# clear_btn.click(
# clear_chat,
# [disable_btn] * 5
# )
submit_btn.click(
chat ,
[chatbot, textbox , max_output_tokens, temperature, top_p],
[textbox ,chatbot]
)
# ================================================================================================================================
demo.launch()
# ================================================================================================================================