Granther commited on
Commit
ab5911d
Β·
verified Β·
1 Parent(s): ecb00f8

Upload prompt_tune_phi3.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. prompt_tune_phi3.ipynb +338 -73
prompt_tune_phi3.ipynb CHANGED
@@ -90,7 +90,7 @@
90
  },
91
  {
92
  "cell_type": "code",
93
- "execution_count": 24,
94
  "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
95
  "metadata": {},
96
  "outputs": [],
@@ -117,23 +117,23 @@
117
  "label_col = 'text_label'\n",
118
  "max_len = 64\n",
119
  "lr = 3e-2\n",
120
- "epochs = 50\n",
121
  "batch_size = 8"
122
  ]
123
  },
124
  {
125
  "cell_type": "code",
126
- "execution_count": 6,
127
  "id": "6f677839-ef23-428a-bcfe-f596590804ca",
128
  "metadata": {},
129
  "outputs": [],
130
  "source": [
131
- "dataset = load_dataset('ought/raft', dataset_name, split='train')"
132
  ]
133
  },
134
  {
135
  "cell_type": "code",
136
- "execution_count": 7,
137
  "id": "c0c05613-7941-4959-ada9-49ed1093bec4",
138
  "metadata": {},
139
  "outputs": [
@@ -143,22 +143,36 @@
143
  "['Unlabeled', 'complaint', 'no complaint']"
144
  ]
145
  },
146
- "execution_count": 7,
147
  "metadata": {},
148
  "output_type": "execute_result"
149
  }
150
  ],
151
  "source": [
152
- "dataset.features['Label'].names\n",
153
  "#>>> ['Unlabeled', 'complaint', 'no complaint']"
154
  ]
155
  },
156
  {
157
  "cell_type": "code",
158
- "execution_count": 8,
159
  "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
160
  "metadata": {},
161
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  {
163
  "data": {
164
  "text/plain": [
@@ -168,26 +182,26 @@
168
  " 'text_label': 'no complaint'}"
169
  ]
170
  },
171
- "execution_count": 8,
172
  "metadata": {},
173
  "output_type": "execute_result"
174
  }
175
  ],
176
  "source": [
177
  "# Create lambda function\n",
178
- "classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n",
179
  "dataset = dataset.map(\n",
180
  " lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
181
  " batched=True,\n",
182
  " num_proc=10,\n",
183
  ")\n",
184
  "\n",
185
- "dataset[0]"
186
  ]
187
  },
188
  {
189
  "cell_type": "code",
190
- "execution_count": 9,
191
  "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
192
  "metadata": {},
193
  "outputs": [
@@ -204,7 +218,7 @@
204
  "[1, 853, 29880, 24025]"
205
  ]
206
  },
207
- "execution_count": 9,
208
  "metadata": {},
209
  "output_type": "execute_result"
210
  }
@@ -236,7 +250,7 @@
236
  },
237
  {
238
  "cell_type": "code",
239
- "execution_count": 31,
240
  "id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
241
  "metadata": {},
242
  "outputs": [],
@@ -261,19 +275,30 @@
261
  " #>>> -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000\n",
262
  " # Pad the beginning of the sequence with n -100s (ignore tokens)\n",
263
  " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
264
- " print(model_inputs[\"attention_mask\"][i])"
 
 
 
 
 
 
 
 
 
 
 
265
  ]
266
  },
267
  {
268
  "cell_type": "code",
269
- "execution_count": 32,
270
  "id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
271
  "metadata": {},
272
  "outputs": [
273
  {
274
  "data": {
275
  "application/vnd.jupyter.widget-view+json": {
276
- "model_id": "cb9f37c876c548fbbcd07a7b889e1764",
277
  "version_major": 2,
278
  "version_minor": 0
279
  },
@@ -285,60 +310,18 @@
285
  "output_type": "display_data"
286
  },
287
  {
288
- "name": "stdout",
289
- "output_type": "stream",
290
- "text": [
291
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
292
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
293
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
294
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
295
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
296
- "\n",
297
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
298
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
299
- "\n",
300
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
301
- "\n",
302
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
303
- "\n",
304
- "\n",
305
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
306
- "\n",
307
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
308
- "\n",
309
- "\n",
310
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
311
- "\n",
312
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
313
- "\n",
314
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
315
- "\n",
316
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
317
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
318
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
319
- "\n",
320
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
321
- "\n",
322
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
323
- "\n",
324
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
325
- "\n",
326
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
327
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}\n",
328
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
329
- "\n",
330
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
331
- "\n",
332
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
333
- "\n",
334
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}\n",
335
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
336
- "\n",
337
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
338
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
339
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
340
- "{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n"
341
- ]
342
  }
343
  ],
344
  "source": [
@@ -346,7 +329,7 @@
346
  " preproc,\n",
347
  " batched=True, # uses default batch size\n",
348
  " num_proc=10,\n",
349
- " remove_columns=dataset.column_names, # All columns from the original dataset will be removed in the new dataset\n",
350
  " load_from_cache_file=False,\n",
351
  " desc=\"Preprocessing dataset\"\n",
352
  ")"
@@ -354,10 +337,292 @@
354
  },
355
  {
356
  "cell_type": "code",
357
- "execution_count": null,
358
  "id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
359
  "metadata": {},
360
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  "source": []
362
  }
363
  ],
 
90
  },
91
  {
92
  "cell_type": "code",
93
+ "execution_count": 54,
94
  "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
95
  "metadata": {},
96
  "outputs": [],
 
117
  "label_col = 'text_label'\n",
118
  "max_len = 64\n",
119
  "lr = 3e-2\n",
120
+ "epochs = 5\n",
121
  "batch_size = 8"
122
  ]
123
  },
124
  {
125
  "cell_type": "code",
126
+ "execution_count": 28,
127
  "id": "6f677839-ef23-428a-bcfe-f596590804ca",
128
  "metadata": {},
129
  "outputs": [],
130
  "source": [
131
+ "dataset = load_dataset('ought/raft', dataset_name)"
132
  ]
133
  },
134
  {
135
  "cell_type": "code",
136
+ "execution_count": 30,
137
  "id": "c0c05613-7941-4959-ada9-49ed1093bec4",
138
  "metadata": {},
139
  "outputs": [
 
143
  "['Unlabeled', 'complaint', 'no complaint']"
144
  ]
145
  },
146
+ "execution_count": 30,
147
  "metadata": {},
148
  "output_type": "execute_result"
149
  }
150
  ],
151
  "source": [
152
+ "dataset['train'].features['Label'].names\n",
153
  "#>>> ['Unlabeled', 'complaint', 'no complaint']"
154
  ]
155
  },
156
  {
157
  "cell_type": "code",
158
+ "execution_count": 32,
159
  "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
160
  "metadata": {},
161
  "outputs": [
162
+ {
163
+ "data": {
164
+ "application/vnd.jupyter.widget-view+json": {
165
+ "model_id": "11da1eb81527428a95c41816f5bf459f",
166
+ "version_major": 2,
167
+ "version_minor": 0
168
+ },
169
+ "text/plain": [
170
+ "Map (num_proc=10): 0%| | 0/3399 [00:00<?, ? examples/s]"
171
+ ]
172
+ },
173
+ "metadata": {},
174
+ "output_type": "display_data"
175
+ },
176
  {
177
  "data": {
178
  "text/plain": [
 
182
  " 'text_label': 'no complaint'}"
183
  ]
184
  },
185
+ "execution_count": 32,
186
  "metadata": {},
187
  "output_type": "execute_result"
188
  }
189
  ],
190
  "source": [
191
  "# Create lambda function\n",
192
+ "classes = [k.replace('_', ' ') for k in dataset['train'].features['Label'].names]\n",
193
  "dataset = dataset.map(\n",
194
  " lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
195
  " batched=True,\n",
196
  " num_proc=10,\n",
197
  ")\n",
198
  "\n",
199
+ "dataset['train'][0]"
200
  ]
201
  },
202
  {
203
  "cell_type": "code",
204
+ "execution_count": 41,
205
  "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
206
  "metadata": {},
207
  "outputs": [
 
218
  "[1, 853, 29880, 24025]"
219
  ]
220
  },
221
+ "execution_count": 41,
222
  "metadata": {},
223
  "output_type": "execute_result"
224
  }
 
250
  },
251
  {
252
  "cell_type": "code",
253
+ "execution_count": 26,
254
  "id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
255
  "metadata": {},
256
  "outputs": [],
 
275
  " #>>> -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000\n",
276
  " # Pad the beginning of the sequence with n -100s (ignore tokens)\n",
277
  " model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
278
+ "\n",
279
+ " for i in range(batch_size):\n",
280
+ " sample_input_ids = model_inputs[\"input_ids\"][i]\n",
281
+ " label_input_ids = labels[\"input_ids\"][i]\n",
282
+ " model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (target_max_len - len(sample_input_ids)) + sample_input_ids\n",
283
+ " model_inputs[\"attention_mask\"][i] = [0] * (target_max_len - len(sample_input_ids)) + model_inputs[\"attention_mask\"][i]\n",
284
+ " labels[\"input_ids\"][i] = [-100] * (target_max_len - len(sample_input_ids)) + label_input_ids\n",
285
+ " model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:target_max_len])\n",
286
+ " model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:target_max_len])\n",
287
+ " labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:target_max_len])\n",
288
+ " model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
289
+ " return model_inputs"
290
  ]
291
  },
292
  {
293
  "cell_type": "code",
294
+ "execution_count": 33,
295
  "id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
296
  "metadata": {},
297
  "outputs": [
298
  {
299
  "data": {
300
  "application/vnd.jupyter.widget-view+json": {
301
+ "model_id": "05958c1cf67d413b9085622ace0cb799",
302
  "version_major": 2,
303
  "version_minor": 0
304
  },
 
310
  "output_type": "display_data"
311
  },
312
  {
313
+ "data": {
314
+ "application/vnd.jupyter.widget-view+json": {
315
+ "model_id": "05e7c3181c20464492f2ec4ced190fd4",
316
+ "version_major": 2,
317
+ "version_minor": 0
318
+ },
319
+ "text/plain": [
320
+ "Preprocessing dataset (num_proc=10): 0%| | 0/3399 [00:00<?, ? examples/s]"
321
+ ]
322
+ },
323
+ "metadata": {},
324
+ "output_type": "display_data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  }
326
  ],
327
  "source": [
 
329
  " preproc,\n",
330
  " batched=True, # uses default batch size\n",
331
  " num_proc=10,\n",
332
+ " remove_columns=dataset[\"train\"].column_names, # All columns from the original dataset will be removed in the new dataset\n",
333
  " load_from_cache_file=False,\n",
334
  " desc=\"Preprocessing dataset\"\n",
335
  ")"
 
337
  },
338
  {
339
  "cell_type": "code",
340
+ "execution_count": 43,
341
  "id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
342
  "metadata": {},
343
  "outputs": [],
344
+ "source": [
345
+ "train_dataset = processed_datasets[\"train\"]\n",
346
+ "eval_dataset = processed_datasets[\"test\"]\n",
347
+ "\n",
348
+ "train_dataloader = DataLoader(train_dataset,\n",
349
+ " shuffle=True, # shuffling is unneccasary since we are not training\n",
350
+ " collate_fn=default_data_collator,\n",
351
+ " batch_size=batch_size,\n",
352
+ " pin_memory=True # pin memory when using a GPU, makes loading data faster\n",
353
+ " )\n",
354
+ "\n",
355
+ "eval_dataloader = DataLoader(eval_dataset,\n",
356
+ " shuffle=False,\n",
357
+ " collate_fn=default_data_collator,\n",
358
+ " batch_size=batch_size,\n",
359
+ " pin_memory=True\n",
360
+ " )"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": 51,
366
+ "id": "a4c529e4-d8ae-42b2-a658-f76d183bb264",
367
+ "metadata": {},
368
+ "outputs": [
369
+ {
370
+ "data": {
371
+ "application/vnd.jupyter.widget-view+json": {
372
+ "model_id": "58f2ef57b8ea49c2a26d4361ce4a5983",
373
+ "version_major": 2,
374
+ "version_minor": 0
375
+ },
376
+ "text/plain": [
377
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
378
+ ]
379
+ },
380
+ "metadata": {},
381
+ "output_type": "display_data"
382
+ },
383
+ {
384
+ "name": "stderr",
385
+ "output_type": "stream",
386
+ "text": [
387
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
388
+ ]
389
+ },
390
+ {
391
+ "name": "stdout",
392
+ "output_type": "stream",
393
+ "text": [
394
+ "trainable params: 24,576 || all params: 3,821,104,128 || trainable%: 0.0006\n",
395
+ "None\n"
396
+ ]
397
+ }
398
+ ],
399
+ "source": [
400
+ "model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=\"flash_attention_2\", torch_dtype=torch.bfloat16)\n",
401
+ "model = get_peft_model(model, peft_conf)\n",
402
+ "\n",
403
+ "# the rest of the model is frozen\n",
404
+ "print(model.print_trainable_parameters())"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": 52,
410
+ "id": "3289e4e3-9b9a-4256-921b-5df21d18344e",
411
+ "metadata": {},
412
+ "outputs": [],
413
+ "source": [
414
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
415
+ "lr_scheduler = get_linear_schedule_with_warmup(\n",
416
+ " optimizer=optimizer,\n",
417
+ " num_warmup_steps=0,\n",
418
+ " num_training_steps=(len(train_dataloader) * epochs),\n",
419
+ ")"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": 55,
425
+ "id": "e7939d75-c6b9-47a8-b1a3-88f7c33ff121",
426
+ "metadata": {},
427
+ "outputs": [
428
+ {
429
+ "name": "stderr",
430
+ "output_type": "stream",
431
+ "text": [
432
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 10.97it/s]\n",
433
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 31.61it/s]\n"
434
+ ]
435
+ },
436
+ {
437
+ "name": "stdout",
438
+ "output_type": "stream",
439
+ "text": [
440
+ "epoch=0: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
441
+ ]
442
+ },
443
+ {
444
+ "name": "stderr",
445
+ "output_type": "stream",
446
+ "text": [
447
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 12.02it/s]\n",
448
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 31.35it/s]\n"
449
+ ]
450
+ },
451
+ {
452
+ "name": "stdout",
453
+ "output_type": "stream",
454
+ "text": [
455
+ "epoch=1: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
456
+ ]
457
+ },
458
+ {
459
+ "name": "stderr",
460
+ "output_type": "stream",
461
+ "text": [
462
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 12.70it/s]\n",
463
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 31.66it/s]\n"
464
+ ]
465
+ },
466
+ {
467
+ "name": "stdout",
468
+ "output_type": "stream",
469
+ "text": [
470
+ "epoch=2: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
471
+ ]
472
+ },
473
+ {
474
+ "name": "stderr",
475
+ "output_type": "stream",
476
+ "text": [
477
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 11.85it/s]\n",
478
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 32.45it/s]\n"
479
+ ]
480
+ },
481
+ {
482
+ "name": "stdout",
483
+ "output_type": "stream",
484
+ "text": [
485
+ "epoch=3: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
486
+ ]
487
+ },
488
+ {
489
+ "name": "stderr",
490
+ "output_type": "stream",
491
+ "text": [
492
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 12.53it/s]\n",
493
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 32.38it/s]"
494
+ ]
495
+ },
496
+ {
497
+ "name": "stdout",
498
+ "output_type": "stream",
499
+ "text": [
500
+ "epoch=4: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
501
+ ]
502
+ },
503
+ {
504
+ "name": "stderr",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "\n"
508
+ ]
509
+ }
510
+ ],
511
+ "source": [
512
+ "model = model.to(device)\n",
513
+ "\n",
514
+ "for epoch in range(epochs):\n",
515
+ " model.train()\n",
516
+ " total_loss = 0\n",
517
+ " for step, batch in enumerate(tqdm(train_dataloader)):\n",
518
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
519
+ " outputs = model(**batch)\n",
520
+ " loss = outputs.loss\n",
521
+ " total_loss += loss.detach().float()\n",
522
+ " loss.backward()\n",
523
+ " optimizer.step()\n",
524
+ " lr_scheduler.step()\n",
525
+ " optimizer.zero_grad()\n",
526
+ "\n",
527
+ " model.eval()\n",
528
+ " eval_loss = 0\n",
529
+ " eval_preds = []\n",
530
+ " for step, batch in enumerate(tqdm(eval_dataloader)):\n",
531
+ " batch = {k: v.to(device) for k, v in batch.items()}\n",
532
+ " with torch.no_grad():\n",
533
+ " outputs = model(**batch)\n",
534
+ " loss = outputs.loss\n",
535
+ " eval_loss += loss.detach().float()\n",
536
+ " eval_preds.extend(\n",
537
+ " tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
538
+ " )\n",
539
+ "\n",
540
+ " eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
541
+ " eval_ppl = torch.exp(eval_epoch_loss)\n",
542
+ " train_epoch_loss = total_loss / len(train_dataloader)\n",
543
+ " train_ppl = torch.exp(train_epoch_loss)\n",
544
+ " print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": 59,
550
+ "id": "806d36f8-499e-4af8-b717-68e5d849866d",
551
+ "metadata": {},
552
+ "outputs": [],
553
+ "source": [
554
+ "model.save_pretrained('model')"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": 1,
560
+ "id": "13db780a-fe20-4b23-b6cb-17118f7b695e",
561
+ "metadata": {},
562
+ "outputs": [
563
+ {
564
+ "data": {
565
+ "application/vnd.jupyter.widget-view+json": {
566
+ "model_id": "d8f94426025f4ad89847ac7e983cec42",
567
+ "version_major": 2,
568
+ "version_minor": 0
569
+ },
570
+ "text/plain": [
571
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
572
+ ]
573
+ },
574
+ "metadata": {},
575
+ "output_type": "display_data"
576
+ },
577
+ {
578
+ "name": "stderr",
579
+ "output_type": "stream",
580
+ "text": [
581
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
582
+ ]
583
+ }
584
+ ],
585
+ "source": [
586
+ "from transformers import pipeline\n",
587
+ "device = 'cuda'\n",
588
+ "pipe = pipeline(model='model', device=device, max_length=100)"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": 2,
594
+ "id": "26438301-3601-44f4-bbe4-3c573a1c28be",
595
+ "metadata": {},
596
+ "outputs": [
597
+ {
598
+ "name": "stderr",
599
+ "output_type": "stream",
600
+ "text": [
601
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
602
+ "You are not running the flash-attention implementation, expect numerical differences.\n"
603
+ ]
604
+ },
605
+ {
606
+ "data": {
607
+ "text/plain": [
608
+ "[{'generated_text': \"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\\n\\n### response\\nI understand your situation and I'm here to help. First, it's important to clarify that as an AI developed by Microsoft, I don't have the authority to directly intervene with your utility bills or the National Grid. However, I can guide you through the steps you should take to address this issue.\\n\\n1\"}]"
609
+ ]
610
+ },
611
+ "execution_count": 2,
612
+ "metadata": {},
613
+ "output_type": "execute_result"
614
+ }
615
+ ],
616
+ "source": [
617
+ "pipe(\"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\")"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": null,
623
+ "id": "f83e960d-ab80-406e-9ba9-e9533fe9d033",
624
+ "metadata": {},
625
+ "outputs": [],
626
  "source": []
627
  }
628
  ],