gpt-wiki-az / train.py
IsmatS's picture
Upload folder using huggingface_hub
463c2c1 verified
raw
history blame
9.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import math
from tqdm import tqdm
import json
from tokenizers import Tokenizer
from datetime import datetime
import gc
class GPTConfig:
def __init__(
self,
vocab_size=22588,
n_embd=768, # Reduced from 2048
n_head=12, # Reduced from 16
n_layer=8, # Reduced from 12
dropout=0.1,
block_size=256, # Reduced from 512
learning_rate=3e-4,
max_epochs=50,
batch_size=8, # Reduced from 64
grad_clip=1.0,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_head = n_head
self.n_layer = n_layer
self.dropout = dropout
self.block_size = block_size
self.learning_rate = learning_rate
self.max_epochs = max_epochs
self.batch_size = batch_size
self.grad_clip = grad_clip
# Model Architecture
class SelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.w_k = nn.Linear(config.n_embd, config.n_embd)
self.w_q = nn.Linear(config.n_embd, config.n_embd)
self.w_v = nn.Linear(config.n_embd, config.n_embd)
self.attn_drop = nn.Dropout(config.dropout)
self.resid_drop = nn.Dropout(config.dropout)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size()
k = self.w_k(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = self.w_q(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = self.w_v(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = SelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx, targets=None):
b, t = idx.size()
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
token_embeddings = self.tok_emb(idx)
position_embeddings = self.pos_emb[:, :t, :]
x = self.drop(token_embeddings + position_embeddings)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
class WikiTextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=256): # Reduced max_length
self.tokenizer = tokenizer
self.max_length = max_length
print("Tokenizing texts...")
self.examples = []
for text in tqdm(texts):
tokens = self.tokenizer.encode(text).ids
for i in range(0, len(tokens) - max_length, max_length // 2):
chunk = tokens[i:i + max_length]
if len(chunk) < max_length:
chunk = chunk + [0] * (max_length - len(chunk))
self.examples.append(chunk)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
tokens = self.examples[idx]
return torch.tensor(tokens[:-1]), torch.tensor(tokens[1:])
def train():
# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()
print("Loading Wikipedia data...")
with open('az_wiki_data.json', 'r', encoding='utf-8') as f:
wiki_data = json.load(f)
texts = [page['text'] for page in wiki_data.values()]
tokenizer = Tokenizer.from_file("az_tokenizer.json")
dataset = WikiTextDataset(texts, tokenizer)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
config = GPTConfig()
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=2, # Reduced from 4
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=2, # Reduced from 4
pin_memory=True
)
model = GPT(config)
model = model.to('cuda')
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=config.max_epochs)
scaler = torch.amp.GradScaler() # Updated deprecation warning
def run_epoch(split, epoch_num=0):
is_train = split == 'train'
model.train(is_train)
if not is_train:
model.eval()
loader = train_loader if is_train else val_loader
losses = []
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
for it, (x, y) in pbar:
# Clear memory
torch.cuda.empty_cache()
x = x.to('cuda', non_blocking=True)
y = y.to('cuda', non_blocking=True)
with torch.amp.autocast(device_type='cuda'): # Updated deprecation warning
logits, loss = model(x, y)
losses.append(loss.item())
if is_train:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
pbar.set_description(f"epoch {epoch_num+1} iter {it}: train loss {loss.item():.5f}")
# Delete unnecessary tensors
del x, y, logits
if is_train:
del loss
mean_loss = torch.tensor(losses).mean().item()
return mean_loss
best_val_loss = float('inf')
try:
for epoch in range(config.max_epochs):
print(f"\nEpoch {epoch+1}/{config.max_epochs}")
train_loss = run_epoch('train', epoch_num=epoch)
with torch.no_grad():
val_loss = run_epoch('val')
scheduler.step()
if val_loss < best_val_loss:
best_val_loss = val_loss
print(f"Saving best model with val_loss: {val_loss:.4f}")
torch.save(model.state_dict(), 'best_model.pt')
print(f"Epoch {epoch+1}: train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}")
if (epoch + 1) % 5 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
}, f'checkpoint_epoch_{epoch+1}.pt')
except KeyboardInterrupt:
print('Training interrupted, saving checkpoint...')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
}, 'interrupt_checkpoint.pt')
if __name__ == '__main__':
train()