from model import build_transformer from dataset import BilingualDataset, causal_mask from config import get_config, get_weights_file_path import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split import warnings from tqdm import tqdm import os from pathlib import Path import subprocess from huggingface_hub import Repository from datasets import load_dataset from tokenizers import Tokenizer from tokenizers.models import WordLevel from tokenizers.trainers import WordLevelTrainer from tokenizers.pre_tokenizers import Whitespace import torchmetrics from torch.utils.tensorboard import SummaryWriter REPO_ID = "torinriley/OratioAI" HF_TOKEN = os.getenv("TOKEN") REPO_PATH = "/tmp/OratioAI" def configure_git(): try: subprocess.run(["git", "config", "--global", "--add", "safe.directory", "/app"], check=True) print("Configured /app as a safe directory for Git.") except subprocess.CalledProcessError as e: print(f"Failed to configure /app as a safe directory: {e}") def initialize_repo(): if not os.path.exists(REPO_PATH): repo = Repository(local_dir=REPO_PATH, clone_from=REPO_ID, use_auth_token=HF_TOKEN) else: repo = Repository(local_dir=REPO_PATH, use_auth_token=HF_TOKEN) repo.git_pull() return repo def greedy_decode(model, source, source_mask, tokenizer_tgt, max_len, device): eos_idx = tokenizer_tgt.token_to_id('[EOS]') sos_idx = tokenizer_tgt.token_to_id('[SOS]') encoder_output = model.encode(source, source_mask) decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device) while True: if decoder_input.size(1) == max_len: break decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device) out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) prob = model.project(out[:, -1]) _, next_word = torch.max(prob, dim=1) decoder_input = torch.cat( [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1 ) if next_word == eos_idx: break return decoder_input.squeeze(0) def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2): model.eval() count = 0 source_texts = [] expected = [] predicted = [] try: with os.popen('stty size', 'r') as console: _, console_width = console.read().split() console_width = int(console_width) except: console_width = 80 with torch.no_grad(): for batch in validation_ds: count += 1 encoder_input = batch["encoder_input"].to(device) encoder_mask = batch["encoder_mask"].to(device) assert encoder_input.size(0) == 1, "Batch size must be 1 for validation" model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_tgt, max_len, device) source_text = batch["src_text"][0] target_text = batch["tgt_text"][0] model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) source_texts.append(source_text) expected.append(target_text) predicted.append(model_out_text) print_msg('-'*console_width) print_msg(f"{f'SOURCE: ':>12}{source_text}") print_msg(f"{f'TARGET: ':>12}{target_text}") print_msg(f"{f'PREDICTED: ':>12}{model_out_text}") if count == num_examples: print_msg('-'*console_width) break if writer: metric = torchmetrics.CharErrorRate() cer = metric(predicted, expected) writer.add_scalar('validation cer', cer, global_step) writer.flush() metric = torchmetrics.WordErrorRate() wer = metric(predicted, expected) writer.add_scalar('validation wer', wer, global_step) writer.flush() metric = torchmetrics.BLEUScore() bleu = metric(predicted, expected) writer.add_scalar('validation BLEU', bleu, global_step) writer.flush() def get_all_sentences(ds, lang): for item in ds: yield item['translation'][lang] def get_or_build_tokenizer(config, ds, lang): tokenizer_path = Path(config['tokenizer_file'].format(lang)) tokenizer_dir = tokenizer_path.parent tokenizer_dir.mkdir(parents=True, exist_ok=True) if not tokenizer_path.exists(): tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) tokenizer.pre_tokenizer = Whitespace() trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2) tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer) tokenizer.save(str(tokenizer_path)) else: tokenizer = Tokenizer.from_file(str(tokenizer_path)) return tokenizer def get_ds(config): ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train') tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src']) tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt']) train_ds_size = int(0.9 * len(ds_raw)) val_ds_size = len(ds_raw) - train_ds_size train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) max_len_src = 0 max_len_tgt = 0 for item in ds_raw: src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids max_len_src = max(max_len_src, len(src_ids)) max_len_tgt = max(max_len_tgt, len(tgt_ids)) print(f'Max length of source sentence: {max_len_src}') print(f'Max length of target sentence: {max_len_tgt}') train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True) val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt def get_model(config, vocab_src_len, vocab_tgt_len): model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model']) return model def latest_weights_file_path(config): model_folder = Path(f"{config['datasource']}_{config['model_folder']}") model_files = list(model_folder.glob("*.pt")) if not model_files: return None latest_file = max(model_files, key=os.path.getctime) return latest_file def train_model(config): device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu" print("Using device:", device) if device == 'cuda': print(f"Device name: {torch.cuda.get_device_name(device.index)}") print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB") elif device == 'mps': print(f"Device name: ") else: print("NOTE: If you have a GPU, consider using it for training.") print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc") print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu") train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config) model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device) writer = SummaryWriter(config['experiment_name']) repo = initialize_repo() checkpoint_path = Path(REPO_PATH) / "EN-FR.pt" start_epoch = 0 global_step = 0 if checkpoint_path.exists(): print(f"Found checkpoint file: {checkpoint_path}. Resuming training...") checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 global_step = checkpoint["global_step"] print(f"Resumed training from epoch {start_epoch}.") else: print("No checkpoint found. Starting training from scratch.") for epoch in range(start_epoch, config['num_epochs']): torch.cuda.empty_cache() model.train() epoch_loss = 0 batch_count = 0 batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}") for batch in batch_iterator: encoder_input = batch['encoder_input'].to(device) decoder_input = batch['decoder_input'].to(device) encoder_mask = batch['encoder_mask'].to(device) decoder_mask = batch['decoder_mask'].to(device) encoder_output = model.encode(encoder_input, encoder_mask) decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) proj_output = model.project(decoder_output) label = batch['label'].to(device) loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1)) epoch_loss += loss.item() batch_count += 1 batch_iterator.set_postfix({"batch_loss": f"{loss.item():.3f}"}) writer.add_scalar('train loss', loss.item(), global_step) writer.flush() loss.backward() optimizer.step() optimizer.zero_grad(set_to_none=True) global_step += 1 # Print average loss for the epoch avg_loss = epoch_loss / batch_count print(f"Epoch {epoch:02d} - Average Loss: {avg_loss:.3f}") # Run validation run_validation( model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer ) # Save checkpoint checkpoint_data = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'global_step': global_step } torch.save(checkpoint_data, checkpoint_path) print(f"Checkpoint saved to {checkpoint_path}") # Push to Hugging Face Hub repo.push_to_hub(commit_message=f"Add checkpoint for epoch {epoch}") print(f"Checkpoint successfully uploaded to Hugging Face Hub.") if __name__ == '__main__': warnings.filterwarnings("ignore") config = get_config() train_model(config)