nebiyu29 commited on
Commit
2f6ade8
1 Parent(s): 1f8178b

added more text capability

Browse files
Files changed (1) hide show
  1. app.py +107 -10
app.py CHANGED
@@ -4,19 +4,116 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  # Load model directly
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
8
  model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
9
 
10
 
11
- def classify_text(text):
12
- """
13
- This function preprocesses, feeds text to the model, and outputs the predicted class.
14
- """
15
- inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
16
- outputs = model(**inputs)
17
- logits = outputs.logits # Access logits instead of pipeline output
18
- predictions = torch.argmax(logits, dim=-1) # Apply argmax for prediction
19
- return model.config.id2label[predictions.item()] # Map index to class label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  interface = gr.Interface(
22
  fn=classify_text,
@@ -24,7 +121,7 @@ interface = gr.Interface(
24
  outputs="text",
25
  title="Text Classification Demo",
26
  description="Enter some text, and the model will classify it.",
27
- choices=["positive", "negative", "neutral"] # Adjust class names
28
  )
29
 
30
  interface.launch()
 
4
  # Load model directly
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
+ import torch
8
+ import transformers
9
+
10
  tokenizer = AutoTokenizer.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
11
  model = AutoModelForSequenceClassification.from_pretrained("nebiyu29/fintunned-v2-roberta_GA")
12
 
13
 
14
+ # Load the model and tokenizer
15
+ # model = transformers.AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
16
+
17
+ # tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
18
+
19
+ # Define a function to split a text into segments of 512 tokens
20
+ def split_text(text):
21
+ # Tokenize the text
22
+ tokens = tokenizer.tokenize(text)
23
+ # Initialize an empty list for segments
24
+ segments = []
25
+ # Initialize an empty list for current segment
26
+ current_segment = []
27
+ # Initialize a counter for tokens
28
+ token_count = 0
29
+ # Loop through the tokens
30
+ for token in tokens:
31
+ # Add the token to the current segment
32
+ current_segment.append(token)
33
+ # Increment the token count
34
+ token_count += 1
35
+ # If the token count reaches 512 or the end of the text, add the current segment to the segments list
36
+ if token_count == 512 or token == tokens[-1]:
37
+ # Convert the current segment to a string and add it to the segments list
38
+ segments.append(tokenizer.convert_tokens_to_string(current_segment))
39
+ # Reset the current segment and the token count
40
+ current_segment = []
41
+ token_count = 0
42
+ # Return the segments list
43
+ return segments
44
+
45
+ def classify(text, model):
46
+ # Define the labels
47
+ labels = ["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"]
48
+ # Encode the labels
49
+ label_encodings = tokenizer(labels, padding=True, return_tensors="pt")
50
+ # Split the text into segments
51
+ segments = split_text(text)
52
+ # Initialize an empty list for logits
53
+ logits_list = []
54
+ # Move device to GPU
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ model = model.to(device) # Move the model to the device
57
+ # Loop through the segments
58
+ for segment in segments:
59
+ # Encode the segment and the labels
60
+ inputs = tokenizer([segment] + labels, padding=True, return_tensors="pt")
61
+ # Get the input ids and attention mask
62
+ input_ids = inputs["input_ids"]
63
+ attention_mask = inputs["attention_mask"]
64
+ # Move the input ids and attention mask to the device
65
+ input_ids = input_ids.to(device)
66
+ attention_mask = attention_mask.to(device)
67
+ # Get the model outputs for each segment
68
+
69
+ with torch.no_grad():
70
+ outputs = model(
71
+ input_ids,
72
+ attention_mask=attention_mask,
73
+ )
74
+ # Get the logits for each segment and append them to the logits list
75
+ logits = outputs.logits
76
+ logits_list.append(logits)
77
+ # Average the logits across the segments
78
+ avg_logits = torch.mean(torch.stack(logits_list), dim=0)
79
+ # Apply softmax to convert logits to probabilities
80
+ probabilities = torch.softmax(avg_logits, dim=1)
81
+ # Get the probabilities for each label
82
+ label_probabilities = probabilities[:, :len(labels)].tolist()
83
+
84
+ # Get the top 3 most likely labels and their probabilities
85
+ # Get the top 3 most likely labels and their probabilities
86
+ top_labels = []
87
+ top_probabilities = []
88
+ label_probabilities = label_probabilities[0] # Extract the list of probabilities for the first (and only) example
89
+ for _ in range(3):
90
+ max_prob_index = label_probabilities.index(max(label_probabilities))
91
+ top_labels.append(labels[max_prob_index])
92
+ top_probabilities.append(max(label_probabilities))
93
+ label_probabilities[max_prob_index] = 0 # Set the max probability to 0 to get the next highest probability
94
+
95
+ # Create a dictionary to store the results
96
+ results = {
97
+ "sequence": text,
98
+ "top_labels": top_labels,
99
+ "top_probabilities": top_probabilities
100
+ }
101
+
102
+ return results
103
+
104
+
105
+
106
+
107
+
108
+ # def classify_text(text):
109
+ # """
110
+ # This function preprocesses, feeds text to the model, and outputs the predicted class.
111
+ # """
112
+ # inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
113
+ # outputs = model(**inputs)
114
+ # logits = outputs.logits # Access logits instead of pipeline output
115
+ # predictions = torch.argmax(logits, dim=-1) # Apply argmax for prediction
116
+ # return model.config.id2label[predictions.item()] # Map index to class label
117
 
118
  interface = gr.Interface(
119
  fn=classify_text,
 
121
  outputs="text",
122
  title="Text Classification Demo",
123
  description="Enter some text, and the model will classify it.",
124
+ choices=["depression", "anxiety", "bipolar disorder", "schizophrenia", "PTSD", "OCD", "ADHD", "autism", "eating disorder", "personality disorder", "phobia"] # Adjust class names
125
  )
126
 
127
  interface.launch()