jykoh commited on
Commit
0b63ed7
·
1 Parent(s): 55e476e

Fix truncation

Browse files
Files changed (4) hide show
  1. app.py +5 -5
  2. fromage/models.py +3 -2
  3. fromage/utils.py +2 -1
  4. share_btn.py +22 -2
app.py CHANGED
@@ -102,10 +102,10 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
102
  elif type(output) == list:
103
  for image in output:
104
  filename = save_image_to_local(image)
105
- response += f'<br/><img src="/file={filename}">'
106
  elif type(output) == Image.Image:
107
  filename = save_image_to_local(output)
108
- response += f'<br/><img src="/file={filename}">'
109
 
110
  # TODO(jykoh): Persist image inputs.
111
  chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
@@ -150,7 +150,7 @@ with gr.Blocks(css=css) as demo:
150
  with gr.Column(scale=0.33):
151
  clear_last_btn = gr.Button("Clear Last Round")
152
  with gr.Column(scale=0.33):
153
- clear_btn = gr.Button("Clear History")
154
 
155
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
156
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
@@ -163,5 +163,5 @@ with gr.Blocks(css=css) as demo:
163
  share_button.click(None, [], [], _js=share_js)
164
 
165
 
166
- # demo.queue(concurrency_count=1, api_open=False, max_size=16)
167
- demo.launch(debug=True, server_name="127.0.0.1")
 
102
  elif type(output) == list:
103
  for image in output:
104
  filename = save_image_to_local(image)
105
+ response += f'<img src="/file={filename}">'
106
  elif type(output) == Image.Image:
107
  filename = save_image_to_local(output)
108
+ response += f'<img src="/file={filename}">'
109
 
110
  # TODO(jykoh): Persist image inputs.
111
  chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
 
150
  with gr.Column(scale=0.33):
151
  clear_last_btn = gr.Button("Clear Last Round")
152
  with gr.Column(scale=0.33):
153
+ clear_btn = gr.Button("Clear All")
154
 
155
  text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
156
  text_input.submit(lambda: "", None, text_input) # Reset chatbox.
 
163
  share_button.click(None, [], [], _js=share_js)
164
 
165
 
166
+ demo.queue(concurrency_count=1, api_open=False, max_size=16)
167
+ demo.launch(debug=True, server_name="0.0.0.0")
fromage/models.py CHANGED
@@ -540,6 +540,7 @@ class Fromage(nn.Module):
540
  generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
541
  temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
542
  embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
 
543
 
544
  # Truncate to newline.
545
  newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
@@ -554,7 +555,7 @@ class Fromage(nn.Module):
554
  else:
555
  raise ValueError
556
 
557
- print('L557 called')
558
  # Save outputs as an interleaved list.
559
  return_outputs = []
560
  # Find up to max_num_rets [RET] tokens, and their corresponding scores.
@@ -635,7 +636,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
 
638
- debug = True
639
  if debug:
640
  model_kwargs['opt_version'] = 'facebook/opt-125m'
641
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
 
540
  generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
541
  temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
542
  embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
543
+ print('L543 generated_ids', generated_ids)
544
 
545
  # Truncate to newline.
546
  newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
 
555
  else:
556
  raise ValueError
557
 
558
+ print('L557 generated_ids', generated_ids)
559
  # Save outputs as an interleaved list.
560
  return_outputs = []
561
  # Find up to max_num_rets [RET] tokens, and their corresponding scores.
 
636
  assert len(ret_token_idx) == 1, ret_token_idx
637
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
638
 
639
+ debug = False
640
  if debug:
641
  model_kwargs['opt_version'] = 'facebook/opt-125m'
642
  model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
fromage/utils.py CHANGED
@@ -35,7 +35,8 @@ def truncate_caption(caption: str) -> str:
35
  trunc_index = caption.find('\n') + 1
36
  if trunc_index <= 0:
37
  trunc_index = caption.find('.') + 1
38
- caption = caption[:trunc_index]
 
39
  return caption
40
 
41
 
 
35
  trunc_index = caption.find('\n') + 1
36
  if trunc_index <= 0:
37
  trunc_index = caption.find('.') + 1
38
+ if trunc_index > 0:
39
+ caption = caption[:trunc_index]
40
  return caption
41
 
42
 
share_btn.py CHANGED
@@ -10,7 +10,8 @@ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" cl
10
  "
11
  xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
12
 
13
- share_js = """async () => {
 
14
  async function uploadFile(file){
15
  const UPLOAD_URL = 'https://huggingface.co/uploads';
16
  const response = await fetch(UPLOAD_URL, {
@@ -25,8 +26,27 @@ share_js = """async () => {
25
  return url;
26
  }
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
29
- const outputChat = gradioEl.querySelector('#chatbot').value;
 
30
  const inputPromptEl = "test test";
31
  //const outputChatbot = await getInputVideoFile(outputVideoEl);
32
  //const urlChatbotImage = await uploadFile(outputChatbot);
 
10
  "
11
  xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
12
 
13
+ share_js = """
14
+ async () => {
15
  async function uploadFile(file){
16
  const UPLOAD_URL = 'https://huggingface.co/uploads';
17
  const response = await fetch(UPLOAD_URL, {
 
26
  return url;
27
  }
28
 
29
+ //Creating dynamic link that automatically click
30
+ function downloadURI(uri, name) {
31
+ var link = document.createElement("a");
32
+ link.download = name;
33
+ link.href = uri;
34
+ link.click();
35
+ }
36
+
37
+ //Your modified code.
38
+ function printToFile(div) {
39
+ html2canvas(div, {
40
+ onrendered: function (canvas) {
41
+ var myImage = canvas.toDataURL("image/png");
42
+ downloadURI("data:" + myImage, "fromage_chat.png");
43
+ }
44
+ });
45
+ }
46
+
47
  const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
48
+ const chatbotEl = gradioEl.querySelector('#chatbot')
49
+ printToFile(chatbotEl)
50
  const inputPromptEl = "test test";
51
  //const outputChatbot = await getInputVideoFile(outputVideoEl);
52
  //const urlChatbotImage = await uploadFile(outputChatbot);