xxyyy123 commited on
Commit
648c219
·
verified ·
1 Parent(s): 5f19c16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -53,7 +53,7 @@ def ovis_chat(chatbot, image_input):
53
  "value": text_input
54
  })
55
  if image_input is not None:
56
- conversations[0]["value"] = image_placeholder + '\n' + conversations[0]["value"]
57
  prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
58
  attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
59
  input_ids = input_ids.unsqueeze(0).to(device=model.device)
@@ -76,18 +76,25 @@ def ovis_chat(chatbot, image_input):
76
  use_cache=True
77
  )
78
  response = ""
79
- thread = Thread(target=model.generate,
80
- kwargs={"inputs": input_ids,
81
- "pixel_values": pixel_values,
82
- "attention_mask": attention_mask,
83
- "streamer": streamer,
84
- **gen_kwargs})
85
- thread.start()
 
 
 
 
 
 
 
86
  for new_text in streamer:
87
  response += new_text
88
  chatbot[-1][1] = response
89
  yield chatbot
90
- thread.join()
91
  # debug
92
  print('*'*60)
93
  print('*'*60)
 
53
  "value": text_input
54
  })
55
  if image_input is not None:
56
+ conversations[1]["value"] = image_placeholder + '\n' + conversations[1]["value"]
57
  prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
58
  attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
59
  input_ids = input_ids.unsqueeze(0).to(device=model.device)
 
76
  use_cache=True
77
  )
78
  response = ""
79
+ # thread = Thread(target=model.generate,
80
+ # kwargs={"inputs": input_ids,
81
+ # "pixel_values": pixel_values,
82
+ # "attention_mask": attention_mask,
83
+ # "streamer": streamer,
84
+ # **gen_kwargs})
85
+ model.generate(
86
+ input_ids,
87
+ pixel_values=pixel_values,
88
+ attention_mask=attention_mask,
89
+ streamer=streamer,
90
+ **gen_kwargs
91
+ )
92
+ # thread.start()
93
  for new_text in streamer:
94
  response += new_text
95
  chatbot[-1][1] = response
96
  yield chatbot
97
+ # thread.join()
98
  # debug
99
  print('*'*60)
100
  print('*'*60)