Spaces:
Running
Running
import streamlit as st | |
from ultralytics import YOLO | |
from PIL import Image | |
import time | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
import requests | |
from io import BytesIO | |
import numpy as np | |
import cv2 | |
import concurrent.futures | |
from classification import pipes | |
time_taken={} | |
# Categories dictionary | |
categories_dict = { | |
"UpperBody": ["shirt, blouse","top, t-shirt, sweatshirt", "blouse", "sweater", "cardigan", "jacket", "vest"], | |
"Lowerbody": ["pants", "shorts", "skirt"], | |
"Wholebody": ["coat", "dress", "jumpsuit", "cape"], | |
"Head": ["glasses", "hat", "headband", "head covering", "hair accessory"], | |
"Neck": ["tie", "neckline", "collar"], | |
"Arms and Hands": ["glove", "watch", "sleeve"], | |
"Waist": ["belt"], | |
"Legs and Feet": ["leg warmer", "tights", "stockings", "sock", "shoe"], | |
"Others": ["bag", "wallet", "scarf", "umbrella"], | |
"Garment parts": ["hood", "lapel", "epaulette", "pocket"], | |
"Closures": ["buckle", "zipper"], | |
"Decorations": ["applique", "bead", "bow", "flower", "fringe", "ribbon", "rivet", "ruffle", "sequin", "tassel"] | |
} | |
def find_category(subcategory): | |
for category, subcategories in categories_dict.items(): | |
if subcategory in subcategories: | |
return category | |
return "Subcategory not found." | |
# Load models and processor only once using Streamlit session state | |
if 'models_loaded' not in st.session_state: | |
sm=time.time() | |
# localization model | |
st.session_state.segment_model = YOLO("best.pt") | |
# image preprocessor | |
st.session_state.image_processor = AutoImageProcessor.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") | |
# top wear | |
st.session_state.top_wear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-topwear") | |
# bottom wear | |
st.session_state.bottomwear_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-bottomwear") | |
# full wear | |
st.session_state.fullwear = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-fullwear") | |
# for full wear and top wear | |
st.session_state.pattern_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-pattern-rgb") | |
st.session_state.print_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-print") | |
st.session_state.sleeve_length_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-sleeve-length") | |
st.session_state.neck_style_model = AutoModelForImageClassification.from_pretrained("vishalkatheriya18/convnextv2-tiny-1k-224-finetuned-neck-style") | |
st.session_state.models_loaded = True | |
time_taken["model loading"]=time.time()-sm | |
# Streamlit app UI | |
st.title("Clothing Classification Pipeline") | |
url = st.sidebar.text_input("Paste image URL here...") | |
if url: | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)) | |
st.sidebar.image(image.resize((200, 200)), caption="Uploaded Image", use_column_width=False) | |
# Convert image to numpy array for YOLO model | |
image_np = np.array(image) | |
outputs={} | |
# Perform inference | |
ss=time.time() | |
results = st.session_state.segment_model(image_np) | |
time_taken["yolo model result time"]=time.time()-ss | |
# Create a copy of the original image to draw bounding boxes and labels | |
output_image = image_np.copy() | |
cropped_images_list = [] # List to hold cropped images and their titles | |
# Visualize the segmentation results | |
for result in results: | |
boxes = result.boxes # Bounding boxes | |
classes = result.names # Class names of the detected objects | |
for i, box in enumerate(boxes): | |
box_coords = box.xyxy[0].cpu().numpy().astype(int) | |
x1, y1, x2, y2 = box_coords | |
# Draw the bounding box on the original image | |
cv2.rectangle(output_image, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=1) | |
# Get the class label and confidence score for the object | |
class_label = classes[box.cls[0].int().item()] | |
confidence = box.conf[0].item() | |
# Prepare the label text with class and confidence | |
label_text = f'{class_label}: {confidence:.2f}' | |
# Put text label on the original image | |
cv2.putText(output_image, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
# Crop the image based on the bounding box | |
cropped_image = image_np[y1:y2, x1:x2].copy() | |
cropped_image_resized = cv2.resize(cropped_image, (200, 200)) # Resize cropped image | |
category_name = find_category(class_label) | |
# st.write(f"Detected category: {category_name}") | |
cs=time.time() | |
if category_name=="Neck": | |
outputs[category_name]=pipes(cropped_image, category_name) | |
else: | |
outputs[category_name]=pipes(image, category_name) | |
time_taken[f"{category_name} prediction time"]=time.time()-cs | |
# st.write(pipes(cropped_image, category_name)) | |
# Add cropped image and its title to the list | |
cropped_images_list.append((cropped_image_resized, f'Class: {category_name}, Confidence: {confidence:.2f}'+">>subcat>>"+label_text)) | |
time_taken["whole process time"]=time.time()-ss | |
# Display the original image with bounding boxes and labels | |
st.sidebar.image(output_image, caption="Segmented Image", channels="RGB", use_column_width=True) | |
# Display cropped images row-wise | |
num_columns = 3 # Number of columns per row | |
num_rows = (len(cropped_images_list) + num_columns - 1) // num_columns # Calculate the number of rows | |
for i in range(num_rows): | |
cols = st.columns(num_columns) | |
for j in range(num_columns): | |
idx = i * num_columns + j | |
if idx < len(cropped_images_list): | |
cropped_image, title = cropped_images_list[idx] | |
with cols[j]: | |
st.image(cropped_image, caption=title, use_column_width=True) | |
st.header("Output") | |
st.json(outputs) | |
st.header("taken time") | |
st.json(time_taken) | |
else: | |
st.write("URL Invalid...!") | |
except Exception as e: | |
st.write(f"An error occurred: {e}") | |