import gradio as gr from huggingface_hub import InferenceClient import os from rag import run_rag # ================================================================================================================================ TOKEN = os.getenv("HF_TOKEN") client = InferenceClient("HuggingFaceH4/zephyr-7b-beta" ) system_message ="You are a capable and freindly assistant." history = [] no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) # ================================================================================================================================ class Int_State: def __init__(self): # initialise history of type list[tuple[str, str]] self.history = [] self.current_query = "" self.current_response = "" self.roles = ["user", "system"] print("State has been initialise") def save_question(self, question): self.current_query = question self.current_response = "" self.history.append({"role": "user", "content": question}) print("Question added ") def save_response(self, assistant_message): current_question = self.current_query self.current_response = assistant_message self.history.append({"role": "system", "content": assistant_message}) print("Response saved ") def get_history(self): return self.history # ================================================================================================================================ state = Int_State() # ================================================================================================================================ # def clear_chat(state): # state.history = [] # return ("",) + (disable_btn,) * 3 # ================================================================================================================================ # def enable_buttons(btn_list ,upvote=False, downvote=False, flag=False, regenerate=False, clear=False): # return [gr.Button(interactive=upvote), gr.Button(interactive=downvote), gr.Button(interactive=flag), gr.Button(interactive=regenerate), gr.Button(interactive=clear)] # def disable_buttons(btn_list, upvote=True, downvote=True, flag=True, regenerate=False, clear=False): # if upvote: # btn_list[0] = disable_btn # if downvote: # btn_list[1] = disable_btn # if flag: # btn_list[2] = disable_btn # if regenerate: # btn_list[3] = disable_btn # if clear: # btn_list[4] = disable_btn # return btn_list # def upvote_last_response(btn_list): # # upvote the last response # print("Upvoted") # return disable_buttons(btn_list) # def downvote_last_response(btn_list): # # downvote the last response # print("Downvoted") # return disable_buttons(btn_list) # def flag_last_response(btn_list): # # flag the last response # print("Flagged") # return disable_buttons(btn_list) def chat( chatbot, message, max_tokens, temperature, top_p, ): print("Message: ", message) print("System Message: ", system_message) print("Max Tokens: ", max_tokens) print("Temperature: ", temperature) print("Top P: ", top_p) question= message messages = [{"role": "system", "content": system_message}] history= state.get_history() state.save_question(message) for val in history: messages.append(val) messages.append({"role": "user", "content": message}) response = "" for msg in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = msg.choices[0].delta.content response += str(token) # yield "" , chatbot chatbot.append((question , response)) state.save_response(response) # btn_list =enable_buttons(btn_list) # yield "" , chatbot + (enable_btn,) * 5 yield "" , chatbot # ================================================================================================================================ theme = gr.themes.Base( font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'], ) EXAMPLES = [ [ "Tell me about the latest news in the world ?"], [ "Tell me about the increase in the price of Bitcoin ?"], [ "Tell me about the actual situation in Ukraine ?"], [ "Tell me about current situation in palestine ?"], ] # max_new_tokens = gr.Slider( # minimum=1, # maximum=2048, # value=1024, # step=1, # interactive=True, # label="Max new tokens", # ) # temperature = gr.Slider( # minimum=0.1, # maximum=0.9, # value=0.6, # step=0.1, # visible=True, # interactive=True, # label="Temperature", # info="Higher values will produce more diverse outputs.", # ) # top_p = gr.Slider( # minimum=0.1, # maximum=1, # value=0.9, # step=0.05, # visible=True, # interactive=True, # label="Top-p (nucleus sampling)", # info="Higher values is equivalent to sampling more low-probability tokens.", # ) # ================================================================================================================================ block_css = """ #buttons button { min-width: min(120px,100%); } """ # ================================================================================================================================ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False, show_copy_button=True ) with gr.Blocks(title="RAG", theme=theme, css=block_css) as demo: gr.Markdown("Retrieval Augmented Generation (RAG) Chatbot" ) with gr.Row(): with gr.Column(scale=8): chatbot = gr.Chatbot( elem_id="chatbot", label="Retrieval Augmented Generation (RAG) Chatbot", height=400, layout="bubble", ) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=100): submit_btn = gr.Button(value="Submit", variant="primary", interactive=True) with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="⚠ī¸ Flag", interactive=False) #stop_btn = gr.Button(value="⏚ī¸ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑ī¸ Clear", interactive=False) with gr.Column(scale=3): gr.Examples(examples=[ [f"Tell me about the latest news in the world ?"], [f"Tell me about the increase in the price of Bitcoin ?"], [f"Tell me about the actual situation in Ukraine ?"], [f"Tell me about current situation in palestinian ?"], ],inputs=[textbox], label="Examples") with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) # ================================================================================================================================ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] # upvote_btn.click( # upvote_last_response, # [btn_list], # [btn_list] # ) # downvote_btn.click( # downvote_last_response, # [btn_list], # [btn_list] # ) # flag_btn.click( # flag_last_response, # [btn_list], # [btn_list] # ) # regenerate_btn.click( # chat, # [btn_list, chatbot, textbox, max_output_tokens, temperature, top_p], # [textbox, chatbot , btn_list] # ) # clear_btn.click( # clear_chat, # [disable_btn] * 5 # ) submit_btn.click( chat , [chatbot, textbox , max_output_tokens, temperature, top_p], [textbox ,chatbot] ) # ================================================================================================================================ demo.launch() # ================================================================================================================================