nebiyu29 commited on
Commit
8609e41
1 Parent(s): 7768e0f

added the id2labe

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -54,6 +54,11 @@ def extract_predictions(outputs):
54
 
55
  # a function that classifies text
56
 
 
 
 
 
 
57
  def classify_text(text):
58
 
59
  # Split text into segments using split_text
@@ -62,10 +67,7 @@ def classify_text(text):
62
  # Initialize empty list for predictions
63
  predictions = []
64
 
65
- # Move device to GPU if available
66
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
- model = model.to(device)
68
-
69
  # Loop through segments, process, and store predictions
70
  for segment in segments:
71
  inputs = tokenizer([segment], padding=True, return_tensors="pt")
@@ -77,13 +79,15 @@ def classify_text(text):
77
 
78
  # Extract predictions for each segment
79
  probs, preds = extract_predictions(outputs) # Define this function based on your model's output
80
-
81
  # Append predictions for this segment
82
  predictions.append({
83
  "segment_text": segment,
84
- "label": preds[0], # Assuming single label prediction
85
- "probability": probs[preds[0]] # Access probability for the predicted label
86
  })
 
 
87
 
88
 
89
  interface = gr.Interface(
 
54
 
55
  # a function that classifies text
56
 
57
+ # Move device to GPU if available
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ model = model.to(device)
60
+ class_names=list(model.config.id2label.values())
61
+
62
  def classify_text(text):
63
 
64
  # Split text into segments using split_text
 
67
  # Initialize empty list for predictions
68
  predictions = []
69
 
70
+
 
 
 
71
  # Loop through segments, process, and store predictions
72
  for segment in segments:
73
  inputs = tokenizer([segment], padding=True, return_tensors="pt")
 
79
 
80
  # Extract predictions for each segment
81
  probs, preds = extract_predictions(outputs) # Define this function based on your model's output
82
+ pred_label=class_names[preds[0].item()]
83
  # Append predictions for this segment
84
  predictions.append({
85
  "segment_text": segment,
86
+ "label": pred_label, # Assuming single label prediction
87
+ "probability": probs[0][preds[0]].item() # Access probability for the predicted label
88
  })
89
+
90
+ return predictions
91
 
92
 
93
  interface = gr.Interface(