Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import time | |
import torch | |
import concurrent.futures | |
def hii(): | |
return "done" | |
# Define image processing and classification functions | |
def topwear(encoding, top_wear_model): | |
with torch.no_grad(): | |
outputs = top_wear_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}") | |
return top_wear_model.config.id2label[predicted_class_idx] | |
def patterns(encoding, pattern_model): | |
with torch.no_grad(): | |
outputs = pattern_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}") | |
return pattern_model.config.id2label[predicted_class_idx] | |
def prints(encoding, print_model): | |
with torch.no_grad(): | |
outputs = print_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}") | |
return print_model.config.id2label[predicted_class_idx] | |
def sleevelengths(encoding, sleeve_length_model): | |
with torch.no_grad(): | |
outputs = sleeve_length_model(**encoding) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}") | |
return sleeve_length_model.config.id2label[predicted_class_idx] | |
def imageprocessing(image): | |
encoding = st.session_state.image_processor(images=image, return_tensors="pt") | |
return encoding | |
# Run all models concurrently using threading | |
def pipes(image): | |
# Process the image once and reuse the encoding | |
encoding = imageprocessing(image) | |
# Access models from session state before threading | |
top_wear_model = st.session_state.top_wear_model | |
pattern_model = st.session_state.pattern_model | |
print_model = st.session_state.print_model | |
sleeve_length_model = st.session_state.sleeve_length_model | |
# Define functions to run the models in parallel | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = { | |
executor.submit(topwear, encoding, top_wear_model): "topwear", | |
executor.submit(patterns, encoding, pattern_model): "patterns", | |
executor.submit(prints, encoding, print_model): "prints", | |
executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length" | |
} | |
results = {} | |
for future in concurrent.futures.as_completed(futures): | |
model_name = futures[future] | |
try: | |
results[model_name] = future.result() | |
except Exception as e: | |
st.error(f"Error in {model_name}: {str(e)}") | |
results[model_name] = None | |
# Display the results | |
st.write(results) | |
return results |