StevenChen16 commited on
Commit
c39e972
·
verified ·
1 Parent(s): 6c2ef5e

update chat_llama3_8b function

Browse files
Files changed (1) hide show
  1. app.py +90 -56
app.py CHANGED
@@ -165,67 +165,101 @@ def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8)
165
 
166
  @spaces.GPU(duration=120)
167
  def chat_llama3_8b(message: str,
168
- history: list,
169
- temperature=0.6,
170
- max_new_tokens=4096
171
- ) -> str:
172
  """
173
- Generate a streaming response using the llama3-8b model.
174
- Will display citations after the response if citations are available.
175
- """
176
- # Get citations from vector store
177
- citation = query_vector_store(vector_store, message, 4, 0.7)
178
-
179
- # Build conversation history
180
- conversation = []
181
- for user, assistant in history:
182
- conversation.extend([
183
- {"role": "user", "content": user},
184
- {"role": "assistant", "content": assistant}
185
- ])
186
-
187
- # Construct the final message with background prompt and citations
188
- if citation:
189
- message = f"{background_prompt}Based on these citations: {citation}\nPlease answer question: {message}"
190
- else:
191
- message = f"{background_prompt}{message}"
192
-
193
- conversation.append({"role": "user", "content": message})
194
-
195
- # Generate response
196
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
197
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
198
-
199
- generate_kwargs = dict(
200
- input_ids=input_ids,
201
- streamer=streamer,
202
- max_new_tokens=max_new_tokens,
203
- do_sample=True,
204
- temperature=temperature,
205
- eos_token_id=terminators,
206
- )
207
 
208
- if temperature == 0:
209
- generate_kwargs['do_sample'] = False
 
 
 
210
 
211
- t = Thread(target=model.generate, kwargs=generate_kwargs)
212
- t.start()
213
-
214
- outputs = []
215
- for text in streamer:
216
- outputs.append(text)
217
- current_output = "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- # If we have citations, append them at the end
220
- if citation and text == streamer[-1]: # On the last chunk
221
- citation_display = "\n\nReferences:\n" + "\n".join(
222
- f"[{i+1}] {cite.strip()}"
223
- for i, cite in enumerate(citation.split('\n'))
224
- if cite.strip()
225
- )
226
- current_output += citation_display
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- yield current_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  # Gradio block
 
165
 
166
  @spaces.GPU(duration=120)
167
  def chat_llama3_8b(message: str,
168
+ history: list,
169
+ temperature=0.6,
170
+ max_new_tokens=4096
171
+ ) -> str:
172
  """
173
+ Generate a streaming response using the LLaMA model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ Args:
176
+ message (str): The current user message
177
+ history (list): List of previous conversation turns
178
+ temperature (float): Sampling temperature (0.0 to 1.0)
179
+ max_new_tokens (int): Maximum number of tokens to generate
180
 
181
+ Returns:
182
+ str: Generated response with citations if available
183
+ """
184
+ try:
185
+ # 1. Get relevant citations from vector store
186
+ citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
187
+
188
+ # 2. Format conversation history
189
+ conversation = []
190
+ for user, assistant in history:
191
+ conversation.extend([
192
+ {"role": "user", "content": str(user)},
193
+ {"role": "assistant", "content": str(assistant)}
194
+ ])
195
+
196
+ # 3. Construct the final prompt
197
+ final_message = ""
198
+ if citation:
199
+ final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
200
+ else:
201
+ final_message = f"{background_prompt}\n{message}"
202
+
203
+ conversation.append({"role": "user", "content": final_message})
204
+
205
+ # 4. Prepare model inputs
206
+ input_ids = tokenizer.apply_chat_template(
207
+ conversation,
208
+ return_tensors="pt"
209
+ ).to(model.device)
210
+
211
+ # 5. Setup streamer
212
+ streamer = TextIteratorStreamer(
213
+ tokenizer,
214
+ timeout=10.0,
215
+ skip_prompt=True,
216
+ skip_special_tokens=True
217
+ )
218
 
219
+ # 6. Configure generation parameters
220
+ generation_config = {
221
+ "input_ids": input_ids,
222
+ "streamer": streamer,
223
+ "max_new_tokens": max_new_tokens,
224
+ "do_sample": temperature > 0,
225
+ "temperature": temperature,
226
+ "eos_token_id": terminators
227
+ }
228
+
229
+ # 7. Generate in a separate thread
230
+ thread = Thread(target=model.generate, kwargs=generation_config)
231
+ thread.start()
232
+
233
+ # 8. Stream the output
234
+ accumulated_text = []
235
+ final_chunk = False
236
+
237
+ for text_chunk in streamer:
238
+ accumulated_text.append(text_chunk)
239
+ current_response = "".join(accumulated_text)
240
+
241
+ # Check if this is the last chunk
242
+ try:
243
+ next_chunk = next(iter(streamer))
244
+ accumulated_text.append(next_chunk)
245
+ except (StopIteration, RuntimeError):
246
+ final_chunk = True
247
 
248
+ # Add citations on the final chunk if they exist
249
+ if final_chunk and citation:
250
+ formatted_citations = "\n\nReferences:\n" + "\n".join(
251
+ f"[{i+1}] {cite.strip()}"
252
+ for i, cite in enumerate(citation.split('\n'))
253
+ if cite.strip()
254
+ )
255
+ current_response += formatted_citations
256
+
257
+ yield current_response
258
+
259
+ except Exception as e:
260
+ error_message = f"An error occurred: {str(e)}"
261
+ print(error_message) # For logging
262
+ yield error_message
263
 
264
 
265
  # Gradio block