make sure to save the lora adapter at the end of RL/dpo training (#1573)
Browse files- src/axolotl/train.py +4 -0
src/axolotl/train.py
CHANGED
@@ -212,6 +212,10 @@ def train(
|
|
212 |
if cfg.flash_optimum and BetterTransformer:
|
213 |
model = BetterTransformer.reverse(model)
|
214 |
|
|
|
|
|
|
|
|
|
215 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
216 |
|
217 |
if not cfg.hub_model_id:
|
|
|
212 |
if cfg.flash_optimum and BetterTransformer:
|
213 |
model = BetterTransformer.reverse(model)
|
214 |
|
215 |
+
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
216 |
+
trainer.model.save_pretrained(
|
217 |
+
cfg.output_dir, safe_serialization=safe_serialization
|
218 |
+
)
|
219 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
220 |
|
221 |
if not cfg.hub_model_id:
|