AIR-hl commited on
Commit
5b51b63
·
verified ·
1 Parent(s): b7bb402

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +135 -3
README.md CHANGED
@@ -1,3 +1,135 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - HuggingFaceH4/ultrafeedback_binarized
5
+ base_model:
6
+ - AIR-hl/Qwen2.5-1.5B-ultrachat200k
7
+ pipeline_tag: text-generation
8
+ tags:
9
+ - trl
10
+ - qwen
11
+ - dpo
12
+ - alignment
13
+ - transformers
14
+ - custome
15
+ - chat
16
+ ---
17
+ # Qwen2.5-1.5B-WPO
18
+
19
+
20
+ ## Model Details
21
+
22
+ - **Model type:** aligned model
23
+ - **License:** Apache license 2.0
24
+ - **Finetuned from model:** [AIR-hl/Qwen2.5-1.5B-ultrachat200k](https://huggingface.co/AIR-hl/Qwen2.5-1.5B-ultrachat200k)
25
+ - **Training data:** [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
26
+ - **Training framework:** [trl](https://github.com/huggingface/trl)
27
+
28
+ ## Training Details
29
+
30
+ devices: 4 * NPU 910B-64GB \
31
+ precision: bf16 mixed-precision \
32
+ global_batch_size: 128
33
+
34
+ ### Training Hyperparameters
35
+ `attn_implementation`: None \
36
+ `beta`: 0.01 \
37
+ `bf16`: True \
38
+ `learning_rate`: 1e-6 \
39
+ `lr_scheduler_type`: cosine \
40
+ `per_device_train_batch_size`: 8 \
41
+ `gradient_accumulation_steps`: 4 \
42
+ `torch_dtype`: bfloat16 \
43
+ `num_train_epochs`: 1 \
44
+ `max_prompt_length`: 512 \
45
+ `max_length`: 1024 \
46
+ `warmup_ratio`: 0.05
47
+
48
+ ### Results
49
+
50
+ `init_train_loss`: 0.2410 \
51
+ `final_train_loss`: 0.1367 \
52
+ `accuracy`: 0.65 \
53
+ `reward_margin`: 0.2402
54
+
55
+ ### Training script
56
+
57
+ ```python
58
+ import torch
59
+ from datasets import load_dataset
60
+ from transformers import AutoModelForCausalLM, AutoTokenizer
61
+ import multiprocessing
62
+ from trl import (
63
+ DPOConfig,
64
+ DPOTrainer,
65
+ ModelConfig,
66
+ ScriptArguments,
67
+ TrlParser,
68
+ get_kbit_device_map,
69
+ get_peft_config,
70
+ get_quantization_config,
71
+ )
72
+ from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
73
+
74
+ if __name__ == "__main__":
75
+ parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig))
76
+ script_args, training_args, model_config = parser.parse_args_and_config()
77
+
78
+ torch_dtype = (
79
+ model_config.torch_dtype
80
+ if model_config.torch_dtype in ["auto", None]
81
+ else getattr(torch, model_config.torch_dtype)
82
+ )
83
+
84
+ quantization_config = get_quantization_config(model_config)
85
+
86
+ model_kwargs = dict(
87
+ revision=model_config.model_revision,
88
+ attn_implementation=model_config.attn_implementation,
89
+ torch_dtype=torch_dtype,
90
+ use_cache=False if training_args.gradient_checkpointing else True,
91
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
92
+ quantization_config=quantization_config,
93
+ )
94
+
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
97
+ )
98
+
99
+ peft_config = get_peft_config(model_config)
100
+ if peft_config is None:
101
+ ref_model = AutoModelForCausalLM.from_pretrained(
102
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
103
+ )
104
+ else:
105
+ ref_model = None
106
+
107
+ tokenizer = AutoTokenizer.from_pretrained(
108
+ model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
109
+ )
110
+ if tokenizer.pad_token is None:
111
+ tokenizer.pad_token = tokenizer.eos_token
112
+ if tokenizer.chat_template is None:
113
+ tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
114
+ if script_args.ignore_bias_buffers:
115
+ model._ddp_params_and_buffers_to_ignore = [
116
+ name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
117
+ ]
118
+
119
+ dataset = load_dataset(script_args.dataset_name,
120
+ split=script_args.dataset_train_split)
121
+ dataset=dataset.select_columns(['chosen', 'prompt', 'rejected'])
122
+
123
+ trainer = DPOTrainer(
124
+ model,
125
+ ref_model,
126
+ args=training_args,
127
+ train_dataset=dataset,
128
+ processing_class=tokenizer,
129
+ peft_config=peft_config,
130
+ )
131
+
132
+ trainer.train()
133
+
134
+ trainer.save_model(training_args.output_dir)
135
+ ```