Spaces:
Runtime error
Runtime error
Commit
·
2c8f495
1
Parent(s):
405f2d4
Add mask filling app
Browse files- app.py +8 -3
- apps/mlm.py +49 -49
- apps/utils.py +1 -0
- apps/vqa.py +44 -42
- multiapp.py +10 -3
- resize_images.py +10 -3
app.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
from apps import mlm, vqa
|
2 |
import os
|
3 |
import streamlit as st
|
|
|
4 |
from multiapp import MultiApp
|
5 |
|
|
|
6 |
def read_markdown(path, parent="./sections/"):
|
7 |
with open(os.path.join(parent, path)) as f:
|
8 |
return f.read()
|
9 |
|
|
|
10 |
def main():
|
|
|
11 |
st.set_page_config(
|
12 |
page_title="Multilingual VQA",
|
13 |
layout="wide",
|
@@ -30,7 +34,7 @@ def main():
|
|
30 |
st.write(read_markdown("abstract.md"))
|
31 |
st.write(read_markdown("caveats.md"))
|
32 |
st.write("## Methodology")
|
33 |
-
col1, col2 = st.beta_columns([1,1])
|
34 |
col1.image(
|
35 |
"./misc/article/Multilingual-VQA.png",
|
36 |
caption="Masked LM model for Image-text Pretraining.",
|
@@ -43,10 +47,11 @@ def main():
|
|
43 |
st.write(read_markdown("checkpoints.md"))
|
44 |
st.write(read_markdown("acknowledgements.md"))
|
45 |
|
46 |
-
app = MultiApp()
|
47 |
app.add_app("Visual Question Answering", vqa.app)
|
48 |
app.add_app("Mask Filling", mlm.app)
|
49 |
app.run()
|
|
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
-
main()
|
|
|
1 |
from apps import mlm, vqa
|
2 |
import os
|
3 |
import streamlit as st
|
4 |
+
from session import _get_state
|
5 |
from multiapp import MultiApp
|
6 |
|
7 |
+
|
8 |
def read_markdown(path, parent="./sections/"):
|
9 |
with open(os.path.join(parent, path)) as f:
|
10 |
return f.read()
|
11 |
|
12 |
+
|
13 |
def main():
|
14 |
+
state = _get_state()
|
15 |
st.set_page_config(
|
16 |
page_title="Multilingual VQA",
|
17 |
layout="wide",
|
|
|
34 |
st.write(read_markdown("abstract.md"))
|
35 |
st.write(read_markdown("caveats.md"))
|
36 |
st.write("## Methodology")
|
37 |
+
col1, col2 = st.beta_columns([1, 1])
|
38 |
col1.image(
|
39 |
"./misc/article/Multilingual-VQA.png",
|
40 |
caption="Masked LM model for Image-text Pretraining.",
|
|
|
47 |
st.write(read_markdown("checkpoints.md"))
|
48 |
st.write(read_markdown("acknowledgements.md"))
|
49 |
|
50 |
+
app = MultiApp(state)
|
51 |
app.add_app("Visual Question Answering", vqa.app)
|
52 |
app.add_app("Mask Filling", mlm.app)
|
53 |
app.run()
|
54 |
+
state.sync()
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
+
main()
|
apps/mlm.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
-
|
2 |
from .utils import (
|
3 |
get_text_attributes,
|
4 |
get_top_5_predictions,
|
5 |
get_transformed_image,
|
6 |
plotly_express_horizontal_bar_plot,
|
7 |
-
|
8 |
-
bert_tokenizer
|
9 |
)
|
10 |
|
11 |
import streamlit as st
|
@@ -13,97 +11,99 @@ import numpy as np
|
|
13 |
import pandas as pd
|
14 |
import os
|
15 |
import matplotlib.pyplot as plt
|
16 |
-
|
17 |
-
from session import _get_state
|
18 |
|
19 |
|
20 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
21 |
FlaxCLIPVisionBertForMaskedLM,
|
22 |
)
|
23 |
|
|
|
24 |
def softmax(logits):
|
25 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
26 |
|
27 |
-
def app():
|
28 |
-
|
29 |
|
30 |
-
@st.cache(persist=False)
|
31 |
def predict(transformed_image, caption_inputs):
|
32 |
-
outputs =
|
33 |
-
indices = np.where(caption_inputs[
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
@st.cache(persist=False)
|
43 |
def load_model(ckpt):
|
44 |
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
|
45 |
|
46 |
-
mlm_checkpoints = [
|
47 |
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
|
48 |
|
49 |
first_index = 20
|
50 |
-
# Init Session
|
51 |
-
if
|
52 |
-
|
53 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
54 |
-
ids = bert_tokenizer(caption)
|
55 |
-
ids[np.random.randint(
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
image_path = os.path.join("cc12m_data/images_vqa",
|
60 |
image = plt.imread(image_path)
|
61 |
-
|
62 |
|
63 |
-
if
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
if st.button(
|
69 |
"Get a random example",
|
70 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
71 |
):
|
72 |
sample = dummy_data.sample(1).reset_index()
|
73 |
-
|
74 |
caption = sample.loc[0, "caption"].strip("- ")
|
75 |
-
ids = bert_tokenizer(caption)
|
76 |
-
ids[np.random.randint(
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
image_path = os.path.join("cc12m_data/images_vqa",
|
81 |
image = plt.imread(image_path)
|
82 |
-
|
83 |
|
84 |
-
transformed_image = get_transformed_image(
|
85 |
|
86 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
87 |
|
88 |
# Display Image
|
89 |
-
new_col1.image(
|
90 |
-
|
91 |
|
92 |
# Display caption
|
93 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
94 |
caption = new_col2.text_input(
|
95 |
label="Text",
|
96 |
-
value=
|
97 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
98 |
)
|
99 |
|
|
|
|
|
|
|
100 |
caption_inputs = get_text_attributes(caption)
|
101 |
|
102 |
# Display Top-5 Predictions
|
103 |
-
|
104 |
with st.spinner("Predicting..."):
|
105 |
-
|
106 |
-
|
107 |
-
labels, values = get_top_5_predictions(
|
|
|
108 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
109 |
-
st.
|
|
|
|
|
|
|
|
1 |
from .utils import (
|
2 |
get_text_attributes,
|
3 |
get_top_5_predictions,
|
4 |
get_transformed_image,
|
5 |
plotly_express_horizontal_bar_plot,
|
6 |
+
bert_tokenizer,
|
|
|
7 |
)
|
8 |
|
9 |
import streamlit as st
|
|
|
11 |
import pandas as pd
|
12 |
import os
|
13 |
import matplotlib.pyplot as plt
|
14 |
+
from mtranslate import translate
|
|
|
15 |
|
16 |
|
17 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
18 |
FlaxCLIPVisionBertForMaskedLM,
|
19 |
)
|
20 |
|
21 |
+
|
22 |
def softmax(logits):
|
23 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
24 |
|
25 |
+
def app(state):
|
26 |
+
mlm_state = state
|
27 |
|
28 |
+
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
|
29 |
def predict(transformed_image, caption_inputs):
|
30 |
+
outputs = model(pixel_values=transformed_image, **caption_inputs)
|
31 |
+
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[
|
32 |
+
1
|
33 |
+
][0]
|
34 |
+
preds = outputs.logits[0][indices]
|
35 |
+
scores = np.array(preds)
|
36 |
+
return scores
|
37 |
+
|
38 |
+
# @st.cache(persist=False)
|
|
|
|
|
39 |
def load_model(ckpt):
|
40 |
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
|
41 |
|
42 |
+
mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
|
43 |
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
|
44 |
|
45 |
first_index = 20
|
46 |
+
# Init Session mlm_state
|
47 |
+
if mlm_state.mlm_image_file is None:
|
48 |
+
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
|
49 |
caption = dummy_data.loc[first_index, "caption"].strip("- ")
|
50 |
+
ids = bert_tokenizer.encode(caption)
|
51 |
+
ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
|
52 |
+
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
53 |
+
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
|
54 |
|
55 |
+
image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
|
56 |
image = plt.imread(image_path)
|
57 |
+
mlm_state.mlm_image = image
|
58 |
|
59 |
+
#if model is None:
|
60 |
+
# Display Top-5 Predictions
|
61 |
+
with st.spinner("Loading model..."):
|
62 |
+
model = load_model(mlm_checkpoints[0])
|
63 |
|
64 |
if st.button(
|
65 |
"Get a random example",
|
66 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
67 |
):
|
68 |
sample = dummy_data.sample(1).reset_index()
|
69 |
+
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
|
70 |
caption = sample.loc[0, "caption"].strip("- ")
|
71 |
+
ids = bert_tokenizer.encode(caption)
|
72 |
+
ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
|
73 |
+
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
|
74 |
+
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
|
75 |
|
76 |
+
image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
|
77 |
image = plt.imread(image_path)
|
78 |
+
mlm_state.mlm_image = image
|
79 |
|
80 |
+
transformed_image = get_transformed_image(mlm_state.mlm_image)
|
81 |
|
82 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
83 |
|
84 |
# Display Image
|
85 |
+
new_col1.image(mlm_state.mlm_image, use_column_width="always")
|
|
|
86 |
|
87 |
# Display caption
|
88 |
new_col2.write("Write your text with exactly one [MASK] token.")
|
89 |
caption = new_col2.text_input(
|
90 |
label="Text",
|
91 |
+
value=mlm_state.caption,
|
92 |
help="Type your masked caption regarding the image above in one of the four languages.",
|
93 |
)
|
94 |
|
95 |
+
new_col2.markdown(
|
96 |
+
f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
|
97 |
+
)
|
98 |
caption_inputs = get_text_attributes(caption)
|
99 |
|
100 |
# Display Top-5 Predictions
|
|
|
101 |
with st.spinner("Predicting..."):
|
102 |
+
scores = predict(transformed_image, dict(caption_inputs))
|
103 |
+
scores = softmax(scores)
|
104 |
+
labels, values = get_top_5_predictions(scores)
|
105 |
+
# newer_col1, newer_col2 = st.beta_columns([6,4])
|
106 |
fig = plotly_express_horizontal_bar_plot(values, labels)
|
107 |
+
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
|
108 |
+
st.plotly_chart(fig, use_container_width=True)
|
109 |
+
|
apps/utils.py
CHANGED
@@ -40,6 +40,7 @@ def get_transformed_image(image):
|
|
40 |
|
41 |
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
|
42 |
|
|
|
43 |
def get_text_attributes(text):
|
44 |
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
|
45 |
|
|
|
40 |
|
41 |
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-uncased")
|
42 |
|
43 |
+
|
44 |
def get_text_attributes(text):
|
45 |
return bert_tokenizer([text], return_token_type_ids=True, return_tensors="np")
|
46 |
|
apps/vqa.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from .utils import (
|
3 |
get_text_attributes,
|
4 |
get_top_5_predictions,
|
@@ -15,29 +14,33 @@ import matplotlib.pyplot as plt
|
|
15 |
import json
|
16 |
|
17 |
from mtranslate import translate
|
18 |
-
from session import _get_state
|
19 |
|
20 |
|
21 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
22 |
FlaxCLIPVisionBertForSequenceClassification,
|
23 |
)
|
24 |
|
|
|
25 |
def softmax(logits):
|
26 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
27 |
|
28 |
-
def app():
|
29 |
-
state = _get_state()
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
return np.array(state.model(pixel_values=transformed_image, **question_inputs)[0][0])
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
@st.cache(persist=
|
37 |
def load_model(ckpt):
|
38 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
39 |
|
40 |
-
vqa_checkpoints = [
|
|
|
|
|
41 |
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
|
42 |
code_to_name = {
|
43 |
"en": "English",
|
@@ -46,77 +49,76 @@ def app():
|
|
46 |
"es": "Spanish",
|
47 |
}
|
48 |
|
49 |
-
|
50 |
with open("answer_reverse_mapping.json") as f:
|
51 |
answer_reverse_mapping = json.load(f)
|
52 |
|
53 |
first_index = 20
|
54 |
-
# Init Session
|
55 |
-
if
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
image_path = os.path.join("resized_images",
|
63 |
image = plt.imread(image_path)
|
64 |
-
|
65 |
|
66 |
-
if
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
if st.button(
|
72 |
"Get a random example",
|
73 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
74 |
):
|
75 |
sample = dummy_data.sample(1).reset_index()
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
|
82 |
-
image_path = os.path.join("resized_images",
|
83 |
image = plt.imread(image_path)
|
84 |
-
|
85 |
|
86 |
-
transformed_image = get_transformed_image(
|
87 |
|
88 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
89 |
|
90 |
# Display Image
|
91 |
-
new_col1.image(
|
92 |
-
|
93 |
|
94 |
# Display Question
|
95 |
question = new_col2.text_input(
|
96 |
label="Question",
|
97 |
-
value=
|
98 |
help="Type your question regarding the image above in one of the four languages.",
|
99 |
)
|
100 |
new_col2.markdown(
|
101 |
-
f"""**English Translation**: {question if
|
102 |
)
|
103 |
|
104 |
question_inputs = get_text_attributes(question)
|
105 |
|
106 |
# Select Language
|
107 |
options = ["en", "de", "es", "fr"]
|
108 |
-
|
109 |
"Answer Language",
|
110 |
-
index=options.index(
|
111 |
options=options,
|
112 |
format_func=lambda x: code_to_name[x],
|
113 |
help="The language to be used to show the top-5 labels.",
|
114 |
)
|
115 |
|
116 |
-
actual_answer = answer_reverse_mapping[str(
|
117 |
new_col2.markdown(
|
118 |
"**Actual Answer**: "
|
119 |
-
+ translate_labels([actual_answer],
|
120 |
+ " ("
|
121 |
+ actual_answer
|
122 |
+ ")"
|
@@ -126,6 +128,6 @@ def app():
|
|
126 |
logits = predict(transformed_image, dict(question_inputs))
|
127 |
logits = softmax(logits)
|
128 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
129 |
-
translated_labels = translate_labels(labels,
|
130 |
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
|
131 |
-
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
1 |
from .utils import (
|
2 |
get_text_attributes,
|
3 |
get_top_5_predictions,
|
|
|
14 |
import json
|
15 |
|
16 |
from mtranslate import translate
|
|
|
17 |
|
18 |
|
19 |
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
|
20 |
FlaxCLIPVisionBertForSequenceClassification,
|
21 |
)
|
22 |
|
23 |
+
|
24 |
def softmax(logits):
|
25 |
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
|
26 |
|
|
|
|
|
27 |
|
28 |
+
def app(state):
|
29 |
+
vqa_state = state
|
|
|
30 |
|
31 |
+
# @st.cache(persist=False)
|
32 |
+
def predict(transformed_image, question_inputs):
|
33 |
+
return np.array(
|
34 |
+
model(pixel_values=transformed_image, **question_inputs)[0][0]
|
35 |
+
)
|
36 |
|
37 |
+
# @st.cache(persist=False)
|
38 |
def load_model(ckpt):
|
39 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
40 |
|
41 |
+
vqa_checkpoints = [
|
42 |
+
"flax-community/clip-vision-bert-vqa-ft-6k"
|
43 |
+
] # TODO: Maybe add more checkpoints?
|
44 |
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
|
45 |
code_to_name = {
|
46 |
"en": "English",
|
|
|
49 |
"es": "Spanish",
|
50 |
}
|
51 |
|
|
|
52 |
with open("answer_reverse_mapping.json") as f:
|
53 |
answer_reverse_mapping = json.load(f)
|
54 |
|
55 |
first_index = 20
|
56 |
+
# Init Session vqa_state
|
57 |
+
if vqa_state.vqa_image_file is None:
|
58 |
+
vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"]
|
59 |
+
vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ")
|
60 |
+
vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"]
|
61 |
+
vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
|
62 |
+
vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
|
63 |
+
|
64 |
+
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
|
65 |
image = plt.imread(image_path)
|
66 |
+
vqa_state.vqa_image = image
|
67 |
|
68 |
+
# if model is None:
|
69 |
+
|
70 |
+
# Display Top-5 Predictions
|
71 |
+
with st.spinner("Loading model..."):
|
72 |
+
model = load_model(vqa_checkpoints[0])
|
73 |
|
74 |
if st.button(
|
75 |
"Get a random example",
|
76 |
help="Get a random example from the 100 `seeded` image-text pairs.",
|
77 |
):
|
78 |
sample = dummy_data.sample(1).reset_index()
|
79 |
+
vqa_state.vqa_image_file = sample.loc[0, "image_file"]
|
80 |
+
vqa_state.question = sample.loc[0, "question"].strip("- ")
|
81 |
+
vqa_state.answer_label = sample.loc[0, "answer_label"]
|
82 |
+
vqa_state.question_lang_id = sample.loc[0, "lang_id"]
|
83 |
+
vqa_state.answer_lang_id = sample.loc[0, "lang_id"]
|
84 |
|
85 |
+
image_path = os.path.join("resized_images", vqa_state.vqa_image_file)
|
86 |
image = plt.imread(image_path)
|
87 |
+
vqa_state.vqa_image = image
|
88 |
|
89 |
+
transformed_image = get_transformed_image(vqa_state.vqa_image)
|
90 |
|
91 |
new_col1, new_col2 = st.beta_columns([5, 5])
|
92 |
|
93 |
# Display Image
|
94 |
+
new_col1.image(vqa_state.vqa_image, use_column_width="always")
|
|
|
95 |
|
96 |
# Display Question
|
97 |
question = new_col2.text_input(
|
98 |
label="Question",
|
99 |
+
value=vqa_state.question,
|
100 |
help="Type your question regarding the image above in one of the four languages.",
|
101 |
)
|
102 |
new_col2.markdown(
|
103 |
+
f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}"""
|
104 |
)
|
105 |
|
106 |
question_inputs = get_text_attributes(question)
|
107 |
|
108 |
# Select Language
|
109 |
options = ["en", "de", "es", "fr"]
|
110 |
+
vqa_state.answer_lang_id = new_col2.selectbox(
|
111 |
"Answer Language",
|
112 |
+
index=options.index(vqa_state.answer_lang_id),
|
113 |
options=options,
|
114 |
format_func=lambda x: code_to_name[x],
|
115 |
help="The language to be used to show the top-5 labels.",
|
116 |
)
|
117 |
|
118 |
+
actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)]
|
119 |
new_col2.markdown(
|
120 |
"**Actual Answer**: "
|
121 |
+
+ translate_labels([actual_answer], vqa_state.answer_lang_id)[0]
|
122 |
+ " ("
|
123 |
+ actual_answer
|
124 |
+ ")"
|
|
|
128 |
logits = predict(transformed_image, dict(question_inputs))
|
129 |
logits = softmax(logits)
|
130 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
131 |
+
translated_labels = translate_labels(labels, vqa_state.answer_lang_id)
|
132 |
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
|
133 |
+
st.plotly_chart(fig, use_container_width=True)
|
multiapp.py
CHANGED
@@ -1,10 +1,17 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
2 |
class MultiApp:
|
3 |
-
def __init__(self):
|
4 |
self.apps = []
|
|
|
|
|
5 |
def add_app(self, title, func):
|
6 |
self.apps.append({"title": title, "function": func})
|
|
|
7 |
def run(self):
|
8 |
st.sidebar.header("Tasks")
|
9 |
-
app = st.sidebar.radio(
|
10 |
-
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from session import _get_state
|
3 |
+
|
4 |
class MultiApp:
|
5 |
+
def __init__(self, state):
|
6 |
self.apps = []
|
7 |
+
self.state = state
|
8 |
+
|
9 |
def add_app(self, title, func):
|
10 |
self.apps.append({"title": title, "function": func})
|
11 |
+
|
12 |
def run(self):
|
13 |
st.sidebar.header("Tasks")
|
14 |
+
app = st.sidebar.radio(
|
15 |
+
"", self.apps, format_func=lambda app: app["title"]
|
16 |
+
)
|
17 |
+
app["function"](self.state)
|
resize_images.py
CHANGED
@@ -7,7 +7,11 @@ def resize_images(path, new_path, num_pixels=300):
|
|
7 |
if not os.path.exists(new_path):
|
8 |
os.makedirs(new_path)
|
9 |
for filename in os.listdir(path):
|
10 |
-
if not filename.startswith(
|
|
|
|
|
|
|
|
|
11 |
img = cv2.imread(os.path.join(path, filename))
|
12 |
height, width, channels = img.shape
|
13 |
if height > width:
|
@@ -16,8 +20,11 @@ def resize_images(path, new_path, num_pixels=300):
|
|
16 |
else:
|
17 |
new_width = num_pixels
|
18 |
new_height = int(height * new_width / width)
|
19 |
-
img = cv2.resize(
|
|
|
|
|
20 |
cv2.imwrite(os.path.join(new_path, filename), img)
|
21 |
|
|
|
22 |
# resize_images('./images/val2014', './resized_images/val2014')
|
23 |
-
resize_images(
|
|
|
7 |
if not os.path.exists(new_path):
|
8 |
os.makedirs(new_path)
|
9 |
for filename in os.listdir(path):
|
10 |
+
if not filename.startswith(".") and (
|
11 |
+
filename.endswith(".jpg")
|
12 |
+
or filename.endswith(".jpeg")
|
13 |
+
or filename.endswith(".png")
|
14 |
+
):
|
15 |
img = cv2.imread(os.path.join(path, filename))
|
16 |
height, width, channels = img.shape
|
17 |
if height > width:
|
|
|
20 |
else:
|
21 |
new_width = num_pixels
|
22 |
new_height = int(height * new_width / width)
|
23 |
+
img = cv2.resize(
|
24 |
+
img, (new_width, new_height), interpolation=cv2.INTER_CUBIC
|
25 |
+
)
|
26 |
cv2.imwrite(os.path.join(new_path, filename), img)
|
27 |
|
28 |
+
|
29 |
# resize_images('./images/val2014', './resized_images/val2014')
|
30 |
+
resize_images("./misc/article", "./misc/article/resized", 500)
|