pipeline / app.py
vishalkatheriya18's picture
Update app.py
a66e03f verified
raw
history blame
No virus
4.92 kB
import streamlit as st
from ultralytics import YOLO
from PIL import Image
import requests
from io import BytesIO
import numpy as np
import cv2
import concurrent.futures
from classification import hii
# Categories dictionary
categories_dict = {
"UpperBody": ["top", "t-shirt", "sweatshirt", "blouse", "sweater", "cardigan", "jacket", "vest"],
"Lowerbody": ["pants", "shorts", "skirt"],
"Wholebody": ["coat", "dress", "jumpsuit", "cape"],
"Head": ["glasses", "hat", "headband", "head covering", "hair accessory"],
"Neck": ["tie", "neckline"],
"Arms and Hands": ["glove", "watch","sleeve"],
"Waist": ["belt"],
"Legs and Feet": ["leg warmer", "tights", "stockings", "sock", "shoe"],
"Others": ["bag", "wallet", "scarf", "umbrella"],
"Garment parts": ["hood", "collar", "lapel", "epaulette","pocket"],
"Closures": ["buckle", "zipper"],
"Decorations": ["applique", "bead", "bow", "flower", "fringe", "ribbon", "rivet", "ruffle", "sequin", "tassel"]
}
def find_category(subcategory):
for category, subcategories in categories_dict.items():
if subcategory in subcategories:
return category
return "Subcategory not found."
# Load models and processor only once using Streamlit session state
if 'models_loaded' not in st.session_state:
st.session_state.segment_model = YOLO("best.pt")
st.write("Model loaded!")
st.session_state.models_loaded = True
# Streamlit app UI
st.title("Clothing Classification Pipeline")
url = st.sidebar.text_input("Paste image URL here...")
if url:
try:
response = requests.get(url)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
st.sidebar.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False)
# Convert image to numpy array for YOLO model
image_np = np.array(image)
# Perform inference
results = st.session_state.segment_model(image_np)
# Create a copy of the original image to draw bounding boxes and labels
output_image = image_np.copy()
cropped_images = [] # List to hold cropped images and their titles
# Visualize the segmentation results
for result in results:
boxes = result.boxes # Bounding boxes
classes = result.names # Class names of the detected objects
for i, box in enumerate(boxes):
box_coords = box.xyxy[0].cpu().numpy().astype(int)
x1, y1, x2, y2 = box_coords
# Draw the bounding box on the original image
cv2.rectangle(output_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
# Get the class label and confidence score for the object
class_label = classes[box.cls[0].int().item()]
confidence = box.conf[0].item()
# Prepare the label text with class and confidence
label_text = f'{class_label}: {confidence:.2f}'
# Put text label on the original image
cv2.putText(output_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# Crop the image based on the bounding box
cropped_image = image_np[y1:y2, x1:x2].copy()
cropped_image = cv2.resize(cropped_image, (200, 200)) # Resize cropped image
category_name = find_category(class_label)
# Add cropped image and its title to the list
cropped_images.append((cropped_image, f'Class: {category_name}, Confidence: {confidence:.2f}'))
# Display the original image with bounding boxes and labels
st.sidebar.image(output_image, caption="Segmented Image", channels="RGB", use_column_width=True)
# Display cropped images row-wise
num_columns = 3 # Number of columns per row
num_rows = (len(cropped_images) + num_columns - 1) // num_columns # Calculate the number of rows
for i in range(num_rows):
cols = st.columns(num_columns)
for j in range(num_columns):
idx = i * num_columns + j
if idx < len(cropped_images):
cropped_image, title = cropped_images[idx]
with cols[j]:
st.image(cropped_image, caption=title, use_column_width=True)
st.write(hii())
else:
st.write("URL Invalid...!")
except Exception as e:
st.write(f"An error occurred: {e}")