Spaces:
Running
Running
import streamlit as st | |
from ultralytics import YOLO | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import numpy as np | |
import cv2 | |
import concurrent.futures | |
# Categories dictionary | |
categories_dict = { | |
"UpperBody": ["top", "t-shirt", "sweatshirt", "blouse", "sweater", "cardigan", "jacket", "vest"], | |
"Lowerbody": ["pants", "shorts", "skirt"], | |
"Wholebody": ["coat", "dress", "jumpsuit", "cape"], | |
"Head": ["glasses", "hat", "headband", "head covering", "hair accessory"], | |
"Neck": ["tie", "neckline"], | |
"Arms and Hands": ["glove", "watch","sleeve"], | |
"Waist": ["belt"], | |
"Legs and Feet": ["leg warmer", "tights", "stockings", "sock", "shoe"], | |
"Others": ["bag", "wallet", "scarf", "umbrella"], | |
"Garment parts": ["hood", "collar", "lapel", "epaulette","pocket"], | |
"Closures": ["buckle", "zipper"], | |
"Decorations": ["applique", "bead", "bow", "flower", "fringe", "ribbon", "rivet", "ruffle", "sequin", "tassel"] | |
} | |
def find_category(subcategory): | |
for category, subcategories in categories_dict.items(): | |
if subcategory in subcategories: | |
return category | |
return "Subcategory not found." | |
# Load models and processor only once using Streamlit session state | |
if 'models_loaded' not in st.session_state: | |
st.session_state.segment_model = YOLO("best.pt") | |
st.write("Model loaded!") | |
st.session_state.models_loaded = True | |
# Streamlit app UI | |
st.title("Clothing Classification Pipeline") | |
url = st.sidebar.text_input("Paste image URL here...") | |
if url: | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)) | |
st.sidebar.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) | |
# Convert image to numpy array for YOLO model | |
image_np = np.array(image) | |
# Perform inference | |
results = st.session_state.segment_model(image_np) | |
# Create a copy of the original image to draw bounding boxes and labels | |
output_image = image_np.copy() | |
cropped_images = [] # List to hold cropped images and their titles | |
# Visualize the segmentation results | |
for result in results: | |
boxes = result.boxes # Bounding boxes | |
classes = result.names # Class names of the detected objects | |
for i, box in enumerate(boxes): | |
box_coords = box.xyxy[0].cpu().numpy().astype(int) | |
x1, y1, x2, y2 = box_coords | |
# Draw the bounding box on the original image | |
cv2.rectangle(output_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) | |
# Get the class label and confidence score for the object | |
class_label = classes[box.cls[0].int().item()] | |
confidence = box.conf[0].item() | |
# Prepare the label text with class and confidence | |
label_text = f'{class_label}: {confidence:.2f}' | |
# Put text label on the original image | |
cv2.putText(output_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
# Crop the image based on the bounding box | |
cropped_image = image_np[y1:y2, x1:x2].copy() | |
cropped_image = cv2.resize(cropped_image, (200, 200)) # Resize cropped image | |
category_name = find_category(class_label) | |
# Add cropped image and its title to the list | |
cropped_images.append((cropped_image, f'Class: {category_name}, Confidence: {confidence:.2f}')) | |
# Display the original image with bounding boxes and labels | |
st.sidebar.image(output_image, caption="Segmented Image", channels="RGB", use_column_width=True) | |
# Display cropped images row-wise | |
num_columns = 3 # Number of columns per row | |
num_rows = (len(cropped_images) + num_columns - 1) // num_columns # Calculate the number of rows | |
for i in range(num_rows): | |
cols = st.columns(num_columns) | |
for j in range(num_columns): | |
idx = i * num_columns + j | |
if idx < len(cropped_images): | |
cropped_image, title = cropped_images[idx] | |
with cols[j]: | |
st.image(cropped_image, caption=title, use_column_width=True) | |
else: | |
st.write("URL Invalid...!") | |
except Exception as e: | |
st.write(f"An error occurred: {e}") | |