Commit
·
219175f
1
Parent(s):
e795394
enhanced model loading procoess
Browse files- README.md +38 -125
- config.json +1 -1
- generation_config.json +0 -4
- merges.txt +0 -0
- model.safetensors +0 -3
- mp_pretrain.py +654 -0
- pytorch_model.bin.index.json +217 -0
- vocab.json +0 -0
README.md
CHANGED
@@ -60,149 +60,62 @@ from huggingface_hub import snapshot_download
|
|
60 |
snapshot_download(repo_id="PursuitOfDataScience/Argonne-1.0")
|
61 |
```
|
62 |
|
63 |
-
|
64 |
|
65 |
```python
|
66 |
-
import os
|
67 |
-
import sys
|
68 |
import torch
|
69 |
-
import
|
70 |
-
|
71 |
-
|
72 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
def main():
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
# Print all input arguments
|
81 |
-
print(f"Model directory: {args.model_dir}")
|
82 |
-
print(f"mp_pretrain directory: {args.mp_dir}")
|
83 |
-
|
84 |
-
# Check that directories exist
|
85 |
-
if not os.path.exists(args.model_dir):
|
86 |
-
print(f"Error: Model directory {args.model_dir} does not exist")
|
87 |
-
sys.exit(1)
|
88 |
-
if not os.path.exists(args.mp_dir):
|
89 |
-
print(f"Error: mp_pretrain directory {args.mp_dir} does not exist")
|
90 |
-
sys.exit(1)
|
91 |
-
|
92 |
-
# Check for required files
|
93 |
-
required_files = [
|
94 |
-
os.path.join(args.model_dir, "config.json"),
|
95 |
-
os.path.join(args.model_dir, "tokenizer.json")
|
96 |
-
]
|
97 |
-
|
98 |
-
for file_path in required_files:
|
99 |
-
if not os.path.exists(file_path):
|
100 |
-
print(f"Error: Required file {file_path} does not exist")
|
101 |
-
sys.exit(1)
|
102 |
-
|
103 |
-
# Check for either pytorch_model.bin or model.safetensors
|
104 |
-
weights_file = None
|
105 |
-
if os.path.exists(os.path.join(args.model_dir, "pytorch_model.bin")):
|
106 |
-
weights_file = os.path.join(args.model_dir, "pytorch_model.bin")
|
107 |
-
print(f"Found PyTorch weights at {weights_file}")
|
108 |
-
elif os.path.exists(os.path.join(args.model_dir, "model.safetensors")):
|
109 |
-
weights_file = os.path.join(args.model_dir, "model.safetensors")
|
110 |
-
print(f"Found safetensors weights at {weights_file}")
|
111 |
-
else:
|
112 |
-
print(f"Error: No model weights found in {args.model_dir}")
|
113 |
-
sys.exit(1)
|
114 |
-
|
115 |
-
# Add mp_pretrain directory to Python path
|
116 |
-
sys.path.insert(0, args.mp_dir)
|
117 |
-
|
118 |
-
# Import required modules
|
119 |
-
try:
|
120 |
-
print("Importing modules from mp_pretrain...")
|
121 |
-
from mp_pretrain import ArgonneModelParallel, ArgonneConfig, load_bpe_tokenizer
|
122 |
-
print("Import successful")
|
123 |
-
except ImportError as e:
|
124 |
-
print(f"Error importing modules from mp_pretrain.py: {e}")
|
125 |
-
sys.exit(1)
|
126 |
-
|
127 |
-
# Load the config
|
128 |
-
print("Loading model config...")
|
129 |
-
with open(os.path.join(args.model_dir, "config.json"), 'r') as f:
|
130 |
-
config_dict = json.load(f)
|
131 |
-
config = ArgonneConfig(**config_dict)
|
132 |
-
print("Config loaded")
|
133 |
-
|
134 |
-
# Load the tokenizer
|
135 |
-
print("Loading tokenizer...")
|
136 |
-
tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
|
137 |
-
print("Tokenizer loaded")
|
138 |
|
139 |
-
#
|
140 |
-
print("Creating model...")
|
141 |
-
model = ArgonneModelParallel(config)
|
142 |
-
print("Model created")
|
143 |
-
|
144 |
-
# Load weights
|
145 |
-
print(f"Loading weights from {weights_file}...")
|
146 |
-
if weights_file.endswith(".bin"):
|
147 |
-
# Load PyTorch weights
|
148 |
-
state_dict = torch.load(weights_file, map_location="cpu")
|
149 |
-
else:
|
150 |
-
# Load safetensors weights
|
151 |
-
from safetensors.torch import load_file
|
152 |
-
state_dict = load_file(weights_file)
|
153 |
-
|
154 |
-
# Load state dict
|
155 |
-
print("Applying weights to model...")
|
156 |
-
model.load_state_dict(state_dict, strict=False)
|
157 |
-
print("Weights loaded")
|
158 |
-
|
159 |
-
# Move to GPU if available
|
160 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
161 |
-
print(f"Moving model to {device}...")
|
162 |
model = model.to(device)
|
163 |
|
164 |
-
#
|
165 |
-
model
|
|
|
166 |
|
167 |
-
|
|
|
|
|
168 |
|
169 |
-
#
|
170 |
-
|
171 |
-
print("Argonne Model Chat - Type 'exit' to quit")
|
172 |
-
print("="*50 + "\n")
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
input_ids,
|
189 |
-
max_new_tokens=50,
|
190 |
-
temperature=0.7,
|
191 |
-
top_k=50)[0]
|
192 |
-
|
193 |
-
# Decode output
|
194 |
-
response = tokenizer.decode(output_ids, skip_special_tokens=True)
|
195 |
-
print(f"Model: {response}")
|
196 |
-
|
197 |
|
198 |
if __name__ == "__main__":
|
199 |
main()
|
200 |
|
201 |
```
|
202 |
|
203 |
-
```
|
204 |
-
python minimal_chat.py --model_dir /path/to/model --mp_dir /path/to/mp_pretrain.py
|
205 |
-
```
|
206 |
|
207 |
### 📝 Example Outputs
|
208 |
Below are generated examples illustrating Argonne-1.0's style and capability when prompted:
|
|
|
60 |
snapshot_download(repo_id="PursuitOfDataScience/Argonne-1.0")
|
61 |
```
|
62 |
|
63 |
+
You can run the following sample code to use the model for text generation:
|
64 |
|
65 |
```python
|
|
|
|
|
66 |
import torch
|
67 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
68 |
+
|
69 |
+
# Register the model architecture with AutoModel
|
70 |
+
from mp_pretrain import ArgonneConfig, ArgonneModelParallel
|
71 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
72 |
+
|
73 |
+
# Register the model with Hugging Face's Auto classes
|
74 |
+
AutoConfig.register("argonne", ArgonneConfig)
|
75 |
+
AutoModel.register(ArgonneConfig, ArgonneModelParallel)
|
76 |
+
AutoModelForCausalLM.register(ArgonneConfig, ArgonneModelParallel)
|
77 |
|
78 |
def main():
|
79 |
+
# Load model and tokenizer using the Auto classes
|
80 |
+
model_dir = "PursuitOfDataScience/Argonne-1.0"
|
81 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
82 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
# Setup for inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
86 |
model = model.to(device)
|
87 |
|
88 |
+
# Add the 'devices' attribute that model.generate() expects
|
89 |
+
if not hasattr(model, 'devices'):
|
90 |
+
model.devices = [device]
|
91 |
|
92 |
+
# Set up pipeline stages to None if model was loaded without distribution
|
93 |
+
if not hasattr(model, 'pipeline_stages') or model.pipeline_stages is None:
|
94 |
+
model.pipeline_stages = None
|
95 |
|
96 |
+
# Generate text from a prompt
|
97 |
+
prompt = "The future of AI research is "
|
|
|
|
|
98 |
|
99 |
+
# Extract just the input_ids from tokenizer output
|
100 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
101 |
+
|
102 |
+
# Generate text
|
103 |
+
outputs = model.generate(
|
104 |
+
input_ids,
|
105 |
+
max_new_tokens=100,
|
106 |
+
temperature=0.7,
|
107 |
+
top_k=50
|
108 |
+
)
|
109 |
+
|
110 |
+
# Print the result
|
111 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
112 |
+
print(f"Generated text:\n{generated_text}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
if __name__ == "__main__":
|
115 |
main()
|
116 |
|
117 |
```
|
118 |
|
|
|
|
|
|
|
119 |
|
120 |
### 📝 Example Outputs
|
121 |
Below are generated examples illustrating Argonne-1.0's style and capability when prompted:
|
config.json
CHANGED
@@ -9,6 +9,6 @@
|
|
9 |
"n_head": 12,
|
10 |
"n_layer": 12,
|
11 |
"torch_dtype": "float32",
|
12 |
-
"transformers_version": "4.
|
13 |
"vocab_size": 12000
|
14 |
}
|
|
|
9 |
"n_head": 12,
|
10 |
"n_layer": 12,
|
11 |
"torch_dtype": "float32",
|
12 |
+
"transformers_version": "4.47.0",
|
13 |
"vocab_size": 12000
|
14 |
}
|
generation_config.json
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"_from_model_config": true,
|
3 |
-
"transformers_version": "4.44.0"
|
4 |
-
}
|
|
|
|
|
|
|
|
|
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c2615b21d9183cf83afc4278f40d465fc128eeaa6237a6f1440e404da555c96c
|
3 |
-
size 1304657232
|
|
|
|
|
|
|
|
mp_pretrain.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from tokenizers import ByteLevelBPETokenizer
|
8 |
+
from transformers import (
|
9 |
+
PreTrainedTokenizerFast,
|
10 |
+
PretrainedConfig,
|
11 |
+
PreTrainedModel
|
12 |
+
)
|
13 |
+
from tqdm import tqdm
|
14 |
+
from datasets import load_dataset, load_from_disk
|
15 |
+
import glob
|
16 |
+
|
17 |
+
os.environ["HF_DATASETS_CACHE"] = "./.cache"
|
18 |
+
|
19 |
+
#####################################
|
20 |
+
# BPE Tokenizer Utilities
|
21 |
+
#####################################
|
22 |
+
|
23 |
+
def create_text_file_from_arrow(arrow_files, output_file="all_text_for_tokenizer.txt"):
|
24 |
+
"""
|
25 |
+
Given a list of Arrow files, extract the 'text' column and write
|
26 |
+
it to a single text file (one text example per line).
|
27 |
+
"""
|
28 |
+
print(f"Creating a combined text file '{output_file}' from Arrow files...")
|
29 |
+
with open(output_file, "w", encoding="utf-8") as wf:
|
30 |
+
for arrow_path in tqdm(arrow_files):
|
31 |
+
# Load the Arrow file in *streaming* mode to avoid large memory usage
|
32 |
+
ds = load_dataset("arrow", data_files=[arrow_path], streaming=True)
|
33 |
+
# If "train" split exists, use ds["train"], else ds is the dataset
|
34 |
+
if "train" in ds:
|
35 |
+
ds = ds["train"]
|
36 |
+
for example in ds:
|
37 |
+
text = example.get("text", "")
|
38 |
+
# Write one line of text
|
39 |
+
wf.write(text.replace("\n", " ") + "\n")
|
40 |
+
|
41 |
+
def train_bpe_tokenizer(text_file, vocab_size=12000):
|
42 |
+
"""
|
43 |
+
Train a ByteLevel BPE tokenizer on a *plain-text file* and save it.
|
44 |
+
"""
|
45 |
+
tokenizer = ByteLevelBPETokenizer()
|
46 |
+
tokenizer.train(
|
47 |
+
files=[text_file],
|
48 |
+
vocab_size=vocab_size,
|
49 |
+
min_frequency=2,
|
50 |
+
special_tokens=[
|
51 |
+
"<|start_of_text|>",
|
52 |
+
"<pad>",
|
53 |
+
"<|end_of_text|>",
|
54 |
+
"<unk>",
|
55 |
+
"<mask>"
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
os.makedirs("bpe_tokenizer", exist_ok=True)
|
60 |
+
tokenizer.save_model("bpe_tokenizer")
|
61 |
+
|
62 |
+
# Save the full tokenizer JSON representation
|
63 |
+
with open(os.path.join("bpe_tokenizer", "tokenizer.json"), "w", encoding="utf-8") as f:
|
64 |
+
f.write(tokenizer._tokenizer.to_str())
|
65 |
+
|
66 |
+
# Create a tokenizer configuration
|
67 |
+
tokenizer_config = {
|
68 |
+
"model_max_length": 2048,
|
69 |
+
"bos_token": "<|start_of_text|>",
|
70 |
+
"eos_token": "<|end_of_text|>",
|
71 |
+
"unk_token": "<unk>",
|
72 |
+
"pad_token": "<pad>",
|
73 |
+
"mask_token": "<mask>"
|
74 |
+
}
|
75 |
+
with open(os.path.join("bpe_tokenizer", "tokenizer_config.json"), "w") as f:
|
76 |
+
json.dump(tokenizer_config, f)
|
77 |
+
|
78 |
+
# Create a Hugging Face PreTrainedTokenizerFast instance
|
79 |
+
hf_tokenizer = PreTrainedTokenizerFast(
|
80 |
+
tokenizer_file=os.path.join("bpe_tokenizer", "tokenizer.json"),
|
81 |
+
bos_token="<|start_of_text|>",
|
82 |
+
eos_token="<|end_of_text|>",
|
83 |
+
unk_token="<unk>",
|
84 |
+
pad_token="<pad>",
|
85 |
+
mask_token="<mask>"
|
86 |
+
)
|
87 |
+
hf_tokenizer.save_pretrained("bpe_tokenizer")
|
88 |
+
return hf_tokenizer
|
89 |
+
|
90 |
+
|
91 |
+
def load_bpe_tokenizer():
|
92 |
+
"""Load a previously trained BPE tokenizer in Hugging Face format."""
|
93 |
+
hf_tokenizer = PreTrainedTokenizerFast.from_pretrained("bpe_tokenizer", use_fast=True)
|
94 |
+
return hf_tokenizer
|
95 |
+
|
96 |
+
#####################################
|
97 |
+
# STREAMING MODE
|
98 |
+
#####################################
|
99 |
+
|
100 |
+
def streaming_token_generator(data_files, hf_tokenizer):
|
101 |
+
"""
|
102 |
+
Yields tokenized examples from a streaming dataset (no shuffle).
|
103 |
+
data_files should be a list of Arrow files.
|
104 |
+
"""
|
105 |
+
dataset = load_dataset("arrow", data_files=data_files, streaming=True)
|
106 |
+
if "train" in dataset:
|
107 |
+
dataset = dataset["train"]
|
108 |
+
|
109 |
+
for example in dataset:
|
110 |
+
text = example["text"] if "text" in example else ""
|
111 |
+
token_ids = hf_tokenizer.encode(text)
|
112 |
+
if len(token_ids) > 0:
|
113 |
+
yield token_ids
|
114 |
+
|
115 |
+
#####################################
|
116 |
+
# NON-STREAMING: Full Pass
|
117 |
+
#####################################
|
118 |
+
|
119 |
+
def load_nonstream_data(data_files, hf_tokenizer, block_size, num_proc=8):
|
120 |
+
"""
|
121 |
+
Loads the entire dataset in memory either from a cached processed directory
|
122 |
+
or processes it in parallel if not yet cached.
|
123 |
+
Returns a list of token ID sequences.
|
124 |
+
"""
|
125 |
+
|
126 |
+
processed_dir = "processed_data/tokenized_data"
|
127 |
+
if os.path.exists(processed_dir):
|
128 |
+
print(f"Loading cached dataset from '{processed_dir}'...")
|
129 |
+
ds = load_from_disk(processed_dir)
|
130 |
+
tokenized_data = ds["token_ids"]
|
131 |
+
return tokenized_data
|
132 |
+
|
133 |
+
print("No cached dataset found. Processing in parallel...")
|
134 |
+
|
135 |
+
ds_dict = load_dataset("arrow", data_files=data_files, streaming=False)
|
136 |
+
if "train" in ds_dict:
|
137 |
+
ds = ds_dict["train"]
|
138 |
+
else:
|
139 |
+
ds = ds_dict
|
140 |
+
|
141 |
+
def tokenize_and_truncate(example):
|
142 |
+
text = example["text"] if "text" in example else ""
|
143 |
+
token_ids = hf_tokenizer.encode(text)
|
144 |
+
if len(token_ids) < block_size + 1:
|
145 |
+
return {"token_ids": None}
|
146 |
+
token_ids = token_ids[:block_size+1]
|
147 |
+
return {"token_ids": token_ids}
|
148 |
+
|
149 |
+
ds = ds.map(
|
150 |
+
tokenize_and_truncate,
|
151 |
+
batched=False,
|
152 |
+
num_proc=num_proc
|
153 |
+
)
|
154 |
+
ds = ds.filter(lambda ex: ex["token_ids"] is not None,
|
155 |
+
num_proc=num_proc)
|
156 |
+
|
157 |
+
if "text" in ds.column_names:
|
158 |
+
ds = ds.remove_columns(["text"])
|
159 |
+
|
160 |
+
os.makedirs(os.path.dirname(processed_dir), exist_ok=True)
|
161 |
+
ds.save_to_disk(processed_dir)
|
162 |
+
print(f"Processed dataset saved to '{processed_dir}'.")
|
163 |
+
|
164 |
+
tokenized_data = ds["token_ids"]
|
165 |
+
return tokenized_data
|
166 |
+
|
167 |
+
def collate_batch(token_list_batch, block_size):
|
168 |
+
"""
|
169 |
+
Convert a list of token-ID lists into x,y Tensors for causal LM.
|
170 |
+
We'll truncate if longer than block_size+1, skip if shorter.
|
171 |
+
"""
|
172 |
+
x_list, y_list = [], []
|
173 |
+
for tokens in token_list_batch:
|
174 |
+
if len(tokens) < block_size + 1:
|
175 |
+
continue
|
176 |
+
tokens = tokens[:block_size+1]
|
177 |
+
x_list.append(tokens[:-1])
|
178 |
+
y_list.append(tokens[1:])
|
179 |
+
|
180 |
+
if not x_list:
|
181 |
+
return None, None
|
182 |
+
|
183 |
+
x_tensor = torch.tensor(x_list, dtype=torch.long)
|
184 |
+
y_tensor = torch.tensor(y_list, dtype=torch.long)
|
185 |
+
return x_tensor, y_tensor
|
186 |
+
|
187 |
+
#####################################
|
188 |
+
# Model Definition
|
189 |
+
#####################################
|
190 |
+
|
191 |
+
class ArgonneConfig(PretrainedConfig):
|
192 |
+
model_type = "argonne"
|
193 |
+
def __init__(self, vocab_size=12000, block_size=2048, n_layer=24, n_head=24, n_embd=1296, dropout=0.1, **kwargs):
|
194 |
+
super().__init__(**kwargs)
|
195 |
+
self.vocab_size = vocab_size
|
196 |
+
self.block_size = block_size
|
197 |
+
self.n_layer = n_layer
|
198 |
+
self.n_head = n_head
|
199 |
+
self.n_embd = n_embd
|
200 |
+
self.dropout = dropout
|
201 |
+
|
202 |
+
class Block(nn.Module):
|
203 |
+
def __init__(self, config):
|
204 |
+
super().__init__()
|
205 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
206 |
+
self.attn = CausalSelfAttention(config)
|
207 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
208 |
+
self.mlp = MLP(config)
|
209 |
+
def forward(self, x):
|
210 |
+
x = x + self.attn(self.ln1(x))
|
211 |
+
x = x + self.mlp(self.ln2(x))
|
212 |
+
return x
|
213 |
+
|
214 |
+
class CausalSelfAttention(nn.Module):
|
215 |
+
def __init__(self, config):
|
216 |
+
super().__init__()
|
217 |
+
assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by n_head"
|
218 |
+
self.n_head = config.n_head
|
219 |
+
self.head_dim = config.n_embd // config.n_head
|
220 |
+
self.query = nn.Linear(config.n_embd, config.n_embd)
|
221 |
+
self.key = nn.Linear(config.n_embd, config.n_embd)
|
222 |
+
self.value = nn.Linear(config.n_embd, config.n_embd)
|
223 |
+
self.attn_drop = nn.Dropout(config.dropout)
|
224 |
+
self.resid_drop = nn.Dropout(config.dropout)
|
225 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
226 |
+
self.register_buffer(
|
227 |
+
"mask",
|
228 |
+
torch.tril(torch.ones(config.block_size, config.block_size))
|
229 |
+
.view(1, 1, config.block_size, config.block_size)
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
b, t, c = x.size()
|
234 |
+
q = self.query(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
|
235 |
+
k = self.key(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
|
236 |
+
v = self.value(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
|
237 |
+
|
238 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
239 |
+
att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float('-inf'))
|
240 |
+
att = torch.softmax(att, dim=-1)
|
241 |
+
att = self.attn_drop(att)
|
242 |
+
y = att @ v
|
243 |
+
y = y.transpose(1, 2).contiguous().view(b, t, c)
|
244 |
+
y = self.resid_drop(self.proj(y))
|
245 |
+
return y
|
246 |
+
|
247 |
+
class MLP(nn.Module):
|
248 |
+
def __init__(self, config):
|
249 |
+
super().__init__()
|
250 |
+
self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
|
251 |
+
self.act = nn.GELU()
|
252 |
+
self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
|
253 |
+
self.drop = nn.Dropout(config.dropout)
|
254 |
+
def forward(self, x):
|
255 |
+
x = self.fc1(x)
|
256 |
+
x = self.act(x)
|
257 |
+
x = self.drop(x)
|
258 |
+
x = self.fc2(x)
|
259 |
+
x = self.drop(x)
|
260 |
+
return x
|
261 |
+
|
262 |
+
class ArgonneModelParallel(PreTrainedModel):
|
263 |
+
config_class = ArgonneConfig
|
264 |
+
|
265 |
+
def __init__(self, config):
|
266 |
+
super().__init__(config)
|
267 |
+
# Create embeddings on CPU initially
|
268 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
|
269 |
+
self.position_embedding = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
270 |
+
self.drop = nn.Dropout(config.dropout)
|
271 |
+
|
272 |
+
# Build all blocks on CPU
|
273 |
+
all_blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
274 |
+
|
275 |
+
# Final LayerNorm + output head
|
276 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
277 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
278 |
+
|
279 |
+
nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
|
280 |
+
self.post_init()
|
281 |
+
|
282 |
+
# Keep the blocks on CPU in a single ModuleList
|
283 |
+
self.blocks = all_blocks
|
284 |
+
# We'll defer pipeline splitting until later:
|
285 |
+
self.pipeline_stages = None
|
286 |
+
|
287 |
+
def distribute_model(self, device_ids=None):
|
288 |
+
"""
|
289 |
+
Distribute the model blocks across multiple GPU devices in a pipeline style.
|
290 |
+
If 'device_ids' is None, we'll discover all available GPUs.
|
291 |
+
"""
|
292 |
+
if device_ids is None:
|
293 |
+
num_gpus = torch.cuda.device_count()
|
294 |
+
if num_gpus < 1:
|
295 |
+
raise ValueError("No GPUs found—can't do pipeline parallel on CPU only.")
|
296 |
+
device_ids = [f"cuda:{i}" for i in range(num_gpus)]
|
297 |
+
|
298 |
+
# Store them so the training loop can keep referencing model.devices
|
299 |
+
self.devices = [torch.device(d) for d in device_ids]
|
300 |
+
|
301 |
+
self.pipeline_stages = nn.ModuleList()
|
302 |
+
num_gpus = len(device_ids)
|
303 |
+
blocks_per_gpu = math.ceil(len(self.blocks) / num_gpus)
|
304 |
+
|
305 |
+
start_idx = 0
|
306 |
+
for i in range(num_gpus):
|
307 |
+
end_idx = min(start_idx + blocks_per_gpu, len(self.blocks))
|
308 |
+
stage_blocks = self.blocks[start_idx:end_idx]
|
309 |
+
stage = nn.Sequential(*stage_blocks).to(device_ids[i])
|
310 |
+
self.pipeline_stages.append(stage)
|
311 |
+
start_idx = end_idx
|
312 |
+
if end_idx >= len(self.blocks):
|
313 |
+
break
|
314 |
+
|
315 |
+
# Move token_embedding + position_embedding to the first device
|
316 |
+
self.token_embedding.to(device_ids[0])
|
317 |
+
self.position_embedding.data = self.position_embedding.data.to(device_ids[0])
|
318 |
+
self.drop.to(device_ids[0])
|
319 |
+
|
320 |
+
# Move final LayerNorm + head to the last device
|
321 |
+
self.ln_f.to(device_ids[-1])
|
322 |
+
self.head.to(device_ids[-1])
|
323 |
+
|
324 |
+
def _init_weights(self, module):
|
325 |
+
if isinstance(module, nn.Linear):
|
326 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
327 |
+
if module.bias is not None:
|
328 |
+
nn.init.zeros_(module.bias)
|
329 |
+
elif isinstance(module, nn.Embedding):
|
330 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
331 |
+
|
332 |
+
def forward(self, idx, targets=None):
|
333 |
+
"""
|
334 |
+
If self.pipeline_stages is None, we do a normal single-device forward
|
335 |
+
(whatever device everything is currently on—CPU or a single GPU).
|
336 |
+
Otherwise, we do a pipeline parallel forward.
|
337 |
+
"""
|
338 |
+
if self.pipeline_stages is None:
|
339 |
+
# Single-device forward pass
|
340 |
+
device = self.token_embedding.weight.device
|
341 |
+
idx = idx.to(device)
|
342 |
+
b, t = idx.size()
|
343 |
+
assert t <= self.config.block_size, "Sequence length exceeds block size"
|
344 |
+
|
345 |
+
token_embeddings = self.token_embedding(idx)
|
346 |
+
position_embeddings = self.position_embedding[:, :t, :]
|
347 |
+
hidden_states = self.drop(token_embeddings + position_embeddings)
|
348 |
+
|
349 |
+
for block in self.blocks:
|
350 |
+
hidden_states = block(hidden_states)
|
351 |
+
|
352 |
+
hidden_states = self.ln_f(hidden_states)
|
353 |
+
logits = self.head(hidden_states)
|
354 |
+
|
355 |
+
loss = None
|
356 |
+
if targets is not None:
|
357 |
+
targets = targets.to(device)
|
358 |
+
logits = logits.view(-1, logits.size(-1))
|
359 |
+
targets = targets.view(-1)
|
360 |
+
loss = F.cross_entropy(logits, targets)
|
361 |
+
|
362 |
+
return logits, loss
|
363 |
+
else:
|
364 |
+
# Pipeline parallel forward
|
365 |
+
first_device = next(self.pipeline_stages[0].parameters()).device
|
366 |
+
last_device = next(self.pipeline_stages[-1].parameters()).device
|
367 |
+
|
368 |
+
x = idx.to(first_device)
|
369 |
+
b, t = x.size()
|
370 |
+
assert t <= self.config.block_size, "Sequence length exceeds block size"
|
371 |
+
|
372 |
+
token_embeddings = self.token_embedding(x)
|
373 |
+
position_embeddings = self.position_embedding[:, :t, :]
|
374 |
+
hidden_states = self.drop(token_embeddings + position_embeddings)
|
375 |
+
|
376 |
+
# Pass through each pipeline stage in sequence
|
377 |
+
for stage in self.pipeline_stages:
|
378 |
+
device_stage = next(stage.parameters()).device
|
379 |
+
hidden_states = hidden_states.to(device_stage)
|
380 |
+
hidden_states = stage(hidden_states)
|
381 |
+
|
382 |
+
hidden_states = hidden_states.to(last_device)
|
383 |
+
hidden_states = self.ln_f(hidden_states)
|
384 |
+
logits = self.head(hidden_states)
|
385 |
+
|
386 |
+
loss = None
|
387 |
+
if targets is not None:
|
388 |
+
targets = targets.to(last_device)
|
389 |
+
logits = logits.view(-1, logits.size(-1))
|
390 |
+
targets = targets.view(-1)
|
391 |
+
loss = F.cross_entropy(logits, targets)
|
392 |
+
|
393 |
+
return logits, loss
|
394 |
+
|
395 |
+
@torch.no_grad()
|
396 |
+
def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None):
|
397 |
+
self.eval()
|
398 |
+
if len(self.devices) == 0:
|
399 |
+
raise ValueError("No GPUs available for model parallelism.")
|
400 |
+
|
401 |
+
generated = input_ids.to(self.devices[0])
|
402 |
+
for _ in range(max_new_tokens):
|
403 |
+
if generated.shape[1] > self.config.block_size:
|
404 |
+
generated = generated[:, -self.config.block_size:]
|
405 |
+
|
406 |
+
logits, _ = self.forward(generated)
|
407 |
+
logits = logits[:, -1, :].to(self.devices[-1])
|
408 |
+
logits = logits / temperature
|
409 |
+
|
410 |
+
if top_k is not None:
|
411 |
+
values, _ = torch.topk(logits, top_k)
|
412 |
+
logits[logits < values[:, -1:]] = float('-inf')
|
413 |
+
|
414 |
+
probs = torch.softmax(logits, dim=-1)
|
415 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
416 |
+
next_token = next_token.to(self.devices[0])
|
417 |
+
generated = torch.cat((generated, next_token), dim=1)
|
418 |
+
|
419 |
+
return generated
|
420 |
+
|
421 |
+
#####################################
|
422 |
+
# Training Loop (Streaming OR Full-Pass Non-Streaming)
|
423 |
+
#####################################
|
424 |
+
|
425 |
+
def train_model_parallel(data_files, use_streaming=False):
|
426 |
+
"""
|
427 |
+
data_files should be a list of actual .arrow file paths, e.g.
|
428 |
+
["data/file1.arrow", "data/file2.arrow", ...]
|
429 |
+
|
430 |
+
Includes automatic batch size adjustment when OOM errors occur.
|
431 |
+
"""
|
432 |
+
# Initial batch size settings
|
433 |
+
initial_batch_size = 128 # initial batch size
|
434 |
+
min_batch_size = 12 # Minimum acceptable batch size
|
435 |
+
batch_size = initial_batch_size # Current working batch size
|
436 |
+
|
437 |
+
# 1) If no tokenizer, train it on text extracted from Arrow
|
438 |
+
if not os.path.exists("bpe_tokenizer/vocab.json"):
|
439 |
+
print("No existing tokenizer found. Building a text file from Arrow and training one...")
|
440 |
+
# Create a text file from Arrow files
|
441 |
+
text_file_path = "all_text_for_tokenizer.txt"
|
442 |
+
create_text_file_from_arrow(data_files, text_file_path)
|
443 |
+
# Now train BPE on that text file
|
444 |
+
train_bpe_tokenizer(text_file_path, vocab_size=12000)
|
445 |
+
|
446 |
+
# Load the tokenizer we just created (or found)
|
447 |
+
hf_tokenizer = load_bpe_tokenizer()
|
448 |
+
|
449 |
+
block_size = 2048
|
450 |
+
epochs = 5
|
451 |
+
n_layer = 12
|
452 |
+
n_head = 12
|
453 |
+
n_embd = 1296
|
454 |
+
dropout = 0.1
|
455 |
+
|
456 |
+
config_model = ArgonneConfig(
|
457 |
+
vocab_size=12000,
|
458 |
+
block_size=block_size,
|
459 |
+
n_layer=n_layer,
|
460 |
+
n_head=n_head,
|
461 |
+
n_embd=n_embd,
|
462 |
+
dropout=dropout
|
463 |
+
)
|
464 |
+
|
465 |
+
# Load non-streaming dataset once, outside the retry loop
|
466 |
+
tokenized_data = None
|
467 |
+
if not use_streaming:
|
468 |
+
print("=== Loading dataset in memory for a full pass approach ===")
|
469 |
+
tokenized_data = load_nonstream_data(data_files, hf_tokenizer, block_size, num_proc=128)
|
470 |
+
total_samples = len(tokenized_data)
|
471 |
+
print(f"Total tokenized samples: {total_samples}")
|
472 |
+
|
473 |
+
# Main training loop with batch size adjustment
|
474 |
+
while True:
|
475 |
+
print(f"\n=== Attempting training with batch_size = {batch_size} ===")
|
476 |
+
|
477 |
+
try:
|
478 |
+
# Initialize a fresh model for each attempt
|
479 |
+
model = ArgonneModelParallel(config_model)
|
480 |
+
model.distribute_model() # chunks across all visible GPUs
|
481 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
|
482 |
+
scaler = torch.amp.GradScaler("cuda")
|
483 |
+
global_step = 0
|
484 |
+
|
485 |
+
if use_streaming:
|
486 |
+
########################################################
|
487 |
+
# STREAMING MODE
|
488 |
+
########################################################
|
489 |
+
steps_per_epoch = 500
|
490 |
+
|
491 |
+
for epoch in tqdm(range(epochs)):
|
492 |
+
print(f"==== Starting epoch {epoch} (STREAMING) with batch_size={batch_size} ====")
|
493 |
+
token_gen = streaming_token_generator(data_files, hf_tokenizer)
|
494 |
+
step_in_epoch = 0
|
495 |
+
token_batch = []
|
496 |
+
|
497 |
+
while step_in_epoch < steps_per_epoch:
|
498 |
+
try:
|
499 |
+
tokens = next(token_gen)
|
500 |
+
token_batch.append(tokens)
|
501 |
+
|
502 |
+
if len(token_batch) == batch_size:
|
503 |
+
x_tens, y_tens = collate_batch(token_batch, block_size)
|
504 |
+
token_batch.clear()
|
505 |
+
if x_tens is None:
|
506 |
+
continue
|
507 |
+
|
508 |
+
first_device = model.devices[0]
|
509 |
+
x_tens, y_tens = x_tens.to(first_device), y_tens.to(first_device)
|
510 |
+
|
511 |
+
optimizer.zero_grad()
|
512 |
+
with torch.amp.autocast("cuda"):
|
513 |
+
logits, loss = model(x_tens, y_tens)
|
514 |
+
|
515 |
+
scaler.scale(loss).backward()
|
516 |
+
scaler.step(optimizer)
|
517 |
+
scaler.update()
|
518 |
+
|
519 |
+
global_step += 1
|
520 |
+
step_in_epoch += 1
|
521 |
+
|
522 |
+
if global_step % 50 == 0:
|
523 |
+
print(f"Epoch {epoch} | Step {global_step} | Loss: {loss.item():.4f}")
|
524 |
+
prompt_str = "Long long time ago, "
|
525 |
+
token_ids = hf_tokenizer.encode(prompt_str)
|
526 |
+
prompt_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
|
527 |
+
generated = model.generate(prompt_tensor, max_new_tokens=50)
|
528 |
+
generated_text = hf_tokenizer.decode(generated[0].tolist())
|
529 |
+
print(f"\n--- Generated text at step {global_step} ---\n{generated_text}\n")
|
530 |
+
|
531 |
+
if global_step % 10000 == 0:
|
532 |
+
checkpoint = {
|
533 |
+
"epoch": epoch,
|
534 |
+
"global_step": global_step,
|
535 |
+
"batch_size": batch_size, # Save the successful batch size
|
536 |
+
"model_state_dict": model.state_dict(),
|
537 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
538 |
+
"loss": loss.item()
|
539 |
+
}
|
540 |
+
os.makedirs("pretrained", exist_ok=True)
|
541 |
+
torch.save(checkpoint, f"pretrained/checkpoint_step_{global_step}.pth")
|
542 |
+
print(f"Checkpoint saved at step {global_step}")
|
543 |
+
|
544 |
+
except StopIteration:
|
545 |
+
print("Reached end of dataset (stream) before finishing this epoch.")
|
546 |
+
break
|
547 |
+
|
548 |
+
else:
|
549 |
+
########################################################
|
550 |
+
# NON-STREAMING MODE: full pass each epoch
|
551 |
+
########################################################
|
552 |
+
batches_per_epoch = total_samples // batch_size
|
553 |
+
|
554 |
+
for epoch in tqdm(range(epochs)):
|
555 |
+
print(f"==== Starting epoch {epoch} (NON-STREAMING) with batch_size={batch_size} ====")
|
556 |
+
|
557 |
+
for batch_idx in tqdm(range(batches_per_epoch)):
|
558 |
+
start_idx = batch_idx * batch_size
|
559 |
+
end_idx = start_idx + batch_size
|
560 |
+
batch_token_lists = tokenized_data[start_idx:end_idx]
|
561 |
+
|
562 |
+
x_tens, y_tens = collate_batch(batch_token_lists, block_size)
|
563 |
+
if x_tens is None:
|
564 |
+
continue
|
565 |
+
|
566 |
+
first_device = model.devices[0]
|
567 |
+
x_tens = x_tens.to(first_device)
|
568 |
+
y_tens = y_tens.to(first_device)
|
569 |
+
|
570 |
+
optimizer.zero_grad()
|
571 |
+
with torch.amp.autocast("cuda"):
|
572 |
+
logits, loss = model(x_tens, y_tens)
|
573 |
+
|
574 |
+
scaler.scale(loss).backward()
|
575 |
+
scaler.step(optimizer)
|
576 |
+
scaler.update()
|
577 |
+
|
578 |
+
global_step += 1
|
579 |
+
|
580 |
+
if global_step % 100 == 0:
|
581 |
+
print(f"Epoch {epoch} | global_step {global_step} | Loss: {loss.item():.4f}")
|
582 |
+
prompt_str = "Long long time ago, "
|
583 |
+
token_ids = hf_tokenizer.encode(prompt_str)
|
584 |
+
prompt_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
|
585 |
+
generated = model.generate(prompt_tensor, max_new_tokens=50)
|
586 |
+
generated_text = hf_tokenizer.decode(generated[0].tolist())
|
587 |
+
print(f"\n--- Generated text at step {global_step} ---\n{generated_text}\n")
|
588 |
+
|
589 |
+
if global_step % 2000 == 0:
|
590 |
+
checkpoint = {
|
591 |
+
"epoch": epoch,
|
592 |
+
"global_step": global_step,
|
593 |
+
"batch_size": batch_size, # Save the successful batch size
|
594 |
+
"model_state_dict": model.state_dict(),
|
595 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
596 |
+
"loss": loss.item()
|
597 |
+
}
|
598 |
+
os.makedirs("pretrained", exist_ok=True)
|
599 |
+
torch.save(checkpoint, f"pretrained/checkpoint_step_{global_step}.pth")
|
600 |
+
print(f"Checkpoint saved at step {global_step}")
|
601 |
+
|
602 |
+
# If we reach here, training completed successfully
|
603 |
+
print(f"Training completed successfully with batch_size={batch_size}")
|
604 |
+
break
|
605 |
+
|
606 |
+
except torch.cuda.OutOfMemoryError:
|
607 |
+
# Free memory
|
608 |
+
del model, optimizer, scaler
|
609 |
+
torch.cuda.empty_cache()
|
610 |
+
|
611 |
+
# Reduce batch size
|
612 |
+
new_batch_size = max(batch_size - 12, min_batch_size)
|
613 |
+
|
614 |
+
if new_batch_size == batch_size:
|
615 |
+
print(f"⚠️ Already at minimum batch size ({min_batch_size}). Training failed.")
|
616 |
+
break
|
617 |
+
|
618 |
+
print(f"CUDA Out of Memory! Reducing batch size from {batch_size} to {new_batch_size}")
|
619 |
+
batch_size = new_batch_size
|
620 |
+
|
621 |
+
# Short pause to ensure memory is freed
|
622 |
+
import time
|
623 |
+
time.sleep(5)
|
624 |
+
|
625 |
+
# Save final model and tokenizer
|
626 |
+
try:
|
627 |
+
model.save_pretrained("Argonne_LLM")
|
628 |
+
hf_tokenizer.save_pretrained("Argonne_LLM")
|
629 |
+
print("Model-parallel training complete; model and tokenizer saved successfully.")
|
630 |
+
except:
|
631 |
+
print("Failed to save final model, likely due to OOM issues.")
|
632 |
+
|
633 |
+
#####################################
|
634 |
+
# Register with Hugging Face Auto Classes
|
635 |
+
#####################################
|
636 |
+
|
637 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
638 |
+
|
639 |
+
# Register the model with Hugging Face's Auto classes
|
640 |
+
AutoConfig.register("argonne", ArgonneConfig)
|
641 |
+
AutoModel.register(ArgonneConfig, ArgonneModelParallel)
|
642 |
+
AutoModelForCausalLM.register(ArgonneConfig, ArgonneModelParallel)
|
643 |
+
|
644 |
+
|
645 |
+
def main():
|
646 |
+
# Expand .arrow files via glob
|
647 |
+
data_files = glob.glob("data/*.arrow")
|
648 |
+
if not data_files:
|
649 |
+
raise ValueError("No files matched the pattern 'data/*.arrow'")
|
650 |
+
|
651 |
+
train_model_parallel(data_files=data_files, use_streaming=False)
|
652 |
+
|
653 |
+
if __name__ == "__main__":
|
654 |
+
main()
|
pytorch_model.bin.index.json
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"format": "pt",
|
4 |
+
"total_size": 1304637312
|
5 |
+
},
|
6 |
+
"weight_map": {
|
7 |
+
"position_embedding": "pytorch_model.bin",
|
8 |
+
"token_embedding.weight": "pytorch_model.bin",
|
9 |
+
"ln_f.weight": "pytorch_model.bin",
|
10 |
+
"ln_f.bias": "pytorch_model.bin",
|
11 |
+
"head.weight": "pytorch_model.bin",
|
12 |
+
"blocks.0.ln1.weight": "pytorch_model.bin",
|
13 |
+
"blocks.0.ln1.bias": "pytorch_model.bin",
|
14 |
+
"blocks.0.attn.mask": "pytorch_model.bin",
|
15 |
+
"blocks.0.attn.query.weight": "pytorch_model.bin",
|
16 |
+
"blocks.0.attn.query.bias": "pytorch_model.bin",
|
17 |
+
"blocks.0.attn.key.weight": "pytorch_model.bin",
|
18 |
+
"blocks.0.attn.key.bias": "pytorch_model.bin",
|
19 |
+
"blocks.0.attn.value.weight": "pytorch_model.bin",
|
20 |
+
"blocks.0.attn.value.bias": "pytorch_model.bin",
|
21 |
+
"blocks.0.attn.proj.weight": "pytorch_model.bin",
|
22 |
+
"blocks.0.attn.proj.bias": "pytorch_model.bin",
|
23 |
+
"blocks.0.ln2.weight": "pytorch_model.bin",
|
24 |
+
"blocks.0.ln2.bias": "pytorch_model.bin",
|
25 |
+
"blocks.0.mlp.fc1.weight": "pytorch_model.bin",
|
26 |
+
"blocks.0.mlp.fc1.bias": "pytorch_model.bin",
|
27 |
+
"blocks.0.mlp.fc2.weight": "pytorch_model.bin",
|
28 |
+
"blocks.0.mlp.fc2.bias": "pytorch_model.bin",
|
29 |
+
"blocks.1.ln1.weight": "pytorch_model.bin",
|
30 |
+
"blocks.1.ln1.bias": "pytorch_model.bin",
|
31 |
+
"blocks.1.attn.mask": "pytorch_model.bin",
|
32 |
+
"blocks.1.attn.query.weight": "pytorch_model.bin",
|
33 |
+
"blocks.1.attn.query.bias": "pytorch_model.bin",
|
34 |
+
"blocks.1.attn.key.weight": "pytorch_model.bin",
|
35 |
+
"blocks.1.attn.key.bias": "pytorch_model.bin",
|
36 |
+
"blocks.1.attn.value.weight": "pytorch_model.bin",
|
37 |
+
"blocks.1.attn.value.bias": "pytorch_model.bin",
|
38 |
+
"blocks.1.attn.proj.weight": "pytorch_model.bin",
|
39 |
+
"blocks.1.attn.proj.bias": "pytorch_model.bin",
|
40 |
+
"blocks.1.ln2.weight": "pytorch_model.bin",
|
41 |
+
"blocks.1.ln2.bias": "pytorch_model.bin",
|
42 |
+
"blocks.1.mlp.fc1.weight": "pytorch_model.bin",
|
43 |
+
"blocks.1.mlp.fc1.bias": "pytorch_model.bin",
|
44 |
+
"blocks.1.mlp.fc2.weight": "pytorch_model.bin",
|
45 |
+
"blocks.1.mlp.fc2.bias": "pytorch_model.bin",
|
46 |
+
"blocks.2.ln1.weight": "pytorch_model.bin",
|
47 |
+
"blocks.2.ln1.bias": "pytorch_model.bin",
|
48 |
+
"blocks.2.attn.mask": "pytorch_model.bin",
|
49 |
+
"blocks.2.attn.query.weight": "pytorch_model.bin",
|
50 |
+
"blocks.2.attn.query.bias": "pytorch_model.bin",
|
51 |
+
"blocks.2.attn.key.weight": "pytorch_model.bin",
|
52 |
+
"blocks.2.attn.key.bias": "pytorch_model.bin",
|
53 |
+
"blocks.2.attn.value.weight": "pytorch_model.bin",
|
54 |
+
"blocks.2.attn.value.bias": "pytorch_model.bin",
|
55 |
+
"blocks.2.attn.proj.weight": "pytorch_model.bin",
|
56 |
+
"blocks.2.attn.proj.bias": "pytorch_model.bin",
|
57 |
+
"blocks.2.ln2.weight": "pytorch_model.bin",
|
58 |
+
"blocks.2.ln2.bias": "pytorch_model.bin",
|
59 |
+
"blocks.2.mlp.fc1.weight": "pytorch_model.bin",
|
60 |
+
"blocks.2.mlp.fc1.bias": "pytorch_model.bin",
|
61 |
+
"blocks.2.mlp.fc2.weight": "pytorch_model.bin",
|
62 |
+
"blocks.2.mlp.fc2.bias": "pytorch_model.bin",
|
63 |
+
"blocks.3.ln1.weight": "pytorch_model.bin",
|
64 |
+
"blocks.3.ln1.bias": "pytorch_model.bin",
|
65 |
+
"blocks.3.attn.mask": "pytorch_model.bin",
|
66 |
+
"blocks.3.attn.query.weight": "pytorch_model.bin",
|
67 |
+
"blocks.3.attn.query.bias": "pytorch_model.bin",
|
68 |
+
"blocks.3.attn.key.weight": "pytorch_model.bin",
|
69 |
+
"blocks.3.attn.key.bias": "pytorch_model.bin",
|
70 |
+
"blocks.3.attn.value.weight": "pytorch_model.bin",
|
71 |
+
"blocks.3.attn.value.bias": "pytorch_model.bin",
|
72 |
+
"blocks.3.attn.proj.weight": "pytorch_model.bin",
|
73 |
+
"blocks.3.attn.proj.bias": "pytorch_model.bin",
|
74 |
+
"blocks.3.ln2.weight": "pytorch_model.bin",
|
75 |
+
"blocks.3.ln2.bias": "pytorch_model.bin",
|
76 |
+
"blocks.3.mlp.fc1.weight": "pytorch_model.bin",
|
77 |
+
"blocks.3.mlp.fc1.bias": "pytorch_model.bin",
|
78 |
+
"blocks.3.mlp.fc2.weight": "pytorch_model.bin",
|
79 |
+
"blocks.3.mlp.fc2.bias": "pytorch_model.bin",
|
80 |
+
"blocks.4.ln1.weight": "pytorch_model.bin",
|
81 |
+
"blocks.4.ln1.bias": "pytorch_model.bin",
|
82 |
+
"blocks.4.attn.mask": "pytorch_model.bin",
|
83 |
+
"blocks.4.attn.query.weight": "pytorch_model.bin",
|
84 |
+
"blocks.4.attn.query.bias": "pytorch_model.bin",
|
85 |
+
"blocks.4.attn.key.weight": "pytorch_model.bin",
|
86 |
+
"blocks.4.attn.key.bias": "pytorch_model.bin",
|
87 |
+
"blocks.4.attn.value.weight": "pytorch_model.bin",
|
88 |
+
"blocks.4.attn.value.bias": "pytorch_model.bin",
|
89 |
+
"blocks.4.attn.proj.weight": "pytorch_model.bin",
|
90 |
+
"blocks.4.attn.proj.bias": "pytorch_model.bin",
|
91 |
+
"blocks.4.ln2.weight": "pytorch_model.bin",
|
92 |
+
"blocks.4.ln2.bias": "pytorch_model.bin",
|
93 |
+
"blocks.4.mlp.fc1.weight": "pytorch_model.bin",
|
94 |
+
"blocks.4.mlp.fc1.bias": "pytorch_model.bin",
|
95 |
+
"blocks.4.mlp.fc2.weight": "pytorch_model.bin",
|
96 |
+
"blocks.4.mlp.fc2.bias": "pytorch_model.bin",
|
97 |
+
"blocks.5.ln1.weight": "pytorch_model.bin",
|
98 |
+
"blocks.5.ln1.bias": "pytorch_model.bin",
|
99 |
+
"blocks.5.attn.mask": "pytorch_model.bin",
|
100 |
+
"blocks.5.attn.query.weight": "pytorch_model.bin",
|
101 |
+
"blocks.5.attn.query.bias": "pytorch_model.bin",
|
102 |
+
"blocks.5.attn.key.weight": "pytorch_model.bin",
|
103 |
+
"blocks.5.attn.key.bias": "pytorch_model.bin",
|
104 |
+
"blocks.5.attn.value.weight": "pytorch_model.bin",
|
105 |
+
"blocks.5.attn.value.bias": "pytorch_model.bin",
|
106 |
+
"blocks.5.attn.proj.weight": "pytorch_model.bin",
|
107 |
+
"blocks.5.attn.proj.bias": "pytorch_model.bin",
|
108 |
+
"blocks.5.ln2.weight": "pytorch_model.bin",
|
109 |
+
"blocks.5.ln2.bias": "pytorch_model.bin",
|
110 |
+
"blocks.5.mlp.fc1.weight": "pytorch_model.bin",
|
111 |
+
"blocks.5.mlp.fc1.bias": "pytorch_model.bin",
|
112 |
+
"blocks.5.mlp.fc2.weight": "pytorch_model.bin",
|
113 |
+
"blocks.5.mlp.fc2.bias": "pytorch_model.bin",
|
114 |
+
"blocks.6.ln1.weight": "pytorch_model.bin",
|
115 |
+
"blocks.6.ln1.bias": "pytorch_model.bin",
|
116 |
+
"blocks.6.attn.mask": "pytorch_model.bin",
|
117 |
+
"blocks.6.attn.query.weight": "pytorch_model.bin",
|
118 |
+
"blocks.6.attn.query.bias": "pytorch_model.bin",
|
119 |
+
"blocks.6.attn.key.weight": "pytorch_model.bin",
|
120 |
+
"blocks.6.attn.key.bias": "pytorch_model.bin",
|
121 |
+
"blocks.6.attn.value.weight": "pytorch_model.bin",
|
122 |
+
"blocks.6.attn.value.bias": "pytorch_model.bin",
|
123 |
+
"blocks.6.attn.proj.weight": "pytorch_model.bin",
|
124 |
+
"blocks.6.attn.proj.bias": "pytorch_model.bin",
|
125 |
+
"blocks.6.ln2.weight": "pytorch_model.bin",
|
126 |
+
"blocks.6.ln2.bias": "pytorch_model.bin",
|
127 |
+
"blocks.6.mlp.fc1.weight": "pytorch_model.bin",
|
128 |
+
"blocks.6.mlp.fc1.bias": "pytorch_model.bin",
|
129 |
+
"blocks.6.mlp.fc2.weight": "pytorch_model.bin",
|
130 |
+
"blocks.6.mlp.fc2.bias": "pytorch_model.bin",
|
131 |
+
"blocks.7.ln1.weight": "pytorch_model.bin",
|
132 |
+
"blocks.7.ln1.bias": "pytorch_model.bin",
|
133 |
+
"blocks.7.attn.mask": "pytorch_model.bin",
|
134 |
+
"blocks.7.attn.query.weight": "pytorch_model.bin",
|
135 |
+
"blocks.7.attn.query.bias": "pytorch_model.bin",
|
136 |
+
"blocks.7.attn.key.weight": "pytorch_model.bin",
|
137 |
+
"blocks.7.attn.key.bias": "pytorch_model.bin",
|
138 |
+
"blocks.7.attn.value.weight": "pytorch_model.bin",
|
139 |
+
"blocks.7.attn.value.bias": "pytorch_model.bin",
|
140 |
+
"blocks.7.attn.proj.weight": "pytorch_model.bin",
|
141 |
+
"blocks.7.attn.proj.bias": "pytorch_model.bin",
|
142 |
+
"blocks.7.ln2.weight": "pytorch_model.bin",
|
143 |
+
"blocks.7.ln2.bias": "pytorch_model.bin",
|
144 |
+
"blocks.7.mlp.fc1.weight": "pytorch_model.bin",
|
145 |
+
"blocks.7.mlp.fc1.bias": "pytorch_model.bin",
|
146 |
+
"blocks.7.mlp.fc2.weight": "pytorch_model.bin",
|
147 |
+
"blocks.7.mlp.fc2.bias": "pytorch_model.bin",
|
148 |
+
"blocks.8.ln1.weight": "pytorch_model.bin",
|
149 |
+
"blocks.8.ln1.bias": "pytorch_model.bin",
|
150 |
+
"blocks.8.attn.mask": "pytorch_model.bin",
|
151 |
+
"blocks.8.attn.query.weight": "pytorch_model.bin",
|
152 |
+
"blocks.8.attn.query.bias": "pytorch_model.bin",
|
153 |
+
"blocks.8.attn.key.weight": "pytorch_model.bin",
|
154 |
+
"blocks.8.attn.key.bias": "pytorch_model.bin",
|
155 |
+
"blocks.8.attn.value.weight": "pytorch_model.bin",
|
156 |
+
"blocks.8.attn.value.bias": "pytorch_model.bin",
|
157 |
+
"blocks.8.attn.proj.weight": "pytorch_model.bin",
|
158 |
+
"blocks.8.attn.proj.bias": "pytorch_model.bin",
|
159 |
+
"blocks.8.ln2.weight": "pytorch_model.bin",
|
160 |
+
"blocks.8.ln2.bias": "pytorch_model.bin",
|
161 |
+
"blocks.8.mlp.fc1.weight": "pytorch_model.bin",
|
162 |
+
"blocks.8.mlp.fc1.bias": "pytorch_model.bin",
|
163 |
+
"blocks.8.mlp.fc2.weight": "pytorch_model.bin",
|
164 |
+
"blocks.8.mlp.fc2.bias": "pytorch_model.bin",
|
165 |
+
"blocks.9.ln1.weight": "pytorch_model.bin",
|
166 |
+
"blocks.9.ln1.bias": "pytorch_model.bin",
|
167 |
+
"blocks.9.attn.mask": "pytorch_model.bin",
|
168 |
+
"blocks.9.attn.query.weight": "pytorch_model.bin",
|
169 |
+
"blocks.9.attn.query.bias": "pytorch_model.bin",
|
170 |
+
"blocks.9.attn.key.weight": "pytorch_model.bin",
|
171 |
+
"blocks.9.attn.key.bias": "pytorch_model.bin",
|
172 |
+
"blocks.9.attn.value.weight": "pytorch_model.bin",
|
173 |
+
"blocks.9.attn.value.bias": "pytorch_model.bin",
|
174 |
+
"blocks.9.attn.proj.weight": "pytorch_model.bin",
|
175 |
+
"blocks.9.attn.proj.bias": "pytorch_model.bin",
|
176 |
+
"blocks.9.ln2.weight": "pytorch_model.bin",
|
177 |
+
"blocks.9.ln2.bias": "pytorch_model.bin",
|
178 |
+
"blocks.9.mlp.fc1.weight": "pytorch_model.bin",
|
179 |
+
"blocks.9.mlp.fc1.bias": "pytorch_model.bin",
|
180 |
+
"blocks.9.mlp.fc2.weight": "pytorch_model.bin",
|
181 |
+
"blocks.9.mlp.fc2.bias": "pytorch_model.bin",
|
182 |
+
"blocks.10.ln1.weight": "pytorch_model.bin",
|
183 |
+
"blocks.10.ln1.bias": "pytorch_model.bin",
|
184 |
+
"blocks.10.attn.mask": "pytorch_model.bin",
|
185 |
+
"blocks.10.attn.query.weight": "pytorch_model.bin",
|
186 |
+
"blocks.10.attn.query.bias": "pytorch_model.bin",
|
187 |
+
"blocks.10.attn.key.weight": "pytorch_model.bin",
|
188 |
+
"blocks.10.attn.key.bias": "pytorch_model.bin",
|
189 |
+
"blocks.10.attn.value.weight": "pytorch_model.bin",
|
190 |
+
"blocks.10.attn.value.bias": "pytorch_model.bin",
|
191 |
+
"blocks.10.attn.proj.weight": "pytorch_model.bin",
|
192 |
+
"blocks.10.attn.proj.bias": "pytorch_model.bin",
|
193 |
+
"blocks.10.ln2.weight": "pytorch_model.bin",
|
194 |
+
"blocks.10.ln2.bias": "pytorch_model.bin",
|
195 |
+
"blocks.10.mlp.fc1.weight": "pytorch_model.bin",
|
196 |
+
"blocks.10.mlp.fc1.bias": "pytorch_model.bin",
|
197 |
+
"blocks.10.mlp.fc2.weight": "pytorch_model.bin",
|
198 |
+
"blocks.10.mlp.fc2.bias": "pytorch_model.bin",
|
199 |
+
"blocks.11.ln1.weight": "pytorch_model.bin",
|
200 |
+
"blocks.11.ln1.bias": "pytorch_model.bin",
|
201 |
+
"blocks.11.attn.mask": "pytorch_model.bin",
|
202 |
+
"blocks.11.attn.query.weight": "pytorch_model.bin",
|
203 |
+
"blocks.11.attn.query.bias": "pytorch_model.bin",
|
204 |
+
"blocks.11.attn.key.weight": "pytorch_model.bin",
|
205 |
+
"blocks.11.attn.key.bias": "pytorch_model.bin",
|
206 |
+
"blocks.11.attn.value.weight": "pytorch_model.bin",
|
207 |
+
"blocks.11.attn.value.bias": "pytorch_model.bin",
|
208 |
+
"blocks.11.attn.proj.weight": "pytorch_model.bin",
|
209 |
+
"blocks.11.attn.proj.bias": "pytorch_model.bin",
|
210 |
+
"blocks.11.ln2.weight": "pytorch_model.bin",
|
211 |
+
"blocks.11.ln2.bias": "pytorch_model.bin",
|
212 |
+
"blocks.11.mlp.fc1.weight": "pytorch_model.bin",
|
213 |
+
"blocks.11.mlp.fc1.bias": "pytorch_model.bin",
|
214 |
+
"blocks.11.mlp.fc2.weight": "pytorch_model.bin",
|
215 |
+
"blocks.11.mlp.fc2.bias": "pytorch_model.bin"
|
216 |
+
}
|
217 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|