Ariel Hsieh commited on
Commit
b001f36
1 Parent(s): 7be12d3

Created using Colaboratory

Browse files
Files changed (1) hide show
  1. AI_Milestone_3.ipynb +714 -0
AI_Milestone_3.ipynb ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/arielhsieh8/cs-uy-4613-project/blob/milestone-3/AI_Milestone_3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
+ "colab": {
18
+ "base_uri": "https://localhost:8080/"
19
+ },
20
+ "id": "MCO9jo5gyX2c",
21
+ "outputId": "b3fc4262-aa28-4363-d56e-b85a8fb29d3c"
22
+ },
23
+ "outputs": [
24
+ {
25
+ "output_type": "stream",
26
+ "name": "stdout",
27
+ "text": [
28
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
29
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.9/dist-packages (4.28.1)\n",
30
+ "Requirement already satisfied: requests in /usr/local/lib/python3.9/dist-packages (from transformers) (2.27.1)\n",
31
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.3)\n",
32
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (23.1)\n",
33
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from transformers) (6.0)\n",
34
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (1.22.4)\n",
35
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.9/dist-packages (from transformers) (4.65.0)\n",
36
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers) (3.11.0)\n",
37
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.9/dist-packages (from transformers) (0.13.4)\n",
38
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers) (2022.10.31)\n",
39
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.9/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.5.0)\n",
40
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (3.4)\n",
41
+ "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2.0.12)\n",
42
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (1.26.15)\n",
43
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/dist-packages (from requests->transformers) (2022.12.7)\n",
44
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
45
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (1.5.3)\n",
46
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas) (2.8.2)\n",
47
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas) (2022.7.1)\n",
48
+ "Requirement already satisfied: numpy>=1.20.3 in /usr/local/lib/python3.9/dist-packages (from pandas) (1.22.4)\n",
49
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/dist-packages (from python-dateutil>=2.8.1->pandas) (1.16.0)\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "!pip install transformers\n",
55
+ "!pip install pandas\n"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {
62
+ "id": "GHSa0Qb1xTvJ"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "from sklearn.model_selection import train_test_split\n",
67
+ "import torch \n",
68
+ "from torch.utils.data import Dataset \n",
69
+ "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification\n",
70
+ "from transformers import Trainer, TrainingArguments \n",
71
+ "import pandas as pd\n",
72
+ "import numpy as np\n",
73
+ "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {
80
+ "id": "Sl36PcY2rxGX"
81
+ },
82
+ "outputs": [],
83
+ "source": [
84
+ "model_name = \"distilbert-base-uncased\"\n",
85
+ "\n",
86
+ "train_data = pd.read_csv('train.csv')\n",
87
+ "\n",
88
+ "train_data.drop([\"id\"], inplace=True, axis=1)\n",
89
+ "train_data.dropna()\n",
90
+ "\n",
91
+ "train_texts = train_data['comment_text'].tolist()\n",
92
+ "train_labels = train_data[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].values.tolist()\n",
93
+ "\n",
94
+ "train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts[:100000],train_labels[:100000],test_size=0.2,random_state=42)\n",
95
+ "\n",
96
+ "class textDataset(Dataset):\n",
97
+ "\n",
98
+ " def __init__(self, encodings, labels):\n",
99
+ " self.encodings = encodings\n",
100
+ " self.labels = torch.tensor(labels).float()\n",
101
+ "\n",
102
+ " def __getitem__(self,index):\n",
103
+ " item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}\n",
104
+ " item['labels'] = torch.tensor(self.labels[index])\n",
105
+ " return item\n",
106
+ "\n",
107
+ " def __len__(self): \n",
108
+ " return len(self.labels)\n",
109
+ "\n",
110
+ "\n",
111
+ "tokenizer = DistilBertTokenizerFast.from_pretrained(model_name,num_labels=6,problem_type=\"multi_label_classification\")\n",
112
+ "\n",
113
+ "train_encodings = tokenizer(train_texts,truncation=True,padding=True)\n",
114
+ "val_encodings = tokenizer(val_texts,truncation=True,padding=True)\n",
115
+ "\n",
116
+ "train_dataset = textDataset(train_encodings,train_labels)\n",
117
+ "val_dataset = textDataset(val_encodings,val_labels)\n",
118
+ "\n"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {
125
+ "colab": {
126
+ "base_uri": "https://localhost:8080/"
127
+ },
128
+ "id": "8uyVppYpxJ7r",
129
+ "outputId": "6c7feff9-2b63-47fc-8fc8-78999b8a2d74"
130
+ },
131
+ "outputs": [
132
+ {
133
+ "output_type": "stream",
134
+ "name": "stderr",
135
+ "text": [
136
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n",
137
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
138
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
139
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
140
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
141
+ ]
142
+ }
143
+ ],
144
+ "source": [
145
+ "training_args = TrainingArguments(\n",
146
+ " output_dir='./results',\n",
147
+ " num_train_epochs=2,\n",
148
+ " per_device_train_batch_size=16,\n",
149
+ " per_device_eval_batch_size=16,\n",
150
+ " warmup_steps=500,\n",
151
+ " learning_rate=5e-5,\n",
152
+ " weight_decay=0.01,\n",
153
+ " logging_dir='./logs',\n",
154
+ " logging_steps=100,\n",
155
+ ")\n",
156
+ "\n",
157
+ "model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=6,problem_type=\"multi_label_classification\")\n",
158
+ "\n",
159
+ "trainer = Trainer(\n",
160
+ " model=model,\n",
161
+ " args=training_args,\n",
162
+ " train_dataset=train_dataset,\n",
163
+ " eval_dataset=val_dataset,\n",
164
+ ")\n",
165
+ "\n"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "metadata": {
172
+ "colab": {
173
+ "base_uri": "https://localhost:8080/",
174
+ "height": 1000
175
+ },
176
+ "id": "lGigZhWtV0ld",
177
+ "outputId": "b2081e70-ed7c-4007-e231-3c9d269f398b"
178
+ },
179
+ "outputs": [
180
+ {
181
+ "output_type": "stream",
182
+ "name": "stderr",
183
+ "text": [
184
+ "/usr/local/lib/python3.9/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
185
+ " warnings.warn(\n",
186
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
187
+ " item['labels'] = torch.tensor(self.labels[index])\n"
188
+ ]
189
+ },
190
+ {
191
+ "output_type": "display_data",
192
+ "data": {
193
+ "text/plain": [
194
+ "<IPython.core.display.HTML object>"
195
+ ],
196
+ "text/html": [
197
+ "\n",
198
+ " <div>\n",
199
+ " \n",
200
+ " <progress value='10000' max='10000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
201
+ " [10000/10000 2:00:23, Epoch 2/2]\n",
202
+ " </div>\n",
203
+ " <table border=\"1\" class=\"dataframe\">\n",
204
+ " <thead>\n",
205
+ " <tr style=\"text-align: left;\">\n",
206
+ " <th>Step</th>\n",
207
+ " <th>Training Loss</th>\n",
208
+ " </tr>\n",
209
+ " </thead>\n",
210
+ " <tbody>\n",
211
+ " <tr>\n",
212
+ " <td>100</td>\n",
213
+ " <td>0.522000</td>\n",
214
+ " </tr>\n",
215
+ " <tr>\n",
216
+ " <td>200</td>\n",
217
+ " <td>0.169400</td>\n",
218
+ " </tr>\n",
219
+ " <tr>\n",
220
+ " <td>300</td>\n",
221
+ " <td>0.088900</td>\n",
222
+ " </tr>\n",
223
+ " <tr>\n",
224
+ " <td>400</td>\n",
225
+ " <td>0.058000</td>\n",
226
+ " </tr>\n",
227
+ " <tr>\n",
228
+ " <td>500</td>\n",
229
+ " <td>0.068900</td>\n",
230
+ " </tr>\n",
231
+ " <tr>\n",
232
+ " <td>600</td>\n",
233
+ " <td>0.051600</td>\n",
234
+ " </tr>\n",
235
+ " <tr>\n",
236
+ " <td>700</td>\n",
237
+ " <td>0.057400</td>\n",
238
+ " </tr>\n",
239
+ " <tr>\n",
240
+ " <td>800</td>\n",
241
+ " <td>0.049300</td>\n",
242
+ " </tr>\n",
243
+ " <tr>\n",
244
+ " <td>900</td>\n",
245
+ " <td>0.048100</td>\n",
246
+ " </tr>\n",
247
+ " <tr>\n",
248
+ " <td>1000</td>\n",
249
+ " <td>0.062500</td>\n",
250
+ " </tr>\n",
251
+ " <tr>\n",
252
+ " <td>1100</td>\n",
253
+ " <td>0.051300</td>\n",
254
+ " </tr>\n",
255
+ " <tr>\n",
256
+ " <td>1200</td>\n",
257
+ " <td>0.050700</td>\n",
258
+ " </tr>\n",
259
+ " <tr>\n",
260
+ " <td>1300</td>\n",
261
+ " <td>0.049000</td>\n",
262
+ " </tr>\n",
263
+ " <tr>\n",
264
+ " <td>1400</td>\n",
265
+ " <td>0.047100</td>\n",
266
+ " </tr>\n",
267
+ " <tr>\n",
268
+ " <td>1500</td>\n",
269
+ " <td>0.041500</td>\n",
270
+ " </tr>\n",
271
+ " <tr>\n",
272
+ " <td>1600</td>\n",
273
+ " <td>0.049000</td>\n",
274
+ " </tr>\n",
275
+ " <tr>\n",
276
+ " <td>1700</td>\n",
277
+ " <td>0.052800</td>\n",
278
+ " </tr>\n",
279
+ " <tr>\n",
280
+ " <td>1800</td>\n",
281
+ " <td>0.049300</td>\n",
282
+ " </tr>\n",
283
+ " <tr>\n",
284
+ " <td>1900</td>\n",
285
+ " <td>0.043500</td>\n",
286
+ " </tr>\n",
287
+ " <tr>\n",
288
+ " <td>2000</td>\n",
289
+ " <td>0.047700</td>\n",
290
+ " </tr>\n",
291
+ " <tr>\n",
292
+ " <td>2100</td>\n",
293
+ " <td>0.046600</td>\n",
294
+ " </tr>\n",
295
+ " <tr>\n",
296
+ " <td>2200</td>\n",
297
+ " <td>0.045900</td>\n",
298
+ " </tr>\n",
299
+ " <tr>\n",
300
+ " <td>2300</td>\n",
301
+ " <td>0.045900</td>\n",
302
+ " </tr>\n",
303
+ " <tr>\n",
304
+ " <td>2400</td>\n",
305
+ " <td>0.042200</td>\n",
306
+ " </tr>\n",
307
+ " <tr>\n",
308
+ " <td>2500</td>\n",
309
+ " <td>0.043100</td>\n",
310
+ " </tr>\n",
311
+ " <tr>\n",
312
+ " <td>2600</td>\n",
313
+ " <td>0.044200</td>\n",
314
+ " </tr>\n",
315
+ " <tr>\n",
316
+ " <td>2700</td>\n",
317
+ " <td>0.043900</td>\n",
318
+ " </tr>\n",
319
+ " <tr>\n",
320
+ " <td>2800</td>\n",
321
+ " <td>0.042400</td>\n",
322
+ " </tr>\n",
323
+ " <tr>\n",
324
+ " <td>2900</td>\n",
325
+ " <td>0.051700</td>\n",
326
+ " </tr>\n",
327
+ " <tr>\n",
328
+ " <td>3000</td>\n",
329
+ " <td>0.049700</td>\n",
330
+ " </tr>\n",
331
+ " <tr>\n",
332
+ " <td>3100</td>\n",
333
+ " <td>0.045700</td>\n",
334
+ " </tr>\n",
335
+ " <tr>\n",
336
+ " <td>3200</td>\n",
337
+ " <td>0.047400</td>\n",
338
+ " </tr>\n",
339
+ " <tr>\n",
340
+ " <td>3300</td>\n",
341
+ " <td>0.042800</td>\n",
342
+ " </tr>\n",
343
+ " <tr>\n",
344
+ " <td>3400</td>\n",
345
+ " <td>0.042400</td>\n",
346
+ " </tr>\n",
347
+ " <tr>\n",
348
+ " <td>3500</td>\n",
349
+ " <td>0.045200</td>\n",
350
+ " </tr>\n",
351
+ " <tr>\n",
352
+ " <td>3600</td>\n",
353
+ " <td>0.047600</td>\n",
354
+ " </tr>\n",
355
+ " <tr>\n",
356
+ " <td>3700</td>\n",
357
+ " <td>0.044800</td>\n",
358
+ " </tr>\n",
359
+ " <tr>\n",
360
+ " <td>3800</td>\n",
361
+ " <td>0.045100</td>\n",
362
+ " </tr>\n",
363
+ " <tr>\n",
364
+ " <td>3900</td>\n",
365
+ " <td>0.041900</td>\n",
366
+ " </tr>\n",
367
+ " <tr>\n",
368
+ " <td>4000</td>\n",
369
+ " <td>0.039300</td>\n",
370
+ " </tr>\n",
371
+ " <tr>\n",
372
+ " <td>4100</td>\n",
373
+ " <td>0.039500</td>\n",
374
+ " </tr>\n",
375
+ " <tr>\n",
376
+ " <td>4200</td>\n",
377
+ " <td>0.044500</td>\n",
378
+ " </tr>\n",
379
+ " <tr>\n",
380
+ " <td>4300</td>\n",
381
+ " <td>0.042700</td>\n",
382
+ " </tr>\n",
383
+ " <tr>\n",
384
+ " <td>4400</td>\n",
385
+ " <td>0.039600</td>\n",
386
+ " </tr>\n",
387
+ " <tr>\n",
388
+ " <td>4500</td>\n",
389
+ " <td>0.040300</td>\n",
390
+ " </tr>\n",
391
+ " <tr>\n",
392
+ " <td>4600</td>\n",
393
+ " <td>0.044700</td>\n",
394
+ " </tr>\n",
395
+ " <tr>\n",
396
+ " <td>4700</td>\n",
397
+ " <td>0.040700</td>\n",
398
+ " </tr>\n",
399
+ " <tr>\n",
400
+ " <td>4800</td>\n",
401
+ " <td>0.036900</td>\n",
402
+ " </tr>\n",
403
+ " <tr>\n",
404
+ " <td>4900</td>\n",
405
+ " <td>0.046200</td>\n",
406
+ " </tr>\n",
407
+ " <tr>\n",
408
+ " <td>5000</td>\n",
409
+ " <td>0.040300</td>\n",
410
+ " </tr>\n",
411
+ " <tr>\n",
412
+ " <td>5100</td>\n",
413
+ " <td>0.031600</td>\n",
414
+ " </tr>\n",
415
+ " <tr>\n",
416
+ " <td>5200</td>\n",
417
+ " <td>0.029200</td>\n",
418
+ " </tr>\n",
419
+ " <tr>\n",
420
+ " <td>5300</td>\n",
421
+ " <td>0.031900</td>\n",
422
+ " </tr>\n",
423
+ " <tr>\n",
424
+ " <td>5400</td>\n",
425
+ " <td>0.030200</td>\n",
426
+ " </tr>\n",
427
+ " <tr>\n",
428
+ " <td>5500</td>\n",
429
+ " <td>0.035700</td>\n",
430
+ " </tr>\n",
431
+ " <tr>\n",
432
+ " <td>5600</td>\n",
433
+ " <td>0.028500</td>\n",
434
+ " </tr>\n",
435
+ " <tr>\n",
436
+ " <td>5700</td>\n",
437
+ " <td>0.034600</td>\n",
438
+ " </tr>\n",
439
+ " <tr>\n",
440
+ " <td>5800</td>\n",
441
+ " <td>0.027400</td>\n",
442
+ " </tr>\n",
443
+ " <tr>\n",
444
+ " <td>5900</td>\n",
445
+ " <td>0.034700</td>\n",
446
+ " </tr>\n",
447
+ " <tr>\n",
448
+ " <td>6000</td>\n",
449
+ " <td>0.038600</td>\n",
450
+ " </tr>\n",
451
+ " <tr>\n",
452
+ " <td>6100</td>\n",
453
+ " <td>0.028500</td>\n",
454
+ " </tr>\n",
455
+ " <tr>\n",
456
+ " <td>6200</td>\n",
457
+ " <td>0.030100</td>\n",
458
+ " </tr>\n",
459
+ " <tr>\n",
460
+ " <td>6300</td>\n",
461
+ " <td>0.028300</td>\n",
462
+ " </tr>\n",
463
+ " <tr>\n",
464
+ " <td>6400</td>\n",
465
+ " <td>0.029900</td>\n",
466
+ " </tr>\n",
467
+ " <tr>\n",
468
+ " <td>6500</td>\n",
469
+ " <td>0.035500</td>\n",
470
+ " </tr>\n",
471
+ " <tr>\n",
472
+ " <td>6600</td>\n",
473
+ " <td>0.031800</td>\n",
474
+ " </tr>\n",
475
+ " <tr>\n",
476
+ " <td>6700</td>\n",
477
+ " <td>0.029200</td>\n",
478
+ " </tr>\n",
479
+ " <tr>\n",
480
+ " <td>6800</td>\n",
481
+ " <td>0.031500</td>\n",
482
+ " </tr>\n",
483
+ " <tr>\n",
484
+ " <td>6900</td>\n",
485
+ " <td>0.029700</td>\n",
486
+ " </tr>\n",
487
+ " <tr>\n",
488
+ " <td>7000</td>\n",
489
+ " <td>0.030000</td>\n",
490
+ " </tr>\n",
491
+ " <tr>\n",
492
+ " <td>7100</td>\n",
493
+ " <td>0.038800</td>\n",
494
+ " </tr>\n",
495
+ " <tr>\n",
496
+ " <td>7200</td>\n",
497
+ " <td>0.030200</td>\n",
498
+ " </tr>\n",
499
+ " <tr>\n",
500
+ " <td>7300</td>\n",
501
+ " <td>0.024700</td>\n",
502
+ " </tr>\n",
503
+ " <tr>\n",
504
+ " <td>7400</td>\n",
505
+ " <td>0.034300</td>\n",
506
+ " </tr>\n",
507
+ " <tr>\n",
508
+ " <td>7500</td>\n",
509
+ " <td>0.030400</td>\n",
510
+ " </tr>\n",
511
+ " <tr>\n",
512
+ " <td>7600</td>\n",
513
+ " <td>0.029200</td>\n",
514
+ " </tr>\n",
515
+ " <tr>\n",
516
+ " <td>7700</td>\n",
517
+ " <td>0.035600</td>\n",
518
+ " </tr>\n",
519
+ " <tr>\n",
520
+ " <td>7800</td>\n",
521
+ " <td>0.033100</td>\n",
522
+ " </tr>\n",
523
+ " <tr>\n",
524
+ " <td>7900</td>\n",
525
+ " <td>0.028300</td>\n",
526
+ " </tr>\n",
527
+ " <tr>\n",
528
+ " <td>8000</td>\n",
529
+ " <td>0.027900</td>\n",
530
+ " </tr>\n",
531
+ " <tr>\n",
532
+ " <td>8100</td>\n",
533
+ " <td>0.031400</td>\n",
534
+ " </tr>\n",
535
+ " <tr>\n",
536
+ " <td>8200</td>\n",
537
+ " <td>0.038500</td>\n",
538
+ " </tr>\n",
539
+ " <tr>\n",
540
+ " <td>8300</td>\n",
541
+ " <td>0.034400</td>\n",
542
+ " </tr>\n",
543
+ " <tr>\n",
544
+ " <td>8400</td>\n",
545
+ " <td>0.030400</td>\n",
546
+ " </tr>\n",
547
+ " <tr>\n",
548
+ " <td>8500</td>\n",
549
+ " <td>0.033000</td>\n",
550
+ " </tr>\n",
551
+ " <tr>\n",
552
+ " <td>8600</td>\n",
553
+ " <td>0.034100</td>\n",
554
+ " </tr>\n",
555
+ " <tr>\n",
556
+ " <td>8700</td>\n",
557
+ " <td>0.027100</td>\n",
558
+ " </tr>\n",
559
+ " <tr>\n",
560
+ " <td>8800</td>\n",
561
+ " <td>0.029500</td>\n",
562
+ " </tr>\n",
563
+ " <tr>\n",
564
+ " <td>8900</td>\n",
565
+ " <td>0.025700</td>\n",
566
+ " </tr>\n",
567
+ " <tr>\n",
568
+ " <td>9000</td>\n",
569
+ " <td>0.029900</td>\n",
570
+ " </tr>\n",
571
+ " <tr>\n",
572
+ " <td>9100</td>\n",
573
+ " <td>0.024000</td>\n",
574
+ " </tr>\n",
575
+ " <tr>\n",
576
+ " <td>9200</td>\n",
577
+ " <td>0.028500</td>\n",
578
+ " </tr>\n",
579
+ " <tr>\n",
580
+ " <td>9300</td>\n",
581
+ " <td>0.031400</td>\n",
582
+ " </tr>\n",
583
+ " <tr>\n",
584
+ " <td>9400</td>\n",
585
+ " <td>0.028300</td>\n",
586
+ " </tr>\n",
587
+ " <tr>\n",
588
+ " <td>9500</td>\n",
589
+ " <td>0.030500</td>\n",
590
+ " </tr>\n",
591
+ " <tr>\n",
592
+ " <td>9600</td>\n",
593
+ " <td>0.025900</td>\n",
594
+ " </tr>\n",
595
+ " <tr>\n",
596
+ " <td>9700</td>\n",
597
+ " <td>0.033600</td>\n",
598
+ " </tr>\n",
599
+ " <tr>\n",
600
+ " <td>9800</td>\n",
601
+ " <td>0.030300</td>\n",
602
+ " </tr>\n",
603
+ " <tr>\n",
604
+ " <td>9900</td>\n",
605
+ " <td>0.028700</td>\n",
606
+ " </tr>\n",
607
+ " <tr>\n",
608
+ " <td>10000</td>\n",
609
+ " <td>0.022900</td>\n",
610
+ " </tr>\n",
611
+ " </tbody>\n",
612
+ "</table><p>"
613
+ ]
614
+ },
615
+ "metadata": {}
616
+ },
617
+ {
618
+ "output_type": "stream",
619
+ "name": "stderr",
620
+ "text": [
621
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
622
+ " item['labels'] = torch.tensor(self.labels[index])\n",
623
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
624
+ " item['labels'] = torch.tensor(self.labels[index])\n",
625
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
626
+ " item['labels'] = torch.tensor(self.labels[index])\n",
627
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
628
+ " item['labels'] = torch.tensor(self.labels[index])\n",
629
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
630
+ " item['labels'] = torch.tensor(self.labels[index])\n",
631
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
632
+ " item['labels'] = torch.tensor(self.labels[index])\n",
633
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
634
+ " item['labels'] = torch.tensor(self.labels[index])\n",
635
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
636
+ " item['labels'] = torch.tensor(self.labels[index])\n",
637
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
638
+ " item['labels'] = torch.tensor(self.labels[index])\n",
639
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
640
+ " item['labels'] = torch.tensor(self.labels[index])\n",
641
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
642
+ " item['labels'] = torch.tensor(self.labels[index])\n",
643
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
644
+ " item['labels'] = torch.tensor(self.labels[index])\n",
645
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
646
+ " item['labels'] = torch.tensor(self.labels[index])\n",
647
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
648
+ " item['labels'] = torch.tensor(self.labels[index])\n",
649
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
650
+ " item['labels'] = torch.tensor(self.labels[index])\n",
651
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
652
+ " item['labels'] = torch.tensor(self.labels[index])\n",
653
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
654
+ " item['labels'] = torch.tensor(self.labels[index])\n",
655
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
656
+ " item['labels'] = torch.tensor(self.labels[index])\n",
657
+ "<ipython-input-3-a55db56b85e8>:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
658
+ " item['labels'] = torch.tensor(self.labels[index])\n"
659
+ ]
660
+ },
661
+ {
662
+ "output_type": "execute_result",
663
+ "data": {
664
+ "text/plain": [
665
+ "TrainOutput(global_step=10000, training_loss=0.045082428359985355, metrics={'train_runtime': 7226.7408, 'train_samples_per_second': 22.14, 'train_steps_per_second': 1.384, 'total_flos': 2.119629570048e+16, 'train_loss': 0.045082428359985355, 'epoch': 2.0})"
666
+ ]
667
+ },
668
+ "metadata": {},
669
+ "execution_count": 5
670
+ }
671
+ ],
672
+ "source": [
673
+ "trainer.train()"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": null,
679
+ "metadata": {
680
+ "id": "lowGDIRRV2Kk"
681
+ },
682
+ "outputs": [],
683
+ "source": [
684
+ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
685
+ "\n",
686
+ "save_directory = \"saved\"\n",
687
+ "tokenizer.save_pretrained(save_directory)\n",
688
+ "model.save_pretrained(save_directory)\n",
689
+ "\n",
690
+ "tokenizer = AutoTokenizer.from_pretrained(save_directory)\n",
691
+ "model = AutoModelForSequenceClassification.from_pretrained(save_directory)"
692
+ ]
693
+ }
694
+ ],
695
+ "metadata": {
696
+ "colab": {
697
+ "provenance": [],
698
+ "mount_file_id": "1SI5wXUWiK-4VnrwWn6Pq2r2e3pzK15mn",
699
+ "authorship_tag": "ABX9TyOWwkZmPEdojeBmja70X/+z",
700
+ "include_colab_link": true
701
+ },
702
+ "kernelspec": {
703
+ "display_name": "Python 3",
704
+ "name": "python3"
705
+ },
706
+ "language_info": {
707
+ "name": "python"
708
+ },
709
+ "accelerator": "GPU",
710
+ "gpuClass": "standard"
711
+ },
712
+ "nbformat": 4,
713
+ "nbformat_minor": 0
714
+ }