OratioAI / train.py
torinriley's picture
Upload 4 files
ec6837d verified
raw
history blame
10.9 kB
from model import build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config, get_weights_file_path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import warnings
from tqdm import tqdm
import os
from pathlib import Path
import subprocess
from huggingface_hub import Repository
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
REPO_ID = "torinriley/LangGPT"
HF_TOKEN = os.getenv("TOKEN")
REPO_PATH = "/tmp/LangGPT"
def configure_git():
try:
subprocess.run(["git", "config", "--global", "--add", "safe.directory", "/app"], check=True)
print("Configured /app as a safe directory for Git.")
except subprocess.CalledProcessError as e:
print(f"Failed to configure /app as a safe directory: {e}")
def initialize_repo():
if not os.path.exists(REPO_PATH):
repo = Repository(local_dir=REPO_PATH, clone_from=REPO_ID, use_auth_token=HF_TOKEN)
else:
repo = Repository(local_dir=REPO_PATH, use_auth_token=HF_TOKEN)
repo.git_pull()
return repo
def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device):
eos_idx = tokenizer_tgt.token_to_id('[EOS]')
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
encoder_output = model.encode(source, source_mask)
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
while True:
if decoder_input.size(1) == max_len:
break
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
)
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
model.eval()
count = 0
source_texts = []
expected = []
predicted = []
try:
with os.popen('stty size', 'r') as console:
_, console_width = console.read().split()
console_width = int(console_width)
except:
console_width = 80
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["encoder_input"].to(device)
encoder_mask = batch["encoder_mask"].to(device)
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
source_texts.append(source_text)
expected.append(target_text)
predicted.append(model_out_text)
print_msg('-'*console_width)
print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
if count == num_examples:
print_msg('-'*console_width)
break
if writer:
metric = torchmetrics.CharErrorRate()
cer = metric(predicted, expected)
writer.add_scalar('validation cer', cer, global_step)
writer.flush()
metric = torchmetrics.WordErrorRate()
wer = metric(predicted, expected)
writer.add_scalar('validation wer', wer, global_step)
writer.flush()
metric = torchmetrics.BLEUScore()
bleu = metric(predicted, expected)
writer.add_scalar('validation BLEU', bleu, global_step)
writer.flush()
def get_all_sentences(ds, lang):
for item in ds:
yield item['translation'][lang]
def get_or_build_tokenizer(config, ds, lang):
tokenizer_path = Path(config['tokenizer_file'].format(lang))
if not Path.exists(tokenizer_path):
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
def get_ds(config):
ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
train_ds_size = int(0.9 * len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
max_len_src = 0
max_len_tgt = 0
for item in ds_raw:
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
max_len_src = max(max_len_src, len(src_ids))
max_len_tgt = max(max_len_tgt, len(tgt_ids))
print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
def get_model(config, vocab_src_len, vocab_tgt_len):
model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
return model
def latest_weights_file_path(config):
model_folder = Path(f"{config['datasource']}_{config['model_folder']}")
model_files = list(model_folder.glob("*.pt"))
if not model_files:
return None
latest_file = max(model_files, key=os.path.getctime)
return latest_file
def train_model(config):
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
if device == 'cuda':
print(f"Device name: {torch.cuda.get_device_name(device.index)}")
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
elif device == 'mps':
print(f"Device name: <mps>")
else:
print("NOTE: If you have a GPU, consider using it for training.")
print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
preload = config.get('preload', False)
model_filename = get_weights_file_path(config, preload) if preload else None
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
writer = SummaryWriter(config['experiment_name'])
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
initial_epoch = 0
global_step = 0
preload = config['preload']
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
if model_filename:
print(f'Preloading model {model_filename}')
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
initial_epoch = state['epoch'] + 1
optimizer.load_state_dict(state['optimizer_state_dict'])
global_step = state['global_step']
else:
print('No model to preload, starting from scratch')
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
configure_git()
repo = initialize_repo()
for epoch in range(initial_epoch, config['num_epochs']):
torch.cuda.empty_cache()
model.train()
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
for batch in batch_iterator:
encoder_input = batch['encoder_input'].to(device)
decoder_input = batch['decoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
decoder_mask = batch['decoder_mask'].to(device)
encoder_output = model.encode(encoder_input, encoder_mask)
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
proj_output = model.project(decoder_output)
label = batch['label'].to(device)
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
writer.add_scalar('train loss', loss.item(), global_step)
writer.flush()
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step
}, model_filename)
hf_checkpoint_path = Path(REPO_PATH) / model_filename.name
model_filename.rename(hf_checkpoint_path)
repo.push_to_hub(commit_message=f"Add checkpoint for epoch {epoch}")
print(f"Checkpoint successfully uploaded to Hugging Face Hub: {hf_checkpoint_path}")
if __name__ == '__main__':
warnings.filterwarnings("ignore")
config = get_config()
train_model(config)