jykoh commited on
Commit
5ffd817
Β·
1 Parent(s): 3d6dac6

Made outputs deterministic

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -74,7 +74,7 @@ def save_image_to_local(image: Image.Image):
74
  return filename
75
 
76
 
77
- def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
78
  # Ignore empty inputs.
79
  if len(input_text) == 0:
80
  return state, state[0], gr.update(visible=True)
@@ -98,7 +98,7 @@ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperat
98
  model_outputs = model.generate_for_images_and_texts(model_inputs,
99
  num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
100
  temperature=temperature, max_num_rets=1,
101
- num_inference_steps=50)
102
  print('model_outputs', model_outputs, ret_scale_factor, flush=True)
103
 
104
  response = ''
@@ -157,6 +157,7 @@ with gr.Blocks(css=css) as demo:
157
  """)
158
 
159
  gr_state = gr.State([[], []]) # conversation, chat_history
 
160
 
161
  with gr.Row():
162
  with gr.Column(scale=0.7, min_width=500):
@@ -193,10 +194,10 @@ with gr.Blocks(css=css) as demo:
193
  ).style(grid=[2], height="auto")
194
 
195
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
196
- gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
197
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
198
  submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
199
- gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
200
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
201
 
202
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
 
74
  return filename
75
 
76
 
77
+ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature, generator):
78
  # Ignore empty inputs.
79
  if len(input_text) == 0:
80
  return state, state[0], gr.update(visible=True)
 
98
  model_outputs = model.generate_for_images_and_texts(model_inputs,
99
  num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
100
  temperature=temperature, max_num_rets=1,
101
+ num_inference_steps=50, generator=generator)
102
  print('model_outputs', model_outputs, ret_scale_factor, flush=True)
103
 
104
  response = ''
 
157
  """)
158
 
159
  gr_state = gr.State([[], []]) # conversation, chat_history
160
+ g_cuda = torch.Generator(device='cuda').manual_seed(1337)
161
 
162
  with gr.Row():
163
  with gr.Column(scale=0.7, min_width=500):
 
194
  ).style(grid=[2], height="auto")
195
 
196
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
197
+ gr_max_len, gr_temperature, g_cuda], [gr_state, chatbot, share_group, save_group])
198
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
199
  submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
200
+ gr_max_len, gr_temperature, g_cuda], [gr_state, chatbot, share_group, save_group])
201
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
202
 
203
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])