Commit
·
241ab60
1
Parent(s):
7e2da68
update phrasing
Browse files
examples/fine-tune-smollm2-on-synthetic-data.ipynb
CHANGED
@@ -75,18 +75,9 @@
|
|
75 |
},
|
76 |
{
|
77 |
"cell_type": "code",
|
78 |
-
"execution_count":
|
79 |
"metadata": {},
|
80 |
-
"outputs": [
|
81 |
-
{
|
82 |
-
"name": "stderr",
|
83 |
-
"output_type": "stream",
|
84 |
-
"text": [
|
85 |
-
"/Users/davidberenstein/Documents/programming/argilla/synthetic-data-generator/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
86 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
87 |
-
]
|
88 |
-
}
|
89 |
-
],
|
90 |
"source": [
|
91 |
"# Import necessary libraries\n",
|
92 |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
@@ -229,7 +220,16 @@
|
|
229 |
"cell_type": "markdown",
|
230 |
"metadata": {},
|
231 |
"source": [
|
232 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
"\n",
|
234 |
"### Run inference\n",
|
235 |
"\n",
|
@@ -238,12 +238,28 @@
|
|
238 |
},
|
239 |
{
|
240 |
"cell_type": "code",
|
241 |
-
"execution_count":
|
242 |
"metadata": {},
|
243 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
"source": [
|
245 |
-
"
|
246 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
]
|
248 |
},
|
249 |
{
|
|
|
75 |
},
|
76 |
{
|
77 |
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
"metadata": {},
|
80 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
"source": [
|
82 |
"# Import necessary libraries\n",
|
83 |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
|
|
220 |
"cell_type": "markdown",
|
221 |
"metadata": {},
|
222 |
"source": [
|
223 |
+
"```\n",
|
224 |
+
"# {'loss': 1.4498, 'grad_norm': 2.3919131755828857, 'learning_rate': 4e-05, 'epoch': 0.1}\n",
|
225 |
+
"# {'loss': 1.362, 'grad_norm': 1.6650595664978027, 'learning_rate': 3e-05, 'epoch': 0.19}\n",
|
226 |
+
"# {'loss': 1.3778, 'grad_norm': 1.4778285026550293, 'learning_rate': 2e-05, 'epoch': 0.29}\n",
|
227 |
+
"# {'loss': 1.3735, 'grad_norm': 2.1424977779388428, 'learning_rate': 1e-05, 'epoch': 0.39}\n",
|
228 |
+
"# {'loss': 1.3512, 'grad_norm': 2.3498542308807373, 'learning_rate': 0.0, 'epoch': 0.48}\n",
|
229 |
+
"# {'train_runtime': 1911.514, 'train_samples_per_second': 1.046, 'train_steps_per_second': 0.262, 'train_loss': 1.3828572998046875, 'epoch': 0.48}\n",
|
230 |
+
"```\n",
|
231 |
+
"\n",
|
232 |
+
"For the example, we did not use a specific validation set but we can see the loss is decreasing, so we assume the model is generalsing well to the training data. To get a better understanding of the model's performance, let's test it again with the same prompt.\n",
|
233 |
"\n",
|
234 |
"### Run inference\n",
|
235 |
"\n",
|
|
|
238 |
},
|
239 |
{
|
240 |
"cell_type": "code",
|
241 |
+
"execution_count": 12,
|
242 |
"metadata": {},
|
243 |
+
"outputs": [
|
244 |
+
{
|
245 |
+
"name": "stderr",
|
246 |
+
"output_type": "stream",
|
247 |
+
"text": [
|
248 |
+
"Device set to use mps\n"
|
249 |
+
]
|
250 |
+
}
|
251 |
+
],
|
252 |
"source": [
|
253 |
+
"from transformers import pipeline\n",
|
254 |
+
"prompt = \"What is the primary function of mitochondria within a cell?\"\n",
|
255 |
+
"generator = pipeline(\n",
|
256 |
+
" \"text-generation\",\n",
|
257 |
+
" model=\"argilla/SmolLM2-360M-synthetic-concise-reasoning\",\n",
|
258 |
+
" device=\"mps\",\n",
|
259 |
+
")\n",
|
260 |
+
"generator(\n",
|
261 |
+
" [{\"role\": \"user\", \"content\": prompt}], max_new_tokens=128, return_full_text=False\n",
|
262 |
+
")[0][\"generated_text\"]"
|
263 |
]
|
264 |
},
|
265 |
{
|