Fine-tuning Mistral on Your Dataset

Community Article Published July 22, 2024

This script is deprecated! Many updates to transformers have happened since its release!

This tutorial will walk you through the process of fine-tuning the Mistral-7B-Instruct model on your own dataset using the Hugging Face Transformers and PEFT libraries

Step 0: Install required libraries

!pip install -q datasets accelerate evaluate trl accelerate bitsandbytes peft

Step 1: Load and format your dataset

We'll define a function to format the prompts in the dataset and load the dataset:

def format_prompts(examples):
    """
    Define the format for your dataset
    This function should return a dictionary with a 'text' key containing the formatted prompts
    """
    pass
from datasets import load_dataset

dataset = load_dataset("your_dataset_name", split="train")
dataset = dataset.map(format_prompts, batched=True)

dataset['text'][2] # Check to see if the fields were formatted correctly

Step 2: Set up the model and tokenizer

Next, we'll load the pre-trained Mistral-7B-Instruct model and tokenizer, and set up the model for quantization and gradient checkpointing.

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_id)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

Step 3: Set up PEFT (Parameter-Efficient Fine-Tuning)

We'll use the PEFT technique to fine-tune the model efficiently. This involves setting up a LoraConfig and getting the PEFT model.

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)

Step 4: Set up the training arguments

We'll define the training arguments for the fine-tuning process.

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="your_model_name",
    num_train_epochs=4, # replace this, depending on your dataset
    per_device_train_batch_size=16,
    learning_rate=1e-5,
    optim="sgd"
)

Replace "your_model_name" with the desired name for your fine-tuned model.

Step 5: Initialize the trainer and fine-tune the model

Now, we'll initialize the SFTTrainer from the trl library and train the model.

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    dataset_text_field='text',
    max_seq_length=1024,
)

trainer.train()

Step 6: Merge the adapter and model back together

When the fine tuning is finished, you can merge the model back together.

adapter_model = trainer.model
merged_model = adapter_model.merge_and_unload()

trained_tokenizer = trainer.tokenizer

Step 7: Push the fine-tuned model to the Hugging Face Hub

After all of that, you can optionally push the fine-tuned model to the Hugging Face Hub for easier sharing and deployment.

repo_id = "your_repo_name"

merged_model.push_to_hub(repo_id)
trained_tokenizer.push_to_hub(repo_id)

Step 8: The cursed child

If you are feeling extra spicy, you can dequantize the model in a new script.

!pip install accelerate bitsandbytes peft transformers # make sure to install dependencies again
from transformers import AutoModelForCausalLM

model_id = "your_repo_name"

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)

config = model.config
del config.quantization_config
del config._pre_quantization_dtype
model.config = config

model.dequantize()

model.push_to_hub(model_id) # the tokenizer will stay the same

be warned that mistral is a very big model, and it will take quite a bit of compute to do this.