Spaces:
Build error
Build error
Fix truncation
Browse files- app.py +5 -5
- fromage/models.py +3 -2
- fromage/utils.py +2 -1
- share_btn.py +22 -2
app.py
CHANGED
@@ -102,10 +102,10 @@ def generate_for_prompt(input_text, state, ret_scale_factor, max_nm_rets, num_wo
|
|
102 |
elif type(output) == list:
|
103 |
for image in output:
|
104 |
filename = save_image_to_local(image)
|
105 |
-
response += f'<
|
106 |
elif type(output) == Image.Image:
|
107 |
filename = save_image_to_local(output)
|
108 |
-
response += f'<
|
109 |
|
110 |
# TODO(jykoh): Persist image inputs.
|
111 |
chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
@@ -150,7 +150,7 @@ with gr.Blocks(css=css) as demo:
|
|
150 |
with gr.Column(scale=0.33):
|
151 |
clear_last_btn = gr.Button("Clear Last Round")
|
152 |
with gr.Column(scale=0.33):
|
153 |
-
clear_btn = gr.Button("Clear
|
154 |
|
155 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
156 |
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
@@ -163,5 +163,5 @@ with gr.Blocks(css=css) as demo:
|
|
163 |
share_button.click(None, [], [], _js=share_js)
|
164 |
|
165 |
|
166 |
-
|
167 |
-
demo.launch(debug=True, server_name="
|
|
|
102 |
elif type(output) == list:
|
103 |
for image in output:
|
104 |
filename = save_image_to_local(image)
|
105 |
+
response += f'<img src="/file={filename}">'
|
106 |
elif type(output) == Image.Image:
|
107 |
filename = save_image_to_local(output)
|
108 |
+
response += f'<img src="/file={filename}">'
|
109 |
|
110 |
# TODO(jykoh): Persist image inputs.
|
111 |
chat_history = model_inputs + [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
|
|
150 |
with gr.Column(scale=0.33):
|
151 |
clear_last_btn = gr.Button("Clear Last Round")
|
152 |
with gr.Column(scale=0.33):
|
153 |
+
clear_btn = gr.Button("Clear All")
|
154 |
|
155 |
text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
156 |
text_input.submit(lambda: "", None, text_input) # Reset chatbox.
|
|
|
163 |
share_button.click(None, [], [], _js=share_js)
|
164 |
|
165 |
|
166 |
+
demo.queue(concurrency_count=1, api_open=False, max_size=16)
|
167 |
+
demo.launch(debug=True, server_name="0.0.0.0")
|
fromage/models.py
CHANGED
@@ -540,6 +540,7 @@ class Fromage(nn.Module):
|
|
540 |
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
|
541 |
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
|
542 |
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
|
|
|
543 |
|
544 |
# Truncate to newline.
|
545 |
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
|
@@ -554,7 +555,7 @@ class Fromage(nn.Module):
|
|
554 |
else:
|
555 |
raise ValueError
|
556 |
|
557 |
-
print('L557
|
558 |
# Save outputs as an interleaved list.
|
559 |
return_outputs = []
|
560 |
# Find up to max_num_rets [RET] tokens, and their corresponding scores.
|
@@ -635,7 +636,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
|
638 |
-
debug =
|
639 |
if debug:
|
640 |
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
641 |
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
|
|
540 |
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
|
541 |
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
|
542 |
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
|
543 |
+
print('L543 generated_ids', generated_ids)
|
544 |
|
545 |
# Truncate to newline.
|
546 |
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
|
|
|
555 |
else:
|
556 |
raise ValueError
|
557 |
|
558 |
+
print('L557 generated_ids', generated_ids)
|
559 |
# Save outputs as an interleaved list.
|
560 |
return_outputs = []
|
561 |
# Find up to max_num_rets [RET] tokens, and their corresponding scores.
|
|
|
636 |
assert len(ret_token_idx) == 1, ret_token_idx
|
637 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
638 |
|
639 |
+
debug = False
|
640 |
if debug:
|
641 |
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
642 |
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
fromage/utils.py
CHANGED
@@ -35,7 +35,8 @@ def truncate_caption(caption: str) -> str:
|
|
35 |
trunc_index = caption.find('\n') + 1
|
36 |
if trunc_index <= 0:
|
37 |
trunc_index = caption.find('.') + 1
|
38 |
-
|
|
|
39 |
return caption
|
40 |
|
41 |
|
|
|
35 |
trunc_index = caption.find('\n') + 1
|
36 |
if trunc_index <= 0:
|
37 |
trunc_index = caption.find('.') + 1
|
38 |
+
if trunc_index > 0:
|
39 |
+
caption = caption[:trunc_index]
|
40 |
return caption
|
41 |
|
42 |
|
share_btn.py
CHANGED
@@ -10,7 +10,8 @@ loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" cl
|
|
10 |
"
|
11 |
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
12 |
|
13 |
-
share_js = """
|
|
|
14 |
async function uploadFile(file){
|
15 |
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
16 |
const response = await fetch(UPLOAD_URL, {
|
@@ -25,8 +26,27 @@ share_js = """async () => {
|
|
25 |
return url;
|
26 |
}
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
29 |
-
const
|
|
|
30 |
const inputPromptEl = "test test";
|
31 |
//const outputChatbot = await getInputVideoFile(outputVideoEl);
|
32 |
//const urlChatbotImage = await uploadFile(outputChatbot);
|
|
|
10 |
"
|
11 |
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
12 |
|
13 |
+
share_js = """
|
14 |
+
async () => {
|
15 |
async function uploadFile(file){
|
16 |
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
17 |
const response = await fetch(UPLOAD_URL, {
|
|
|
26 |
return url;
|
27 |
}
|
28 |
|
29 |
+
//Creating dynamic link that automatically click
|
30 |
+
function downloadURI(uri, name) {
|
31 |
+
var link = document.createElement("a");
|
32 |
+
link.download = name;
|
33 |
+
link.href = uri;
|
34 |
+
link.click();
|
35 |
+
}
|
36 |
+
|
37 |
+
//Your modified code.
|
38 |
+
function printToFile(div) {
|
39 |
+
html2canvas(div, {
|
40 |
+
onrendered: function (canvas) {
|
41 |
+
var myImage = canvas.toDataURL("image/png");
|
42 |
+
downloadURI("data:" + myImage, "fromage_chat.png");
|
43 |
+
}
|
44 |
+
});
|
45 |
+
}
|
46 |
+
|
47 |
const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
|
48 |
+
const chatbotEl = gradioEl.querySelector('#chatbot')
|
49 |
+
printToFile(chatbotEl)
|
50 |
const inputPromptEl = "test test";
|
51 |
//const outputChatbot = await getInputVideoFile(outputVideoEl);
|
52 |
//const urlChatbotImage = await uploadFile(outputChatbot);
|