philipp-zettl commited on
Commit
ece90a2
1 Parent(s): ee89e6e

update train script

Browse files
Files changed (1) hide show
  1. train.py +379 -15
train.py CHANGED
@@ -1,9 +1,15 @@
 
1
  import argparse
 
2
  import torch
3
  import torch.nn as nn
 
4
  from torch.nn import functional as F
5
  from gpt_p.model import DecoderTransformer
 
 
6
  from datasets import load_dataset
 
7
 
8
 
9
  torch.manual_seed(420) # 1337
@@ -12,7 +18,7 @@ base_name = 'gpt-p_CHARS_CHAT_'
12
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
  context_size = 256 # how many tokens to consider while generating the next
14
  batch_size = 128 # how many independent sequences will we process in parallel
15
- max_iters = 30_000
16
  learning_rate = 3e-5
17
  eval_interval = 100
18
  eval_iters = 20 # number evaluation iterations
@@ -21,28 +27,304 @@ n_layer = 6 # number of transformer layers
21
  n_head = 6
22
  dropout = 0.2 # dropout factor
23
 
24
- dataset = load_dataset('Lichess/standard-chess-games', split='train')
25
- content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext'])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  ## BUILD DATA SET ##
 
 
 
 
28
  book = content
29
- characters = sorted(list(set(book)))
 
 
 
30
  vocab_size = len(characters)
31
 
32
  # convert
33
- stoi = {ch: idx for idx, ch in enumerate(characters)}
34
- itos = {idx: ch for idx, ch in enumerate(characters)}
 
 
 
 
 
 
35
 
36
- encode = lambda s: [stoi[c] for c in s]
37
- decode = lambda i: ''.join([itos[x] for x in i])
38
 
 
 
 
 
 
39
 
40
- data = torch.tensor(encode(book), dtype=torch.long)
41
- n = int(0.9 * len(data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  train_data = data[:n]
43
  val_data = data[n:]
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def get_batch(split):
47
  data = train_data if split == 'train' else val_data
48
  idx = torch.randint(len(data) - context_size, (batch_size,))
@@ -50,6 +332,9 @@ def get_batch(split):
50
  y = torch.stack([data[i+1:i+context_size+1] for i in idx])
51
  return x.to(device), y.to(device)
52
 
 
 
 
53
  ## END BUILD DATA SET ##
54
  ## MODEL DEFINITION ##
55
 
@@ -72,15 +357,58 @@ def estimate_loss():
72
  for k in range(eval_iters):
73
  X, Y = get_batch(split)
74
  logits, loss = model(X, Y)
 
 
 
 
 
 
 
 
 
75
  losses[k] = loss.item()
76
  out[split] = losses.mean()
77
 
78
- input_string = '1. e4 g6'
79
  print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string))))
80
  model.train()
81
  return out
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if __name__ == "__main__":
85
  args = argparse.ArgumentParser()
86
  args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.')
@@ -91,28 +419,62 @@ if __name__ == "__main__":
91
  params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout}
92
  if args.load:
93
  m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
94
- m.load_state_dict(torch.load(f'./models/{base_name}' + ''.join(f'{key}={v}' for key, v in params.items())))
95
  else:
96
  m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
97
  model = m.to(device)
98
 
99
  if args.inference:
 
 
 
 
 
 
100
  exit()
101
  ## END MODEL ##
102
  ## START TRAINING ##
 
 
 
103
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
 
104
 
105
- for step in range(max_iters):
106
  if step % eval_interval == 0:
107
  losses = estimate_loss()
108
- print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}')
 
 
 
 
109
 
110
  xb, yb = get_batch('train')
111
 
112
  logits, loss = model(xb, yb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  optimizer.zero_grad(set_to_none=True)
114
  loss.backward()
115
  optimizer.step()
 
 
116
 
117
  print()
118
  print('Loss:')
@@ -124,6 +486,8 @@ if __name__ == "__main__":
124
  ## END VALIDATION ##
125
 
126
  # save model weights
127
- torch.save(model.state_dict(), f'./models/{base_name}' + ''.join([f'{key}={v}' for key, v in params.items()]))
 
 
128
  with open('train.log', 'a') as f:
129
  f.write(f'{max_iters},{learning_rate}\n')
 
1
+ import re
2
  import argparse
3
+ import json
4
  import torch
5
  import torch.nn as nn
6
+ from tqdm import tqdm
7
  from torch.nn import functional as F
8
  from gpt_p.model import DecoderTransformer
9
+ from torch.optim.lr_scheduler import _LRScheduler
10
+ import math
11
  from datasets import load_dataset
12
+ import wandb
13
 
14
 
15
  torch.manual_seed(420) # 1337
 
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  context_size = 256 # how many tokens to consider while generating the next
20
  batch_size = 128 # how many independent sequences will we process in parallel
21
+ max_iters = 50_000
22
  learning_rate = 3e-5
23
  eval_interval = 100
24
  eval_iters = 20 # number evaluation iterations
 
27
  n_head = 6
28
  dropout = 0.2 # dropout factor
29
 
30
+ mask_all_data = True
31
+ use_scheduler = False
32
+
33
+ dataset = load_dataset('Lichess/standard-chess-games', '2014-08', split='train')
34
+ og_samples = list(filter(lambda x: 'eval' not in x, dataset['movetext']))
35
+
36
+
37
+ new_dataset = load_dataset('Lichess/standard-chess-games', '2024-07', split='train', data_files=[f'data/year=2024/month=07/train-{str(i).zfill(5)}-of-00384.parquet' for i in range(10)])
38
+
39
+ new_dataset = [re.sub('[0-9]+\.\.\.', '', re.sub('{[^\}]*}', '', foo)).replace(' ', ' ').replace(' ', ' ') for foo in dataset['movetext']]
40
+
41
+ og_samples += new_dataset
42
+
43
+ if mask_all_data:
44
+ content = '\n'.join(list(filter(lambda x: 'eval' not in x, dataset['movetext'])))
45
+ else:
46
+ content = og_samples
47
+
48
+ print('Data loaded')
49
+ print('Training on ', len(content), 'characters. Good luck!')
50
 
51
  ## BUILD DATA SET ##
52
+ # load data
53
+ #with open('data.txt', 'r') as f:
54
+ # content = f.read()
55
+
56
  book = content
57
+ if mask_all_data:
58
+ characters = sorted(list(set(book)))
59
+ else:
60
+ characters = sorted(list(set('\n'.join(book))))
61
  vocab_size = len(characters)
62
 
63
  # convert
64
+ class Tokenizer:
65
+ def __init__(self, vocab):
66
+ self.vocab = vocab
67
+ self.stoi = {ch: idx for idx, ch in enumerate(vocab)}
68
+ self.itos = {idx: ch for idx, ch in enumerate(vocab)}
69
+
70
+ def encode(self, s):
71
+ return [self.stoi[c] for c in s]
72
 
73
+ def decode(self, i):
74
+ return ''.join([self.itos[x] for x in i])
75
 
76
+ @classmethod
77
+ def from_pretrained(cls, path):
78
+ with open(path, 'r') as f:
79
+ vocab = json.load(f)
80
+ return cls(vocab)
81
 
82
+ def save_pretrained(self, path):
83
+ with open(path, 'w') as f:
84
+ json.dump(self.vocab, f)
85
+
86
+
87
+ tokenizer = Tokenizer(characters)
88
+ encode = tokenizer.encode
89
+ decode = tokenizer.decode
90
+
91
+ if mask_all_data:
92
+ data = torch.tensor(encode(book), dtype=torch.long)
93
+ else:
94
+ data = [torch.tensor(encode(s), dtype=torch.long) for s in book]
95
+ max_len = max(len(x) for x in og_samples)
96
+ context_size = min(context_size, max_len)
97
+
98
+
99
+ n = int(0.8 * len(data))
100
  train_data = data[:n]
101
  val_data = data[n:]
102
 
103
 
104
+
105
+ # Constants for piece movement validation
106
+ PIECE_VALUES = {
107
+ 'P': 1, 'N': 3, 'B': 3, 'R': 5, 'Q': 9, 'K': 0, # White pieces
108
+ 'p': 1, 'n': 3, 'b': 3, 'r': 5, 'q': 9, 'k': 0 # Black pieces
109
+ }
110
+
111
+ def initialize_board():
112
+ """Initializes the standard chessboard setup."""
113
+ return [
114
+ ['r', 'n', 'b', 'q', 'k', 'b', 'n', 'r'], # 8th rank (Black)
115
+ ['p', 'p', 'p', 'p', 'p', 'p', 'p', 'p'], # 7th rank (Black)
116
+ ['.', '.', '.', '.', '.', '.', '.', '.'], # 6th rank
117
+ ['.', '.', '.', '.', '.', '.', '.', '.'], # 5th rank
118
+ ['.', '.', '.', '.', '.', '.', '.', '.'], # 4th rank
119
+ ['.', '.', '.', '.', '.', '.', '.', '.'], # 3rd rank
120
+ ['P', 'P', 'P', 'P', 'P', 'P', 'P', 'P'], # 2nd rank (White)
121
+ ['R', 'N', 'B', 'Q', 'K', 'B', 'N', 'R'] # 1st rank (White)
122
+ ]
123
+
124
+ def get_piece(board, position):
125
+ """Returns the piece at a given board position (e.g., e4 -> 'P' or '.')."""
126
+ col = ord(position[0]) - ord('a')
127
+ row = 8 - int(position[1])
128
+ return board[row][col]
129
+
130
+ def set_piece(board, position, piece):
131
+ """Sets a piece on the board at a given position."""
132
+ col = ord(position[0]) - ord('a')
133
+ row = 8 - int(position[1])
134
+ board[row][col] = piece
135
+
136
+ def validate_pawn_move(board, start, end, is_white_turn):
137
+ """Validates pawn movement including capturing, advancing, and promotion."""
138
+ start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1])
139
+ end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1])
140
+
141
+ pawn_direction = -1 if is_white_turn else 1 # White moves up, black moves down
142
+
143
+ # Regular forward move
144
+ if start_col == end_col and board[end_row][end_col] == '.':
145
+ if start_row + pawn_direction == end_row: # 1 square move
146
+ return True
147
+ if (is_white_turn and start_row == 6 or not is_white_turn and start_row == 1) and start_row + 2 * pawn_direction == end_row:
148
+ return True
149
+
150
+ # Capture
151
+ if abs(start_col - end_col) == 1 and start_row + pawn_direction == end_row:
152
+ target_piece = board[end_row][end_col]
153
+ if (is_white_turn and target_piece.islower()) or (not is_white_turn and target_piece.isupper()):
154
+ return True
155
+
156
+ return False
157
+
158
+ def validate_knight_move(start, end):
159
+ """Validates knight movement (L-shape)."""
160
+ start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1])
161
+ end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1])
162
+
163
+ col_diff = abs(start_col - end_col)
164
+ row_diff = abs(start_row - end_row)
165
+
166
+ return (col_diff == 2 and row_diff == 1) or (col_diff == 1 and row_diff == 2)
167
+
168
+ def validate_rook_move(board, start, end):
169
+ """Validates rook movement (straight lines along rank or file)."""
170
+ start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1])
171
+ end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1])
172
+
173
+ if start_col != end_col and start_row != end_row:
174
+ return False # Must be either same column or row
175
+
176
+ # Check if path is clear
177
+ if start_col == end_col:
178
+ step = 1 if end_row > start_row else -1
179
+ for row in range(start_row + step, end_row, step):
180
+ if board[row][start_col] != '.':
181
+ return False
182
+ else:
183
+ step = 1 if end_col > start_col else -1
184
+ for col in range(start_col + step, end_col, step):
185
+ if board[start_row][col] != '.':
186
+ return False
187
+
188
+ return True
189
+
190
+ def validate_bishop_move(board, start, end):
191
+ """Validates bishop movement (diagonals)."""
192
+ start_col, start_row = ord(start[0]) - ord('a'), 8 - int(start[1])
193
+ end_col, end_row = ord(end[0]) - ord('a'), 8 - int(end[1])
194
+
195
+ if abs(start_col - end_col) != abs(start_row - end_row):
196
+ return False # Must move diagonally
197
+
198
+ # Check if path is clear
199
+ col_step = 1 if end_col > start_col else -1
200
+ row_step = 1 if end_row > start_row else -1
201
+ col, row = start_col + col_step, start_row + row_step
202
+ while col != end_col and row != end_row:
203
+ if board[row][col] != '.':
204
+ return False
205
+ col += col_step
206
+ row += row_step
207
+
208
+ return True
209
+
210
+ def validate_move(board, move, is_white_turn):
211
+ """Validates a move based on the current board state."""
212
+ if move == "O-O" or move == "O-O-O":
213
+ return True # Castling placeholder
214
+
215
+ piece_type = 'P' if move[0].islower() else move[0]
216
+ start = move[-2:] # Simplification; would need to parse actual source square
217
+ end = move[-2:] # Actual end position is the destination
218
+
219
+ if piece_type == 'P':
220
+ return validate_pawn_move(board, start, end, is_white_turn)
221
+ elif piece_type == 'N':
222
+ return validate_knight_move(start, end)
223
+ elif piece_type == 'R':
224
+ return validate_rook_move(board, start, end)
225
+ elif piece_type == 'B':
226
+ return validate_bishop_move(board, start, end)
227
+
228
+ # Other pieces can be added similarly
229
+ return True # Placeholder for other pieces
230
+
231
+ def update_board(board, move, is_white_turn):
232
+ """Updates the board according to the move."""
233
+ start = move[-2:]
234
+ end = move[-2:]
235
+ piece = get_piece(board, start)
236
+
237
+ # Move the piece
238
+ set_piece(board, end, piece)
239
+ set_piece(board, start, '.')
240
+
241
+ return board # Placeholder for now
242
+
243
+ def validate_pgn(pgn_string):
244
+ """
245
+ Validates the PGN string format and chess move legality.
246
+ """
247
+
248
+ move_pattern = r'([PNBRQK]?[a-h]?[1-8]?[x]?[a-h][1-8](=[QRNB])?|O-O(-O)?)[+#]?' # Chess move
249
+ result_pattern = r'(1-0|0-1|1/2-1/2)' # Game results
250
+ tag_pattern = r'\[([A-Za-z0-9_]+)\s+"([^"]+)"\]' # PGN tags
251
+
252
+ pgn_lines = pgn_string.strip().splitlines()
253
+
254
+ tags = [line for line in pgn_lines if line.startswith('[')]
255
+ for tag in tags:
256
+ if not re.match(tag_pattern, tag):
257
+ return False # Invalid tag format
258
+
259
+ moves_section = ' '.join([line for line in pgn_lines if not line.startswith('[')]).strip()
260
+
261
+ if not re.search(result_pattern, moves_section):
262
+ return False # No valid result found
263
+
264
+ moves_section = re.sub(result_pattern, '', moves_section).strip()
265
+
266
+ board = initialize_board()
267
+ is_white_turn = True
268
+
269
+ move_tokens = re.split(r'\s|\d+\.', moves_section)
270
+ for token in move_tokens:
271
+ if token:
272
+ if not re.match(move_pattern, token):
273
+ return False # Invalid move format
274
+
275
+ if not validate_move(board, token, is_white_turn):
276
+ return False # Invalid chess move
277
+
278
+ board = update_board(board, token, is_white_turn)
279
+ is_white_turn = not is_white_turn
280
+
281
+ return True
282
+
283
+ # Test case
284
+ pgn_string = """
285
+ [Event "World Championship"]
286
+ [Site "Moscow URS"]
287
+ [Date "1985.11.09"]
288
+ [Round "16"]
289
+ [White "Kasparov, Garry"]
290
+ [Black "Karpov, Anatoly"]
291
+ [Result "1-0"]
292
+
293
+ 1. e4 e5 2. Nf3 Nc6 3. Bb5 a6 4. Ba4 Nf6 5. O-O Be7 6. Re1 b5 7. Bb3 d6
294
+ 8. c3 O-O 9. h3 Nb8 10. d4 Nbd7 11. c4 Bb7 12. Nbd2 c6 13. Bc2 Re8 14. b3 Bf8
295
+ 15. Bb2 Qc7 16. Rc1 Rad8 17. a3 Qb8 18. Bd3 g6 19. Qc2 Nh5 20. g3 Ng7 21. Qb1
296
+ exd4 22. Nxd4 c5 23. N4f3 Ne6 24. Bf1 Ne5 25. Qa1 Nxf3+ 26. Nxf3 Qa8 27. b4
297
+ Rc8 28. Bd3 Bh6 29. Rc2 Bc6 30. h4 f5 31. exf5 Bxf3 32. fxe6 Bh1 33. Bf1 Qf3
298
+ 34. Re2 Bg7 35. Kh2 Rc7 36. Bxg7 Rxg7 37. Qf6 bxc4 38. e7 Qxf6 39. exf6 1-0
299
+ """
300
+
301
+
302
+
303
+ def get_batch_from_samples(split):
304
+ data = train_data if split == 'train' else val_data
305
+ sample_idx = torch.randint(len(data), (batch_size,))
306
+ inputs = []
307
+ outputs = []
308
+ space = encode(' ')[0]
309
+ for idx in sample_idx:
310
+ sample_size = len(data[idx])
311
+ start = torch.randint(max(sample_size - 2, sample_size - context_size), (1,))
312
+ end = start + context_size
313
+ i1 = data[idx][start:end].tolist()
314
+ i2 = [space] * (context_size - len(i1))
315
+ input_sample = torch.tensor(i1 + i2)
316
+ o1 = data[idx][start+1:end+1].tolist()
317
+ o2 = [space] * (context_size - len(o1))
318
+ output_sample = torch.tensor(o1 + o2)
319
+
320
+ inputs.append(input_sample)
321
+ outputs.append(output_sample)
322
+
323
+ x = torch.stack(inputs)
324
+ y = torch.stack(outputs)
325
+ return x.to(device), y.to(device)
326
+
327
+
328
  def get_batch(split):
329
  data = train_data if split == 'train' else val_data
330
  idx = torch.randint(len(data) - context_size, (batch_size,))
 
332
  y = torch.stack([data[i+1:i+context_size+1] for i in idx])
333
  return x.to(device), y.to(device)
334
 
335
+ if not mask_all_data:
336
+ get_batch = get_batch_from_samples
337
+
338
  ## END BUILD DATA SET ##
339
  ## MODEL DEFINITION ##
340
 
 
357
  for k in range(eval_iters):
358
  X, Y = get_batch(split)
359
  logits, loss = model(X, Y)
360
+ """
361
+ input_string = X[0].tolist()
362
+ gen = model.generate(X[0].view(1, -1), max_new_tokens=5, context_size=context_size)
363
+ o = tokenizer.decode(gen[0].tolist())
364
+ try:
365
+ valid = int(not validate_pgn(o))
366
+ except Exception:
367
+ valid = 2
368
+ """
369
  losses[k] = loss.item()
370
  out[split] = losses.mean()
371
 
372
+ input_string = '1. e4 g6 2.'
373
  print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string))))
374
  model.train()
375
  return out
376
 
377
 
378
+ class CosineAnnealingScheduler(_LRScheduler):
379
+ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
380
+ """
381
+ Args:
382
+ optimizer (Optimizer): Wrapped optimizer.
383
+ T_max (int): Maximum number of iterations.
384
+ eta_min (float): Minimum learning rate. Default: 0.
385
+ last_epoch (int): The index of last epoch. Default: -1.
386
+ """
387
+ self.T_max = T_max
388
+ self.eta_min = eta_min
389
+ super().__init__(optimizer, last_epoch)
390
+
391
+ def get_lr(self):
392
+ if not self._get_lr_called_within_step:
393
+ warnings.warn("To get the last learning rate computed by the scheduler, "
394
+ "please use `get_last_lr()`.", UserWarning)
395
+
396
+ if self.last_epoch == 0:
397
+ return [group['lr'] for group in self.optimizer.param_groups]
398
+ elif self._step_count == 1 and self.last_epoch > 0:
399
+ return [self.eta_min + (base_lr - self.eta_min) *
400
+ (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2
401
+ for base_lr in self.base_lrs]
402
+ elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
403
+ return [group['lr'] + (base_lr - self.eta_min) *
404
+ (1 - math.cos(math.pi / self.T_max)) / 2
405
+ for base_lr, group in
406
+ zip(self.base_lrs, self.optimizer.param_groups)]
407
+ return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
408
+ (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
409
+ (group['lr'] - self.eta_min) + self.eta_min
410
+ for group in self.optimizer.param_groups]
411
+
412
  if __name__ == "__main__":
413
  args = argparse.ArgumentParser()
414
  args.add_argument('--load', '-l', action='store_true', default=False, help='Load model state.')
 
419
  params = {'vocab_size': vocab_size, 'n_embed': n_embed, 'context_size': context_size, 'n_layer': n_layer, 'n_head': n_head, 'dropout': dropout}
420
  if args.load:
421
  m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
422
+ m.load_state_dict(torch.load(f'./models/{base_name}'))# + ''.join(f'{key}={v}' for key, v in params.items())))
423
  else:
424
  m = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
425
  model = m.to(device)
426
 
427
  if args.inference:
428
+ input_string = input('Enter a PGN string: ')
429
+ print_sample(torch.tensor(encode(input_string), dtype=torch.long, device=device).view((1, len(input_string))))
430
+ with open(f'./models/{base_name}_params.json', 'w') as f:
431
+ json.dump(params, f)
432
+
433
+ tokenizer.save_pretrained(f'./models/{base_name}_vocab.json')
434
  exit()
435
  ## END MODEL ##
436
  ## START TRAINING ##
437
+ wandb.init(project='chessPT')
438
+
439
+ wandb.watch(model)
440
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
441
+ if use_scheduler:
442
+ scheduler = CosineAnnealingScheduler(optimizer, max_iters, eta_min=learning_rate//1e6)
443
 
444
+ for step in tqdm(range(max_iters), total=max_iters, desc='Training'):
445
  if step % eval_interval == 0:
446
  losses = estimate_loss()
447
+ if use_scheduler:
448
+ print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}, lr: {scheduler.get_last_lr()[0]}')
449
+ else:
450
+ print(f'step {step:4d}: train loss {losses["train"]:.4f}, val loss: {losses["val"]:.4f}')
451
+ wandb.log({'train_loss': losses['train'], 'val_loss': losses['val']})
452
 
453
  xb, yb = get_batch('train')
454
 
455
  logits, loss = model(xb, yb)
456
+ """
457
+
458
+ input_string = xb[0].tolist()
459
+ gen = model.generate(xb[0].view(1, -1), max_new_tokens=5, context_size=context_size)
460
+ out = tokenizer.decode(gen[0].tolist())
461
+ try:
462
+ valid = int(not validate_pgn(out))
463
+ except Exception:
464
+ valid = 2
465
+ loss += valid
466
+ """
467
+
468
+ if use_scheduler:
469
+ wandb.log({'running_train_loss': loss.item(), 'lr': scheduler.get_last_lr()[0]})
470
+ else:
471
+ wandb.log({'running_train_loss': loss.item()})
472
+
473
  optimizer.zero_grad(set_to_none=True)
474
  loss.backward()
475
  optimizer.step()
476
+ if use_scheduler:
477
+ scheduler.step()
478
 
479
  print()
480
  print('Loss:')
 
486
  ## END VALIDATION ##
487
 
488
  # save model weights
489
+ torch.save(model.state_dict(), f'./models/{base_name}')
490
+ with open(f'./models/{base_name}_params.json', 'w') as f:
491
+ json.dump(params, f)
492
  with open('train.log', 'a') as f:
493
  f.write(f'{max_iters},{learning_rate}\n')