pipeline / classification.py
vishalkatheriya18's picture
Create classification.py
94f41f8 verified
raw
history blame
No virus
3.08 kB
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
import concurrent.futures
def hii():
return "done"
# Define image processing and classification functions
def topwear(encoding, top_wear_model):
with torch.no_grad():
outputs = top_wear_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
return top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding, pattern_model):
with torch.no_grad():
outputs = pattern_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
return pattern_model.config.id2label[predicted_class_idx]
def prints(encoding, print_model):
with torch.no_grad():
outputs = print_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
return print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding, sleeve_length_model):
with torch.no_grad():
outputs = sleeve_length_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
return sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(images=image, return_tensors="pt")
return encoding
# Run all models concurrently using threading
def pipes(image):
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Access models from session state before threading
top_wear_model = st.session_state.top_wear_model
pattern_model = st.session_state.pattern_model
print_model = st.session_state.print_model
sleeve_length_model = st.session_state.sleeve_length_model
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(topwear, encoding, top_wear_model): "topwear",
executor.submit(patterns, encoding, pattern_model): "patterns",
executor.submit(prints, encoding, print_model): "prints",
executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
# Display the results
st.write(results)
return results