I Trained the llama 3.1 model integrating the google vison encoder. This is a base model It has not been trained on images the model itself, this modeel would be useefull to train on your own image datasets. It has only the encoder integrated into it. It has not been trained on any closed source datasets. Other than what is listed, for some reason its listing the japanese verison of the dataset above.. Install https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main prior to running below. Thanks to that team for their fantastic work. you should install as an editable install as you will need to modify the conversation.py file to point to this repo instead of the llama 3.0 repo for the tokenizer etc pip install -e ".[train]"
you can test with something like this. download this image and place into the path below in script or use your own image.
Models first ouput?: ["The image shows a man in a yellow shirt and shorts sitting on the hood of a car with a clothes iron and ironing board in the back.\nThis is a common sight to see in many cities, especially in major cities like new york, where ironing clothes is a common activity for people to carry out while they are at home.\nHowever, this image is a little unusual because the man is ironing clothes on top of the car.\nIt is not unusual to see people ironing clothes while driving, but this is a rare sight.\nThis image is also unusual because the person is sitting on the hood of the car with their clothes in the back, and it seems that they are using an ironing board.\nThe man in the image is wearing a yellow shirt and shorts, and his pants and shirt appear to be in a bag on the hood.\nThe man is sitting on the car with the ironing board, which has a steamer, an ironing board, and clothes.\nThis image is unusual because it is a picture of a man in the middle of ironing clothes, and it's also unusual because the car is driving down a street.\nThe man is using an ironing board with a steamer and clothes, and is sitting on the hood of the"]
"""" from llava.model.builder import load_pretrained_model from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX from llava.conversation import conv_templates, SeparatorStyle
from PIL import Image import requests import copy import torch
pretrained = "mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre-llava" model_name = "llava_llama3" device = "cuda" device_map = "auto" tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args
model.eval() model.tie_weights()
image = Image.open("/home/myles/Desktop/extreme_ironing.jpg") image_tensor = process_images([image], image_processor, model.config) image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
conv_template = "llava_llama_3" # Make sure you use correct chat template for different models, you will also need to modify the conversation.py file to point to this repo isntead of the 3.0 repo, you need transfomers versison above a certain one or you get tokenization error. question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image? Is there anything strange about this image? Is this normal behaviour" conv = copy.deepcopy(conv_templates[conv_template]) conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) image_sizes = [image.size]
cont = model.generate( input_ids, images=image_tensor, image_sizes=image_sizes, do_sample=True, temperature=0.9, max_new_tokens=256, ) text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) print(text_outputs) """
LLM_VERSION="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated" LLM_VERSION_CLEAN="${LLM_VERSION////}" VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION////}"
############### Pretrain ################
PROMPT_VERSION=plain
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
deepspeed llava/train/train_mem.py
--deepspeed scripts/zero3.json
--model_name_or_path ${LLM_VERSION}
--version ${PROMPT_VERSION}
--data_path ./data/llava_data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json
--image_folder ./data/llava_data/LLaVA-Pretrain/images
--vision_tower ${VISION_MODEL_VERSION}
--mm_tunable_parts="mm_mlp_adapter"
--mm_vision_select_layer -2
--mm_projector_type mlp2x_gelu
--mm_use_im_start_end False
--mm_use_im_patch_token False
--bf16 True
--output_dir ./checkpoints/projectors/${BASE_RUN_NAME}
--num_train_epochs 1
--per_device_train_batch_size 6
--per_device_eval_batch_size 6
--gradient_accumulation_steps 6
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 500
--learning_rate 1e-3
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--tf32 True
--model_max_length 131072
--gradient_checkpointing True
--dataloader_num_workers 6
--lazy_preprocess True
--report_to wandb
--run_name $BASE_RUN_NAME
--attn_implementation flash_attention_2
here is my pip list for my system. if you use flash attention you must download it after installig above and compile and install the wheel from source. Other wise when you isntall it will download the older version of transformers etc pytorch
pip3 list
Package Version Editable project location
absl-py 2.1.0 accelerate 0.34.2 aiofiles 22.1.0 aiohappyeyeballs 2.4.0 aiohttp 3.10.5 aiosignal 1.3.1 aiosqlite 0.20.0 altair 5.4.1 annotated-types 0.7.0 anyio 4.4.0 appdirs 1.4.4 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 async-timeout 4.0.3 attrs 24.2.0 audioread 3.0.1 av 13.0.0 babel 2.16.0 beartype 0.14.1 beautifulsoup4 4.12.3 better-abc 0.0.3 bidict 0.23.1 bitsandbytes 0.43.3 black 24.1.0 bleach 6.1.0 Brotli 1.1.0 cachetools 5.5.0 certifi 2024.8.30 cffi 1.17.1 cfgv 3.4.0 chardet 5.2.0 charset-normalizer 3.3.2 click 8.1.7 cmake 3.30.2 colorama 0.4.6 comm 0.2.2 contourpy 1.3.0 crcmod 1.7 cryptography 43.0.1 cuda-python 12.4.0 /home/myles/cuda-python-12.4.0 cycler 0.12.1 Cython 3.0.11 DataProperty 1.0.1 datasets 2.16.1 debugpy 1.8.5 decorator 5.1.1 decord 0.6.0 deepspeed 0.15.2+fc22d960 deepspeed-kernels 0.0.1.dev1698255861 defusedxml 0.7.1 Deprecated 1.2.14 dill 0.3.7 distlib 0.3.8 distro 1.9.0 dnspython 2.6.1 docker-pycreds 0.4.0 docopt 0.6.2 docstring_parser 0.16 e 1.4.5 einops 0.8.0 einops-exts 0.0.4 entrypoints 0.4 et-xmlfile 1.1.0 eval_type_backport 0.2.0 evaluate 0.4.2 exceptiongroup 1.2.2 executing 2.1.0 fancy-einsum 0.0.3 fastapi 0.112.4 fastjsonschema 2.20.0 ffmpeg-python 0.2.0 ffmpy 0.4.0 filelock 3.16.0 flash_attn 2.6.3 flatbuffers 24.3.25 fonttools 4.53.1 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.10.0 ftfy 6.2.3 future 1.0.0 gast 0.6.0 gitdb 4.0.11 GitPython 3.1.43 google-pasta 0.2.0 gradio 4.43.0 gradio_client 1.3.0 graphviz 0.20.3 grpcio 1.66.1 h11 0.14.0 h5py 3.11.0 hf_transfer 0.1.8 hjson 3.1.0 httpcore 1.0.5 httpx 0.27.2 huggingface-hub 0.24.6 identify 2.6.0 idna 3.8 importlib_metadata 8.4.0 importlib_resources 6.4.4 iniconfig 2.0.0 ipaddress 1.0.23 ipykernel 6.29.5 ipython 8.27.0 ipython-genutils 0.2.0 ipywidgets 8.1.5 isoduration 20.11.0 isort 5.13.2 jaxtyping 0.2.34 jedi 0.19.1 Jinja2 3.1.4 jiter 0.5.0 joblib 1.4.2 json5 0.9.25 jsonlines 4.0.0 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 jupyter 1.1.1 jupyter_client 8.6.2 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.2 jupyter_server_fileid 0.9.3 jupyter_server_terminals 0.5.3 jupyter_server_ydoc 0.8.0 jupyter-ydoc 0.3.4 jupyterlab 4.2.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.13 keras 3.5.0 kiwisolver 1.4.7 latex2mathml 3.77.0 lazy_loader 0.4 Levenshtein 0.25.1 libclang 18.1.1 librosa 0.10.2.post1 linkify-it-py 2.0.3 llava 1.7.0.dev0 /home/myles/LLaVA-NeXT llvmlite 0.43.0 lmms_eval 0.2.3 /home/myles/lmms-eval loguru 0.7.2 lxml 5.3.0 Markdown 3.7 markdown-it-py 3.0.0 markdown2 2.5.0 MarkupSafe 2.1.5 matplotlib 3.9.2 matplotlib-inline 0.1.7 mbstrdecoder 1.1.3 mdit-py-plugins 0.4.1 mdurl 0.1.2 mistune 3.0.2 ml-dtypes 0.4.0 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.15 mutagen 1.47.0 mypy-extensions 1.0.0 namex 0.0.8 narwhals 1.6.2 nbclassic 1.1.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.3 ninja 1.11.1.1 nltk 3.9.1 nodeenv 1.9.1 notebook 7.2.2 notebook_shim 0.2.4 num2words 0.5.13 numba 0.60.0 numexpr 2.10.1 numpy 1.26.4 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-cutlass 3.5.1.0 /home/myles/cutlass nvidia-ml-py 12.560.30 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 nvidia-pyindex 1.0.9 open_clip_torch 2.26.1 openai 1.44.0 opencv-python 4.10.0.84 opencv-python-headless 4.10.0.84 openpyxl 3.1.5 opt-einsum 3.3.0 optree 0.12.1 orjson 3.10.7 overrides 7.7.0 packaging 24.1 pandas 2.2.2 pandocfilters 1.5.1 parso 0.8.4 pathlib2 2.3.7.post1 pathspec 0.12.1 pathvalidate 3.2.1 peft 0.12.0 pexpect 4.9.0 Pillow 10.1.0 pip 24.2 platformdirs 4.3.1 pluggy 1.5.0 ply 3.11 pooch 1.8.2 portalocker 2.10.1 pre-commit 3.8.0 prometheus_client 0.20.0 promise 2.3 prompt_toolkit 3.0.47 protobuf 4.25.4 psutil 6.0.0 ptyprocess 0.7.0 pure_eval 0.2.3 py 1.11.0 py-cpuinfo 9.0.0 py-spy 0.3.14 pyarrow 17.0.0 pyarrow-hotfix 0.6 pybind11 2.13.5 pycocoevalcap 1.2 pycocotools 2.0.8 pycparser 2.22 pycryptodomex 3.20.0 pydantic 2.9.0 pydantic_core 2.23.2 pydot 3.0.1 pydub 0.25.1 Pygments 2.18.0 PyJWT 2.9.0 pynndescent 0.5.13 pynvml 11.5.3 pyOpenSSL 24.2.1 pyparsing 3.1.4 pyproject-api 1.7.1 pytablewriter 1.2.0 pytest 8.3.2 python-consul 1.1.0 python-dateutil 2.9.0.post0 python-engineio 4.9.1 python-etcd 0.4.5 python-json-logger 2.0.7 python-multipart 0.0.9 python-socketio 5.11.4 pytorch-triton 3.0.0+757b6a61e7 pytz 2024.1 PyYAML 6.0.2 pyzmq 26.2.0 qtconsole 5.6.0 QtPy 2.4.1 rapidfuzz 3.9.7 referencing 0.35.1 regex 2024.7.24 requests 2.32.3 responses 0.25.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.8.0 ring_flash_attn 0.1 /home/myles/ring-flash-attention rouge_score 0.1.2 rpds-py 0.20.0 ruff 0.6.4 sacrebleu 2.4.3 safetensors 0.4.5 schedule 1.2.2 scikit-learn 1.5.1 scipy 1.14.1 semantic-version 2.10.0 Send2Trash 1.8.3 sentencepiece 0.2.0 sentry-sdk 2.13.0 setproctitle 1.3.3 setuptools 70.2.0 shellingham 1.5.4 shortuuid 1.0.13 shtab 1.7.1 simple-websocket 1.0.0 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 sounddevice 0.5.0 soundfile 0.12.1 soupsieve 2.6 soxr 0.5.0.post1 sqlitedict 2.1.0 stack-data 0.6.3 starlette 0.38.4 svgwrite 1.4.3 sympy 1.13.1 tabledata 1.3.3 tabulate 0.9.0 tcolorpy 0.1.6 tenacity 9.0.0 tensorboard 2.17.1 tensorboard-data-server 0.7.2 tensorflow 2.17.0 termcolor 2.4.0 terminado 0.18.1 threadpoolctl 3.5.0 thriftpy2 0.5.2 tiktoken 0.7.0 timm 1.0.9 tinycss2 1.3.0 tokenizers 0.19.1 toml 0.10.2 tomli 2.0.1 tomlkit 0.12.0 toolz 0.12.1 torch 2.5.0.dev20240907+cu124 torchaudio 2.5.0.dev20240907+cu124 torchvision 0.20.0.dev20240907+cu124 tornado 6.4.1 tox 4.18.1 tqdm 4.66.5 tqdm-multiprocess 0.0.11 traitlets 5.14.3 transformer-lens 2.4.1 transformers 4.45.0.dev0 /home/myles/transformers transformers-stream-generator 0.0.5 treelib 1.7.0 triton 3.0.0 typeguard 2.13.3 typepy 1.3.2 typer 0.12.5 types-python-dateutil 2.9.0.20240906 typing_extensions 4.12.2 tyro 0.8.10 tzdata 2024.1 uc-micro-py 1.0.3 umap-learn 0.5.6 Unidecode 1.3.8 uri-template 1.3.0 urllib3 2.2.2 uvicorn 0.30.6 virtualenv 20.26.4 wandb 0.17.9 watchdog 5.0.2 wavedrom 2.0.3.post3 wcwidth 0.2.13 webcolors 24.8.0 webencodings 0.5.1 websocket-client 1.8.0 websockets 12.0 Werkzeug 3.0.4 wheel 0.44.0 widgetsnbextension 4.0.13 wrapt 1.16.0 wsproto 1.2.0 xxhash 3.5.0 y-py 0.6.2 yarl 1.10.0 ypy-websocket 0.8.4 yt-dlp 2024.8.6 zipp 3.20.1 zss 1.2.0 zstandard 0.23.0
I Just tested a training run for this model using LLAVA next one vison repo above on a fresh install. from the pip list above dont install that tensorflow version. To train on your own image datasets for your use case. you would need to first adjust thee conversation.py in the llava folder to this:
conv_llava_llama_3 = Conversation( system="You are a helpful language and vision, AI. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("user", "assistant"), version="llama_v3", messages=[], offset=0, sep="<|eot_id|>", sep_style=SeparatorStyle.LLAMA_3, tokenizer_id="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre-llava", tokenizer=safe_load_tokenizer("mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre-llava"), stop_token_ids=[128009], )
And here is an example of a training script. You can replace the json dataset with the one you want to train your model on. LLM_VERSION="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre-llava" LLM_VERSION_CLEAN="${LLM_VERSION////}" VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION////}"
############### Pretrain ################ PROMPT_VERSION=llava_llama_3
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" PRE_RUN_NAME="${BASE_RUN_NAME}-synthdog_en" CKPT_PATH=$LLM_VERSION
accelerate launch llava/train/train_mem.py
--deepspeed scripts/zero3.json
--model_name_or_path ${CKPT_PATH}
--version ${PROMPT_VERSION}
--data_path ./data/synthdog_en/synthdog_en_processed.json
--image_folder ./data/synthdog_en
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model"
--mm_vision_tower_lr=2e-6
--vision_tower ${VISION_MODEL_VERSION}
--mm_projector_type mlp2x_gelu
--mm_vision_select_layer -2
--mm_use_im_start_end False
--mm_use_im_patch_token False
--group_by_modality_length True
--image_aspect_ratio anyres
--image_grid_pinpoints "[(384, 768), (768, 384), (768, 768), (1152, 384), (384, 1152)]"
--mm_patch_merge_type spatial_unpad
--bf16 True
--output_dir "./checkpoints/${PRE_RUN_NAME}"
--num_train_epochs 1
--per_device_train_batch_size 6
--per_device_eval_batch_size 0
--gradient_accumulation_steps 6
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 5
--save_total_limit 2
--learning_rate 1e-5
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--tf32 True
--gradient_checkpointing True
--dataloader_num_workers 2
--lazy_preprocess True
--report_to wandb
--torch_compile True
--torch_compile_backend "inductor"
--dataloader_drop_last True
--attn_implementation flash_attention_2
--run_name ${PRE_RUN_NAME}
- Downloads last month
- 8
Model tree for mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-pre-llava
Base model
meta-llama/Llama-3.1-8B