pipeline / classification.py
vishalkatheriya18's picture
Update classification.py
c096e3a verified
raw
history blame
No virus
5.59 kB
import streamlit as st
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import requests
from io import BytesIO
import time
import torch
import concurrent.futures
import torch.nn.functional as F
def inferring(encoding, model):
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1) # Convert logits to probabilities
predicted_class_idx = logits.argmax(-1).item() # Get the predicted class index
predicted_probability = probabilities[0, predicted_class_idx].item() # Get the probability of the predicted class
return model.config.id2label[predicted_class_idx]+f" prob:{predicted_probability}"
# #testing
# def inferring(encoding,model):
# with torch.no_grad():
# outputs = model(**encoding)
# logits = outputs.logits
# predicted_class_idx = logits.argmax(-1).item()
# # st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
# return model.config.id2label[predicted_class_idx]
def imageprocessing(image):
encoding = st.session_state.image_processor(images=image, return_tensors="pt")
return encoding
# Run all models concurrently using threading
def pipes(image,categories):
# st.header(categories)
# Process the image once and reuse the encoding
encoding = imageprocessing(image)
# Access models from session state before threading
top_wear_model = st.session_state.top_wear_model
full_wear_model=st.session_state.fullwear
bottom_wear_model=st.session_state.bottomwear_model
pattern_model = st.session_state.pattern_model
print_model = st.session_state.print_model
sleeve_length_model = st.session_state.sleeve_length_model
neck_style_model=st.session_state.neck_style_model
#process ---------------------------------------------------------------------------upperwear--------------
if categories=="UpperBody":
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(inferring, encoding, top_wear_model): "topwear",
executor.submit(inferring, encoding, pattern_model): "patterns",
executor.submit(inferring, encoding, print_model): "prints",
executor.submit(inferring, encoding, sleeve_length_model): "sleeve_length"
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
return results
#process ---------------------------------------------------------------------------fullwear--------------
elif categories=="Wholebody":
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(inferring, encoding, full_wear_model): "fullwear",
executor.submit(inferring, encoding, pattern_model): "patterns",
executor.submit(inferring, encoding, print_model): "prints",
executor.submit(inferring, encoding, sleeve_length_model): "sleeve_length"
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
return results
#process ---------------------------------------------------------------------------bottomwear--------------
elif categories=="Lowerbody":
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(inferring, encoding, bottom_wear_model): "lowerwear",
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
return results
#process ---------------------------------------------------------------------------Neck_design--------------
elif categories=="Neck":
# Define functions to run the models in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(inferring, encoding, neck_style_model): "Neckstyle",
}
results = {}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
try:
results[model_name] = future.result()
except Exception as e:
st.error(f"Error in {model_name}: {str(e)}")
results[model_name] = None
return results
else:
return {"invalid categorie":f"{categories} categorie not in process!"}