|
|
|
import streamlit as st |
|
import tensorflow as tf |
|
import numpy as np |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
from bs4 import BeautifulSoup |
|
import pandas as pd |
|
import os |
|
|
|
def download_model(model_url, model_path): |
|
if not os.path.exists(model_path): |
|
response = requests.get(model_url) |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
|
|
def load_model(model_path): |
|
interpreter = tf.lite.Interpreter(model_path=model_path) |
|
interpreter.allocate_tensors() |
|
return interpreter |
|
|
|
def preprocess_image(image, input_size): |
|
image = image.convert('RGB') |
|
image = image.resize((input_size, input_size)) |
|
image_np = np.array(image, dtype=np.float32) |
|
image_np = np.expand_dims(image_np, axis=0) |
|
image_np = image_np / 255.0 |
|
return image_np |
|
|
|
def run_inference(interpreter, input_data): |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details() |
|
|
|
interpreter.set_tensor(input_details[0]['index'], input_data) |
|
interpreter.invoke() |
|
|
|
output_data_shopping_intent = interpreter.get_tensor(output_details[0]['index']) |
|
|
|
return output_data_shopping_intent |
|
|
|
def fetch_images_from_url(url): |
|
response = requests.get(url) |
|
soup = BeautifulSoup(response.content, 'html.parser') |
|
img_tags = soup.find_all('img') |
|
img_urls = [img['src'] for img in img_tags if 'src' in img.attrs] |
|
return img_urls |
|
|
|
def render_intent_bars(labels, percentages): |
|
for label, percentage in zip(labels, percentages): |
|
bar_html = f""" |
|
<div style='display: flex; align-items: center;'> |
|
<div style='width: 30%; text-align: right; padding-right: 10px;'>{label}</div> |
|
<div style='width: 70%; display: flex; align-items: center;'> |
|
<div style='background-color: #007BFF; height: 10px; width: {percentage}%;'></div> |
|
<div style='padding-left: 10px;'>{percentage:.2f}%</div> |
|
</div> |
|
</div> |
|
""" |
|
st.markdown(bar_html, unsafe_allow_html=True) |
|
|
|
def main(): |
|
st.set_page_config(layout="wide") |
|
st.title("Shopping Intent Classification - SEO by DEJAN") |
|
|
|
st.markdown(""" |
|
Multi-label image classification model [extracted from Chrome](https://dejanmarketing.com/product-image-optimisation-with-chromes-convolutional-neural-network/). The model can be deployed in an automated pipeline capable of classifying product images in bulk. Javascript-based website scraping currently unsupported. |
|
""") |
|
|
|
st.write("Enter a URL to fetch and classify all images on the page:") |
|
|
|
model_url = "https://huggingface.co/dejanseo/shopping-intent/resolve/main/model.tflite" |
|
model_path = "model.tflite" |
|
download_model(model_url, model_path) |
|
|
|
url = st.text_input("Enter URL") |
|
|
|
if url: |
|
img_urls = fetch_images_from_url(url) |
|
if img_urls: |
|
st.write(f"Found {len(img_urls)} images") |
|
interpreter = load_model(model_path) |
|
input_details = interpreter.get_input_details() |
|
input_shape = input_details[0]['shape'] |
|
input_size = input_shape[1] |
|
|
|
categories = [ |
|
"No Shopping Intent", |
|
"Apparel", |
|
"Home Decor", |
|
"Other" |
|
] |
|
|
|
for img_url in img_urls: |
|
try: |
|
response = requests.get(img_url) |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
input_data = preprocess_image(image, input_size) |
|
output_data_shopping_intent = run_inference(interpreter, input_data) |
|
|
|
shopping_intent_percentages = (output_data_shopping_intent.flatten() * 100).tolist() |
|
|
|
col1, col2 = st.columns([1, 3]) |
|
with col1: |
|
st.image(image.resize((224, 224)), width=224) |
|
with col2: |
|
st.write(f"[URL]({img_url})") |
|
render_intent_bars(categories, shopping_intent_percentages) |
|
st.write("---") |
|
except Exception as e: |
|
st.write(f"Could not process image {img_url}: {e}") |
|
|
|
st.markdown(""" |
|
Interested in using this in an automated pipeline for bulk image classification? |
|
Please [book an appointment](https://dejanmarketing.com/conference/) to discuss your needs. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|