import os import math import json import torch import torch.nn as nn import torch.nn.functional as F from tokenizers import ByteLevelBPETokenizer from transformers import ( PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel ) from tqdm import tqdm from datasets import load_dataset, load_from_disk import glob os.environ["HF_DATASETS_CACHE"] = "./.cache" ##################################### # BPE Tokenizer Utilities ##################################### def create_text_file_from_arrow(arrow_files, output_file="all_text_for_tokenizer.txt"): """ Given a list of Arrow files, extract the 'text' column and write it to a single text file (one text example per line). """ print(f"Creating a combined text file '{output_file}' from Arrow files...") with open(output_file, "w", encoding="utf-8") as wf: for arrow_path in tqdm(arrow_files): # Load the Arrow file in *streaming* mode to avoid large memory usage ds = load_dataset("arrow", data_files=[arrow_path], streaming=True) # If "train" split exists, use ds["train"], else ds is the dataset if "train" in ds: ds = ds["train"] for example in ds: text = example.get("text", "") # Write one line of text wf.write(text.replace("\n", " ") + "\n") def train_bpe_tokenizer(text_file, vocab_size=12000): """ Train a ByteLevel BPE tokenizer on a *plain-text file* and save it. """ tokenizer = ByteLevelBPETokenizer() tokenizer.train( files=[text_file], vocab_size=vocab_size, min_frequency=2, special_tokens=[ "<|start_of_text|>", "", "<|end_of_text|>", "", "" ] ) os.makedirs("bpe_tokenizer", exist_ok=True) tokenizer.save_model("bpe_tokenizer") # Save the full tokenizer JSON representation with open(os.path.join("bpe_tokenizer", "tokenizer.json"), "w", encoding="utf-8") as f: f.write(tokenizer._tokenizer.to_str()) # Create a tokenizer configuration tokenizer_config = { "model_max_length": 2048, "bos_token": "<|start_of_text|>", "eos_token": "<|end_of_text|>", "unk_token": "", "pad_token": "", "mask_token": "" } with open(os.path.join("bpe_tokenizer", "tokenizer_config.json"), "w") as f: json.dump(tokenizer_config, f) # Create a Hugging Face PreTrainedTokenizerFast instance hf_tokenizer = PreTrainedTokenizerFast( tokenizer_file=os.path.join("bpe_tokenizer", "tokenizer.json"), bos_token="<|start_of_text|>", eos_token="<|end_of_text|>", unk_token="", pad_token="", mask_token="" ) hf_tokenizer.save_pretrained("bpe_tokenizer") return hf_tokenizer def load_bpe_tokenizer(): """Load a previously trained BPE tokenizer in Hugging Face format.""" hf_tokenizer = PreTrainedTokenizerFast.from_pretrained("bpe_tokenizer", use_fast=True) return hf_tokenizer ##################################### # STREAMING MODE ##################################### def streaming_token_generator(data_files, hf_tokenizer): """ Yields tokenized examples from a streaming dataset (no shuffle). data_files should be a list of Arrow files. """ dataset = load_dataset("arrow", data_files=data_files, streaming=True) if "train" in dataset: dataset = dataset["train"] for example in dataset: text = example["text"] if "text" in example else "" token_ids = hf_tokenizer.encode(text) if len(token_ids) > 0: yield token_ids ##################################### # NON-STREAMING: Full Pass ##################################### def load_nonstream_data(data_files, hf_tokenizer, block_size, num_proc=8): """ Loads the entire dataset in memory either from a cached processed directory or processes it in parallel if not yet cached. Returns a list of token ID sequences. """ processed_dir = "processed_data/tokenized_data" if os.path.exists(processed_dir): print(f"Loading cached dataset from '{processed_dir}'...") ds = load_from_disk(processed_dir) tokenized_data = ds["token_ids"] return tokenized_data print("No cached dataset found. Processing in parallel...") ds_dict = load_dataset("arrow", data_files=data_files, streaming=False) if "train" in ds_dict: ds = ds_dict["train"] else: ds = ds_dict def tokenize_and_truncate(example): text = example["text"] if "text" in example else "" token_ids = hf_tokenizer.encode(text) if len(token_ids) < block_size + 1: return {"token_ids": None} token_ids = token_ids[:block_size+1] return {"token_ids": token_ids} ds = ds.map( tokenize_and_truncate, batched=False, num_proc=num_proc ) ds = ds.filter(lambda ex: ex["token_ids"] is not None, num_proc=num_proc) if "text" in ds.column_names: ds = ds.remove_columns(["text"]) os.makedirs(os.path.dirname(processed_dir), exist_ok=True) ds.save_to_disk(processed_dir) print(f"Processed dataset saved to '{processed_dir}'.") tokenized_data = ds["token_ids"] return tokenized_data def collate_batch(token_list_batch, block_size): """ Convert a list of token-ID lists into x,y Tensors for causal LM. We'll truncate if longer than block_size+1, skip if shorter. """ x_list, y_list = [], [] for tokens in token_list_batch: if len(tokens) < block_size + 1: continue tokens = tokens[:block_size+1] x_list.append(tokens[:-1]) y_list.append(tokens[1:]) if not x_list: return None, None x_tensor = torch.tensor(x_list, dtype=torch.long) y_tensor = torch.tensor(y_list, dtype=torch.long) return x_tensor, y_tensor ##################################### # Model Definition ##################################### class ArgonneConfig(PretrainedConfig): model_type = "argonne" def __init__(self, vocab_size=12000, block_size=2048, n_layer=24, n_head=24, n_embd=1296, dropout=0.1, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.block_size = block_size self.n_layer = n_layer self.n_head = n_head self.n_embd = n_embd self.dropout = dropout class Block(nn.Module): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by n_head" self.n_head = config.n_head self.head_dim = config.n_embd // config.n_head self.query = nn.Linear(config.n_embd, config.n_embd) self.key = nn.Linear(config.n_embd, config.n_embd) self.value = nn.Linear(config.n_embd, config.n_embd) self.attn_drop = nn.Dropout(config.dropout) self.resid_drop = nn.Dropout(config.dropout) self.proj = nn.Linear(config.n_embd, config.n_embd) self.register_buffer( "mask", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size) ) def forward(self, x): b, t, c = x.size() q = self.query(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2) k = self.key(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2) v = self.value(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float('-inf')) att = torch.softmax(att, dim=-1) att = self.attn_drop(att) y = att @ v y = y.transpose(1, 2).contiguous().view(b, t, c) y = self.resid_drop(self.proj(y)) return y class MLP(nn.Module): def __init__(self, config): super().__init__() self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd) self.act = nn.GELU() self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd) self.drop = nn.Dropout(config.dropout) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class ArgonneModelParallel(PreTrainedModel): config_class = ArgonneConfig def __init__(self, config): super().__init__(config) # Create embeddings on CPU initially self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd) self.position_embedding = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) self.drop = nn.Dropout(config.dropout) # Build all blocks on CPU all_blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) # Final LayerNorm + output head self.ln_f = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) nn.init.normal_(self.position_embedding, mean=0.0, std=0.02) self.post_init() # Keep the blocks on CPU in a single ModuleList self.blocks = all_blocks # We'll defer pipeline splitting until later: self.pipeline_stages = None def distribute_model(self, device_ids=None): """ Distribute the model blocks across multiple GPU devices in a pipeline style. If 'device_ids' is None, we'll discover all available GPUs. """ if device_ids is None: num_gpus = torch.cuda.device_count() if num_gpus < 1: raise ValueError("No GPUs found—can't do pipeline parallel on CPU only.") device_ids = [f"cuda:{i}" for i in range(num_gpus)] # Store them so the training loop can keep referencing model.devices self.devices = [torch.device(d) for d in device_ids] self.pipeline_stages = nn.ModuleList() num_gpus = len(device_ids) blocks_per_gpu = math.ceil(len(self.blocks) / num_gpus) start_idx = 0 for i in range(num_gpus): end_idx = min(start_idx + blocks_per_gpu, len(self.blocks)) stage_blocks = self.blocks[start_idx:end_idx] stage = nn.Sequential(*stage_blocks).to(device_ids[i]) self.pipeline_stages.append(stage) start_idx = end_idx if end_idx >= len(self.blocks): break # Move token_embedding + position_embedding to the first device self.token_embedding.to(device_ids[0]) self.position_embedding.data = self.position_embedding.data.to(device_ids[0]) self.drop.to(device_ids[0]) # Move final LayerNorm + head to the last device self.ln_f.to(device_ids[-1]) self.head.to(device_ids[-1]) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): """ If self.pipeline_stages is None, we do a normal single-device forward (whatever device everything is currently on—CPU or a single GPU). Otherwise, we do a pipeline parallel forward. """ if self.pipeline_stages is None: # Single-device forward pass device = self.token_embedding.weight.device idx = idx.to(device) b, t = idx.size() assert t <= self.config.block_size, "Sequence length exceeds block size" token_embeddings = self.token_embedding(idx) position_embeddings = self.position_embedding[:, :t, :] hidden_states = self.drop(token_embeddings + position_embeddings) for block in self.blocks: hidden_states = block(hidden_states) hidden_states = self.ln_f(hidden_states) logits = self.head(hidden_states) loss = None if targets is not None: targets = targets.to(device) logits = logits.view(-1, logits.size(-1)) targets = targets.view(-1) loss = F.cross_entropy(logits, targets) return logits, loss else: # Pipeline parallel forward first_device = next(self.pipeline_stages[0].parameters()).device last_device = next(self.pipeline_stages[-1].parameters()).device x = idx.to(first_device) b, t = x.size() assert t <= self.config.block_size, "Sequence length exceeds block size" token_embeddings = self.token_embedding(x) position_embeddings = self.position_embedding[:, :t, :] hidden_states = self.drop(token_embeddings + position_embeddings) # Pass through each pipeline stage in sequence for stage in self.pipeline_stages: device_stage = next(stage.parameters()).device hidden_states = hidden_states.to(device_stage) hidden_states = stage(hidden_states) hidden_states = hidden_states.to(last_device) hidden_states = self.ln_f(hidden_states) logits = self.head(hidden_states) loss = None if targets is not None: targets = targets.to(last_device) logits = logits.view(-1, logits.size(-1)) targets = targets.view(-1) loss = F.cross_entropy(logits, targets) return logits, loss @torch.no_grad() def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None): self.eval() if len(self.devices) == 0: raise ValueError("No GPUs available for model parallelism.") generated = input_ids.to(self.devices[0]) for _ in range(max_new_tokens): if generated.shape[1] > self.config.block_size: generated = generated[:, -self.config.block_size:] logits, _ = self.forward(generated) logits = logits[:, -1, :].to(self.devices[-1]) logits = logits / temperature if top_k is not None: values, _ = torch.topk(logits, top_k) logits[logits < values[:, -1:]] = float('-inf') probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = next_token.to(self.devices[0]) generated = torch.cat((generated, next_token), dim=1) return generated ##################################### # Training Loop (Streaming OR Full-Pass Non-Streaming) ##################################### def train_model_parallel(data_files, use_streaming=False): """ data_files should be a list of actual .arrow file paths, e.g. ["data/file1.arrow", "data/file2.arrow", ...] Includes automatic batch size adjustment when OOM errors occur. """ # Initial batch size settings initial_batch_size = 128 # initial batch size min_batch_size = 12 # Minimum acceptable batch size batch_size = initial_batch_size # Current working batch size # 1) If no tokenizer, train it on text extracted from Arrow if not os.path.exists("bpe_tokenizer/vocab.json"): print("No existing tokenizer found. Building a text file from Arrow and training one...") # Create a text file from Arrow files text_file_path = "all_text_for_tokenizer.txt" create_text_file_from_arrow(data_files, text_file_path) # Now train BPE on that text file train_bpe_tokenizer(text_file_path, vocab_size=12000) # Load the tokenizer we just created (or found) hf_tokenizer = load_bpe_tokenizer() block_size = 2048 epochs = 5 n_layer = 12 n_head = 12 n_embd = 1296 dropout = 0.1 config_model = ArgonneConfig( vocab_size=12000, block_size=block_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, dropout=dropout ) # Load non-streaming dataset once, outside the retry loop tokenized_data = None if not use_streaming: print("=== Loading dataset in memory for a full pass approach ===") tokenized_data = load_nonstream_data(data_files, hf_tokenizer, block_size, num_proc=128) total_samples = len(tokenized_data) print(f"Total tokenized samples: {total_samples}") # Main training loop with batch size adjustment while True: print(f"\n=== Attempting training with batch_size = {batch_size} ===") try: # Initialize a fresh model for each attempt model = ArgonneModelParallel(config_model) model.distribute_model() # chunks across all visible GPUs optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5) scaler = torch.amp.GradScaler("cuda") global_step = 0 if use_streaming: ######################################################## # STREAMING MODE ######################################################## steps_per_epoch = 500 for epoch in tqdm(range(epochs)): print(f"==== Starting epoch {epoch} (STREAMING) with batch_size={batch_size} ====") token_gen = streaming_token_generator(data_files, hf_tokenizer) step_in_epoch = 0 token_batch = [] while step_in_epoch < steps_per_epoch: try: tokens = next(token_gen) token_batch.append(tokens) if len(token_batch) == batch_size: x_tens, y_tens = collate_batch(token_batch, block_size) token_batch.clear() if x_tens is None: continue first_device = model.devices[0] x_tens, y_tens = x_tens.to(first_device), y_tens.to(first_device) optimizer.zero_grad() with torch.amp.autocast("cuda"): logits, loss = model(x_tens, y_tens) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() global_step += 1 step_in_epoch += 1 if global_step % 50 == 0: print(f"Epoch {epoch} | Step {global_step} | Loss: {loss.item():.4f}") prompt_str = "Long long time ago, " token_ids = hf_tokenizer.encode(prompt_str) prompt_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) generated = model.generate(prompt_tensor, max_new_tokens=50) generated_text = hf_tokenizer.decode(generated[0].tolist()) print(f"\n--- Generated text at step {global_step} ---\n{generated_text}\n") if global_step % 10000 == 0: checkpoint = { "epoch": epoch, "global_step": global_step, "batch_size": batch_size, # Save the successful batch size "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss.item() } os.makedirs("pretrained", exist_ok=True) torch.save(checkpoint, f"pretrained/checkpoint_step_{global_step}.pth") print(f"Checkpoint saved at step {global_step}") except StopIteration: print("Reached end of dataset (stream) before finishing this epoch.") break else: ######################################################## # NON-STREAMING MODE: full pass each epoch ######################################################## batches_per_epoch = total_samples // batch_size for epoch in tqdm(range(epochs)): print(f"==== Starting epoch {epoch} (NON-STREAMING) with batch_size={batch_size} ====") for batch_idx in tqdm(range(batches_per_epoch)): start_idx = batch_idx * batch_size end_idx = start_idx + batch_size batch_token_lists = tokenized_data[start_idx:end_idx] x_tens, y_tens = collate_batch(batch_token_lists, block_size) if x_tens is None: continue first_device = model.devices[0] x_tens = x_tens.to(first_device) y_tens = y_tens.to(first_device) optimizer.zero_grad() with torch.amp.autocast("cuda"): logits, loss = model(x_tens, y_tens) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() global_step += 1 if global_step % 100 == 0: print(f"Epoch {epoch} | global_step {global_step} | Loss: {loss.item():.4f}") prompt_str = "Long long time ago, " token_ids = hf_tokenizer.encode(prompt_str) prompt_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) generated = model.generate(prompt_tensor, max_new_tokens=50) generated_text = hf_tokenizer.decode(generated[0].tolist()) print(f"\n--- Generated text at step {global_step} ---\n{generated_text}\n") if global_step % 2000 == 0: checkpoint = { "epoch": epoch, "global_step": global_step, "batch_size": batch_size, # Save the successful batch size "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss.item() } os.makedirs("pretrained", exist_ok=True) torch.save(checkpoint, f"pretrained/checkpoint_step_{global_step}.pth") print(f"Checkpoint saved at step {global_step}") # If we reach here, training completed successfully print(f"Training completed successfully with batch_size={batch_size}") break except torch.cuda.OutOfMemoryError: # Free memory del model, optimizer, scaler torch.cuda.empty_cache() # Reduce batch size new_batch_size = max(batch_size - 12, min_batch_size) if new_batch_size == batch_size: print(f"⚠️ Already at minimum batch size ({min_batch_size}). Training failed.") break print(f"CUDA Out of Memory! Reducing batch size from {batch_size} to {new_batch_size}") batch_size = new_batch_size # Short pause to ensure memory is freed import time time.sleep(5) # Save final model and tokenizer try: model.save_pretrained("Argonne_LLM") hf_tokenizer.save_pretrained("Argonne_LLM") print("Model-parallel training complete; model and tokenizer saved successfully.") except: print("Failed to save final model, likely due to OOM issues.") ##################################### # Register with Hugging Face Auto Classes ##################################### from transformers import AutoConfig, AutoModel, AutoModelForCausalLM # Register the model with Hugging Face's Auto classes AutoConfig.register("argonne", ArgonneConfig) AutoModel.register(ArgonneConfig, ArgonneModelParallel) AutoModelForCausalLM.register(ArgonneConfig, ArgonneModelParallel) def main(): # Expand .arrow files via glob data_files = glob.glob("data/*.arrow") if not data_files: raise ValueError("No files matched the pattern 'data/*.arrow'") train_model_parallel(data_files=data_files, use_streaming=False) if __name__ == "__main__": main()