import os import math import time import torch import torch.nn as nn from torch.nn import functional as F import wandb import gradio as gr from tqdm import tqdm import tiktoken from transformer import GPT, GPTConfig # Import from transformer.py instead from torch.cuda.amp import autocast, GradScaler # DataLoader class for handling input.txt class DataLoaderLite: def __init__(self, B, T, config): self.B = B self.T = T self.config = config # Load and tokenize input.txt with open('input.txt', 'r', encoding='utf-8') as f: text = f.read() enc = tiktoken.get_encoding('gpt2') self.tokens = torch.tensor(enc.encode(text), dtype=torch.long) # Create dataset chunks for faster loading self.data = [] for i in range(0, len(self.tokens) - T, B * T): chunk = self.tokens[i:i + B * T + 1] if len(chunk) == B * T + 1: self.data.append(chunk) print(f'Loaded {len(self.tokens)} tokens') print(f'Created {len(self.data)} batches') self.current_idx = 0 def next_batch(self): chunk = self.data[self.current_idx] x = chunk[:-1].view(self.B, self.T) y = chunk[1:].view(self.B, self.T) self.current_idx = (self.current_idx + 1) % len(self.data) if self.config.pin_memory: x = x.pin_memory() y = y.pin_memory() return x, y class TrainingConfig: def __init__(self): # Smaller model architecture (~30M params) self.n_layer = 4 # Further reduced self.n_head = 8 self.n_embd = 384 # Further reduced self.block_size = 256 self.dropout = 0.2 # Increased dropout for better regularization # Optimized training hyperparameters for faster convergence self.learning_rate = 1e-4 # Reduced learning rate for stability self.max_iters = 50000 # Increased max iterations self.batch_size = 4 # Reduced batch size self.grad_clip = 0.5 # Reduced gradient clipping self.weight_decay = 0.1 self.betas = (0.9, 0.95) self.warmup_iters = 2000 self.lr_decay_iters = 40000 # Increased decay iterations self.min_lr = 1e-5 self.eval_interval = 100 # More frequent evaluation self.eval_iters = 20 # Performance optimization flags self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.gradient_checkpointing = True self.mixed_precision = True self.gradient_accumulation_steps = 8 # Increased for effective batch size self.num_workers = 4 self.pin_memory = True # Check if Triton is available before enabling compile try: import triton self.compile_model = True except ImportError: print("Triton not available, disabling model compilation") self.compile_model = False class TrainingLogger: def __init__(self, log_file='training_log.txt'): self.log_file = log_file self.start_time = time.time() # Initialize log file with open(self.log_file, 'w') as f: f.write("Training Log\n") f.write("=" * 50 + "\n") f.write(f"Training started at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") f.write("Iteration | Train Loss | Val Loss | Learning Rate | Tokens/sec\n") f.write("-" * 65 + "\n") def log_step(self, iter_num, train_loss, val_loss, lr, tokens_per_sec): log_line = f"{iter_num:>9} | {train_loss:>10.4f} | {val_loss:>8.4f} | {lr:>12.2e} | {tokens_per_sec:>9.2f}" print(log_line) with open(self.log_file, 'a') as f: f.write(log_line + "\n") def log_message(self, message): print(message) with open(self.log_file, 'a') as f: f.write("\n" + message + "\n") def finish(self): total_time = (time.time() - self.start_time) / 3600 # Convert to hours message = f"\nTraining completed in {total_time:.2f} hours" self.log_message(message) def get_lr(it, config): if it < config.warmup_iters: return config.learning_rate * it / config.warmup_iters if it > config.lr_decay_iters: return config.min_lr decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return config.min_lr + coeff * (config.learning_rate - config.min_lr) def evaluate_loss(model, train_loader, config): model.eval() total_loss = 0.0 with torch.no_grad(): for _ in range(config.eval_iters): x, y = train_loader.next_batch() x, y = x.to(config.device), y.to(config.device) _, loss = model(x, y) total_loss += loss.item() model.train() return total_loss / config.eval_iters def train_model(): config = TrainingConfig() logger = TrainingLogger() # Create and optimize model model_config = GPTConfig( block_size=config.block_size, n_layer=config.n_layer, n_head=config.n_head, n_embd=config.n_embd, dropout=config.dropout ) model = GPT(model_config) if config.compile_model and hasattr(torch, 'compile'): try: model = torch.compile(model) logger.log_message("Model compilation successful") except Exception as e: logger.log_message(f"Model compilation failed: {e}") logger.log_message("Continuing without compilation") if config.gradient_checkpointing: model.gradient_checkpointing_enable() model.to(config.device) logger.log_message(f"Number of parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, betas=config.betas, weight_decay=config.weight_decay ) train_loader = DataLoaderLite(B=config.batch_size, T=config.block_size, config=config) scaler = GradScaler() if config.mixed_precision else None best_val_loss = float('inf') no_improvement_count = 0 for iter in tqdm(range(config.max_iters)): iter_start = time.time() # Training step x, y = train_loader.next_batch() x, y = x.to(config.device, non_blocking=True), y.to(config.device, non_blocking=True) lr = get_lr(iter, config) for param_group in optimizer.param_groups: param_group['lr'] = lr if config.mixed_precision: with autocast(): logits, loss = model(x, y) loss = loss / config.gradient_accumulation_steps scaler.scale(loss).backward() if (iter + 1) % config.gradient_accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) else: logits, loss = model(x, y) loss = loss / config.gradient_accumulation_steps loss.backward() if (iter + 1) % config.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) optimizer.step() optimizer.zero_grad(set_to_none=True) # Calculate metrics iter_time = time.time() - iter_start tokens_per_sec = config.batch_size * config.block_size / iter_time # Evaluation and logging if iter % config.eval_interval == 0: val_loss = evaluate_loss(model, train_loader, config) logger.log_step(iter, loss.item(), val_loss, lr, tokens_per_sec) if val_loss < best_val_loss: best_val_loss = val_loss no_improvement_count = 0 torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_loss, 'iter': iter, 'config': model_config }, 'best_model.pt') logger.log_message(f"New best model saved with validation loss: {val_loss:.6f}") else: no_improvement_count += 1 if val_loss < 0.099999: logger.log_message(f"Target loss achieved at iteration {iter}") logger.log_message(f"Final validation loss: {val_loss:.6f}") break if no_improvement_count >= 5: for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 no_improvement_count = 0 logger.log_message("Reducing learning rate due to no improvement") logger.finish() return model def generate_text(model, prompt, max_length=100, temperature=0.7): model.eval() device = model.device enc = tiktoken.get_encoding('gpt2') input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device) with torch.no_grad(): output_sequence = [] for _ in range(max_length): outputs = model(input_ids) logits = outputs[0] if isinstance(outputs, tuple) else outputs next_token_logits = logits[:, -1, :] # Apply temperature next_token_logits = next_token_logits / temperature probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) output_sequence.append(next_token.item()) input_ids = torch.cat([input_ids, next_token], dim=1) return enc.decode(output_sequence) if __name__ == "__main__": # Train the model model = train_model() # Create and launch Gradio interface def predict(prompt, length, temp=0.7): return generate_text(model, prompt, length, temp) iface = gr.Interface( fn=predict, inputs=[ gr.Textbox(lines=2, label="Enter your prompt"), gr.Slider(minimum=10, maximum=200, value=50, label="Max Length"), gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature", step=0.1) ], outputs=gr.Textbox(lines=5, label="Generated Text"), title="Custom Transformer Text Generator", description="Enter a prompt and adjust parameters to generate text" ) iface.launch(share=True)