vishalkatheriya18 commited on
Commit
1a35476
1 Parent(s): 1fb0946

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +4 -4
classification.py CHANGED
@@ -17,7 +17,7 @@ def topwear(encoding, top_wear_model):
17
  outputs = top_wear_model(**encoding)
18
  logits = outputs.logits
19
  predicted_class_idx = logits.argmax(-1).item()
20
- st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
21
  return top_wear_model.config.id2label[predicted_class_idx]
22
 
23
  def patterns(encoding, pattern_model):
@@ -25,7 +25,7 @@ def patterns(encoding, pattern_model):
25
  outputs = pattern_model(**encoding)
26
  logits = outputs.logits
27
  predicted_class_idx = logits.argmax(-1).item()
28
- st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
29
  return pattern_model.config.id2label[predicted_class_idx]
30
 
31
  def prints(encoding, print_model):
@@ -33,7 +33,7 @@ def prints(encoding, print_model):
33
  outputs = print_model(**encoding)
34
  logits = outputs.logits
35
  predicted_class_idx = logits.argmax(-1).item()
36
- st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
37
  return print_model.config.id2label[predicted_class_idx]
38
 
39
  def sleevelengths(encoding, sleeve_length_model):
@@ -41,7 +41,7 @@ def sleevelengths(encoding, sleeve_length_model):
41
  outputs = sleeve_length_model(**encoding)
42
  logits = outputs.logits
43
  predicted_class_idx = logits.argmax(-1).item()
44
- st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
45
  return sleeve_length_model.config.id2label[predicted_class_idx]
46
 
47
  def imageprocessing(image):
 
17
  outputs = top_wear_model(**encoding)
18
  logits = outputs.logits
19
  predicted_class_idx = logits.argmax(-1).item()
20
+ # st.write(f"Top Wear: {top_wear_model.config.id2label[predicted_class_idx]}")
21
  return top_wear_model.config.id2label[predicted_class_idx]
22
 
23
  def patterns(encoding, pattern_model):
 
25
  outputs = pattern_model(**encoding)
26
  logits = outputs.logits
27
  predicted_class_idx = logits.argmax(-1).item()
28
+ # st.write(f"Pattern: {pattern_model.config.id2label[predicted_class_idx]}")
29
  return pattern_model.config.id2label[predicted_class_idx]
30
 
31
  def prints(encoding, print_model):
 
33
  outputs = print_model(**encoding)
34
  logits = outputs.logits
35
  predicted_class_idx = logits.argmax(-1).item()
36
+ # st.write(f"Print: {print_model.config.id2label[predicted_class_idx]}")
37
  return print_model.config.id2label[predicted_class_idx]
38
 
39
  def sleevelengths(encoding, sleeve_length_model):
 
41
  outputs = sleeve_length_model(**encoding)
42
  logits = outputs.logits
43
  predicted_class_idx = logits.argmax(-1).item()
44
+ # st.write(f"Sleeve Length: {sleeve_length_model.config.id2label[predicted_class_idx]}")
45
  return sleeve_length_model.config.id2label[predicted_class_idx]
46
 
47
  def imageprocessing(image):