Spaces:
Running
Running
File size: 3,076 Bytes
94f41f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
import concurrent.futures
def hii():
return "done"
# Define image processing and classification functions
def topwear(encoding, top_wear_model):
with torch.no_grad():
outputs = top_wear_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
return top_wear_model.config.id2label[predicted_class_idx]
def patterns(encoding, pattern_model):
with torch.no_grad():
outputs = pattern_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
return pattern_model.config.id2label[predicted_class_idx]
def prints(encoding, print_model):
with torch.no_grad():
outputs = print_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
return print_model.config.id2label[predicted_class_idx]
def sleevelengths(encoding, sleeve_length_model):
with torch.no_grad():
outputs = sleeve_length_model(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
return sleeve_length_model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(images=image, return_tensors="pt")
return encoding
# Run all models concurrently using threading
def pipes(image):
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Access models from session state before threading
top_wear_model = st.session_state.top_wear_model
pattern_model = st.session_state.pattern_model
print_model = st.session_state.print_model
sleeve_length_model = st.session_state.sleeve_length_model
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(topwear, encoding, top_wear_model): "topwear",
executor.submit(patterns, encoding, pattern_model): "patterns",
executor.submit(prints, encoding, print_model): "prints",
executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
# Display the results
st.write(results)
return results |