eaglesarezzo commited on
Commit
e67fd61
·
verified ·
1 Parent(s): 4652799

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -24,6 +24,15 @@ backend = Backend()
24
 
25
  cv2.setNumThreads(1)
26
 
 
 
 
 
 
 
 
 
 
27
  @spaces.GPU(duration=20)
28
  def respond(
29
  message,
@@ -35,28 +44,21 @@ def respond(
35
  top_p,
36
  top_k,
37
  repeat_penalty,
38
- selected_topic
39
  ):
40
  chat_template = MessagesFormatterType.GEMMA_2
41
 
42
- print("HISTORY SO FAR ", history)
43
- print("Selected topic:", selected_topic)
44
 
45
- if selected_topic:
46
- query_engine = backend.create_index_for_query_engine(documents_paths[selected_topic])
47
- message = backend.generate_prompt(query_engine, message)
48
- gr.Info(f"Relevant context indexed from {selected_topic} docs...")
49
- else:
50
- query_engine = backend.load_index_for_query_engine()
51
- message = backend.generate_prompt(query_engine, message)
52
- gr.Info("Relevant context extracted from db...")
53
 
54
  # Load model only if it's not already loaded or if a new model is selected
55
  if backend.llm is None or backend.llm_model != model:
56
  try:
57
  backend.load_model(model)
58
  except Exception as e:
59
- return f"Error loading model: {str(e)}"
60
 
61
  provider = LlamaCppPythonProvider(backend.llm)
62
 
@@ -84,7 +86,7 @@ def respond(
84
 
85
  try:
86
  stream = agent.get_chat_response(
87
- message,
88
  llm_sampling_settings=settings,
89
  chat_history=messages,
90
  returns_streaming_generator=True,
@@ -99,7 +101,16 @@ def respond(
99
  yield history + [[message, f"Error during response generation: {str(e)}"]]
100
 
101
  def select_topic(topic):
102
- return gr.update(visible=True), topic, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
103
 
104
  with gr.Blocks(css="""
105
  .gradio-container {
@@ -114,7 +125,7 @@ with gr.Blocks(css="""
114
  metaverse_btn = gr.Button("🌐 Metaverse", scale=1)
115
  payment_btn = gr.Button("💳 Payment", scale=1)
116
 
117
- selected_topic = gr.State(value="")
118
 
119
  chatbot = gr.Chatbot(
120
  scale=1,
@@ -129,8 +140,9 @@ with gr.Blocks(css="""
129
  show_label=False,
130
  placeholder="Inserisci il tuo messaggio...",
131
  container=False,
 
132
  )
133
- submit_btn = gr.Button("Invia", scale=1)
134
 
135
  with gr.Accordion("Advanced Options", open=False):
136
  model = gr.Dropdown([
@@ -156,19 +168,19 @@ with gr.Blocks(css="""
156
  top_k = gr.Slider(minimum=0, maximum=100, value=30, step=1, label="Top-k")
157
  repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty")
158
 
159
- blockchain_btn.click(lambda: select_topic("blockchain"), inputs=None, outputs=[chatbot, selected_topic, blockchain_btn, metaverse_btn, payment_btn])
160
- metaverse_btn.click(lambda: select_topic("metaverse"), inputs=None, outputs=[chatbot, selected_topic, blockchain_btn, metaverse_btn, payment_btn])
161
- payment_btn.click(lambda: select_topic("payment"), inputs=None, outputs=[chatbot, selected_topic, blockchain_btn, metaverse_btn, payment_btn])
162
 
163
  submit_btn.click(
164
  respond,
165
- inputs=[msg, chatbot, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, selected_topic],
166
  outputs=chatbot
167
  )
168
 
169
  msg.submit(
170
  respond,
171
- inputs=[msg, chatbot, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, selected_topic],
172
  outputs=chatbot
173
  )
174
 
 
24
 
25
  cv2.setNumThreads(1)
26
 
27
+ def load_topic_data(topic):
28
+ if topic:
29
+ query_engine = backend.create_index_for_query_engine(documents_paths[topic])
30
+ gr.Info(f"Data loaded for {topic} topic")
31
+ return query_engine
32
+ else:
33
+ gr.Warning("No topic selected. Please select a topic first.")
34
+ return None
35
+
36
  @spaces.GPU(duration=20)
37
  def respond(
38
  message,
 
44
  top_p,
45
  top_k,
46
  repeat_penalty,
47
+ query_engine
48
  ):
49
  chat_template = MessagesFormatterType.GEMMA_2
50
 
51
+ if query_engine is None:
52
+ return history + [[message, "Please select a topic before asking questions."]]
53
 
54
+ full_prompt = backend.generate_prompt(query_engine, message)
 
 
 
 
 
 
 
55
 
56
  # Load model only if it's not already loaded or if a new model is selected
57
  if backend.llm is None or backend.llm_model != model:
58
  try:
59
  backend.load_model(model)
60
  except Exception as e:
61
+ return history + [[message, f"Error loading model: {str(e)}"]]
62
 
63
  provider = LlamaCppPythonProvider(backend.llm)
64
 
 
86
 
87
  try:
88
  stream = agent.get_chat_response(
89
+ full_prompt,
90
  llm_sampling_settings=settings,
91
  chat_history=messages,
92
  returns_streaming_generator=True,
 
101
  yield history + [[message, f"Error during response generation: {str(e)}"]]
102
 
103
  def select_topic(topic):
104
+ query_engine = load_topic_data(topic)
105
+ return (
106
+ gr.update(interactive=True), # Enable the chat input
107
+ gr.update(interactive=True), # Enable the submit button
108
+ gr.update(visible=True), # Make the chatbot visible
109
+ gr.update(interactive=False), # Disable blockchain button
110
+ gr.update(interactive=False), # Disable metaverse button
111
+ gr.update(interactive=False), # Disable payment button
112
+ query_engine # Return the loaded query engine
113
+ )
114
 
115
  with gr.Blocks(css="""
116
  .gradio-container {
 
125
  metaverse_btn = gr.Button("🌐 Metaverse", scale=1)
126
  payment_btn = gr.Button("💳 Payment", scale=1)
127
 
128
+ query_engine = gr.State(None)
129
 
130
  chatbot = gr.Chatbot(
131
  scale=1,
 
140
  show_label=False,
141
  placeholder="Inserisci il tuo messaggio...",
142
  container=False,
143
+ interactive=False
144
  )
145
+ submit_btn = gr.Button("Invia", scale=1, interactive=False)
146
 
147
  with gr.Accordion("Advanced Options", open=False):
148
  model = gr.Dropdown([
 
168
  top_k = gr.Slider(minimum=0, maximum=100, value=30, step=1, label="Top-k")
169
  repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty")
170
 
171
+ blockchain_btn.click(lambda: select_topic("blockchain"), inputs=None, outputs=[msg, submit_btn, chatbot, blockchain_btn, metaverse_btn, payment_btn, query_engine])
172
+ metaverse_btn.click(lambda: select_topic("metaverse"), inputs=None, outputs=[msg, submit_btn, chatbot, blockchain_btn, metaverse_btn, payment_btn, query_engine])
173
+ payment_btn.click(lambda: select_topic("payment"), inputs=None, outputs=[msg, submit_btn, chatbot, blockchain_btn, metaverse_btn, payment_btn, query_engine])
174
 
175
  submit_btn.click(
176
  respond,
177
+ inputs=[msg, chatbot, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, query_engine],
178
  outputs=chatbot
179
  )
180
 
181
  msg.submit(
182
  respond,
183
+ inputs=[msg, chatbot, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, query_engine],
184
  outputs=chatbot
185
  )
186