nebiyu29 commited on
Commit
dd12fa7
1 Parent(s): 855ff50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -60
app.py CHANGED
@@ -54,74 +54,37 @@ def extract_predictions(outputs):
54
 
55
  # a function that classifies text
56
 
57
- # def classify_text(text):
58
- # # Define labels
59
- # labels = ["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"]
60
 
61
- # # Split text into segments using split_text
62
- # segments = split_text(text)
63
 
64
- # # Initialize empty list for predictions
65
- # predictions = []
66
 
67
- # # Move device to GPU if available
68
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
- # model = model.to(device)
70
 
71
- # # Loop through segments, process, and store predictions
72
- # for segment in segments:
73
- # inputs = tokenizer([segment], padding=True, return_tensors="pt")
74
- # input_ids = inputs["input_ids"].to(device)
75
- # attention_mask = inputs["attention_mask"].to(device)
76
 
77
- # with torch.no_grad():
78
- # outputs = model(input_ids, attention_mask=attention_mask)
79
 
80
- # # Extract predictions for each segment
81
- # probs, preds = extract_predictions(outputs) # Define this function based on your model's output
82
 
83
- # # Append predictions for this segment
84
- # predictions.append({
85
- # "segment_text": segment,
86
- # "label": preds[0], # Assuming single label prediction
87
- # "probability": probs[preds[0]] # Access probability for the predicted label
88
- # })
89
 
90
- def classify_text(text):
91
-
92
-
93
- segments=split_text(text)
94
-
95
- predictions = []
96
- for segment in segments:
97
- inputs = tokenizer([segment], padding=True, return_tensors="pt")
98
- input_ids = inputs["input_ids"].to(device)
99
- attention_mask = inputs["attention_mask"].to(device)
100
-
101
- with torch.no_grad():
102
- outputs = model(input_ids, attention_mask=attention_mask)
103
-
104
- probs, preds = extract_predictions(outputs)
105
-
106
- predictions.append({
107
- "segment_text": segment,
108
- "label": model.config.id2label[preds[0]], # assuming single label prediction
109
- "probability": probs[preds[0]]
110
- })
111
-
112
- return predictions
113
-
114
-
115
-
116
- # def classify_text(text):
117
- # """
118
- # This function preprocesses, feeds text to the model, and outputs the predicted class.
119
- # """
120
- # inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
121
- # outputs = model(**inputs)
122
- # logits = outputs.logits # Access logits instead of pipeline output
123
- # predictions = torch.argmax(logits, dim=-1) # Apply argmax for prediction
124
- # return model.config.id2label[predictions.item()] # Map index to class label
125
 
126
  interface = gr.Interface(
127
  fn=classify_text,
 
54
 
55
  # a function that classifies text
56
 
57
+ def classify_text(text):
 
 
58
 
59
+ # Split text into segments using split_text
60
+ segments = split_text(text)
61
 
62
+ # Initialize empty list for predictions
63
+ predictions = []
64
 
65
+ # Move device to GPU if available
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ model = model.to(device)
68
 
69
+ # Loop through segments, process, and store predictions
70
+ for segment in segments:
71
+ inputs = tokenizer([segment], padding=True, return_tensors="pt")
72
+ input_ids = inputs["input_ids"].to(device)
73
+ attention_mask = inputs["attention_mask"].to(device)
74
 
75
+ with torch.no_grad():
76
+ outputs = model(input_ids, attention_mask=attention_mask)
77
 
78
+ # Extract predictions for each segment
79
+ probs, preds = extract_predictions(outputs) # Define this function based on your model's output
80
 
81
+ # Append predictions for this segment
82
+ predictions.append({
83
+ "segment_text": segment,
84
+ "label": preds[0], # Assuming single label prediction
85
+ "probability": probs[preds[0]] # Access probability for the predicted label
86
+ })
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  interface = gr.Interface(
90
  fn=classify_text,