FEAT: add tagging support to axolotl (#1004)
Browse files* add tagging support to axolotl
* chore: lint
* fix method w self
---------
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
@@ -9,7 +9,7 @@ import math
|
|
9 |
import sys
|
10 |
from abc import abstractmethod
|
11 |
from dataclasses import dataclass, field
|
12 |
-
from functools import partial
|
13 |
from pathlib import Path
|
14 |
from typing import Optional
|
15 |
|
@@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer):
|
|
120 |
"""
|
121 |
|
122 |
args = None # type: AxolotlTrainingArguments
|
|
|
123 |
|
124 |
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
125 |
self.num_epochs = num_epochs
|
@@ -290,12 +291,41 @@ class AxolotlTrainer(Trainer):
|
|
290 |
# return (loss, outputs) if return_outputs else loss
|
291 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
295 |
"""
|
296 |
Mamba specific trainer to handle loss calculation
|
297 |
"""
|
298 |
|
|
|
|
|
299 |
def compute_loss(
|
300 |
self,
|
301 |
model,
|
@@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|
322 |
Trainer subclass that uses the OneCycleLR scheduler
|
323 |
"""
|
324 |
|
|
|
|
|
325 |
def __init__(self, *args, **kwargs):
|
326 |
super().__init__(*args, **kwargs)
|
327 |
self.lr_scheduler = None
|
@@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
351 |
Trainer subclass that uses the OneCycleLR scheduler
|
352 |
"""
|
353 |
|
|
|
|
|
354 |
def __init__(self, *args, **kwargs):
|
355 |
super().__init__(*args, **kwargs)
|
356 |
self.lr_scheduler = None
|
|
|
9 |
import sys
|
10 |
from abc import abstractmethod
|
11 |
from dataclasses import dataclass, field
|
12 |
+
from functools import partial, wraps
|
13 |
from pathlib import Path
|
14 |
from typing import Optional
|
15 |
|
|
|
120 |
"""
|
121 |
|
122 |
args = None # type: AxolotlTrainingArguments
|
123 |
+
tag_names = ["axolotl"]
|
124 |
|
125 |
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
126 |
self.num_epochs = num_epochs
|
|
|
291 |
# return (loss, outputs) if return_outputs else loss
|
292 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
293 |
|
294 |
+
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
295 |
+
if isinstance(tag_names, str):
|
296 |
+
tag_names = [tag_names]
|
297 |
+
|
298 |
+
if kwargs is not None:
|
299 |
+
if "tags" not in kwargs:
|
300 |
+
kwargs["tags"] = tag_names
|
301 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
302 |
+
kwargs["tags"].extend(tag_names)
|
303 |
+
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
304 |
+
tag_names.append(kwargs["tags"])
|
305 |
+
kwargs["tags"] = tag_names
|
306 |
+
|
307 |
+
return kwargs
|
308 |
+
|
309 |
+
@wraps(Trainer.push_to_hub)
|
310 |
+
def push_to_hub(self, *args, **kwargs) -> str:
|
311 |
+
"""
|
312 |
+
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
313 |
+
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
314 |
+
"""
|
315 |
+
kwargs = self._sanitize_kwargs_for_tagging(
|
316 |
+
tag_names=self.tag_names, kwargs=kwargs
|
317 |
+
)
|
318 |
+
|
319 |
+
return super().push_to_hub(*args, **kwargs)
|
320 |
+
|
321 |
|
322 |
class AxolotlMambaTrainer(AxolotlTrainer):
|
323 |
"""
|
324 |
Mamba specific trainer to handle loss calculation
|
325 |
"""
|
326 |
|
327 |
+
tag_names = ["axolotl", "mamba"]
|
328 |
+
|
329 |
def compute_loss(
|
330 |
self,
|
331 |
model,
|
|
|
352 |
Trainer subclass that uses the OneCycleLR scheduler
|
353 |
"""
|
354 |
|
355 |
+
tag_names = ["axolotl", "onecycle"]
|
356 |
+
|
357 |
def __init__(self, *args, **kwargs):
|
358 |
super().__init__(*args, **kwargs)
|
359 |
self.lr_scheduler = None
|
|
|
383 |
Trainer subclass that uses the OneCycleLR scheduler
|
384 |
"""
|
385 |
|
386 |
+
tag_names = ["axolotl", "relora"]
|
387 |
+
|
388 |
def __init__(self, *args, **kwargs):
|
389 |
super().__init__(*args, **kwargs)
|
390 |
self.lr_scheduler = None
|