KameronB commited on
Commit
b37695a
1 Parent(s): 43648e3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -1
README.md CHANGED
@@ -2,4 +2,62 @@
2
  license: mit
3
  language:
4
  - en
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: mit
3
  language:
4
  - en
5
+ ---
6
+
7
+
8
+ ```python
9
+ import torch
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
12
+ from sklearn.model_selection import train_test_split
13
+ import pandas as pd
14
+
15
+ # Load the tokenizer
16
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
17
+
18
+ # Load RoBERTa pre-trained model
19
+ model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
20
+ model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
21
+
22
+
23
+
24
+
25
+ ```
26
+
27
+ ```python
28
+
29
+ def predict_description(model, tokenizer, text, max_length=512):
30
+ model.eval() # Set the model to evaluation mode
31
+
32
+ # Ensure model is on the correct device
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = model.to(device)
35
+
36
+ # Encode the input text
37
+ inputs = tokenizer.encode_plus(
38
+ text,
39
+ None,
40
+ add_special_tokens=True,
41
+ max_length=max_length,
42
+ padding='max_length',
43
+ return_token_type_ids=False,
44
+ return_tensors='pt',
45
+ truncation=True
46
+ )
47
+
48
+ # Move tensors to the correct device
49
+ inputs = {key: value.to(device) for key, value in inputs.items()}
50
+
51
+ # Make prediction
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+ logits = outputs.logits
55
+ probabilities = torch.softmax(logits, dim=-1)
56
+ predicted_class_id = torch.argmax(probabilities, dim=-1).item()
57
+
58
+ return predicted_class_id
59
+
60
+
61
+ (['INCIDENT', 'REQUEST'])[predict_description(model, tokenizer, """My ID card is not being detected.""")]
62
+
63
+ ```