philipp-zettl commited on
Commit
f0e697b
1 Parent(s): f8a8106

Attempt adding seed + make optimization optional

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -185,9 +185,9 @@ def find_best_parameters(eval_data, model, tokenizer, max_length=85):
185
 
186
 
187
 
188
- def run_model(inputs, tokenizer, model, num_beams=2, num_beam_groups=2, temperature=0.5, num_return_sequences=1, max_length=85):
189
  all_outputs = []
190
- torch.manual_seed(42069)
191
  for input_text in inputs:
192
  model_inputs = tokenizer([input_text], max_length=512, padding=True, truncation=True)
193
  input_ids = torch.tensor(model_inputs['input_ids']).to(device)
@@ -232,7 +232,7 @@ def run_model(inputs, tokenizer, model, num_beams=2, num_beam_groups=2, temperat
232
 
233
 
234
  @spaces.GPU
235
- def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_qg=1, num_return_sequences_qa=1, max_length=85):
236
  inputs = [
237
  f'context: {content}'
238
  ]
@@ -244,21 +244,24 @@ def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_q
244
  num_beam_groups=num_return_sequences_qg,
245
  temperature=temperature_qg,
246
  num_return_sequences=num_return_sequences_qg,
247
- max_length=max_length
 
248
  )
249
 
250
- q_params = find_best_parameters(list(chain.from_iterable(question)), qg_model, tokenizer, max_length=max_length)
251
-
252
- question = run_model(
253
- inputs,
254
- tokenizer,
255
- qg_model,
256
- num_beams=q_params[0],
257
- num_beam_groups=q_params[1],
258
- temperature=temperature_qg,
259
- num_return_sequences=num_return_sequences_qg,
260
- max_length=max_length
261
- )
 
 
262
 
263
  inputs = list(chain.from_iterable([
264
  [f'question: {q} context: {content}' for q in q_set] for q_set in question
@@ -271,7 +274,8 @@ def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_q
271
  num_beam_groups=num_return_sequences_qa,
272
  temperature=temperature_qa,
273
  num_return_sequences=num_return_sequences_qa,
274
- max_length=max_length
 
275
  )
276
 
277
  questions = list(chain.from_iterable(question))
@@ -338,6 +342,8 @@ with gr.Blocks(css='.hidden_input {display: none;}') as demo:
338
  max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
339
  num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
340
  num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
 
 
341
 
342
  with gr.Row():
343
  gen_btn = gr.Button("Generate")
@@ -345,14 +351,14 @@ with gr.Blocks(css='.hidden_input {display: none;}') as demo:
345
  @gr.render(
346
  inputs=[
347
  content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
348
- max_length
349
  ],
350
  triggers=[gen_btn.click]
351
  )
352
  def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length):
353
  qnas = gen(
354
  content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
355
- max_length
356
  )
357
  df = gr.Dataframe(
358
  value=[u.values() for u in qnas],
 
185
 
186
 
187
 
188
+ def run_model(inputs, tokenizer, model, num_beams=2, num_beam_groups=2, temperature=0.5, num_return_sequences=1, max_length=85, seed=42069):
189
  all_outputs = []
190
+ torch.manual_seed(seed)
191
  for input_text in inputs:
192
  model_inputs = tokenizer([input_text], max_length=512, padding=True, truncation=True)
193
  input_ids = torch.tensor(model_inputs['input_ids']).to(device)
 
232
 
233
 
234
  @spaces.GPU
235
+ def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_qg=1, num_return_sequences_qa=1, max_length=85, seed=42069, optimize_questions=False):
236
  inputs = [
237
  f'context: {content}'
238
  ]
 
244
  num_beam_groups=num_return_sequences_qg,
245
  temperature=temperature_qg,
246
  num_return_sequences=num_return_sequences_qg,
247
+ max_length=max_length,
248
+ seed=seed
249
  )
250
 
251
+ if optimize_questions:
252
+ q_params = find_best_parameters(list(chain.from_iterable(question)), qg_model, tokenizer, max_length=max_length)
253
+
254
+ question = run_model(
255
+ inputs,
256
+ tokenizer,
257
+ qg_model,
258
+ num_beams=q_params[0],
259
+ num_beam_groups=q_params[1],
260
+ temperature=temperature_qg,
261
+ num_return_sequences=num_return_sequences_qg,
262
+ max_length=max_length,
263
+ seed=seed
264
+ )
265
 
266
  inputs = list(chain.from_iterable([
267
  [f'question: {q} context: {content}' for q in q_set] for q_set in question
 
274
  num_beam_groups=num_return_sequences_qa,
275
  temperature=temperature_qa,
276
  num_return_sequences=num_return_sequences_qa,
277
+ max_length=max_length,
278
+ seed=seed
279
  )
280
 
281
  questions = list(chain.from_iterable(question))
 
342
  max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
343
  num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
344
  num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
345
+ seed = gr.Number(label="seed", value=42069)
346
+ optimize_questions = gr.Checkbox(label="Optimize questions?", value=False)
347
 
348
  with gr.Row():
349
  gen_btn = gr.Button("Generate")
 
351
  @gr.render(
352
  inputs=[
353
  content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
354
+ max_length, seed, optimize_questions
355
  ],
356
  triggers=[gen_btn.click]
357
  )
358
  def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length):
359
  qnas = gen(
360
  content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
361
+ max_length, seed, optimize_questions
362
  )
363
  df = gr.Dataframe(
364
  value=[u.values() for u in qnas],