Spaces:
Running
Running
File size: 3,033 Bytes
6985586 a01e989 90de990 a01e989 d05a223 6668c84 0466d84 6985586 a9e905c bca94f8 a9e905c d05a223 a9e905c d05a223 bca94f8 a01e989 d05a223 a9e905c 90de990 6668c84 d05a223 a01e989 94e7a7b a01e989 a9e905c a01e989 705e8fa a01e989 a9e905c a01e989 d05a223 6c67d85 a01e989 90de990 a01e989 0466d84 90de990 a01e989 90de990 a01e989 6668c84 a01e989 367d052 a01e989 624b826 6668c84 6101644 0466d84 90de990 6101644 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import streamlit as st
from text2image import get_model, get_tokenizer, get_image_transform
from utils import text_encoder, image_encoder
from PIL import Image
from jax import numpy as jnp
from io import BytesIO
import pandas as pd
import requests
import jax
import gc
headers = {
"User-Agent":
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36 Edge/18.19582"
}
def app():
st.title("From Image to Text")
st.markdown(
"""
### π Ciao!
Here you can find the captions or the labels that are most related to a given image. It is a zero-shot
image classification task!
π€ Italian mode on! π€
For example, try typing "gatto" (cat) in the space for label1 and "cane" (dog) in the space for label2 and click
"classify"!
"""
)
image_url = st.text_input(
"You can input the URL of an image",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Ragdoll%2C_blue_mitted.JPG/1280px-Ragdoll%2C_blue_mitted.JPG",
)
MAX_CAP = 4
col1, col2 = st.columns([0.75, 0.25])
with col2:
captions_count = st.selectbox(
"Number of labels", options=range(1, MAX_CAP + 1), index=1
)
compute = st.button("CLASSIFY")
with col1:
captions = list()
for idx in range(min(MAX_CAP, captions_count)):
captions.append(st.text_input(f"Insert label {idx+1}"))
if compute:
captions = [c for c in captions if c != ""]
if not captions or not image_url:
st.error("Please choose one image and at least one label")
else:
with st.spinner("Computing..."):
model = get_model()
tokenizer = get_tokenizer()
text_embeds = list()
for i, c in enumerate(captions):
text_embeds.extend(text_encoder(c, model, tokenizer)[0])
text_embeds = jnp.array(text_embeds)
response = requests.get(image_url, headers=headers, stream=True)
image = Image.open(BytesIO(response.content)).convert("RGB")
transform = get_image_transform(model.config.vision_config.image_size)
image_embed, _ = image_encoder(transform(image), model)
# we could have a softmax here
cos_similarities = jax.nn.softmax(
jnp.matmul(image_embed, text_embeds.T)
)
chart_data = pd.Series(cos_similarities[0], index=captions)
col1, col2 = st.columns(2)
with col1:
st.bar_chart(chart_data)
with col2:
st.image(image, use_column_width=True)
gc.collect()
elif image_url:
response = requests.get(image_url, headers=headers, stream=True)
image = Image.open(BytesIO(response.content)).convert("RGB")
st.image(image)
|