|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import AdamW |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.nn.utils import clip_grad_norm_ |
|
|
|
import wandb |
|
from tqdm import tqdm |
|
from transformers import GPT2LMHeadModel |
|
from gated_state_spaces_pytorch import GatedStateSpacesLM |
|
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper |
|
|
|
from c4x import C4X |
|
|
|
|
|
if __name__ == '__main__': |
|
wandb.init( |
|
project="gated-state-space", |
|
entity="naxalpha", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
f_emb = 1600 |
|
model = AutoregressiveWrapper( |
|
GatedStateSpacesLM( |
|
num_tokens=50257, |
|
dim=f_emb, |
|
depth=24, |
|
), |
|
) |
|
wandb.watch(model) |
|
|
|
|
|
|
|
model.net.token_emb.weight.requires_grad_(False) |
|
|
|
|
|
model.net.to_logits.weight.requires_grad_(False) |
|
|
|
|
|
model.net.to_logits = nn.Sequential( |
|
nn.LayerNorm(f_emb), |
|
model.net.to_logits, |
|
) |
|
|
|
model.load_state_dict(torch.load('model.pt')) |
|
model = model.cuda() |
|
optim = AdamW(model.parameters(), 2e-5) |
|
|
|
bs = 8 |
|
kk = 128 |
|
dsx = C4X(kk+1) |
|
dlx = DataLoader( |
|
dsx, |
|
batch_size=bs, |
|
num_workers=16, |
|
) |
|
|
|
k = 4 |
|
prog = tqdm(dlx) |
|
optim.zero_grad() |
|
|
|
for i, batch in enumerate(prog): |
|
batch = batch.cuda() |
|
los = model(batch) |
|
|
|
(los / k).backward() |
|
if (i+1) % k == 0: |
|
clip_grad_norm_( |
|
model.parameters(), |
|
max_norm=1., |
|
) |
|
optim.step() |
|
optim.zero_grad() |
|
|
|
if i % 1000 == 0: |
|
b, n = 4, 512 |
|
init = torch.tensor([[50256]]*b).cuda() |
|
prd = model.generate(init, n) |
|
prd = [dsx.decode(p) for p in prd] |
|
try: |
|
wandb.log(dict( |
|
text=wandb.Html( |
|
'<hr>'.join( |
|
p.replace('\n', '<br>') for p in prd |
|
) |
|
)), step=i) |
|
except Exception as ex: |
|
print('Failed to log to W&B...', ex) |
|
torch.save(model.state_dict(), 'model.pt') |
|
|
|
wandb.log(dict( |
|
loss=los.item(), |
|
), step=i) |
|
prog.set_postfix(loss=los.item()) |