import streamlit as st from ultralytics import YOLO from PIL import Image import requests from io import BytesIO import numpy as np import cv2 import concurrent.futures # Categories dictionary categories_dict = { "UpperBody": ["top", "t-shirt", "sweatshirt", "blouse", "sweater", "cardigan", "jacket", "vest"], "Lowerbody": ["pants", "shorts", "skirt"], "Wholebody": ["coat", "dress", "jumpsuit", "cape"], "Head": ["glasses", "hat", "headband", "head covering", "hair accessory"], "Neck": ["tie", "neckline"], "Arms and Hands": ["glove", "watch","sleeve"], "Waist": ["belt"], "Legs and Feet": ["leg warmer", "tights", "stockings", "sock", "shoe"], "Others": ["bag", "wallet", "scarf", "umbrella"], "Garment parts": ["hood", "collar", "lapel", "epaulette","pocket"], "Closures": ["buckle", "zipper"], "Decorations": ["applique", "bead", "bow", "flower", "fringe", "ribbon", "rivet", "ruffle", "sequin", "tassel"] } def find_category(subcategory): for category, subcategories in categories_dict.items(): if subcategory in subcategories: return category return "Subcategory not found." # Load models and processor only once using Streamlit session state if 'models_loaded' not in st.session_state: st.session_state.segment_model = YOLO("best.pt") st.write("Model loaded!") st.session_state.models_loaded = True # Streamlit app UI st.title("Clothing Classification Pipeline") url = st.sidebar.text_input("Paste image URL here...") if url: try: response = requests.get(url) if response.status_code == 200: image = Image.open(BytesIO(response.content)) st.sidebar.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) # Convert image to numpy array for YOLO model image_np = np.array(image) # Perform inference results = st.session_state.segment_model(image_np) # Create a copy of the original image to draw bounding boxes and labels output_image = image_np.copy() cropped_images = [] # List to hold cropped images and their titles # Visualize the segmentation results for result in results: boxes = result.boxes # Bounding boxes classes = result.names # Class names of the detected objects for i, box in enumerate(boxes): box_coords = box.xyxy[0].cpu().numpy().astype(int) x1, y1, x2, y2 = box_coords # Draw the bounding box on the original image cv2.rectangle(output_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) # Get the class label and confidence score for the object class_label = classes[box.cls[0].int().item()] confidence = box.conf[0].item() # Prepare the label text with class and confidence label_text = f'{class_label}: {confidence:.2f}' # Put text label on the original image cv2.putText(output_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # Crop the image based on the bounding box cropped_image = image_np[y1:y2, x1:x2].copy() cropped_image = cv2.resize(cropped_image, (200, 200)) # Resize cropped image category_name = find_category(class_label) # Add cropped image and its title to the list cropped_images.append((cropped_image, f'Class: {category_name}, Confidence: {confidence:.2f}')) # Display the original image with bounding boxes and labels st.sidebar.image(output_image, caption="Segmented Image", channels="RGB", use_column_width=True) # Display cropped images row-wise num_columns = 3 # Number of columns per row num_rows = (len(cropped_images) + num_columns - 1) // num_columns # Calculate the number of rows for i in range(num_rows): cols = st.columns(num_columns) for j in range(num_columns): idx = i * num_columns + j if idx < len(cropped_images): cropped_image, title = cropped_images[idx] with cols[j]: st.image(cropped_image, caption=title, use_column_width=True) else: st.write("URL Invalid...!") except Exception as e: st.write(f"An error occurred: {e}")