Update README.md
Browse files
README.md
CHANGED
@@ -11,7 +11,8 @@ language:
|
|
11 |
- en
|
12 |
---
|
13 |
|
14 |
-
#
|
|
|
15 |
|
16 |
- **Developed by:** morizon
|
17 |
- **License:** apache-2.0
|
@@ -20,3 +21,189 @@ language:
|
|
20 |
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
|
21 |
|
22 |
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
- en
|
12 |
---
|
13 |
|
14 |
+
# morizon/llm-jp-3-13b-instruct2-grpo-0215_lora
|
15 |
+
このモデルは日本語テキスト生成タスク向けに最適化されたLoRAアダプタ付きのモデルです。
|
16 |
|
17 |
- **Developed by:** morizon
|
18 |
- **License:** apache-2.0
|
|
|
21 |
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
|
22 |
|
23 |
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
|
24 |
+
|
25 |
+
## Sample Use
|
26 |
+
|
27 |
+
|
28 |
+
```python
|
29 |
+
|
30 |
+
%%capture
|
31 |
+
# Skip restarting message in Colab
|
32 |
+
import sys; modules = list(sys.modules.keys())
|
33 |
+
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
|
34 |
+
|
35 |
+
!pip install unsloth vllm
|
36 |
+
!pip install --upgrade pillow
|
37 |
+
# If you are running this notebook on local, you need to install `diffusers` too
|
38 |
+
# !pip install diffusers
|
39 |
+
# Temporarily install a specific TRL nightly version
|
40 |
+
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
|
41 |
+
```
|
42 |
+
|
43 |
+
```python
|
44 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
45 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
46 |
+
|
47 |
+
import re
|
48 |
+
import torch
|
49 |
+
from datasets import load_dataset, Dataset
|
50 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
51 |
+
from peft import LoraConfig
|
52 |
+
from trl import GRPOConfig, GRPOTrainer
|
53 |
+
from unsloth import is_bfloat16_supported
|
54 |
+
```
|
55 |
+
|
56 |
+
```python
|
57 |
+
model_id="llm-jp/llm-jp-3-13b-instruct2"
|
58 |
+
adpter_id="morizon/llm-jp-3-13b-instruct2-grpo-MATH-lighteval_step1000_lora"
|
59 |
+
|
60 |
+
# --- モデルの読み込みと LoRA 適用 ---
|
61 |
+
max_seq_length = 1024 # 推論トレースの最大長
|
62 |
+
lora_rank = 64 # LoRA のランク(推奨値:64)
|
63 |
+
|
64 |
+
# FastLanguageModel 経由でモデルとトークナイザーを読み込み
|
65 |
+
# ※ モデル名は使用するものに合わせてください
|
66 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
67 |
+
model_name=model_id,
|
68 |
+
max_seq_length=max_seq_length,
|
69 |
+
load_in_4bit=True, # 4bit量子化(LoRAファインチューニング時は設定に注意)
|
70 |
+
fast_inference=True, # vLLM 高速推論を有効化
|
71 |
+
max_lora_rank=lora_rank,
|
72 |
+
gpu_memory_utilization=0.7,
|
73 |
+
)
|
74 |
+
|
75 |
+
# LoRA (PEFT) を適用
|
76 |
+
model = FastLanguageModel.get_peft_model(
|
77 |
+
model,
|
78 |
+
r=lora_rank,
|
79 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
80 |
+
lora_alpha=lora_rank,
|
81 |
+
use_gradient_checkpointing="unsloth",
|
82 |
+
random_state=3407,
|
83 |
+
)
|
84 |
+
```
|
85 |
+
|
86 |
+
```python
|
87 |
+
# --- プロンプトとデータセットの準備 ---
|
88 |
+
# 推奨:システムプロンプトを排除し、ユーザープロンプトに全指示を統合
|
89 |
+
USER_INSTRUCTION = (
|
90 |
+
"Please ensure your response begins with \"<reasoning>\n\". "
|
91 |
+
"Please reason step by step, and put your final answer within \\boxed{}. "
|
92 |
+
)
|
93 |
+
|
94 |
+
# テストデータの例(リスト形式)
|
95 |
+
test_data = [
|
96 |
+
{"id": 0, "text": "$x^{-1}>x$を満たす正の整数$x$の個数を求めなさい。", "gold": "0", "response": "", "type": "Algebra", "level": "Level 2"},
|
97 |
+
##評価したいテストデータを入力してください
|
98 |
+
]
|
99 |
+
|
100 |
+
def extract_boxed_answer_rev(text: str) -> str:
|
101 |
+
"""
|
102 |
+
テキスト中から最初の \boxed{...} の中身(ネストを考慮)を抽出する。
|
103 |
+
例: r"\boxed{\frac{\pi}{6}}" -> "\frac{\pi}{6}"
|
104 |
+
"""
|
105 |
+
key = r"\boxed{"
|
106 |
+
start_idx = text.find(key)
|
107 |
+
if start_idx == -1:
|
108 |
+
return ""
|
109 |
+
# \boxed{ の直後の位置を開始位置とする
|
110 |
+
start_idx += len(key)
|
111 |
+
brace_count = 1 # 最初の { を既にカウント
|
112 |
+
i = start_idx
|
113 |
+
while i < len(text) and brace_count > 0:
|
114 |
+
if text[i] == "{":
|
115 |
+
brace_count += 1
|
116 |
+
elif text[i] == "}":
|
117 |
+
brace_count -= 1
|
118 |
+
i += 1
|
119 |
+
# i-1 が閉じ括弧に対応する位置
|
120 |
+
return text[start_idx:i-1].strip()
|
121 |
+
|
122 |
+
from vllm import SamplingParams
|
123 |
+
|
124 |
+
correct = 0
|
125 |
+
total = len(test_data)
|
126 |
+
|
127 |
+
# 正解ケースと誤答ケースを記録するリスト
|
128 |
+
correct_cases = []
|
129 |
+
incorrect_cases = []
|
130 |
+
|
131 |
+
for item in test_data:
|
132 |
+
# プロンプト生成(USER_INSTRUCTION を先頭に追加)
|
133 |
+
prompt = USER_INSTRUCTION + item["text"]
|
134 |
+
text = tokenizer.apply_chat_template([
|
135 |
+
{"role": "user", "content": prompt},
|
136 |
+
], tokenize=False, add_generation_prompt=True)
|
137 |
+
|
138 |
+
# 推論実行
|
139 |
+
sampling_params = SamplingParams(
|
140 |
+
temperature=0.6,
|
141 |
+
max_tokens=2048,
|
142 |
+
)
|
143 |
+
output = model.fast_generate(
|
144 |
+
text,
|
145 |
+
sampling_params=sampling_params,
|
146 |
+
lora_request = model.load_lora(adpter_id),
|
147 |
+
# lora_request = model.load_lora("grpo_saved_lora"),
|
148 |
+
)[0].outputs[0].text
|
149 |
+
|
150 |
+
# \boxed{...} の中身を抽出する関数で回答を取得
|
151 |
+
boxed_answer = extract_boxed_answer_rev(output)
|
152 |
+
|
153 |
+
# 結果の表示用
|
154 |
+
print("\n----------Test ID:", item["id"], "----------")
|
155 |
+
print("Prompt:")
|
156 |
+
print(prompt)
|
157 |
+
print("\nLLM Output:")
|
158 |
+
print(output)
|
159 |
+
print("\nExtracted Answer:")
|
160 |
+
print(boxed_answer)
|
161 |
+
print("Gold Answer:", item["gold"])
|
162 |
+
|
163 |
+
# 抽出回答と gold の一致で正解判定
|
164 |
+
if boxed_answer == item["gold"]:
|
165 |
+
correct += 1
|
166 |
+
correct_cases.append({
|
167 |
+
"id": item["id"],
|
168 |
+
"prompt": prompt,
|
169 |
+
"LLM_output": output,
|
170 |
+
"extracted_answer": boxed_answer,
|
171 |
+
"gold": item["gold"]
|
172 |
+
})
|
173 |
+
else:
|
174 |
+
incorrect_cases.append({
|
175 |
+
"id": item["id"],
|
176 |
+
"prompt": prompt,
|
177 |
+
"LLM_output": output,
|
178 |
+
"extracted_answer": boxed_answer,
|
179 |
+
"gold": item["gold"]
|
180 |
+
})
|
181 |
+
|
182 |
+
# 正解ケースの表示
|
183 |
+
print("\n========== 正解ケース ==========")
|
184 |
+
for case in correct_cases:
|
185 |
+
print("\nTest ID:", case["id"])
|
186 |
+
print("Prompt:")
|
187 |
+
print(case["prompt"])
|
188 |
+
print("LLM Output:")
|
189 |
+
print(case["LLM_output"])
|
190 |
+
print("Extracted Answer:", case["extracted_answer"])
|
191 |
+
print("Gold Answer:", case["gold"])
|
192 |
+
print("-" * 40)
|
193 |
+
|
194 |
+
# 誤答ケースの表示
|
195 |
+
print("\n========== 誤答ケース ==========")
|
196 |
+
for case in incorrect_cases:
|
197 |
+
print("\nTest ID:", case["id"])
|
198 |
+
print("Prompt:")
|
199 |
+
print(case["prompt"])
|
200 |
+
print("LLM Output:")
|
201 |
+
print(case["LLM_output"])
|
202 |
+
print("Extracted Answer:", case["extracted_answer"])
|
203 |
+
print("Gold Answer:", case["gold"])
|
204 |
+
print("-" * 40)
|
205 |
+
|
206 |
+
accuracy = correct / total * 100
|
207 |
+
print("\nOverall Accuracy: {}/{} ({:.2f}%)".format(correct, total, accuracy))
|
208 |
+
|
209 |
+
```
|