Training Questions
Dear authors,
Thank you for releasing this VLM.
I was trying the model for some image classification task. It is working to some extend but right now, it cannot really compete with GPT4v.
Is there any insights you can give me for training the model myself or can you even share the training code?
What hardware would be necessary for training the 8B model or the 70B model?
I would highly appreciate your reply.
Best,
Chris
I am not the maker of this repo. yet for some guidence for you I pretrained a 8b model on 5rtx 4090 in 30 hours set at 390w each. using 26g of data instead of the 6gb given in the script on the llava page. they give you all the times to train. it says 20hours on 8a100 gpus or 5hours to pretrain on 8a100. haev a look here for some refernce google github LLaVA-OneVision ther are traing scritp tehr and palces to get the data. here is an example of one of their scripts . this is for the lattest model versions. perhaps this one is an earlier 1.5 llava im not sure.
export OMP_NUM_THREADS=8
export NCCL_IB_DISABLE=0
export NCCL_IB_GID_INDEX=3
export NCCL_SOCKET_IFNAME=eth0
export NCCL_DEBUG=INFO
LLM_VERSION="Qwen/Qwen2-7B-Instruct"
LLM_VERSION_CLEAN="${LLM_VERSION////}"
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION////}"
############### Pretrain ################
PROMPT_VERSION=plain
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain"
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}"
llava/train/train_mem.py
--deepspeed scripts/zero3.json
--model_name_or_path ${LLM_VERSION}
--version ${PROMPT_VERSION}
--data_path /blip_558k/blip_558k_plain.json
--image_folder /blip_558k/images
--vision_tower ${VISION_MODEL_VERSION}
--mm_tunable_parts="mm_mlp_adapter"
--mm_vision_select_layer -2
--mm_projector_type mlp2x_gelu
--mm_use_im_start_end False
--mm_use_im_patch_token False
--bf16 True
--output_dir /checkpoints/projectors/${BASE_RUN_NAME}
--num_train_epochs 1
--per_device_train_batch_size 16
--per_device_eval_batch_size 4
--gradient_accumulation_steps 1
--evaluation_strategy "no"
--save_strategy "no"
--save_steps 50000
--learning_rate 1e-3
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--tf32 True
--model_max_length 8192
--gradient_checkpointing True
--dataloader_num_workers 16
--lazy_preprocess True
--report_to wandb
--run_name $BASE_RUN_NAME
--attn_implementation sdpa