NextWordGPT / lr_finder.py
chbsaikiran's picture
Initial commit with model trained with loss less than 0.099999
49c48e3
raw
history blame
2.91 kB
from torch_lr_finder import LRFinder
from torch.nn import CrossEntropyLoss
import torch.optim as optim
import torch
from transformer import Config, DecoderOnlyTransformer
class DataLoaderLite:
def __init__(self, B, T):
self.B = B
self.T = T
# at init load tokens from disk and store them in memory
with open('input.txt', 'r') as f:
text = f.read()
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(text)
self.tokens = torch.tensor(tokens)
print(f'loaded {len(self.tokens)} tokens')
print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
# state
self.current_position = 0
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position: self.current_position + B * T + 1]
x = (buf[:-1]).view(B, T) # inputs
y = (buf[1:]).view(B, T) # targets
# advance the position in the tensor
self.current_position += B*T
# if loading the next batch would be out of bounds, reset
if self.current_position + (B * T + 1) > len(self.tokens):
self.current_position = 0
return x, y
batches, no_of_tokens = 16, 128
train_loader = DataLoaderLite(B=batches, T=no_of_tokens)
steps_per_epoch = len(train_loader.tokens) // (batches * no_of_tokens)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Model configuration
config = Config()
# Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Use GPT-2 tokenizer for compatibility
# Load trained model
model = DecoderOnlyTransformer(config)
model.load_state_dict(torch.load("decoder_only_transformer.pth", map_location=torch.device('cpu')))
model.eval()
model.to(device)
amp_config = {
'device_type': 'cuda',
'dtype': torch.float16,
}
criterion = CrossEntropyLoss()
grad_scaler = torch.cuda.amp.GradScaler()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Define a custom batch fetching wrapper
class CustomDataLoader:
def __init__(self, next_batch_func, num_batches):
self.next_batch_func = next_batch_func
self.num_batches = num_batches
self.current_batch = 0
def __iter__(self):
self.current_batch = 0
return self
def __next__(self):
if self.current_batch < self.num_batches:
self.current_batch += 1
return self.next_batch_func()
else:
raise StopIteration
# Create a custom data loader using next_batch
custom_train_loader = CustomDataLoader(train_loader.next_batch(), num_batches=steps_per_epoch)
# Use the custom data loader with LRFinder
lr_finder = LRFinder(
model, optimizer, criterion, device='cuda',
amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler
)
lr_finder.range_test(custom_train_loader, end_lr=5, num_iter=1000, step_mode='exp')
lr_finder.plot()
lr_finder.reset()