winglian commited on
Commit
796a085
·
unverified ·
1 Parent(s): cb78a36

make sure to save the lora adapter at the end of RL/dpo training (#1573)

Browse files
Files changed (1) hide show
  1. src/axolotl/train.py +4 -0
src/axolotl/train.py CHANGED
@@ -212,6 +212,10 @@ def train(
212
  if cfg.flash_optimum and BetterTransformer:
213
  model = BetterTransformer.reverse(model)
214
 
 
 
 
 
215
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
216
 
217
  if not cfg.hub_model_id:
 
212
  if cfg.flash_optimum and BetterTransformer:
213
  model = BetterTransformer.reverse(model)
214
 
215
+ if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
216
+ trainer.model.save_pretrained(
217
+ cfg.output_dir, safe_serialization=safe_serialization
218
+ )
219
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
220
 
221
  if not cfg.hub_model_id: