Granther commited on
Commit
fd1f3e8
Β·
verified Β·
1 Parent(s): ab5911d

Upload prompt_tune_phi3.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. prompt_tune_phi3.ipynb +122 -71
prompt_tune_phi3.ipynb CHANGED
@@ -36,13 +36,13 @@
36
  },
37
  {
38
  "cell_type": "code",
39
- "execution_count": 2,
40
  "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
44
  "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
45
- "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
46
  "import torch\n",
47
  "from datasets import load_dataset\n",
48
  "import os\n",
@@ -54,27 +54,42 @@
54
  },
55
  {
56
  "cell_type": "code",
57
- "execution_count": null,
58
  "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
59
  "metadata": {},
60
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  "source": [
62
  "notebook_login()"
63
  ]
64
  },
65
  {
66
  "cell_type": "code",
67
- "execution_count": 25,
68
  "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
69
  "metadata": {},
70
  "outputs": [
71
  {
72
  "data": {
73
  "text/plain": [
74
- "CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/7ea57da9a4eccf3794c58bb4317df1c97a0fe2c8', commit_message='Upload prompt_tune_phi3.ipynb with huggingface_hub', commit_description='', oid='7ea57da9a4eccf3794c58bb4317df1c97a0fe2c8', pr_url=None, pr_revision=None, pr_num=None)"
75
  ]
76
  },
77
- "execution_count": 25,
78
  "metadata": {},
79
  "output_type": "execute_result"
80
  }
@@ -90,7 +105,7 @@
90
  },
91
  {
92
  "cell_type": "code",
93
- "execution_count": 54,
94
  "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
95
  "metadata": {},
96
  "outputs": [],
@@ -103,7 +118,7 @@
103
  " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
104
  " task_type=TaskType.CAUSAL_LM, # config task\n",
105
  " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
106
- " num_virtual_tokens=8, # x times the number of hidden transformer layers\n",
107
  " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
108
  " tokenizer_name_or_path=model_id\n",
109
  ")\n",
@@ -123,7 +138,7 @@
123
  },
124
  {
125
  "cell_type": "code",
126
- "execution_count": 28,
127
  "id": "6f677839-ef23-428a-bcfe-f596590804ca",
128
  "metadata": {},
129
  "outputs": [],
@@ -133,7 +148,7 @@
133
  },
134
  {
135
  "cell_type": "code",
136
- "execution_count": 30,
137
  "id": "c0c05613-7941-4959-ada9-49ed1093bec4",
138
  "metadata": {},
139
  "outputs": [
@@ -143,7 +158,7 @@
143
  "['Unlabeled', 'complaint', 'no complaint']"
144
  ]
145
  },
146
- "execution_count": 30,
147
  "metadata": {},
148
  "output_type": "execute_result"
149
  }
@@ -155,24 +170,10 @@
155
  },
156
  {
157
  "cell_type": "code",
158
- "execution_count": 32,
159
  "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
160
  "metadata": {},
161
  "outputs": [
162
- {
163
- "data": {
164
- "application/vnd.jupyter.widget-view+json": {
165
- "model_id": "11da1eb81527428a95c41816f5bf459f",
166
- "version_major": 2,
167
- "version_minor": 0
168
- },
169
- "text/plain": [
170
- "Map (num_proc=10): 0%| | 0/3399 [00:00<?, ? examples/s]"
171
- ]
172
- },
173
- "metadata": {},
174
- "output_type": "display_data"
175
- },
176
  {
177
  "data": {
178
  "text/plain": [
@@ -182,7 +183,7 @@
182
  " 'text_label': 'no complaint'}"
183
  ]
184
  },
185
- "execution_count": 32,
186
  "metadata": {},
187
  "output_type": "execute_result"
188
  }
@@ -201,7 +202,7 @@
201
  },
202
  {
203
  "cell_type": "code",
204
- "execution_count": 41,
205
  "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
206
  "metadata": {},
207
  "outputs": [
@@ -218,7 +219,7 @@
218
  "[1, 853, 29880, 24025]"
219
  ]
220
  },
221
- "execution_count": 41,
222
  "metadata": {},
223
  "output_type": "execute_result"
224
  }
@@ -250,7 +251,7 @@
250
  },
251
  {
252
  "cell_type": "code",
253
- "execution_count": 26,
254
  "id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
255
  "metadata": {},
256
  "outputs": [],
@@ -291,14 +292,14 @@
291
  },
292
  {
293
  "cell_type": "code",
294
- "execution_count": 33,
295
  "id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
296
  "metadata": {},
297
  "outputs": [
298
  {
299
  "data": {
300
  "application/vnd.jupyter.widget-view+json": {
301
- "model_id": "05958c1cf67d413b9085622ace0cb799",
302
  "version_major": 2,
303
  "version_minor": 0
304
  },
@@ -312,7 +313,7 @@
312
  {
313
  "data": {
314
  "application/vnd.jupyter.widget-view+json": {
315
- "model_id": "05e7c3181c20464492f2ec4ced190fd4",
316
  "version_major": 2,
317
  "version_minor": 0
318
  },
@@ -337,7 +338,7 @@
337
  },
338
  {
339
  "cell_type": "code",
340
- "execution_count": 43,
341
  "id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
342
  "metadata": {},
343
  "outputs": [],
@@ -362,14 +363,21 @@
362
  },
363
  {
364
  "cell_type": "code",
365
- "execution_count": 51,
366
  "id": "a4c529e4-d8ae-42b2-a658-f76d183bb264",
367
  "metadata": {},
368
  "outputs": [
 
 
 
 
 
 
 
369
  {
370
  "data": {
371
  "application/vnd.jupyter.widget-view+json": {
372
- "model_id": "58f2ef57b8ea49c2a26d4361ce4a5983",
373
  "version_major": 2,
374
  "version_minor": 0
375
  },
@@ -391,7 +399,7 @@
391
  "name": "stdout",
392
  "output_type": "stream",
393
  "text": [
394
- "trainable params: 24,576 || all params: 3,821,104,128 || trainable%: 0.0006\n",
395
  "None\n"
396
  ]
397
  }
@@ -406,7 +414,7 @@
406
  },
407
  {
408
  "cell_type": "code",
409
- "execution_count": 52,
410
  "id": "3289e4e3-9b9a-4256-921b-5df21d18344e",
411
  "metadata": {},
412
  "outputs": [],
@@ -421,7 +429,7 @@
421
  },
422
  {
423
  "cell_type": "code",
424
- "execution_count": 55,
425
  "id": "e7939d75-c6b9-47a8-b1a3-88f7c33ff121",
426
  "metadata": {},
427
  "outputs": [
@@ -429,8 +437,9 @@
429
  "name": "stderr",
430
  "output_type": "stream",
431
  "text": [
432
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 10.97it/s]\n",
433
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 31.61it/s]\n"
 
434
  ]
435
  },
436
  {
@@ -444,8 +453,8 @@
444
  "name": "stderr",
445
  "output_type": "stream",
446
  "text": [
447
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 12.02it/s]\n",
448
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 31.35it/s]\n"
449
  ]
450
  },
451
  {
@@ -459,8 +468,8 @@
459
  "name": "stderr",
460
  "output_type": "stream",
461
  "text": [
462
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 12.70it/s]\n",
463
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 31.66it/s]\n"
464
  ]
465
  },
466
  {
@@ -474,8 +483,8 @@
474
  "name": "stderr",
475
  "output_type": "stream",
476
  "text": [
477
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 11.85it/s]\n",
478
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 32.45it/s]\n"
479
  ]
480
  },
481
  {
@@ -489,8 +498,8 @@
489
  "name": "stderr",
490
  "output_type": "stream",
491
  "text": [
492
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 12.53it/s]\n",
493
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:13<00:00, 32.38it/s]"
494
  ]
495
  },
496
  {
@@ -546,7 +555,7 @@
546
  },
547
  {
548
  "cell_type": "code",
549
- "execution_count": 59,
550
  "id": "806d36f8-499e-4af8-b717-68e5d849866d",
551
  "metadata": {},
552
  "outputs": [],
@@ -556,14 +565,14 @@
556
  },
557
  {
558
  "cell_type": "code",
559
- "execution_count": 1,
560
- "id": "13db780a-fe20-4b23-b6cb-17118f7b695e",
561
  "metadata": {},
562
  "outputs": [
563
  {
564
  "data": {
565
  "application/vnd.jupyter.widget-view+json": {
566
- "model_id": "d8f94426025f4ad89847ac7e983cec42",
567
  "version_major": 2,
568
  "version_minor": 0
569
  },
@@ -573,48 +582,90 @@
573
  },
574
  "metadata": {},
575
  "output_type": "display_data"
576
- },
577
- {
578
- "name": "stderr",
579
- "output_type": "stream",
580
- "text": [
581
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
582
- ]
583
  }
584
  ],
585
  "source": [
586
- "from transformers import pipeline\n",
587
- "device = 'cuda'\n",
588
- "pipe = pipeline(model='model', device=device, max_length=100)"
 
 
 
 
 
589
  ]
590
  },
591
  {
592
  "cell_type": "code",
593
- "execution_count": 2,
594
- "id": "26438301-3601-44f4-bbe4-3c573a1c28be",
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  "metadata": {},
596
  "outputs": [
597
  {
598
  "name": "stderr",
599
  "output_type": "stream",
600
  "text": [
601
- "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
602
- "You are not running the flash-attention implementation, expect numerical differences.\n"
603
  ]
604
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  {
606
  "data": {
607
  "text/plain": [
608
- "[{'generated_text': \"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\\n\\n### response\\nI understand your situation and I'm here to help. First, it's important to clarify that as an AI developed by Microsoft, I don't have the authority to directly intervene with your utility bills or the National Grid. However, I can guide you through the steps you should take to address this issue.\\n\\n1\"}]"
609
  ]
610
  },
611
- "execution_count": 2,
612
  "metadata": {},
613
  "output_type": "execute_result"
614
  }
615
  ],
616
  "source": [
617
- "pipe(\"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\")"
618
  ]
619
  },
620
  {
 
36
  },
37
  {
38
  "cell_type": "code",
39
+ "execution_count": 3,
40
  "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b",
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
44
  "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
45
+ "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType, PeftConfig\n",
46
  "import torch\n",
47
  "from datasets import load_dataset\n",
48
  "import os\n",
 
54
  },
55
  {
56
  "cell_type": "code",
57
+ "execution_count": 17,
58
  "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da",
59
  "metadata": {},
60
+ "outputs": [
61
+ {
62
+ "data": {
63
+ "application/vnd.jupyter.widget-view+json": {
64
+ "model_id": "7f03fcf3844743fcb41f8bfc9c6c9b70",
65
+ "version_major": 2,
66
+ "version_minor": 0
67
+ },
68
+ "text/plain": [
69
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
70
+ ]
71
+ },
72
+ "metadata": {},
73
+ "output_type": "display_data"
74
+ }
75
+ ],
76
  "source": [
77
  "notebook_login()"
78
  ]
79
  },
80
  {
81
  "cell_type": "code",
82
+ "execution_count": 3,
83
  "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2",
84
  "metadata": {},
85
  "outputs": [
86
  {
87
  "data": {
88
  "text/plain": [
89
+ "CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/ab5911db092a8e53ea24c33f170e8013a8b172aa', commit_message='Upload prompt_tune_phi3.ipynb with huggingface_hub', commit_description='', oid='ab5911db092a8e53ea24c33f170e8013a8b172aa', pr_url=None, pr_revision=None, pr_num=None)"
90
  ]
91
  },
92
+ "execution_count": 3,
93
  "metadata": {},
94
  "output_type": "execute_result"
95
  }
 
105
  },
106
  {
107
  "cell_type": "code",
108
+ "execution_count": 4,
109
  "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
110
  "metadata": {},
111
  "outputs": [],
 
118
  " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n",
119
  " task_type=TaskType.CAUSAL_LM, # config task\n",
120
  " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n",
121
+ " num_virtual_tokens=100, # x times the number of hidden transformer layers\n",
122
  " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
123
  " tokenizer_name_or_path=model_id\n",
124
  ")\n",
 
138
  },
139
  {
140
  "cell_type": "code",
141
+ "execution_count": 5,
142
  "id": "6f677839-ef23-428a-bcfe-f596590804ca",
143
  "metadata": {},
144
  "outputs": [],
 
148
  },
149
  {
150
  "cell_type": "code",
151
+ "execution_count": 11,
152
  "id": "c0c05613-7941-4959-ada9-49ed1093bec4",
153
  "metadata": {},
154
  "outputs": [
 
158
  "['Unlabeled', 'complaint', 'no complaint']"
159
  ]
160
  },
161
+ "execution_count": 11,
162
  "metadata": {},
163
  "output_type": "execute_result"
164
  }
 
170
  },
171
  {
172
  "cell_type": "code",
173
+ "execution_count": 7,
174
  "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
175
  "metadata": {},
176
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  {
178
  "data": {
179
  "text/plain": [
 
183
  " 'text_label': 'no complaint'}"
184
  ]
185
  },
186
+ "execution_count": 7,
187
  "metadata": {},
188
  "output_type": "execute_result"
189
  }
 
202
  },
203
  {
204
  "cell_type": "code",
205
+ "execution_count": 8,
206
  "id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
207
  "metadata": {},
208
  "outputs": [
 
219
  "[1, 853, 29880, 24025]"
220
  ]
221
  },
222
+ "execution_count": 8,
223
  "metadata": {},
224
  "output_type": "execute_result"
225
  }
 
251
  },
252
  {
253
  "cell_type": "code",
254
+ "execution_count": 14,
255
  "id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
256
  "metadata": {},
257
  "outputs": [],
 
292
  },
293
  {
294
  "cell_type": "code",
295
+ "execution_count": 15,
296
  "id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
297
  "metadata": {},
298
  "outputs": [
299
  {
300
  "data": {
301
  "application/vnd.jupyter.widget-view+json": {
302
+ "model_id": "5494bc1fbce24646b61e60e119ae1cb2",
303
  "version_major": 2,
304
  "version_minor": 0
305
  },
 
313
  {
314
  "data": {
315
  "application/vnd.jupyter.widget-view+json": {
316
+ "model_id": "857675d314254672964cafc522e3869f",
317
  "version_major": 2,
318
  "version_minor": 0
319
  },
 
338
  },
339
  {
340
  "cell_type": "code",
341
+ "execution_count": 16,
342
  "id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
343
  "metadata": {},
344
  "outputs": [],
 
363
  },
364
  {
365
  "cell_type": "code",
366
+ "execution_count": 17,
367
  "id": "a4c529e4-d8ae-42b2-a658-f76d183bb264",
368
  "metadata": {},
369
  "outputs": [
370
+ {
371
+ "name": "stderr",
372
+ "output_type": "stream",
373
+ "text": [
374
+ "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
375
+ ]
376
+ },
377
  {
378
  "data": {
379
  "application/vnd.jupyter.widget-view+json": {
380
+ "model_id": "1d09f75f23894968a6acd482a53fc92b",
381
  "version_major": 2,
382
  "version_minor": 0
383
  },
 
399
  "name": "stdout",
400
  "output_type": "stream",
401
  "text": [
402
+ "trainable params: 307,200 || all params: 3,821,386,752 || trainable%: 0.0080\n",
403
  "None\n"
404
  ]
405
  }
 
414
  },
415
  {
416
  "cell_type": "code",
417
+ "execution_count": 18,
418
  "id": "3289e4e3-9b9a-4256-921b-5df21d18344e",
419
  "metadata": {},
420
  "outputs": [],
 
429
  },
430
  {
431
  "cell_type": "code",
432
+ "execution_count": 19,
433
  "id": "e7939d75-c6b9-47a8-b1a3-88f7c33ff121",
434
  "metadata": {},
435
  "outputs": [
 
437
  "name": "stderr",
438
  "output_type": "stream",
439
  "text": [
440
+ " 0%| | 0/7 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
441
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:01<00:00, 5.36it/s]\n",
442
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:29<00:00, 14.23it/s]\n"
443
  ]
444
  },
445
  {
 
453
  "name": "stderr",
454
  "output_type": "stream",
455
  "text": [
456
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 7.66it/s]\n",
457
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:29<00:00, 14.26it/s]\n"
458
  ]
459
  },
460
  {
 
468
  "name": "stderr",
469
  "output_type": "stream",
470
  "text": [
471
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 7.76it/s]\n",
472
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:29<00:00, 14.25it/s]\n"
473
  ]
474
  },
475
  {
 
483
  "name": "stderr",
484
  "output_type": "stream",
485
  "text": [
486
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 7.72it/s]\n",
487
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:29<00:00, 14.24it/s]\n"
488
  ]
489
  },
490
  {
 
498
  "name": "stderr",
499
  "output_type": "stream",
500
  "text": [
501
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 7.77it/s]\n",
502
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 425/425 [00:29<00:00, 14.18it/s]"
503
  ]
504
  },
505
  {
 
555
  },
556
  {
557
  "cell_type": "code",
558
+ "execution_count": 20,
559
  "id": "806d36f8-499e-4af8-b717-68e5d849866d",
560
  "metadata": {},
561
  "outputs": [],
 
565
  },
566
  {
567
  "cell_type": "code",
568
+ "execution_count": 10,
569
+ "id": "cff41965-fa71-420b-80d8-ce597510f1d3",
570
  "metadata": {},
571
  "outputs": [
572
  {
573
  "data": {
574
  "application/vnd.jupyter.widget-view+json": {
575
+ "model_id": "821777d6daa442c7a5779f3aff695739",
576
  "version_major": 2,
577
  "version_minor": 0
578
  },
 
582
  },
583
  "metadata": {},
584
  "output_type": "display_data"
 
 
 
 
 
 
 
585
  }
586
  ],
587
  "source": [
588
+ "from peft import PeftModel, PeftConfig\n",
589
+ "from transformers import AutoModelForCausalLM, AutoTokenizer \n",
590
+ "\n",
591
+ "#tokenizer = AutoTokenizer.from_pretrained('model')\n",
592
+ "\n",
593
+ "config = PeftConfig.from_pretrained('model')\n",
594
+ "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
595
+ "model = PeftModel.from_pretrained(model, 'model')"
596
  ]
597
  },
598
  {
599
  "cell_type": "code",
600
+ "execution_count": 11,
601
+ "id": "d8a432c9-9ddb-4bb7-a7f0-c4cadd612535",
602
+ "metadata": {},
603
+ "outputs": [],
604
+ "source": [
605
+ "inputs = tokenizer(\n",
606
+ " f'{text_col} : {\"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\"} Label : ',\n",
607
+ " return_tensors=\"pt\",\n",
608
+ ")"
609
+ ]
610
+ },
611
+ {
612
+ "cell_type": "code",
613
+ "execution_count": 15,
614
+ "id": "66cfaab3-dc63-4a1e-ab4d-2a687695993d",
615
  "metadata": {},
616
  "outputs": [
617
  {
618
  "name": "stderr",
619
  "output_type": "stream",
620
  "text": [
621
+ "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1249: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
622
+ " warnings.warn(\n"
623
  ]
624
  },
625
+ {
626
+ "ename": "ValueError",
627
+ "evalue": "Input length of input_ids is 32, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.",
628
+ "output_type": "error",
629
+ "traceback": [
630
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
631
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
632
+ "Cell \u001b[0;32mIn[15], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 4\u001b[0m inputs \u001b[38;5;241m=\u001b[39m {k: v\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m inputs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m----> 5\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
633
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1493\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.generate\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1491\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model\u001b[38;5;241m.\u001b[39mgenerate(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1492\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1493\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1494\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[1;32m 1495\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model\u001b[38;5;241m.\u001b[39mprepare_inputs_for_generation \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model_prepare_inputs_for_generation\n",
634
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
635
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1786\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1783\u001b[0m model_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpast_key_values\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m DynamicCache\u001b[38;5;241m.\u001b[39mfrom_legacy_cache(past)\n\u001b[1;32m 1784\u001b[0m use_dynamic_cache_by_default \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m-> 1786\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_generated_length\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_ids_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhas_default_max_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1788\u001b[0m \u001b[38;5;66;03m# 7. determine generation mode\u001b[39;00m\n\u001b[1;32m 1789\u001b[0m generation_mode \u001b[38;5;241m=\u001b[39m generation_config\u001b[38;5;241m.\u001b[39mget_generation_mode(assistant_model)\n",
636
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1257\u001b[0m, in \u001b[0;36mGenerationMixin._validate_generated_length\u001b[0;34m(self, generation_config, input_ids_length, has_default_max_length)\u001b[0m\n\u001b[1;32m 1255\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m input_ids_length \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m generation_config\u001b[38;5;241m.\u001b[39mmax_length:\n\u001b[1;32m 1256\u001b[0m input_ids_string \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdecoder_input_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1257\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1258\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInput length of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minput_ids_string\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is \u001b[39m\u001b[38;5;132;01m{\u001b[39;00minput_ids_length\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, but `max_length` is set to\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1259\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgeneration_config\u001b[38;5;241m.\u001b[39mmax_length\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This can lead to unexpected behavior. You should consider\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1260\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m increasing `max_length` or, better yet, setting `max_new_tokens`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1261\u001b[0m )\n\u001b[1;32m 1263\u001b[0m \u001b[38;5;66;03m# 2. Min length warnings due to unfeasible parameter combinations\u001b[39;00m\n\u001b[1;32m 1264\u001b[0m min_length_error_suffix \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Generation will stop at the defined maximum length. You should decrease the minimum length and/or \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mincrease the maximum length.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1267\u001b[0m )\n",
637
+ "\u001b[0;31mValueError\u001b[0m: Input length of input_ids is 32, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`."
638
+ ]
639
+ }
640
+ ],
641
+ "source": [
642
+ "model.to(device)\n",
643
+ "\n",
644
+ "with torch.no_grad():\n",
645
+ " inputs = {k: v.to(device) for k, v in inputs.items()}\n",
646
+ " out = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"])#, max_new_tokens=10) #, eos_token_id=3)\n",
647
+ " #print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": 24,
653
+ "id": "26438301-3601-44f4-bbe4-3c573a1c28be",
654
+ "metadata": {},
655
+ "outputs": [
656
  {
657
  "data": {
658
  "text/plain": [
659
+ "[{'generated_text': '@HMRCcustomers No this is my first job and I am not sure what to do. I have been told that I need to register with HMRC but I am not sure how to do this. Can you please help me?\\n\\n### response\\nTo register with HMRC for your first job, you need to complete a Self Assessment tax return if you are self-employed or have income to report. For employees, you may need to complete'}]"
660
  ]
661
  },
662
+ "execution_count": 24,
663
  "metadata": {},
664
  "output_type": "execute_result"
665
  }
666
  ],
667
  "source": [
668
+ "pipe(\"@HMRCcustomers No this is my first job\")"
669
  ]
670
  },
671
  {