vishalkatheriya18 commited on
Commit
94f41f8
1 Parent(s): b1e655a

Create classification.py

Browse files
Files changed (1) hide show
  1. classification.py +81 -0
classification.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import time
7
+ import torch
8
+ import concurrent.futures
9
+
10
+ def hii():
11
+ return "done"
12
+
13
+ # Define image processing and classification functions
14
+ def topwear(encoding, top_wear_model):
15
+ with torch.no_grad():
16
+ outputs = top_wear_model(**encoding)
17
+ logits = outputs.logits
18
+ predicted_class_idx = logits.argmax(-1).item()
19
+ st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
20
+ return top_wear_model.config.id2label[predicted_class_idx]
21
+
22
+ def patterns(encoding, pattern_model):
23
+ with torch.no_grad():
24
+ outputs = pattern_model(**encoding)
25
+ logits = outputs.logits
26
+ predicted_class_idx = logits.argmax(-1).item()
27
+ st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
28
+ return pattern_model.config.id2label[predicted_class_idx]
29
+
30
+ def prints(encoding, print_model):
31
+ with torch.no_grad():
32
+ outputs = print_model(**encoding)
33
+ logits = outputs.logits
34
+ predicted_class_idx = logits.argmax(-1).item()
35
+ st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
36
+ return print_model.config.id2label[predicted_class_idx]
37
+
38
+ def sleevelengths(encoding, sleeve_length_model):
39
+ with torch.no_grad():
40
+ outputs = sleeve_length_model(**encoding)
41
+ logits = outputs.logits
42
+ predicted_class_idx = logits.argmax(-1).item()
43
+ st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
44
+ return sleeve_length_model.config.id2label[predicted_class_idx]
45
+
46
+ def imageprocessing(image):
47
+ encoding = st.session_state.image_processor(images=image, return_tensors="pt")
48
+ return encoding
49
+
50
+ # Run all models concurrently using threading
51
+ def pipes(image):
52
+ # Process the image once and reuse the encoding
53
+ encoding = imageprocessing(image)
54
+
55
+ # Access models from session state before threading
56
+ top_wear_model = st.session_state.top_wear_model
57
+ pattern_model = st.session_state.pattern_model
58
+ print_model = st.session_state.print_model
59
+ sleeve_length_model = st.session_state.sleeve_length_model
60
+
61
+ # Define functions to run the models in parallel
62
+ with concurrent.futures.ThreadPoolExecutor() as executor:
63
+ futures = {
64
+ executor.submit(topwear, encoding, top_wear_model): "topwear",
65
+ executor.submit(patterns, encoding, pattern_model): "patterns",
66
+ executor.submit(prints, encoding, print_model): "prints",
67
+ executor.submit(sleevelengths, encoding, sleeve_length_model): "sleeve_length"
68
+ }
69
+
70
+ results = {}
71
+ for future in concurrent.futures.as_completed(futures):
72
+ model_name = futures[future]
73
+ try:
74
+ results[model_name] = future.result()
75
+ except Exception as e:
76
+ st.error(f"Error in {model_name}: {str(e)}")
77
+ results[model_name] = None
78
+
79
+ # Display the results
80
+ st.write(results)
81
+ return results