|
import os |
|
import sys |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments |
|
|
|
class GPTAssistant: |
|
def __init__(self, model_name="/Users/migueldeguzman/Desktop/papercliptodd/phi-2b/base_model/"): |
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_name) |
|
except Exception as e: |
|
print(f"Error initializing the model or tokenizer: {e}") |
|
sys.exit(1) |
|
|
|
def fine_tune(self, answer_file_path, model_output_dir, epochs=1.0): |
|
|
|
try: |
|
train_dataset = TextDataset( |
|
tokenizer=self.tokenizer, |
|
file_path=answer_file_path, |
|
block_size=128 |
|
) |
|
except Exception as e: |
|
print(f"Error loading training dataset: {e}") |
|
sys.exit(1) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=self.tokenizer, |
|
mlm=False |
|
) |
|
|
|
total_steps = len(train_dataset) * epochs |
|
warmup_steps = 0.1 * total_steps |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=model_output_dir, |
|
overwrite_output_dir=True, |
|
num_train_epochs=epochs, |
|
per_device_train_batch_size=4, |
|
save_steps=10_000, |
|
save_total_limit=2, |
|
weight_decay=0.001, |
|
gradient_accumulation_steps=8, |
|
learning_rate=48e-7, |
|
lr_scheduler_type='cosine', |
|
warmup_steps=warmup_steps |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=self.model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset |
|
) |
|
|
|
|
|
trainer.train() |
|
self.model.save_pretrained(model_output_dir) |
|
self.tokenizer.save_pretrained(model_output_dir) |
|
|
|
def main(): |
|
|
|
text_file_path = "/Users/migueldeguzman/Desktop/papercliptodd/phi-2b/v1/awakening.text" |
|
model_output_dir = "/Users/migueldeguzman/Desktop/papercliptodd/phi-2b/v1/" |
|
|
|
|
|
assistant = GPTAssistant() |
|
assistant.fine_tune(text_file_path, model_output_dir) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|