Islam YAHIAOUI commited on
Commit
96f677c
·
1 Parent(s): 31e6eb8
Files changed (2) hide show
  1. app.py +166 -40
  2. example.py +0 -102
app.py CHANGED
@@ -9,16 +9,33 @@ from rag import run_rag
9
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
  """
11
 
12
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
13
 
 
 
 
 
 
 
 
 
 
 
14
  def chat(
 
15
  message,
16
- history: list[tuple[str, str]],
17
- system_message,
18
  max_tokens,
19
  temperature,
20
  top_p,
21
  ):
 
 
 
 
 
 
 
22
  messages = [{"role": "system", "content": system_message}]
23
 
24
  for val in history:
@@ -26,14 +43,17 @@ def chat(
26
  messages.append({"role": "user", "content": val[0]})
27
  if val[1]:
28
  messages.append({"role": "assistant", "content": val[1]})
29
- message =run_rag(message, history)
30
 
31
- messages.append({"role": "user", "content": message})
32
-
33
  response = ""
34
-
 
 
 
35
  for message in client.chat_completion(
36
  messages,
 
37
  max_tokens=max_tokens,
38
  stream=True,
39
  temperature=temperature,
@@ -41,18 +61,13 @@ def chat(
41
  ):
42
  token = message.choices[0].delta.content
43
  response += str(token)
44
-
45
- yield response
 
 
 
 
46
 
47
- """
48
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
49
- """
50
- chatbot = gr.Chatbot(
51
- label="Retrieval Augmented Generation News & Finance",
52
- # avatar_images=[None, BOT_AVATAR],
53
- show_copy_button=True,
54
- likeable=True,
55
- layout="bubble")
56
  theme = gr.themes.Base(
57
  font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
58
  )
@@ -65,7 +80,7 @@ EXAMPLES = [
65
  max_new_tokens = gr.Slider(
66
  minimum=1,
67
  maximum=2048,
68
- value=512,
69
  step=1,
70
  interactive=True,
71
  label="Max new tokens",
@@ -90,28 +105,139 @@ top_p = gr.Slider(
90
  label="Top-p (nucleus sampling)",
91
  info="Higher values is equivalent to sampling more low-probability tokens.",
92
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- with gr.Blocks(
95
- fill_height=True,
96
- css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
97
- ) as main:
98
- gr.ChatInterface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  chat,
100
- chatbot=chatbot,
101
- title="Retrieval Augmented Generation (RAG) Chatbot",
102
- description="A chatbot that uses a RAG model to generate responses based on the input query.",
103
- examples=EXAMPLES,
104
- theme=theme,
105
- fill_height=True,
106
- multimodal=True,
107
- additional_inputs=[
108
- max_new_tokens,
109
- temperature,
110
- top_p,
111
- ],
112
  )
113
- with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="RAG") as demo:
114
- gr.TabbedInterface([main] , tab_names=["Chatbot"] )
115
  demo.launch()
116
-
117
-
 
9
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
  """
11
 
12
+ TOKEN = os.getenv("HF_TOKEN")
13
 
14
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta" , token=TOKEN)
15
+ system_message ="You are a capable and freindly assistant."
16
+ history = []
17
+ no_change_btn = gr.Button()
18
+ enable_btn = gr.Button(interactive=True)
19
+ disable_btn = gr.Button(interactive=False)
20
+
21
+ # ================================================================================================================================
22
+ # ================================================================================================================================
23
+
24
  def chat(
25
+ state,
26
  message,
27
+ # history: list[tuple[str, str]],
 
28
  max_tokens,
29
  temperature,
30
  top_p,
31
  ):
32
+ print("Message: ", message)
33
+ print("History: ", history)
34
+ print("System Message: ", system_message)
35
+ print("Max Tokens: ", max_tokens)
36
+ print("Temperature: ", temperature)
37
+ print("Top P: ", top_p)
38
+
39
  messages = [{"role": "system", "content": system_message}]
40
 
41
  for val in history:
 
43
  messages.append({"role": "user", "content": val[0]})
44
  if val[1]:
45
  messages.append({"role": "assistant", "content": val[1]})
46
+ # message =run_rag(message, history)
47
 
48
+ messages.append({"role": "user", "content": run_rag(message)})
 
49
  response = ""
50
+ if state is None:
51
+ state = gr.State()
52
+ state.messages = [[("assistant", "")]]
53
+
54
  for message in client.chat_completion(
55
  messages,
56
+
57
  max_tokens=max_tokens,
58
  stream=True,
59
  temperature=temperature,
 
61
  ):
62
  token = message.choices[0].delta.content
63
  response += str(token)
64
+ state.messages[-1][-1] = str(token)
65
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
66
+
67
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
68
+
69
+ # ================================================================================================================================
70
 
 
 
 
 
 
 
 
 
 
71
  theme = gr.themes.Base(
72
  font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
73
  )
 
80
  max_new_tokens = gr.Slider(
81
  minimum=1,
82
  maximum=2048,
83
+ value=1024,
84
  step=1,
85
  interactive=True,
86
  label="Max new tokens",
 
105
  label="Top-p (nucleus sampling)",
106
  info="Higher values is equivalent to sampling more low-probability tokens.",
107
  )
108
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
109
+ # ================================================================================================================================
110
+
111
+ # with gr.Blocks(
112
+ # fill_height=True,
113
+ # css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
114
+ # ) as main:
115
+ # gr.ChatInterface(
116
+ # chat,
117
+ # chatbot=chatbot,
118
+ # title="Retrieval Augmented Generation (RAG) Chatbot",
119
+ # examples=EXAMPLES,
120
+ # theme=theme,
121
+ # fill_height=True,
122
+ # additional_inputs=[
123
+
124
+ # max_new_tokens,
125
+ # temperature,
126
+ # top_p,
127
+ # ],
128
+ # )
129
+
130
+
131
+
132
+ # with gr.Blocks(theme=theme, css="footer {visibility: hidden}textbox{resize:none}", title="RAG") as demo:
133
+ # gr.TabbedInterface([main ] , tab_names=["Chatbot"] )
134
+
135
+ # demo.launch()
136
+
137
+
138
+ def upvote_last_response(state):
139
+ return ("",) + (disable_btn,) * 3
140
 
141
+ def downvote_last_response(state):
142
+ return ("",) + (disable_btn,) * 3
143
+
144
+ def flag_last_response(state):
145
+ return ("",) + (disable_btn,) * 3
146
+
147
+ def add_text(state ,textbox ):
148
+ print("textbox: ", textbox)
149
+ if state is None:
150
+ state = gr.State()
151
+ state.messages = [[("assistant", "")]]
152
+ state.text = textbox
153
+ history=""
154
+ state.append_message(state.roles[0], textbox)#
155
+ state.append_message(state.roles[1], "")
156
+ yield (state, None, history) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
157
+
158
+ block_css = """
159
+ #buttons button {
160
+ min-width: min(120px,100%);
161
+ }
162
+ """
163
+ # ================================================================================================================================
164
+
165
+ with gr.Blocks(title="CuMo", theme=theme, css=block_css) as demo:
166
+ state = gr.State()
167
+ gr.Markdown("Retrieval Augmented Generation (RAG) Chatbot" )
168
+ with gr.Row():
169
+ with gr.Column(scale=8):
170
+ chatbot = gr.Chatbot(
171
+ elem_id="chatbot",
172
+ label="Retrieval Augmented Generation (RAG) Chatbot",
173
+ height=400,
174
+ layout="bubble",
175
+
176
+ )
177
+ with gr.Row():
178
+ with gr.Column(scale=8):
179
+ textbox.render()
180
+ with gr.Column(scale=1, min_width=100):
181
+ submit_btn = gr.Button(value="Submit", variant="primary" )
182
+ with gr.Row(elem_id="buttons") as button_row:
183
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
184
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
185
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
186
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
187
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
188
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
189
+
190
+ with gr.Column(scale=3):
191
+ gr.Examples(examples=[
192
+ [f"Tell me about the latest news in the world ?"],
193
+ [f"Tell me about the increase in the price of Bitcoin ?"],
194
+ [f"Tell me about the actual situation in Ukraine ?"],
195
+ [f"Tell me about current situation in palestinian ?"],
196
+ ],inputs=[textbox], label="Examples")
197
+ with gr.Accordion("Parameters", open=False) as parameter_row:
198
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
199
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
200
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
201
+
202
+ # ================================================================================================================================
203
+
204
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
205
+ upvote_btn.click(
206
+ upvote_last_response,
207
+ [state],
208
+ [textbox, upvote_btn, downvote_btn, flag_btn]
209
+ )
210
+ downvote_btn.click(
211
+ downvote_last_response,
212
+ [state],
213
+ [textbox, upvote_btn, downvote_btn, flag_btn]
214
+ )
215
+ flag_btn.click(
216
+ flag_last_response,
217
+ [state],
218
+ [textbox, upvote_btn, downvote_btn, flag_btn]
219
+ )
220
+
221
+
222
+ textbox.submit(
223
+ add_text,
224
+ [state, textbox],
225
+ [state, chatbot, textbox] + btn_list,
226
+ ).then(
227
+ chat,
228
+ [state, textbox,max_output_tokens, temperature, top_p],
229
+ [state, chatbot, textbox] + btn_list,
230
+ )
231
+
232
+ submit_btn.click(
233
+ add_text,
234
+ [state , textbox],
235
+ [state,chatbot, textbox] + btn_list,
236
+ ).then(
237
  chat,
238
+ [state, textbox, max_output_tokens , temperature, top_p ],
239
+ [state,chatbot, textbox] + btn_list,
 
 
 
 
 
 
 
 
 
 
240
  )
241
+ # ================================================================================================================================
 
242
  demo.launch()
243
+ # ================================================================================================================================
 
example.py DELETED
@@ -1,102 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import (
4
- AutoModelForCausalLM,
5
- AutoTokenizer,
6
- TextIteratorStreamer,
7
- BitsAndBytesConfig,
8
- )
9
- import os
10
- from threading import Thread
11
- import spaces
12
- import time
13
-
14
- token = os.environ["HF_TOKEN"]
15
-
16
- quantization_config = BitsAndBytesConfig(
17
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
18
- )
19
-
20
- model = AutoModelForCausalLM.from_pretrained(
21
- "NousResearch/Hermes-2-Pro-Llama-3-8B", quantization_config=quantization_config, token=token
22
- )
23
- tok = AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Llama-3-8B", token=token)
24
- terminators = [
25
- tok.eos_token_id,
26
- tok.convert_tokens_to_ids("<|eot_id|>")
27
- ]
28
-
29
- if torch.cuda.is_available():
30
- device = torch.device("cuda")
31
- print(f"Using GPU: {torch.cuda.get_device_name(device)}")
32
- else:
33
- device = torch.device("cpu")
34
- print("Using CPU")
35
-
36
- # model = model.to(device)
37
- # Dispatch Errors
38
-
39
-
40
- @spaces.GPU(duration=150)
41
- def chat(message, history, temperature,do_sample, max_tokens):
42
- chat = []
43
- for item in history:
44
- chat.append({"role": "user", "content": item[0]})
45
- if item[1] is not None:
46
- chat.append({"role": "assistant", "content": item[1]})
47
- chat.append({"role": "user", "content": message})
48
- messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
49
- model_inputs = tok([messages], return_tensors="pt").to(device)
50
- streamer = TextIteratorStreamer(
51
- tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
52
- )
53
- generate_kwargs = dict(
54
- model_inputs,
55
- streamer=streamer,
56
- max_new_tokens=max_tokens,
57
- do_sample=True,
58
- temperature=temperature,
59
- eos_token_id=terminators,
60
- )
61
-
62
- if temperature == 0:
63
- generate_kwargs['do_sample'] = False
64
-
65
- t = Thread(target=model.generate, kwargs=generate_kwargs)
66
- t.start()
67
-
68
- partial_text = ""
69
- for new_text in streamer:
70
- partial_text += new_text
71
- yield partial_text
72
-
73
- tokens = len(tok.tokenize(partial_text))
74
- yield partial_text
75
-
76
-
77
- demo = gr.ChatInterface(
78
- fn=chat,
79
- examples=[["Write me a poem about Machine Learning."]],
80
- # multimodal=False,
81
- additional_inputs_accordion=gr.Accordion(
82
- label="⚙️ Parameters", open=False, render=False
83
- ),
84
- additional_inputs=[
85
- gr.Slider(
86
- minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
87
- ),
88
- gr.Checkbox(label="Sampling",value=True),
89
- gr.Slider(
90
- minimum=128,
91
- maximum=4096,
92
- step=1,
93
- value=512,
94
- label="Max new tokens",
95
- render=False,
96
- ),
97
- ],
98
- stop_btn="Stop Generation",
99
- title="Chat With LLMs",
100
- description="Now Running [NousResearch/Hermes-2-Pro-Llama-3-8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B) in 4bit"
101
- )
102
- demo.launch()