jykoh commited on
Commit
d904dbe
Β·
1 Parent(s): 5ffd817

Fix gradio bug

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -45,6 +45,7 @@ decision_model_path = huggingface_hub.hf_hub_download(
45
  args_path = huggingface_hub.hf_hub_download(
46
  repo_id='jykoh/gill', filename='model_args.json')
47
  model = models.load_gill('./', args_path, ckpt_path, decision_model_path)
 
48
 
49
 
50
  def upload_image(state, image_input):
@@ -74,7 +75,7 @@ def save_image_to_local(image: Image.Image):
74
  return filename
75
 
76
 
77
- def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature, generator):
78
  # Ignore empty inputs.
79
  if len(input_text) == 0:
80
  return state, state[0], gr.update(visible=True)
@@ -98,7 +99,7 @@ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperat
98
  model_outputs = model.generate_for_images_and_texts(model_inputs,
99
  num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
100
  temperature=temperature, max_num_rets=1,
101
- num_inference_steps=50, generator=generator)
102
  print('model_outputs', model_outputs, ret_scale_factor, flush=True)
103
 
104
  response = ''
@@ -157,7 +158,6 @@ with gr.Blocks(css=css) as demo:
157
  """)
158
 
159
  gr_state = gr.State([[], []]) # conversation, chat_history
160
- g_cuda = torch.Generator(device='cuda').manual_seed(1337)
161
 
162
  with gr.Row():
163
  with gr.Column(scale=0.7, min_width=500):
@@ -194,10 +194,10 @@ with gr.Blocks(css=css) as demo:
194
  ).style(grid=[2], height="auto")
195
 
196
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
197
- gr_max_len, gr_temperature, g_cuda], [gr_state, chatbot, share_group, save_group])
198
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
199
  submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
200
- gr_max_len, gr_temperature, g_cuda], [gr_state, chatbot, share_group, save_group])
201
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
202
 
203
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
 
45
  args_path = huggingface_hub.hf_hub_download(
46
  repo_id='jykoh/gill', filename='model_args.json')
47
  model = models.load_gill('./', args_path, ckpt_path, decision_model_path)
48
+ g_cuda = torch.Generator(device='cuda').manual_seed(1337)
49
 
50
 
51
  def upload_image(state, image_input):
 
75
  return filename
76
 
77
 
78
+ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
79
  # Ignore empty inputs.
80
  if len(input_text) == 0:
81
  return state, state[0], gr.update(visible=True)
 
99
  model_outputs = model.generate_for_images_and_texts(model_inputs,
100
  num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
101
  temperature=temperature, max_num_rets=1,
102
+ num_inference_steps=50, generator=g_cuda)
103
  print('model_outputs', model_outputs, ret_scale_factor, flush=True)
104
 
105
  response = ''
 
158
  """)
159
 
160
  gr_state = gr.State([[], []]) # conversation, chat_history
 
161
 
162
  with gr.Row():
163
  with gr.Column(scale=0.7, min_width=500):
 
194
  ).style(grid=[2], height="auto")
195
 
196
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
197
+ gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
198
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
199
  submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
200
+ gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
201
  submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
202
 
203
  image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])