chbsaikiran commited on
Commit
49c48e3
·
1 Parent(s): d35be66

Initial commit with model trained with loss less than 0.099999

Browse files
Files changed (9) hide show
  1. README.md +98 -4
  2. app.py +56 -0
  3. decoder_only_transformer.pth +3 -0
  4. input.txt +0 -0
  5. lr_finder.py +90 -0
  6. requirements.txt +3 -0
  7. train.py +275 -0
  8. train_get2-8-init.py +287 -0
  9. transformer.py +125 -0
README.md CHANGED
@@ -1,13 +1,107 @@
1
  ---
2
  title: NextWordGPT
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.12.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Next Word Generator Which Trained On Shakespeare's Text
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: NextWordGPT
3
+ emoji: 🏃
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.12.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: 'Transformer trained on Shakespearean text '
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ <pre>
17
+ Epoch 1/50: 100%|██████████| 82/82 [01:16<00:00, 1.08step/s, loss=6.2489]
18
+ Epoch 1/50, Loss: 7.0745, Time: 76.07s
19
+ Epoch 2/50: 100%|██████████| 82/82 [01:22<00:00, 1.00s/step, loss=5.6592]
20
+ Epoch 2/50, Loss: 5.6716, Time: 82.14s
21
+ Epoch 3/50: 100%|██████████| 82/82 [01:25<00:00, 1.05s/step, loss=5.2294]
22
+ Epoch 3/50, Loss: 5.1465, Time: 85.97s
23
+ Epoch 4/50: 100%|██████████| 82/82 [01:27<00:00, 1.07s/step, loss=4.8800]
24
+ Epoch 4/50, Loss: 4.8121, Time: 87.40s
25
+ Epoch 5/50: 100%|██████████| 82/82 [01:28<00:00, 1.08s/step, loss=4.6155]
26
+ Epoch 5/50, Loss: 4.5597, Time: 88.28s
27
+ Epoch 6/50: 100%|██████████| 82/82 [01:29<00:00, 1.10s/step, loss=4.4006]
28
+ Epoch 6/50, Loss: 4.3344, Time: 89.88s
29
+ Epoch 7/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=4.1696]
30
+ Epoch 7/50, Loss: 4.1084, Time: 91.19s
31
+ Epoch 8/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=3.9078]
32
+ Epoch 8/50, Loss: 3.8753, Time: 91.43s
33
+ Epoch 9/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=3.6197]
34
+ Epoch 9/50, Loss: 3.6167, Time: 91.38s
35
+ Epoch 10/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=3.3067]
36
+ Epoch 10/50, Loss: 3.3436, Time: 91.24s
37
+ Epoch 11/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=3.0890]
38
+ Epoch 11/50, Loss: 2.9951, Time: 91.45s
39
+ Epoch 12/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=2.7631]
40
+ Epoch 12/50, Loss: 2.7189, Time: 91.25s
41
+ Epoch 13/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=2.5140]
42
+ Epoch 13/50, Loss: 2.4935, Time: 91.21s
43
+ Epoch 14/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=2.3475]
44
+ Epoch 14/50, Loss: 2.3095, Time: 91.42s
45
+ Epoch 15/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=2.1527]
46
+ Epoch 15/50, Loss: 2.1343, Time: 91.61s
47
+ Epoch 16/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=1.9820]
48
+ Epoch 16/50, Loss: 1.9522, Time: 91.35s
49
+ Epoch 17/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=1.7411]
50
+ Epoch 17/50, Loss: 1.7585, Time: 91.53s
51
+ Epoch 18/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=1.5516]
52
+ Epoch 18/50, Loss: 1.5744, Time: 91.77s
53
+ Epoch 19/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=1.3633]
54
+ Epoch 19/50, Loss: 1.4087, Time: 91.45s
55
+ Epoch 20/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=1.2165]
56
+ Epoch 20/50, Loss: 1.2397, Time: 91.37s
57
+ Epoch 21/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=1.1129]
58
+ Epoch 21/50, Loss: 1.0790, Time: 91.69s
59
+ Epoch 22/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.9431]
60
+ Epoch 22/50, Loss: 0.9302, Time: 91.61s
61
+ Epoch 23/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.8262]
62
+ Epoch 23/50, Loss: 0.8121, Time: 91.39s
63
+ Epoch 24/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.7406]
64
+ Epoch 24/50, Loss: 0.7170, Time: 91.36s
65
+ Epoch 25/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.6618]
66
+ Epoch 25/50, Loss: 0.6387, Time: 91.58s
67
+ Epoch 26/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.5878]
68
+ Epoch 26/50, Loss: 0.5709, Time: 91.55s
69
+ Epoch 27/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.5246]
70
+ Epoch 27/50, Loss: 0.5079, Time: 91.23s
71
+ Epoch 28/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.4453]
72
+ Epoch 28/50, Loss: 0.4472, Time: 91.39s
73
+ Epoch 29/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.3966]
74
+ Epoch 29/50, Loss: 0.3912, Time: 91.58s
75
+ Epoch 30/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.3454]
76
+ Epoch 30/50, Loss: 0.3401, Time: 91.14s
77
+ Epoch 31/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.3288]
78
+ Epoch 31/50, Loss: 0.3059, Time: 91.06s
79
+ Epoch 32/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.2900]
80
+ Epoch 32/50, Loss: 0.2712, Time: 91.22s
81
+ Epoch 33/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.2608]
82
+ Epoch 33/50, Loss: 0.2438, Time: 91.44s
83
+ Epoch 34/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.2365]
84
+ Epoch 34/50, Loss: 0.2215, Time: 91.02s
85
+ Epoch 35/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.2159]
86
+ Epoch 35/50, Loss: 0.2017, Time: 91.14s
87
+ Epoch 36/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1979]
88
+ Epoch 36/50, Loss: 0.1840, Time: 91.59s
89
+ Epoch 37/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1814]
90
+ Epoch 37/50, Loss: 0.1681, Time: 91.70s
91
+ Epoch 38/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1661]
92
+ Epoch 38/50, Loss: 0.1539, Time: 91.46s
93
+ Epoch 39/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1522]
94
+ Epoch 39/50, Loss: 0.1410, Time: 91.53s
95
+ Epoch 40/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1390]
96
+ Epoch 40/50, Loss: 0.1295, Time: 91.60s
97
+ Epoch 41/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1350]
98
+ Epoch 41/50, Loss: 0.1215, Time: 91.51s
99
+ Epoch 42/50: 100%|██████████| 82/82 [01:31<00:00, 1.11s/step, loss=0.1304]
100
+ Epoch 42/50, Loss: 0.1156, Time: 91.43s
101
+ Epoch 43/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1247]
102
+ Epoch 43/50, Loss: 0.1099, Time: 91.80s
103
+ Epoch 44/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1162]
104
+ Epoch 44/50, Loss: 0.1047, Time: 91.56s
105
+ Epoch 45/50: 100%|██████████| 82/82 [01:31<00:00, 1.12s/step, loss=0.1122]
106
+ Epoch 45/50, Loss: 0.0998, Time: 91.53s
107
+ </pre>
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import GPT2Tokenizer
4
+ import gradio as gr
5
+
6
+ # Load tokenizer
7
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Using GPT-2 tokenizer for compatibility
8
+
9
+ # Load model
10
+ from train_get2_8_init import GPT, GPTConfig
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+
14
+ # Initialize the model
15
+ config = GPTConfig()
16
+ model = GPT(config)
17
+ model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device(device)))
18
+ model.eval()
19
+ model.to(device)
20
+
21
+ # Prediction function
22
+ def generate_text(input_text, max_length=50, top_k=50):
23
+ with torch.no_grad():
24
+ # Tokenize input
25
+ tokens = tokenizer.encode(input_text, return_tensors="pt").to(device)
26
+ x = tokens
27
+ while x.size(1) < max_length:
28
+ # Forward pass to get logits
29
+ logits = model(x)[0] # (B, T, vocab_size)
30
+ logits = logits[:, -1, :] # Take the logits at the last position
31
+
32
+ # Get probabilities and do top-k sampling
33
+ probs = F.softmax(logits, dim=-1)
34
+ topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
35
+ ix = torch.multinomial(topk_probs, 1) # Sample token
36
+ xcol = torch.gather(topk_indices, -1, ix) # Gather indices
37
+ x = torch.cat((x, xcol), dim=1) # Append to sequence
38
+
39
+ # Decode tokens into text
40
+ generated_text = tokenizer.decode(x[0])
41
+ return generated_text
42
+
43
+ # Gradio Interface
44
+ def gradio_interface(input_text):
45
+ return generate_text(input_text)
46
+
47
+ interface = gr.Interface(
48
+ fn=gradio_interface,
49
+ inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
50
+ outputs=gr.Textbox(lines=2, placeholder="Generated text will appear here..."),
51
+ title="Text Prediction App",
52
+ description="Enter a text prompt to generate the next sequence of words.",
53
+ )
54
+
55
+ # Launch the app
56
+ interface.launch()
decoder_only_transformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d66ba9508c76b5b60af5713845ebe0528ea2e034d4fb015243bc6f62e764e144
3
+ size 548151190
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
lr_finder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch_lr_finder import LRFinder
2
+ from torch.nn import CrossEntropyLoss
3
+ import torch.optim as optim
4
+ import torch
5
+ from transformer import Config, DecoderOnlyTransformer
6
+
7
+ class DataLoaderLite:
8
+ def __init__(self, B, T):
9
+ self.B = B
10
+ self.T = T
11
+
12
+ # at init load tokens from disk and store them in memory
13
+ with open('input.txt', 'r') as f:
14
+ text = f.read()
15
+ enc = tiktoken.get_encoding('gpt2')
16
+ tokens = enc.encode(text)
17
+ self.tokens = torch.tensor(tokens)
18
+ print(f'loaded {len(self.tokens)} tokens')
19
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
20
+
21
+ # state
22
+ self.current_position = 0
23
+
24
+ def next_batch(self):
25
+ B, T = self.B, self.T
26
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
27
+ x = (buf[:-1]).view(B, T) # inputs
28
+ y = (buf[1:]).view(B, T) # targets
29
+ # advance the position in the tensor
30
+ self.current_position += B*T
31
+ # if loading the next batch would be out of bounds, reset
32
+ if self.current_position + (B * T + 1) > len(self.tokens):
33
+ self.current_position = 0
34
+ return x, y
35
+
36
+
37
+ batches, no_of_tokens = 16, 128
38
+ train_loader = DataLoaderLite(B=batches, T=no_of_tokens)
39
+ steps_per_epoch = len(train_loader.tokens) // (batches * no_of_tokens)
40
+
41
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
+ # Model configuration
43
+ config = Config()
44
+
45
+ # Tokenizer
46
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Use GPT-2 tokenizer for compatibility
47
+
48
+ # Load trained model
49
+ model = DecoderOnlyTransformer(config)
50
+ model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device('cpu')))
51
+ model.eval()
52
+ model.to(device)
53
+
54
+ amp_config = {
55
+ 'device_type': 'cuda',
56
+ 'dtype': torch.float16,
57
+ }
58
+ criterion = CrossEntropyLoss()
59
+ grad_scaler = torch.cuda.amp.GradScaler()
60
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
61
+
62
+ # Define a custom batch fetching wrapper
63
+ class CustomDataLoader:
64
+ def __init__(self, next_batch_func, num_batches):
65
+ self.next_batch_func = next_batch_func
66
+ self.num_batches = num_batches
67
+ self.current_batch = 0
68
+
69
+ def __iter__(self):
70
+ self.current_batch = 0
71
+ return self
72
+
73
+ def __next__(self):
74
+ if self.current_batch < self.num_batches:
75
+ self.current_batch += 1
76
+ return self.next_batch_func()
77
+ else:
78
+ raise StopIteration
79
+
80
+ # Create a custom data loader using next_batch
81
+ custom_train_loader = CustomDataLoader(train_loader.next_batch(), num_batches=steps_per_epoch)
82
+
83
+ # Use the custom data loader with LRFinder
84
+ lr_finder = LRFinder(
85
+ model, optimizer, criterion, device='cuda',
86
+ amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler
87
+ )
88
+ lr_finder.range_test(custom_train_loader, end_lr=5, num_iter=1000, step_mode='exp')
89
+ lr_finder.plot()
90
+ lr_finder.reset()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
train.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import time
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from dataclasses import dataclass
8
+ from torch.optim.lr_scheduler import StepLR
9
+ from torch.cuda.amp import GradScaler, autocast
10
+ import tiktoken
11
+ from tqdm import tqdm
12
+
13
+ class CausalSelfAttention(nn.Module):
14
+
15
+ def __init__(self, config):
16
+ super().__init__()
17
+ assert config.n_embd % config.n_head == 0
18
+ # key, query, value projections for all heads, but in a batch
19
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
20
+ # output projection
21
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
22
+ self.c_proj.NANGPT_SCALE_INIT = 1
23
+ # regularization
24
+ self.n_head = config.n_head
25
+ self.n_embd = config.n_embd
26
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
27
+
28
+ def forward(self, x):
29
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
30
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
31
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
32
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
33
+ qkv = self.c_attn(x)
34
+ q, k, v = qkv.split(self.n_embd, dim=2)
35
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
36
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
38
+
39
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
40
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
41
+ att = F.softmax(att, dim=-1)
42
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
43
+
44
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
45
+ # output projection
46
+ y = self.c_proj(y)
47
+ return y
48
+
49
+
50
+ class MLP(nn.Module):
51
+
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
55
+ self.gelu = nn.GELU(approximate='tanh')
56
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
57
+ self.c_proj.NANOGPT_SCALE_INIT = 1
58
+
59
+ def forward(self, x):
60
+ x = self.c_fc(x)
61
+ x = self.gelu(x)
62
+ x = self.c_proj(x)
63
+ return x
64
+
65
+ class Block(nn.Module):
66
+
67
+ def __init__(self, config):
68
+ super().__init__()
69
+ self.ln_1 = nn.LayerNorm(config.n_embd)
70
+ self.attn = CausalSelfAttention(config)
71
+ self.ln_2 = nn.LayerNorm(config.n_embd)
72
+ self.mlp = MLP(config)
73
+
74
+ def forward(self, x):
75
+ x = x + self.attn(self.ln_1(x))
76
+ x = x + self.mlp(self.ln_2(x))
77
+ return x
78
+
79
+
80
+ @dataclass
81
+ class GPTConfig:
82
+ block_size: int = 1024 # max sequence length
83
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
84
+ n_layer: int = 12 # number of layers
85
+ n_head: int = 12 # number of heads
86
+ n_embd: int = 768 # embedding dimension
87
+
88
+
89
+ class GPT(nn.Module):
90
+
91
+ def __init__(self, config):
92
+ super().__init__()
93
+ self.config = config
94
+
95
+ self.transformer = nn.ModuleDict(dict(
96
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
97
+ wpe = nn.Embedding(config.block_size, config.n_embd),
98
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
99
+ ln_f = nn.LayerNorm(config.n_embd),
100
+ ))
101
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
102
+
103
+ # weight sharing
104
+ self.transformer.wte.weight = self.lm_head.weight
105
+
106
+ # weight initialization
107
+ self.apply(self._init_weights)
108
+
109
+ def _init_weights(self, module):
110
+ if isinstance(module, nn.Linear):
111
+ std = 0.02
112
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
113
+ std *= (2 * self.config.n_layer) ** -0.5
114
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
115
+ if module.bias is not None:
116
+ torch.nn.init.zeros_(module.bias)
117
+ elif isinstance(module, nn.Embedding):
118
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
119
+
120
+
121
+
122
+ def forward(self, idx, targets=None):
123
+ # idx is of shape (B, T)
124
+ B, T = idx.size()
125
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
126
+ # forward the token and posisition embeddings
127
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
128
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
129
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
130
+ x = tok_emb + pos_emb
131
+ # forward the blocks of the transformer
132
+ for block in self.transformer.h:
133
+ x = block(x)
134
+ # forward the final layernorm and the classifier
135
+ x = self.transformer.ln_f(x)
136
+ logits = self.lm_head(x) # (B, T, vocab_size)
137
+ loss = None
138
+ if targets is not None:
139
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
140
+ return logits, loss
141
+
142
+ @classmethod
143
+ def from_pretrained(cls, model_type):
144
+ """Loads pretrained GPT-2 model weights from huggingface"""
145
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
146
+ from transformers import GPT2LMHeadModel
147
+ print("loading weights from pretrained gpt: %s" % model_type)
148
+
149
+ # n_layer, n_head and n_embd are determined from model_type
150
+ config_args = {
151
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
152
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
153
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
154
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
155
+ }[model_type]
156
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
157
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
158
+ # create a from-scratch initialized minGPT model
159
+ config = GPTConfig(**config_args)
160
+ model = GPT(config)
161
+ sd = model.state_dict()
162
+ sd_keys = sd.keys()
163
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
164
+
165
+ # init a huggingface/transformers model
166
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
167
+ sd_hf = model_hf.state_dict()
168
+
169
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
170
+ sd_keys_hf = sd_hf.keys()
171
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
172
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
173
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
174
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
175
+ # this means that we have to transpose these weights when we import them
176
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
177
+ for k in sd_keys_hf:
178
+ if any(k.endswith(w) for w in transposed):
179
+ # special treatment for the Conv1D weights we need to transpose
180
+ assert sd_hf[k].shape[::-1] == sd[k].shape
181
+ with torch.no_grad():
182
+ sd[k].copy_(sd_hf[k].t())
183
+ else:
184
+ # vanilla copy over the other parameters
185
+ assert sd_hf[k].shape == sd[k].shape
186
+ with torch.no_grad():
187
+ sd[k].copy_(sd_hf[k])
188
+
189
+ return model
190
+
191
+ # model = GPT.from_pretrained('gpt2')
192
+
193
+ device = 'cpu'
194
+ if torch.cuda.is_available():
195
+ device = 'cuda'
196
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
197
+ device = "mps"
198
+ print(f"using device: {device}")
199
+
200
+ # SEED
201
+ torch.manual_seed(1337)
202
+ if torch.cuda.is_available():
203
+ torch.cuda.manual_seed(1337)
204
+
205
+ # STOP
206
+ num_return_sequences = 5
207
+ max_length = 30
208
+
209
+
210
+
211
+ import tiktoken
212
+
213
+ class DataLoaderLite:
214
+ def __init__(self, B, T):
215
+ self.B = B
216
+ self.T = T
217
+
218
+ # at init load tokens from disk and store them in memory
219
+ with open('/kaggle/input/input-txt/input.txt', 'r') as f:
220
+ text = f.read()
221
+ enc = tiktoken.get_encoding('gpt2')
222
+ tokens = enc.encode(text)
223
+ self.tokens = torch.tensor(tokens)
224
+ print(f'loaded {len(self.tokens)} tokens')
225
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
226
+
227
+ # state
228
+ self.current_position = 0
229
+
230
+ def next_batch(self):
231
+ B, T = self.B, self.T
232
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
233
+ x = (buf[:-1]).view(B, T) # inputs
234
+ y = (buf[1:]).view(B, T) # targets
235
+ # advance the position in the tensor
236
+ self.current_position += B*T
237
+ # if loading the next batch would be out of bounds, reset
238
+ if self.current_position + (B * T + 1) > len(self.tokens):
239
+ self.current_position = 0
240
+ return x, y
241
+
242
+ model = GPT(GPTConfig())
243
+ model.to(device)
244
+
245
+ batches, no_of_tokens = 16, 256
246
+ train_loader = DataLoaderLite(B=batches, T=no_of_tokens)
247
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
248
+ scheduler = StepLR(optimizer, step_size=10, gamma=0.5)
249
+
250
+ # Training Loop
251
+ steps_per_epoch = len(train_loader.tokens) // (batches * no_of_tokens)
252
+ print(steps_per_epoch)
253
+ EPOCHS = 50
254
+ for epoch in range(EPOCHS):
255
+ loss_list = []
256
+ train_loader_temp = train_loader
257
+ start_time = time.time()
258
+
259
+ with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch + 1}/{EPOCHS}", unit="step") as pbar:
260
+ for step in range(steps_per_epoch):
261
+ x, y = train_loader.next_batch()
262
+ x, y = x.to(device), y.to(device)
263
+ optimizer.zero_grad()
264
+ logits, loss = model(x, y)
265
+ loss.backward()
266
+ optimizer.step()
267
+ loss_list.append(loss.item())
268
+ pbar.update(1)
269
+ pbar.set_postfix(loss=f"{loss.item():.4f}")
270
+
271
+ scheduler.step()
272
+ epoch_loss = sum(loss_list) / len(loss_list)
273
+ print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {epoch_loss:.4f}, Time: {time.time() - start_time:.2f}s")
274
+ if(epoch_loss < 0.099999):
275
+ break
train_get2-8-init.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Solving for residual std scaling issue
2
+ import os
3
+ import math
4
+ import time
5
+ import inspect
6
+ from dataclasses import dataclass
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class CausalSelfAttention(nn.Module):
13
+
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ assert config.n_embd % config.n_head == 0
17
+ # key, query, value projections for all heads, but in a batch
18
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
19
+ # output projection
20
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
21
+ self.c_proj.NANGPT_SCALE_INIT = 1
22
+ # regularization
23
+ self.n_head = config.n_head
24
+ self.n_embd = config.n_embd
25
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
26
+
27
+ def forward(self, x):
28
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
29
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
30
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
31
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
32
+ qkv = self.c_attn(x)
33
+ q, k, v = qkv.split(self.n_embd, dim=2)
34
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
35
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
36
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
+
38
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
39
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
40
+ att = F.softmax(att, dim=-1)
41
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
42
+
43
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
44
+ # output projection
45
+ y = self.c_proj(y)
46
+ return y
47
+
48
+
49
+ class MLP(nn.Module):
50
+
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
54
+ self.gelu = nn.GELU(approximate='tanh')
55
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
56
+ self.c_proj.NANOGPT_SCALE_INIT = 1
57
+
58
+ def forward(self, x):
59
+ x = self.c_fc(x)
60
+ x = self.gelu(x)
61
+ x = self.c_proj(x)
62
+ return x
63
+
64
+ class Block(nn.Module):
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.ln_1 = nn.LayerNorm(config.n_embd)
69
+ self.attn = CausalSelfAttention(config)
70
+ self.ln_2 = nn.LayerNorm(config.n_embd)
71
+ self.mlp = MLP(config)
72
+
73
+ def forward(self, x):
74
+ x = x + self.attn(self.ln_1(x))
75
+ x = x + self.mlp(self.ln_2(x))
76
+ return x
77
+
78
+
79
+ @dataclass
80
+ class GPTConfig:
81
+ block_size: int = 1024 # max sequence length
82
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
83
+ n_layer: int = 12 # number of layers
84
+ n_head: int = 12 # number of heads
85
+ n_embd: int = 768 # embedding dimension
86
+
87
+
88
+ class GPT(nn.Module):
89
+
90
+ def __init__(self, config):
91
+ super().__init__()
92
+ self.config = config
93
+
94
+ self.transformer = nn.ModuleDict(dict(
95
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
96
+ wpe = nn.Embedding(config.block_size, config.n_embd),
97
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
98
+ ln_f = nn.LayerNorm(config.n_embd),
99
+ ))
100
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
101
+
102
+ # weight sharing
103
+ self.transformer.wte.weight = self.lm_head.weight
104
+
105
+ # weight initialization
106
+ self.apply(self._init_weights)
107
+
108
+ def _init_weights(self, module):
109
+ if isinstance(module, nn.Linear):
110
+ std = 0.02
111
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
112
+ std *= (2 * self.config.n_layer) ** -0.5
113
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
114
+ if module.bias is not None:
115
+ torch.nn.init.zeros_(module.bias)
116
+ elif isinstance(module, nn.Embedding):
117
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
118
+
119
+
120
+
121
+ def forward(self, idx, targets=None):
122
+ # idx is of shape (B, T)
123
+ B, T = idx.size()
124
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
125
+ # forward the token and posisition embeddings
126
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
127
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
128
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
129
+ x = tok_emb + pos_emb
130
+ # forward the blocks of the transformer
131
+ for block in self.transformer.h:
132
+ x = block(x)
133
+ # forward the final layernorm and the classifier
134
+ x = self.transformer.ln_f(x)
135
+ logits = self.lm_head(x) # (B, T, vocab_size)
136
+ loss = None
137
+ if targets is not None:
138
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
139
+ return logits, loss
140
+
141
+ @classmethod
142
+ def from_pretrained(cls, model_type):
143
+ """Loads pretrained GPT-2 model weights from huggingface"""
144
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
145
+ from transformers import GPT2LMHeadModel
146
+ print("loading weights from pretrained gpt: %s" % model_type)
147
+
148
+ # n_layer, n_head and n_embd are determined from model_type
149
+ config_args = {
150
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
151
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
152
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
153
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
154
+ }[model_type]
155
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
156
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
157
+ # create a from-scratch initialized minGPT model
158
+ config = GPTConfig(**config_args)
159
+ model = GPT(config)
160
+ sd = model.state_dict()
161
+ sd_keys = sd.keys()
162
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
163
+
164
+ # init a huggingface/transformers model
165
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
166
+ sd_hf = model_hf.state_dict()
167
+
168
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
169
+ sd_keys_hf = sd_hf.keys()
170
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
171
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
172
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
173
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
174
+ # this means that we have to transpose these weights when we import them
175
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
176
+ for k in sd_keys_hf:
177
+ if any(k.endswith(w) for w in transposed):
178
+ # special treatment for the Conv1D weights we need to transpose
179
+ assert sd_hf[k].shape[::-1] == sd[k].shape
180
+ with torch.no_grad():
181
+ sd[k].copy_(sd_hf[k].t())
182
+ else:
183
+ # vanilla copy over the other parameters
184
+ assert sd_hf[k].shape == sd[k].shape
185
+ with torch.no_grad():
186
+ sd[k].copy_(sd_hf[k])
187
+
188
+ return model
189
+
190
+ # model = GPT.from_pretrained('gpt2')
191
+
192
+ device = 'cpu'
193
+ if torch.cuda.is_available():
194
+ device = 'cuda'
195
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
196
+ device = "mps"
197
+ print(f"using device: {device}")
198
+
199
+ # SEED
200
+ torch.manual_seed(1337)
201
+ if torch.cuda.is_available():
202
+ torch.cuda.manual_seed(1337)
203
+
204
+ # STOP
205
+ num_return_sequences = 5
206
+ max_length = 30
207
+
208
+
209
+
210
+ import tiktoken
211
+
212
+ class DataLoaderLite:
213
+ def __init__(self, B, T):
214
+ self.B = B
215
+ self.T = T
216
+
217
+ # at init load tokens from disk and store them in memory
218
+ with open('input.txt', 'r') as f:
219
+ text = f.read()
220
+ enc = tiktoken.get_encoding('gpt2')
221
+ tokens = enc.encode(text)
222
+ self.tokens = torch.tensor(tokens)
223
+ print(f'loaded {len(self.tokens)} tokens')
224
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
225
+
226
+ # state
227
+ self.current_position = 0
228
+
229
+ def next_batch(self):
230
+ B, T = self.B, self.T
231
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
232
+ x = (buf[:-1]).view(B, T) # inputs
233
+ y = (buf[1:]).view(B, T) # targets
234
+ # advance the position in the tensor
235
+ self.current_position += B*T
236
+ # if loading the next batch would be out of bounds, reset
237
+ if self.current_position + (B * T + 1) > len(self.tokens):
238
+ self.current_position = 0
239
+ return x, y
240
+
241
+
242
+ model = GPT(GPTConfig())
243
+ model.to(device)
244
+
245
+ train_loader = DataLoaderLite(B = 4, T = 32)
246
+
247
+ # NEW CODE
248
+ optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)
249
+ for i in range(50):
250
+ x, y = train_loader.next_batch()
251
+ x, y = x.to(device), y.to(device)
252
+ optimizer.zero_grad()
253
+ logits, loss = model(x, y)
254
+ loss.backward()
255
+ optimizer.step()
256
+ print(f'step{i}, loss: {loss.item()}')
257
+
258
+
259
+ print(loss)
260
+ import sys; sys.exit(0)
261
+
262
+ torch.manual_seed(42)
263
+ torch.cuda.manual_seed(42)
264
+ while x.size(1) < max_length:
265
+ # forward the model to get the logits
266
+ with torch.no_grad():
267
+ logits = model(x)[0] # (B, T, vocab_size)
268
+ # take the logits at the last position
269
+ logits = logits[:, -1, :] # (B, vocab_size)
270
+ # get the probabilities
271
+ probs = F.softmax(logits, dim=-1)
272
+ # do top-k sampling of 50 (huggingface pipeline default)
273
+ # topk_probs here becomes (5, 50), topk_indices is (5, 50)
274
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
275
+ # select a token from the top-k probabilities
276
+ # note: multinomial does not demand the input to sum to 1
277
+ ix = torch.multinomial(topk_probs, 1) # (B, 1)
278
+ # gather the corresponding indices
279
+ xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
280
+ # append to the sequence
281
+ x = torch.cat((x, xcol), dim=1)
282
+
283
+ # print the generated text
284
+ for i in range(num_return_sequences):
285
+ tokens = x[i, :max_length].tolist()
286
+ decoded = enc.decode(tokens)
287
+ print(">", decoded)
transformer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class Config:
8
+ vocab_size: int = 50257
9
+ max_seq_len: int = 2048
10
+ dim: int = 768
11
+ num_layers: int = 12
12
+ num_heads: int = 12
13
+ dropout: float = 0.1
14
+
15
+ class MultiHeadAttention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ self.config = config
19
+ self.n_head = config.num_heads
20
+ self.n_embd = config.dim
21
+
22
+ # Linear projections for Q, K, V
23
+ self.c_attn = nn.Linear(config.dim, 3 * config.dim) # [n_embd, 3 * n_embd]
24
+ self.c_proj = nn.Linear(config.dim, config.dim) # [n_embd, n_embd]
25
+
26
+ self.attn_dropout = nn.Dropout(config.dropout)
27
+ self.resid_dropout = nn.Dropout(config.dropout)
28
+
29
+ def forward(self, x):
30
+ B, T, C = x.size() # [B, T, n_embd]
31
+
32
+ # Linear projection and split into Q, K, V
33
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each
34
+
35
+ # Reshape for multi-head attention
36
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd/n_head]
37
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd/n_head]
38
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # [B, n_head, T, n_embd/n_head]
39
+
40
+ # Attention scores
41
+ att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5)) # [B, n_head, T, T]
42
+ att = F.softmax(att, dim=-1) # [B, n_head, T, T]
43
+ att = self.attn_dropout(att) # [B, n_head, T, T]
44
+
45
+ # Weighted sum of values
46
+ y = att @ v # [B, n_head, T, n_embd/n_head]
47
+
48
+ # Reshape and project
49
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # [B, T, n_embd]
50
+ y = self.c_proj(y) # [B, T, n_embd]
51
+ y = self.resid_dropout(y) # [B, T, n_embd]
52
+
53
+ return y
54
+
55
+ class FeedForward(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.c_fc = nn.Linear(config.dim, 4 * config.dim) # [n_embd, 4 * n_embd]
59
+ self.c_proj = nn.Linear(4 * config.dim, config.dim) # [4 * n_embd, n_embd]
60
+ self.dropout = nn.Dropout(config.dropout)
61
+
62
+ def forward(self, x):
63
+ x = self.c_fc(x) # [B, T, 4 * n_embd]
64
+ x = F.gelu(x) # [B, T, 4 * n_embd]
65
+ x = self.c_proj(x) # [B, T, n_embd]
66
+ x = self.dropout(x) # [B, T, n_embd]
67
+ return x
68
+
69
+ class TransformerBlock(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.ln_1 = nn.LayerNorm(config.dim) # [n_embd]
73
+ self.attn = MultiHeadAttention(config)
74
+ self.ln_2 = nn.LayerNorm(config.dim) # [n_embd]
75
+ self.mlp = FeedForward(config)
76
+
77
+ def forward(self, x):
78
+ x = x + self.attn(self.ln_1(x)) # [B, T, n_embd]
79
+ x = x + self.mlp(self.ln_2(x)) # [B, T, n_embd]
80
+ return x
81
+
82
+ class DecoderOnlyTransformer(nn.Module):
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ self.config = config
86
+ self.wte = nn.Embedding(config.vocab_size, config.dim) # [vocab_size, n_embd]
87
+ self.wpe = nn.Embedding(config.max_seq_len, config.dim) # [max_seq_len, n_embd]
88
+ self.drop = nn.Dropout(config.dropout)
89
+ self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
90
+ self.ln_f = nn.LayerNorm(config.dim) # [n_embd]
91
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) # [n_embd, vocab_size]
92
+
93
+ self.apply(self._init_weights)
94
+
95
+ def _init_weights(self, module):
96
+ if isinstance(module, (nn.Linear, nn.Embedding)):
97
+ module.weight.data.normal_(mean=0.0, std=0.02)
98
+ if isinstance(module, nn.Linear) and module.bias is not None:
99
+ module.bias.data.zero_()
100
+ elif isinstance(module, nn.LayerNorm):
101
+ module.bias.data.zero_()
102
+ module.weight.data.fill_(1.0)
103
+
104
+ def forward(self, idx):
105
+ B, T = idx.size() # [B, T]
106
+
107
+ # Positional embeddings
108
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device).unsqueeze(0) # [1, T]
109
+
110
+ # Token and position embeddings
111
+ tok_emb = self.wte(idx) # [B, T, n_embd]
112
+ pos_emb = self.wpe(pos) # [1, T, n_embd]
113
+
114
+ # Combine embeddings and apply dropout
115
+ x = self.drop(tok_emb + pos_emb) # [B, T, n_embd]
116
+
117
+ # Transformer blocks
118
+ for block in self.blocks:
119
+ x = block(x) # [B, T, n_embd]
120
+
121
+ # Final layer norm and linear projection
122
+ x = self.ln_f(x) # [B, T, n_embd]
123
+ logits = self.lm_head(x) # [B, T, vocab_size]
124
+
125
+ return logits