llm-t97 / train-t5-small.py
ysn-rfd's picture
Upload 22 files
5500979 verified
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from datasets import Dataset
# Load the T5 tokenizer and model
model_name = "google-t5/t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# Load your data
with open('data.txt', 'r') as file:
text = file.read()
# Create a dataset from the text file
def preprocess_function(examples):
# Tokenize the input and output pairs
inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
labels = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
inputs['labels'] = labels['input_ids']
return inputs
# For demonstration, we create a simple dataset
# You should adjust this part according to your task
def create_dataset(text):
return Dataset.from_dict({
'text': [text[i:i+512] for i in range(0, len(text), 512)]
})
dataset = create_dataset(text)
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# Split the dataset into training and validation sets
train_dataset = tokenized_dataset.shuffle(seed=42).select([i for i in list(range(len(tokenized_dataset)))])
eval_dataset = train_dataset
# Define training arguments
training_args = TrainingArguments(
output_dir="./results", # Output directory
evaluation_strategy="epoch", # Evaluation strategy to use
learning_rate=5e-5, # Learning rate
per_device_train_batch_size=2, # Batch size for training
per_device_eval_batch_size=2, # Batch size for evaluation
num_train_epochs=3, # Number of training epochs
weight_decay=0.01, # Strength of weight decay
logging_dir="./logs", # Directory for storing logs
logging_steps=10,
)
# Define the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train and evaluate the model
trainer.train()
trainer.evaluate()
# Save the model and tokenizer
model.save_pretrained("./t5-small-finetuned")
tokenizer.save_pretrained("./t5-small-finetuned")