{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "3890292a-c99e-4367-955d-5883b93dba36", "metadata": { "scrolled": true }, "outputs": [], "source": [ "!pip install -q peft transformers datasets huggingface_hub\n", "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": 20, "id": "f1cc378f-afb6-441f-a4c6-2ec427b4cd4b", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n", "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n", "import torch\n", "from datasets import load_dataset\n", "import os\n", "from torch.utils.data import DataLoader\n", "from tqdm import tqdm\n", "from huggingface_hub import notebook_login\n", "from huggingface_hub import HfApi" ] }, { "cell_type": "code", "execution_count": null, "id": "e4ab50d7-a4c9-4246-acd8-8875b87fe0da", "metadata": {}, "outputs": [], "source": [ "notebook_login()" ] }, { "cell_type": "code", "execution_count": 24, "id": "8a1cb1f9-b89d-4cac-a595-44e1e0ef85b2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/Granther/prompt-tuned-phi3/commit/912e66e469c6dd381daaa1ee25f5284e17c9377a', commit_message='Upload prompt_tune_phi3.ipynb with huggingface_hub', commit_description='', oid='912e66e469c6dd381daaa1ee25f5284e17c9377a', pr_url=None, pr_revision=None, pr_num=None)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "api = HfApi()\n", "api.upload_file(path_or_fileobj='prompt_tune_phi3.ipynb',\n", " path_in_repo='prompt_tune_phi3.ipynb',\n", " repo_id='Granther/prompt-tuned-phi3',\n", " repo_type='model'\n", " )" ] }, { "cell_type": "code", "execution_count": 6, "id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4", "metadata": {}, "outputs": [], "source": [ "device = 'cuda'\n", "\n", "model_id = 'microsoft/Phi-3-mini-128k-instruct'\n", "\n", "peft_conf = PromptTuningConfig(\n", " peft_type=PeftType.PROMPT_TUNING, # what kind of peft\n", " task_type=TaskType.CAUSAL_LM, # config task\n", " prompt_tuning_init=PromptTuningInit.TEXT, # Set to 'TEXT' to use prompt_tuning_init_text\n", " num_virtual_tokens=8, # x times the number of hidden transformer layers\n", " prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n", " tokenizer_name_or_path=model_id\n", ")\n", "\n", "dataset_name = \"twitter_complaints\"\n", "checkpoint_name = f\"{dataset_name}_{model_id}_{peft_conf.peft_type}_{peft_conf.task_type}_v1.pt\".replace(\n", " \"/\", \"_\"\n", ")\n", "\n", "text_col = 'Tweet text'\n", "lab_col = 'text_label'\n", "max_len = 64\n", "lr = 3e-2\n", "epochs = 50\n", "batch_size = 8" ] }, { "cell_type": "code", "execution_count": 7, "id": "6f677839-ef23-428a-bcfe-f596590804ca", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset('ought/raft', dataset_name, split='train')" ] }, { "cell_type": "code", "execution_count": 8, "id": "c0c05613-7941-4959-ada9-49ed1093bec4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Unlabeled', 'complaint', 'no complaint']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.features['Label'].names\n", "#>>> ['Unlabeled', 'complaint', 'no complaint']" ] }, { "cell_type": "code", "execution_count": 11, "id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d9e958c687dd493880d18d4f1621dad9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=10): 0%| | 0/50 [00:00