morizon commited on
Commit
51e0e51
·
verified ·
1 Parent(s): 88cd2a8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +188 -1
README.md CHANGED
@@ -11,7 +11,8 @@ language:
11
  - en
12
  ---
13
 
14
- # Uploaded model
 
15
 
16
  - **Developed by:** morizon
17
  - **License:** apache-2.0
@@ -20,3 +21,189 @@ language:
20
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
21
 
22
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  - en
12
  ---
13
 
14
+ # morizon/llm-jp-3-13b-instruct2-grpo-0215_lora
15
+ このモデルは日本語テキスト生成タスク向けに最適化されたLoRAアダプタ付きのモデルです。
16
 
17
  - **Developed by:** morizon
18
  - **License:** apache-2.0
 
21
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
22
 
23
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
24
+
25
+ ## Sample Use
26
+
27
+
28
+ ```python
29
+
30
+ %%capture
31
+ # Skip restarting message in Colab
32
+ import sys; modules = list(sys.modules.keys())
33
+ for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
34
+
35
+ !pip install unsloth vllm
36
+ !pip install --upgrade pillow
37
+ # If you are running this notebook on local, you need to install `diffusers` too
38
+ # !pip install diffusers
39
+ # Temporarily install a specific TRL nightly version
40
+ !pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
41
+ ```
42
+
43
+ ```python
44
+ from unsloth import FastLanguageModel, PatchFastRL
45
+ PatchFastRL("GRPO", FastLanguageModel)
46
+
47
+ import re
48
+ import torch
49
+ from datasets import load_dataset, Dataset
50
+ from transformers import AutoTokenizer, AutoModelForCausalLM
51
+ from peft import LoraConfig
52
+ from trl import GRPOConfig, GRPOTrainer
53
+ from unsloth import is_bfloat16_supported
54
+ ```
55
+
56
+ ```python
57
+ model_id="llm-jp/llm-jp-3-13b-instruct2"
58
+ adpter_id="morizon/llm-jp-3-13b-instruct2-grpo-MATH-lighteval_step1000_lora"
59
+
60
+ # --- モデルの読み込みと LoRA 適用 ---
61
+ max_seq_length = 1024 # 推論トレースの最大長
62
+ lora_rank = 64 # LoRA のランク(推奨値:64)
63
+
64
+ # FastLanguageModel 経由でモデルとトークナイザーを読み込み
65
+ # ※ モデル名は使用するものに合わせてください
66
+ model, tokenizer = FastLanguageModel.from_pretrained(
67
+ model_name=model_id,
68
+ max_seq_length=max_seq_length,
69
+ load_in_4bit=True, # 4bit量子化(LoRAファインチューニング時は設定に注意)
70
+ fast_inference=True, # vLLM 高速推論を有効化
71
+ max_lora_rank=lora_rank,
72
+ gpu_memory_utilization=0.7,
73
+ )
74
+
75
+ # LoRA (PEFT) を適用
76
+ model = FastLanguageModel.get_peft_model(
77
+ model,
78
+ r=lora_rank,
79
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
80
+ lora_alpha=lora_rank,
81
+ use_gradient_checkpointing="unsloth",
82
+ random_state=3407,
83
+ )
84
+ ```
85
+
86
+ ```python
87
+ # --- プロンプトとデータセットの準備 ---
88
+ # 推奨:システムプロンプトを排除し、ユーザープロンプトに全指示を統合
89
+ USER_INSTRUCTION = (
90
+ "Please ensure your response begins with \"<reasoning>\n\". "
91
+ "Please reason step by step, and put your final answer within \\boxed{}. "
92
+ )
93
+
94
+ # テストデータの例(リスト形式)
95
+ test_data = [
96
+ {"id": 0, "text": "$x^{-1}>x$を満たす正の整数$x$の個数を求めなさい。", "gold": "0", "response": "", "type": "Algebra", "level": "Level 2"},
97
+ ##評価したいテストデータを入力してください
98
+ ]
99
+
100
+ def extract_boxed_answer_rev(text: str) -> str:
101
+ """
102
+ テキスト中から最初の \boxed{...} の中身(ネストを考慮)を抽出する。
103
+ 例: r"\boxed{\frac{\pi}{6}}" -> "\frac{\pi}{6}"
104
+ """
105
+ key = r"\boxed{"
106
+ start_idx = text.find(key)
107
+ if start_idx == -1:
108
+ return ""
109
+ # \boxed{ の直後の位置を開始位置とする
110
+ start_idx += len(key)
111
+ brace_count = 1 # 最初の { を既にカウント
112
+ i = start_idx
113
+ while i < len(text) and brace_count > 0:
114
+ if text[i] == "{":
115
+ brace_count += 1
116
+ elif text[i] == "}":
117
+ brace_count -= 1
118
+ i += 1
119
+ # i-1 が閉じ括弧に対応する位置
120
+ return text[start_idx:i-1].strip()
121
+
122
+ from vllm import SamplingParams
123
+
124
+ correct = 0
125
+ total = len(test_data)
126
+
127
+ # 正解ケースと誤答ケースを記録するリスト
128
+ correct_cases = []
129
+ incorrect_cases = []
130
+
131
+ for item in test_data:
132
+ # プロンプト生成(USER_INSTRUCTION を先頭に追加)
133
+ prompt = USER_INSTRUCTION + item["text"]
134
+ text = tokenizer.apply_chat_template([
135
+ {"role": "user", "content": prompt},
136
+ ], tokenize=False, add_generation_prompt=True)
137
+
138
+ # 推論実行
139
+ sampling_params = SamplingParams(
140
+ temperature=0.6,
141
+ max_tokens=2048,
142
+ )
143
+ output = model.fast_generate(
144
+ text,
145
+ sampling_params=sampling_params,
146
+ lora_request = model.load_lora(adpter_id),
147
+ # lora_request = model.load_lora("grpo_saved_lora"),
148
+ )[0].outputs[0].text
149
+
150
+ # \boxed{...} の中身を抽出する関数で回答を取得
151
+ boxed_answer = extract_boxed_answer_rev(output)
152
+
153
+ # 結果の表示用
154
+ print("\n----------Test ID:", item["id"], "----------")
155
+ print("Prompt:")
156
+ print(prompt)
157
+ print("\nLLM Output:")
158
+ print(output)
159
+ print("\nExtracted Answer:")
160
+ print(boxed_answer)
161
+ print("Gold Answer:", item["gold"])
162
+
163
+ # 抽出回答と gold の一致で正解判定
164
+ if boxed_answer == item["gold"]:
165
+ correct += 1
166
+ correct_cases.append({
167
+ "id": item["id"],
168
+ "prompt": prompt,
169
+ "LLM_output": output,
170
+ "extracted_answer": boxed_answer,
171
+ "gold": item["gold"]
172
+ })
173
+ else:
174
+ incorrect_cases.append({
175
+ "id": item["id"],
176
+ "prompt": prompt,
177
+ "LLM_output": output,
178
+ "extracted_answer": boxed_answer,
179
+ "gold": item["gold"]
180
+ })
181
+
182
+ # 正解ケースの表示
183
+ print("\n========== 正解ケース ==========")
184
+ for case in correct_cases:
185
+ print("\nTest ID:", case["id"])
186
+ print("Prompt:")
187
+ print(case["prompt"])
188
+ print("LLM Output:")
189
+ print(case["LLM_output"])
190
+ print("Extracted Answer:", case["extracted_answer"])
191
+ print("Gold Answer:", case["gold"])
192
+ print("-" * 40)
193
+
194
+ # 誤答ケースの表示
195
+ print("\n========== 誤答ケース ==========")
196
+ for case in incorrect_cases:
197
+ print("\nTest ID:", case["id"])
198
+ print("Prompt:")
199
+ print(case["prompt"])
200
+ print("LLM Output:")
201
+ print(case["LLM_output"])
202
+ print("Extracted Answer:", case["extracted_answer"])
203
+ print("Gold Answer:", case["gold"])
204
+ print("-" * 40)
205
+
206
+ accuracy = correct / total * 100
207
+ print("\nOverall Accuracy: {}/{} ({:.2f}%)".format(correct, total, accuracy))
208
+
209
+ ```