Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from datasets import load_dataset | |
from mmengine.config import read_base | |
from mmengine.dataset import DefaultSampler | |
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, | |
LoggerHook, ParamSchedulerHook) | |
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR | |
from peft import LoraConfig | |
from torch.optim import AdamW | |
from transformers import (AutoModelForCausalLM, AutoTokenizer, | |
BitsAndBytesConfig) | |
from xtuner.dataset import process_hf_dataset | |
from xtuner.dataset.collate_fns import default_collate_fn | |
from xtuner.engine.hooks import DatasetInfoHook | |
from xtuner.engine.runner import TrainLoop | |
from xtuner.model import SupervisedFinetune | |
with read_base(): | |
from .map_fn import pretrain_map_fn as dataset_map_fn | |
####################################################################### | |
# PART 1 Settings # | |
####################################################################### | |
# Model | |
pretrained_model_name_or_path = 'internlm/internlm-7b' | |
# Data | |
data_path = './data.json' | |
max_length = 2048 | |
pack_to_max_length = True | |
# Scheduler & Optimizer | |
batch_size = 1 # per_device | |
accumulative_counts = 16 | |
dataloader_num_workers = 0 | |
max_epochs = 3 | |
optim_type = AdamW | |
lr = 2e-4 | |
betas = (0.9, 0.999) | |
weight_decay = 0 | |
max_norm = 1 # grad clip | |
# Save | |
save_steps = 500 | |
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) | |
# Evaluate the generation performance during the training | |
evaluation_freq = 500 | |
SYSTEM = '' | |
evaluation_inputs = [ | |
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' | |
] | |
####################################################################### | |
# PART 2 Model & Tokenizer # | |
####################################################################### | |
tokenizer = dict( | |
type=AutoTokenizer.from_pretrained, | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
trust_remote_code=True, | |
padding_side='right') | |
model = dict( | |
type=SupervisedFinetune, | |
llm=dict( | |
type=AutoModelForCausalLM.from_pretrained, | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
quantization_config=dict( | |
type=BitsAndBytesConfig, | |
load_in_4bit=True, | |
load_in_8bit=False, | |
llm_int8_threshold=6.0, | |
llm_int8_has_fp16_weight=False, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type='nf4')), | |
lora=dict( | |
type=LoraConfig, | |
r=64, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
bias='none', | |
task_type='CAUSAL_LM')) | |
####################################################################### | |
# PART 3 Dataset & Dataloader # | |
####################################################################### | |
train_dataset = dict( | |
type=process_hf_dataset, | |
dataset=dict( | |
type=load_dataset, path='json', data_files=dict(train=data_path)), | |
tokenizer=tokenizer, | |
max_length=max_length, | |
dataset_map_fn=dataset_map_fn, | |
template_map_fn=None, | |
remove_unused_columns=True, | |
shuffle_before_pack=True, | |
pack_to_max_length=pack_to_max_length) | |
train_dataloader = dict( | |
batch_size=batch_size, | |
num_workers=dataloader_num_workers, | |
dataset=train_dataset, | |
sampler=dict(type=DefaultSampler, shuffle=True), | |
collate_fn=dict(type=default_collate_fn)) | |
####################################################################### | |
# PART 4 Scheduler & Optimizer # | |
####################################################################### | |
# optimizer | |
optim_wrapper = dict( | |
type=AmpOptimWrapper, | |
optimizer=dict( | |
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), | |
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), | |
accumulative_counts=accumulative_counts, | |
loss_scale='dynamic', | |
dtype='float16') | |
# learning policy | |
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 | |
param_scheduler = dict( | |
type=CosineAnnealingLR, | |
eta_min=0.0, | |
by_epoch=True, | |
end=max_epochs, | |
convert_to_iter_based=True) | |
# train, val, test setting | |
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) | |
####################################################################### | |
# PART 5 Runtime # | |
####################################################################### | |
# Log the dialogue periodically during the training process, optional | |
custom_hooks = [dict(type=DatasetInfoHook, tokenizer=tokenizer)] | |
# configure default hooks | |
default_hooks = dict( | |
# record the time of every iteration. | |
timer=dict(type=IterTimerHook), | |
# print log every 10 iterations. | |
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), | |
# enable the parameter scheduler. | |
param_scheduler=dict(type=ParamSchedulerHook), | |
# save checkpoint per `save_steps`. | |
checkpoint=dict( | |
type=CheckpointHook, | |
by_epoch=False, | |
interval=save_steps, | |
max_keep_ckpts=save_total_limit), | |
# set sampler seed in distributed evrionment. | |
sampler_seed=dict(type=DistSamplerSeedHook), | |
) | |
# configure environment | |
env_cfg = dict( | |
# whether to enable cudnn benchmark | |
cudnn_benchmark=False, | |
# set multi process parameters | |
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | |
# set distributed parameters | |
dist_cfg=dict(backend='nccl'), | |
) | |
# set visualizer | |
visualizer = None | |
# set log level | |
log_level = 'INFO' | |
# load from which checkpoint | |
load_from = None | |
# whether to resume training from the loaded checkpoint | |
resume = False | |
# Defaults to use random seed and disable `deterministic` | |
randomness = dict(seed=None, deterministic=False) | |
# set log processor | |
log_processor = dict(by_epoch=False) | |