File size: 37,471 Bytes
583d10e |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyPWZ8Jb4Kxe+LPy00eQFSll"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["!pip install transformers datasets trl bitsandbytes peft\n","!pip install datasets\n","!pip install -U accelerate"],"metadata":{"collapsed":true,"colab":{"base_uri":"https://localhost:8080/"},"id":"-84y3O9audxh","executionInfo":{"status":"ok","timestamp":1730211323839,"user_tz":-540,"elapsed":24581,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"outputId":"62401782-8e8b-4215-efdd-41331def3f6d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n","Collecting datasets\n"," Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)\n","Collecting trl\n"," Downloading trl-0.11.4-py3-none-any.whl.metadata (12 kB)\n","Collecting bitsandbytes\n"," Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)\n","Collecting peft\n"," Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n","Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n","Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n","Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.1.0)\n","Collecting dill<0.3.9,>=0.3.0 (from datasets)\n"," Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n","Collecting xxhash (from datasets)\n"," Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n","Collecting multiprocess<0.70.17 (from datasets)\n"," Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n","Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.6.1)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.10)\n","Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from trl) (2.5.0+cu121)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from trl) (0.34.2)\n","Collecting tyro>=0.5.11 (from trl)\n"," Downloading tyro-0.8.14-py3-none-any.whl.metadata (8.4 kB)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n","Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n","Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.16.0)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.4.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.1.4)\n","Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (1.13.1)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.4.0->trl) (1.3.0)\n","Requirement already satisfied: docstring-parser>=0.16 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (0.16)\n","Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.9.3)\n","Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n"," Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.18.0)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4.0->trl) (3.0.2)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n","Downloading datasets-3.0.2-py3-none-any.whl (472 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m472.7/472.7 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading trl-0.11.4-py3-none-any.whl (316 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m316.6/316.6 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m122.4/122.4 MB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading peft-0.13.2-py3-none-any.whl (320 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m320.7/320.7 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m8.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m12.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading tyro-0.8.14-py3-none-any.whl (109 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m109.8/109.8 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hDownloading shtab-1.7.1-py3-none-any.whl (14 kB)\n","Installing collected packages: xxhash, shtab, dill, multiprocess, tyro, bitsandbytes, peft, datasets, trl\n","Successfully installed bitsandbytes-0.44.1 datasets-3.0.2 dill-0.3.8 multiprocess-0.70.16 peft-0.13.2 shtab-1.7.1 trl-0.11.4 tyro-0.8.14 xxhash-3.5.0\n","Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.0.2)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n","Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.1.0)\n","Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n","Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n","Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n","Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n","Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.6.1)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.10)\n","Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.7)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n","Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n","Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.16.0)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.34.2)\n","Collecting accelerate\n"," Downloading accelerate-1.0.1-py3-none-any.whl.metadata (19 kB)\n","Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.2)\n","Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.5.0+cu121)\n","Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.24.7)\n","Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)\n","Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n","Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.4.2)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\n","Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.1)\n","Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (3.0.2)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.2.3)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)\n","Downloading accelerate-1.0.1-py3-none-any.whl (330 kB)\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m330.9/330.9 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hInstalling collected packages: accelerate\n"," Attempting uninstall: accelerate\n"," Found existing installation: accelerate 0.34.2\n"," Uninstalling accelerate-0.34.2:\n"," Successfully uninstalled accelerate-0.34.2\n","Successfully installed accelerate-1.0.1\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"collapsed":true,"id":"dCTa8Ekcs2ZB"},"outputs":[],"source":["import os\n","import torch\n","import torchvision\n","torchvision.disable_beta_transforms_warning()\n","from datasets import load_dataset, concatenate_datasets\n","from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig\n","from trl import SFTTrainer\n","from peft import LoraConfig, get_peft_model"]},{"cell_type":"markdown","source":["# W&B λΉνμ±ν"],"metadata":{"id":"3QxeXUHqyZSL"}},{"cell_type":"code","source":["import os\n","os.environ[\"WANDB_MODE\"] = \"disabled\""],"metadata":{"id":"G5KNSMgQyZJs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["torch.cuda.empty_cache()\n","torch.cuda.memory_summary(device=None, abbreviated=False)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":227},"collapsed":true,"id":"dvrZxfhRoSX_","executionInfo":{"status":"ok","timestamp":1730211361971,"user_tz":-540,"elapsed":375,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"outputId":"27bcfaf8-f420-451f-824c-d6e329218838"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'|===========================================================================|\\n| PyTorch CUDA memory summary, device ID 0 |\\n|---------------------------------------------------------------------------|\\n| CUDA OOMs: 0 | cudaMalloc retries: 0 |\\n|===========================================================================|\\n| Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |\\n|---------------------------------------------------------------------------|\\n| Allocated memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Active memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Requested memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| GPU reserved memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Non-releasable memory | 0 B | 0 B | 0 B | 0 B |\\n| from large pool | 0 B | 0 B | 0 B | 0 B |\\n| from small pool | 0 B | 0 B | 0 B | 0 B |\\n|---------------------------------------------------------------------------|\\n| Allocations | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Active allocs | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| GPU reserved segments | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Non-releasable allocs | 0 | 0 | 0 | 0 |\\n| from large pool | 0 | 0 | 0 | 0 |\\n| from small pool | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Oversize allocations | 0 | 0 | 0 | 0 |\\n|---------------------------------------------------------------------------|\\n| Oversize GPU segments | 0 | 0 | 0 | 0 |\\n|===========================================================================|\\n'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":5}]},{"cell_type":"markdown","source":["# νκ²½ λ³μ λ‘λ λ° Google Colab νκ²½ μ€μ "],"metadata":{"id":"hvemiqXbtPVo"}},{"cell_type":"code","source":["if os.path.exists('C:/Users/yd170/OneDrive/λ°ν νλ©΄/Coding/KRX.env'):\n"," load_dotenv('C:/Users/yd170/OneDrive/λ°ν νλ©΄/Coding/KRX.env')\n"," hf_token = os.getenv(\"HF_TOKEN\")\n","else:\n"," hf_token = \"YOUR_HF_TOKEN\"\n","\n","model_name = \"Qwen/Qwen2-1.5B\"\n","max_seq_length = 2048"],"metadata":{"id":"2nEoHW7itR3t"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# μμν μ€μ (4λΉνΈ μμν μ¬μ©)"],"metadata":{"id":"XPOupITbtTe1"}},{"cell_type":"code","source":["bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True, # 4λΉνΈ μμν μ μ©\n"," bnb_4bit_use_double_quant=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.float16\n",")"],"metadata":{"id":"LxQopNuFtTOy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# λͺ¨λΈ λ° ν ν¬λμ΄μ λ‘λ (GPU μ¬μ©νλλ‘ μ€μ )"],"metadata":{"id":"dcD5t1OstWbx"}},{"cell_type":"code","source":["tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)\n","model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map=\"auto\")"],"metadata":{"collapsed":true,"id":"uXH2KLTytX0A"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# LoRA μ€μ μΆκ°"],"metadata":{"id":"230hvGJmtZU6"}},{"cell_type":"code","source":["lora_config = LoraConfig(\n"," r=16,\n"," lora_alpha=32,\n"," target_modules=[\"q_proj\", \"v_proj\"],\n"," lora_dropout=0.05,\n"," bias=\"none\"\n",")\n","model = get_peft_model(model, lora_config)"],"metadata":{"id":"RM4Xvc74tZD8"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# λ κ°μ λ°μ΄ν°μ
λ‘λ λ° λ³ν©"],"metadata":{"id":"Y-jCFuCdtb7i"}},{"cell_type":"code","source":["first_dataset = load_dataset(\"amphora/krx-sample-instructions\", split=\"train\")\n","second_dataset = load_dataset(\"Cartinoe5930/web_text_synthetic_dataset_50k\", split=\"train\")\n","\n","# λ°μ΄ν°μ
λ³ν©\n","dataset = concatenate_datasets([first_dataset, second_dataset])\n","\n","# ν둬ννΈ ν¬λ§· μ€μ \n","prompt_format = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n","\n","### Instruction:\n","{}\n","\n","### Response:\n","{}\"\"\"\n","\n","EOS_TOKEN = tokenizer.eos_token\n","\n","def formatting_prompts_func(examples):\n"," instructions = []\n"," outputs = []\n","\n"," # λ°μ΄ν°μ
λ΄ νλ νμΈ ν κ° νλμ λ°λΌ μ²λ¦¬\n"," if \"prompt\" in examples and \"response\" in examples:\n"," instructions = examples[\"prompt\"]\n"," outputs = examples[\"response\"]\n"," elif \"question\" in examples and \"response\" in examples:\n"," instructions = examples[\"question\"]\n"," outputs = examples[\"response\"]\n"," else:\n"," raise KeyError(\"The dataset fields do not match the expected format.\")\n","\n"," texts = [prompt_format.format(instr, output) + EOS_TOKEN for instr, output in zip(instructions, outputs)]\n"," return {\"formatted_text\": texts}\n","\n","# λ°μ΄ν°μ
κ°κ³΅\n","dataset = dataset.map(formatting_prompts_func, batched=True)"],"metadata":{"collapsed":true,"id":"FNBsMfZWtdYm"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# λͺ¨λΈ νμ΅ μ€μ "],"metadata":{"id":"yhoxZxuBtmuu"}},{"cell_type":"code","source":["training_args = TrainingArguments(\n"," output_dir=\"./output\",\n"," per_device_train_batch_size=1, d\n"," gradient_accumulation_steps=8,\n"," max_steps=100,\n"," logging_steps=10,\n"," learning_rate=2e-5,\n"," seed=42,\n"," save_steps=100,\n"," fp16=True, # νΌν© μ λ°λ μ¬μ©μΌλ‘ λ©λͺ¨λ¦¬ μ΅μ ν\n"," report_to=\"none\",\n",")"],"metadata":{"id":"ijdYp1Wxtmgi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# GPU λ©λͺ¨λ¦¬ κ΄λ¦¬ μ΅μ ν νκ²½ λ³μ μ€μ "],"metadata":{"id":"O7iV-HEyxagR"}},{"cell_type":"code","source":["os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\""],"metadata":{"id":"WqNvG7vwxboU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# SFTTrainer μ΄κΈ°ν"],"metadata":{"id":"bPlQN84Nto4j"}},{"cell_type":"code","source":["trainer = SFTTrainer(\n"," model=model,\n"," tokenizer=tokenizer,\n"," train_dataset=dataset, # λ³ν©λ λ°μ΄ν°μ
μ¬μ©\n"," dataset_text_field=\"formatted_text\",\n"," max_seq_length=1024,\n"," args=training_args,\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"collapsed":true,"id":"O4MNZHjNtp11","executionInfo":{"status":"ok","timestamp":1730212244752,"user_tz":-540,"elapsed":765,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"outputId":"0e779eee-4195-4ef5-93bc-ce3096a1e2a9"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length. Will not be supported from version '1.0.0'.\n","\n","Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.\n"," warnings.warn(message, FutureWarning)\n","/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:283: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n"," warnings.warn(\n","/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:321: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.\n"," warnings.warn(\n","max_steps is given, it will override any value given in num_train_epochs\n"]}]},{"cell_type":"markdown","source":["# νμ΅"],"metadata":{"id":"qbLwnuqYtqYh"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"t586zh7owDHf","executionInfo":{"status":"ok","timestamp":1730211648880,"user_tz":-540,"elapsed":18263,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"outputId":"0409cadd-f76f-41aa-a2ff-a738e6d575f5"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["print(\"λͺ¨λΈ νμ΅ μμ...\")\n","trainer.train()\n","print(\"λͺ¨λΈ νμ΅ μλ£.\")\n","\n","# νμ΅λ λͺ¨λΈ μ μ₯\n","print(\"λͺ¨λΈ μ μ₯ μ€...\")\n","model.save_pretrained(\"/content/drive/My Drive/KRX_Qwen2_1_5B\")\n","print(\"λͺ¨λΈ μ μ₯ μλ£.\")\n","\n","print(\"ν ν¬λμ΄μ μ μ₯ μ€...\")\n","tokenizer.save_pretrained(\"/content/drive/My Drive/KRX_Qwen2_1_5B\")\n","print(\"ν ν¬λμ΄μ μ μ₯ μλ£.\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":493},"collapsed":true,"id":"42X2FaIPtsdj","executionInfo":{"status":"ok","timestamp":1730212538132,"user_tz":-540,"elapsed":287807,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"outputId":"8e271ef3-1f6b-4754-841d-fe7797d070fb"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["λͺ¨λΈ νμ΅ μμ...\n"]},{"output_type":"display_data","data":{"text/plain":["<IPython.core.display.HTML object>"],"text/html":["\n"," <div>\n"," \n"," <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n"," [100/100 04:37, Epoch 0/1]\n"," </div>\n"," <table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: left;\">\n"," <th>Step</th>\n"," <th>Training Loss</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <td>10</td>\n"," <td>1.821300</td>\n"," </tr>\n"," <tr>\n"," <td>20</td>\n"," <td>1.804400</td>\n"," </tr>\n"," <tr>\n"," <td>30</td>\n"," <td>1.777700</td>\n"," </tr>\n"," <tr>\n"," <td>40</td>\n"," <td>1.827900</td>\n"," </tr>\n"," <tr>\n"," <td>50</td>\n"," <td>1.770500</td>\n"," </tr>\n"," <tr>\n"," <td>60</td>\n"," <td>1.734600</td>\n"," </tr>\n"," <tr>\n"," <td>70</td>\n"," <td>1.803100</td>\n"," </tr>\n"," <tr>\n"," <td>80</td>\n"," <td>1.693200</td>\n"," </tr>\n"," <tr>\n"," <td>90</td>\n"," <td>1.727900</td>\n"," </tr>\n"," <tr>\n"," <td>100</td>\n"," <td>1.709200</td>\n"," </tr>\n"," </tbody>\n","</table><p>"]},"metadata":{}},{"output_type":"stream","name":"stdout","text":["λͺ¨λΈ νμ΅ μλ£.\n","λͺ¨λΈ μ μ₯ μ€...\n","λͺ¨λΈ μ μ₯ μλ£.\n","ν ν¬λμ΄μ μ μ₯ μ€...\n","ν ν¬λμ΄μ μ μ₯ μλ£.\n"]}]},{"cell_type":"markdown","source":["# νμ΅λ λͺ¨λΈ λ‘λ"],"metadata":{"id":"TiLfumsA1TVr"}},{"cell_type":"code","source":["from transformers import AutoTokenizer, AutoModelForCausalLM\n","from google.colab import drive\n","\n","# Google Drive λ§μ΄νΈ\n","drive.mount('/content/drive')\n","\n","# νμ΅λ λͺ¨λΈ λ‘λ (Google Driveμμ μ μ₯λ κ²½λ‘ μ§μ )\n","model_name = \"/content/drive/My Drive/KRX_Qwen2_1_5B\"\n","tokenizer = AutoTokenizer.from_pretrained(model_name)\n","model = AutoModelForCausalLM.from_pretrained(model_name)\n","\n","print(\"λͺ¨λΈκ³Ό ν ν¬λμ΄μ κ° μ±κ³΅μ μΌλ‘ λ‘λλμμ΅λλ€.\")"],"metadata":{"collapsed":true,"id":"onqwou0O1VU8","executionInfo":{"status":"ok","timestamp":1730212718308,"user_tz":-540,"elapsed":41590,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"07c75506-35c0-4647-8a83-d78bf82f4a39"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n"," warnings.warn(\n"]},{"output_type":"stream","name":"stdout","text":["λͺ¨λΈκ³Ό ν ν¬λμ΄μ κ° μ±κ³΅μ μΌλ‘ λ‘λλμμ΅λλ€.\n"]}]},{"cell_type":"markdown","source":["# μΆλ‘ μ μν ν둬ννΈ μ€μ "],"metadata":{"id":"J3YP2p4b1YMF"}},{"cell_type":"code","source":["prompt_format = \"\"\"The following is a detailed financial question or instruction, and the corresponding answer is expected to be precise and informative. Use relevant financial terms and provide a comprehensive explanation.\n","\n","### Instruction:\n","{}\n","\n","### Response:\"\"\""],"metadata":{"id":"bEuDvZ9N1Z-l"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# μμ ν둬ννΈ"],"metadata":{"id":"WAcjaaz91av1"}},{"cell_type":"code","source":["instruction = \"μ λ¬Όμ΅μ
μ λν΄ μ€λͺ
ν΄μ€.\"\n","prompt = prompt_format.format(instruction)\n","inputs = tokenizer(prompt, return_tensors=\"pt\")"],"metadata":{"id":"5rCUMNMT1cKw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# ν
μ€νΈ μμ±"],"metadata":{"id":"OT6wiZby1dvy"}},{"cell_type":"code","source":["outputs = model.generate(\n"," **inputs,\n"," max_new_tokens=256,\n"," temperature=0.7, # λ€μμ± μ‘°μ \n"," top_k=50, # μ΅μμ Kκ°μ λ¨μ΄λ§ κ³ λ €\n"," repetition_penalty=1.2, # λ°λ³΅μ μ€μ΄κΈ° μν ν¨λν°\n"," use_cache=True\n",")"],"metadata":{"id":"ZQEsBUq21eyc","executionInfo":{"status":"ok","timestamp":1730212993303,"user_tz":-540,"elapsed":24179,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"08c63ca0-9b2f-4437-9a80-5fc0d814a7a0"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.7` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n"," warnings.warn(\n","Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.\n"]}]},{"cell_type":"markdown","source":["# κ²°κ³Ό μΆλ ₯"],"metadata":{"id":"RG3dENu01gNC"}},{"cell_type":"code","source":["response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n","\n","# μ
λ ₯λ ν둬ννΈ μ΄νμ μλ΅ λΆλΆλ§ μΆλ ₯\n","print(response[len(prompt):].strip())"],"metadata":{"id":"IgPVmSqQ1hlC","executionInfo":{"status":"ok","timestamp":1730213030713,"user_tz":-540,"elapsed":385,"user":{"displayName":"βꡬμμ [ νλΆμ¬ν / μνκ³Ό ]","userId":"15682121601729926510"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"1f5ee5aa-8144-4b70-a47d-953e81782543"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Orange price at the port of origin?\n","\n","Yes, that's correct! The term \"orange\" in this context refers to the orange fruit itself rather than its value as an investment asset.\n"]}]}]} |