lvwerra HF staff commited on
Commit
47731b7
1 Parent(s): 6554e3d

Create codeparrot_training.py

Browse files
Files changed (1) hide show
  1. codeparrot_training.py +205 -0
codeparrot_training.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, AutoTokenizer
2
+ from transformers import AdamW, get_scheduler, set_seed
3
+ from datasets import load_dataset
4
+ from accelerate import Accelerator
5
+ import datasets, transformers
6
+ from huggingface_hub import Repository
7
+
8
+ from torch.utils.data import IterableDataset
9
+ from torch.utils.data.dataloader import DataLoader
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from argparse import Namespace
12
+ import torch
13
+ import logging
14
+ import wandb
15
+ import time
16
+
17
+
18
+ class ConstantLengthDataset(IterableDataset):
19
+ def __init__(self, tokenizer, dataset, seq_length=1024,
20
+ num_of_sequences=1024, chars_per_token=3.6):
21
+ self.tokenizer = tokenizer
22
+ self.concatenation_token_id = tokenizer.bos_token_id
23
+ self.dataset = dataset
24
+ self.seq_length = seq_length
25
+ self.input_characters = seq_length * chars_per_token * num_of_sequences
26
+ self.produced_samples = 0
27
+ def __iter__(self):
28
+ iterator = iter(self.dataset)
29
+ more_examples = True
30
+ while more_examples:
31
+ buffer = []
32
+ buffer_len = 0
33
+ while True:
34
+ if buffer_len >= self.input_characters:
35
+ break
36
+ try:
37
+ buffer.append(next(iterator)['content'])
38
+ buffer_len += len(buffer[-1])
39
+ except StopIteration:
40
+ more_examples = False
41
+ break
42
+ tokenized_inputs = tokenizer(buffer, truncation=False)['input_ids']
43
+ all_token_ids = []
44
+ for tokenized_input in tokenized_inputs:
45
+ all_token_ids.extend(tokenized_input + [self.concatenation_token_id])
46
+ for i in range(0, len(all_token_ids), self.seq_length):
47
+ input_ids = all_token_ids[i : i + self.seq_length]
48
+ if len(input_ids) == self.seq_length:
49
+ yield torch.tensor(input_ids)
50
+
51
+ def setup_logging(project_name):
52
+ logger = logging.getLogger(__name__)
53
+ logging.basicConfig(
54
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
55
+ datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,)
56
+ if accelerator.is_main_process: # we only want to setup logging once
57
+ wandb.init(project=project_name, config=args)
58
+ run_name = wandb.run.name
59
+ tb_writer = SummaryWriter()
60
+ tb_writer.add_hparams(vars(args), {'0': 0})
61
+ logger.setLevel(logging.INFO)
62
+ datasets.utils.logging.set_verbosity_warning()
63
+ transformers.utils.logging.set_verbosity_info()
64
+ else:
65
+ tb_writer = None
66
+ run_name = ''
67
+ logger.setLevel(logging.ERROR)
68
+ datasets.utils.logging.set_verbosity_error()
69
+ transformers.utils.logging.set_verbosity_error()
70
+ return logger, tb_writer, run_name
71
+
72
+ def create_dataloaders(dataset_name):
73
+ train_data = load_dataset(dataset_name+'-train', split="train",
74
+ streaming=True)
75
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer,
76
+ seed=args.seed)
77
+ valid_data = load_dataset(dataset_name+'-valid', split="train",
78
+ streaming=True)
79
+ train_dataset = ConstantLengthDataset(tokenizer, train_data,
80
+ seq_length=args.seq_length)
81
+ valid_dataset = ConstantLengthDataset(tokenizer, valid_data,
82
+ seq_length=args.seq_length)
83
+ train_dataloader=DataLoader(train_dataset, batch_size=args.train_batch_size)
84
+ eval_dataloader=DataLoader(valid_dataset, batch_size=args.valid_batch_size)
85
+ return train_dataloader, eval_dataloader
86
+
87
+ def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
88
+ params_with_wd, params_without_wd = [], []
89
+ for n, p in model.named_parameters():
90
+ if any(nd in n for nd in no_decay): params_without_wd.append(p)
91
+ else: params_with_wd.append(p)
92
+ return [{'params': params_with_wd, 'weight_decay': args.weight_decay},
93
+ {'params': params_without_wd, 'weight_decay': 0.0}]
94
+
95
+ def log_metrics(step, metrics):
96
+ logger.info(f"Step {step}: {metrics}")
97
+ if accelerator.is_main_process:
98
+ wandb.log(metrics)
99
+ [tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
100
+
101
+ def evaluate():
102
+ model.eval()
103
+ losses = []
104
+ for step, batch in enumerate(eval_dataloader):
105
+ with torch.no_grad():
106
+ outputs = model(batch, labels=batch)
107
+ loss = outputs.loss.repeat(args.valid_batch_size)
108
+ losses.append(accelerator.gather(loss))
109
+ if args.max_eval_steps > 0 and step >= args.max_eval_steps: break
110
+ loss = torch.mean(torch.cat(losses))
111
+ try: perplexity = torch.exp(loss)
112
+ except OverflowError: perplexity = float("inf")
113
+ return loss.item(), perplexity.item()
114
+
115
+ # Hyperparameters
116
+ project_name = 'transformersbook/codeparrot'
117
+ dataset_name = 'transformersbook/codeparrot'
118
+ config = {"train_batch_size": 4,
119
+ "valid_batch_size": 4,
120
+ "weight_decay": 0.1,
121
+ "shuffle_buffer": 1000,
122
+ "learning_rate": 5e-4,
123
+ "lr_scheduler_type": "cosine",
124
+ "num_warmup_steps": 1000,
125
+ "gradient_accumulation_steps": 2,
126
+ "max_train_steps": 24_000,
127
+ "max_eval_steps": 500,
128
+ "seq_length": 1024,
129
+ "seed": 1,
130
+ "save_checkpoint_steps":6_000,}
131
+ args = Namespace(**config)
132
+ set_seed(args.seed)
133
+
134
+ # Accelerator
135
+ accelerator = Accelerator()
136
+ samples_per_step = accelerator.state.num_processes * args.train_batch_size
137
+
138
+ # Logging
139
+ logger, tb_writer, run_name = setup_logging(project_name.split("/")[1])
140
+ logger.info(accelerator.state)
141
+
142
+ # Load model and tokenizer
143
+ if accelerator.is_main_process: # we only want to setup logging once
144
+ hf_repo = Repository("./", clone_from=project_name, revision=run_name)
145
+ model = GPT2LMHeadModel.from_pretrained("./")
146
+ tokenizer = AutoTokenizer.from_pretrained("./")
147
+
148
+ # Load dataset and dataloader
149
+ train_dataloader, eval_dataloader = create_dataloaders(dataset_name)
150
+
151
+ # Prepare the optimizer and learning rate scheduler
152
+ optimizer = AdamW(get_grouped_params(model), lr=args.learning_rate)
153
+ lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=optimizer,
154
+ num_warmup_steps=args.num_warmup_steps,
155
+ num_training_steps=args.max_train_steps,)
156
+ def get_lr(): return optimizer.param_groups[0]['lr']
157
+
158
+ # Prepare everything with our `accelerator`.
159
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
160
+ model, optimizer, train_dataloader, eval_dataloader)
161
+
162
+ # Train model
163
+ model.train()
164
+ completed_steps = 0
165
+ t0 = time.time()
166
+ for step, batch in enumerate(train_dataloader, start=1):
167
+ t1 = time.time()
168
+ loss = model(batch, labels=batch).loss
169
+ t2 = time.time()
170
+ log_metrics(step, {'lr': get_lr(), 'samples': step*samples_per_step,
171
+ 'steps': completed_steps, 'loss/train': loss.item()})
172
+ loss = loss / args.gradient_accumulation_steps
173
+ accelerator.backward(loss)
174
+ t3 = time.time()
175
+ if step % args.gradient_accumulation_steps == 0:
176
+ optimizer.step()
177
+ lr_scheduler.step()
178
+ optimizer.zero_grad()
179
+ completed_steps += 1
180
+ if step % args.save_checkpoint_steps == 0:
181
+ logger.info('Evaluating and saving model checkpoint')
182
+ eval_loss, perplexity = evaluate()
183
+ log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
184
+ accelerator.wait_for_everyone()
185
+ unwrapped_model = accelerator.unwrap_model(model)
186
+ if accelerator.is_main_process:
187
+ unwrapped_model.save_pretrained("./")
188
+ hf_repo.push_to_hub(commit_message=f'step {step}')
189
+ model.train()
190
+ if completed_steps >= args.max_train_steps:
191
+ break
192
+ t4 = time.time()
193
+ #logger.info(f'ITER: {t1-t0:.3f}, FRWD: {t2-t1:.3f}, BKWD: {t3-t2:.3f}, OPT: {t4-t3:.3f}, ALL: {t4-t0}')
194
+ t0 = time.time()
195
+
196
+ # Evaluate and save the last checkpoint
197
+ logger.info('Evaluating and saving model after training')
198
+ eval_loss, perplexity = evaluate()
199
+ log_metrics(step, {'loss/eval': eval_loss, 'perplexity': perplexity})
200
+ accelerator.wait_for_everyone()
201
+ unwrapped_model = accelerator.unwrap_model(model)
202
+ if accelerator.is_main_process:
203
+ unwrapped_model.save_pretrained("./")
204
+ try: hf_repo.push_to_hub(commit_message=f'final model')
205
+ except: logger.info('No changes to previously saved model.')