Edit model card

This is the meta llama 3.1 8B model merged with the vision encoder by google. I have made this model using the lava repo. except this has not been trained on any data yet other than the images to embed the vision encoder. With --image_aspect_ratio anyres_max_9

Installation

1. Clone this repository and navigate to the LLaVA folder:

git clone https://github.com/LLaVA-VL/LLaVA-NeXT
cd LLaVA-NeXT

2. Install the inference package:

You may have to edit the requirments txt file that is downloaded to remove the version requirments to match below. install fash attn after

pip3 install --upgrade pip  # Enable PEP 660 support.
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
pip3 install -e ".[train]"
pip3 install --upgrade transformers
install flash attention from source..
https://github.com/Dao-AILab/flash-attention

Replace the conversation.py file with the one I have placed in the repo.

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import (
    IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN, IGNORE_INDEX
)
from llava.conversation import conv_templates, SeparatorStyle
from datetime import date
from PIL import Image
import copy
import torch

# Model and device configurations
pretrained = "mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-reflection-one-vision-pretrain"
model_name = "llava_llama3"
device = "cuda"
device_map = "auto"

# Load the pretrained model
tokenizer, model, image_processor, max_length = load_pretrained_model(
    pretrained, None, model_name, device_map=device_map, attn_implementation="flash_attention_2" 
)

# Set model to evaluation mode and tie weights
model.eval()
model.tie_weights()

# Load and process the image
image = Image.open("/home/myles/Desktop/ezgif-2-0a63da8f13.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
date_string: str = date.today().strftime("%d %b %Y")
# Conversation template and question
conv_template = "llava_llama_3_1"  
question = (
    DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?"
)

# Set up conversation and prepare the prompt
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

# Tokenize the prompt and prepare input for the model
input_ids = tokenizer_image_token(
    prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0).to(device)
image_sizes = [image.size]


# Generate output from the model
cont = model.generate(
    input_ids,
    images=image_tensor,
    image_sizes=image_sizes,
    do_sample=True,
    temperature=0.9, 
    max_new_tokens=3000,
)

# Decode and print the generated output
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)
LLM_VERSION="mylesgoose/Meta-Llama-3.1-8B-Instruct-goose-abliterated-reflection" 
LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"

############### Pretrain ################
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain"
PROMPT_VERSION="llava_llama_3_1"
MID_RUN_NAME="${BASE_RUN_NAME}--one-vision-pretrain"

echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
CKPT_PATH=$LLM_VERSION 

accelerate launch llava/train/train_mem.py \
    --deepspeed scripts/zero3.json \
    --model_name_or_path ${CKPT_PATH} \
    --version ${PROMPT_VERSION} \
    --data_path ./json/blip_laion_cc_sbu_558k.json \
    --image_folder ./data/images \
    --video_folder ./data/videos \
    --pretrain_mm_mlp_adapter "./checkpoints/projectors/llavanext-google_siglip-so400m-patch14-384-mylesgoose_Meta-Llama-3.1-8B-Instruct-goose-abliterated-mlp2x_gelu-pretrain_blip558k_plain/mm_projector.bin" \
    --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
    --mm_vision_tower_lr=2e-6 \
    --vision_tower ${VISION_MODEL_VERSION} \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --group_by_modality_length True \
    --image_aspect_ratio anyres_max_9 \
    --image_grid_pinpoints  "(1x1),...,(6x6)" \
    --mm_patch_merge_type spatial_unpad \
    --bf16 True \
    --run_name $MID_RUN_NAME \
    --output_dir "./checkpoints/${MID_RUN_NAME}" \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 0 \
    --gradient_accumulation_steps 20 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 20 \
    --save_total_limit 2 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 12768 \
    --gradient_checkpointing True \
    --dataloader_num_workers 10 \
    --lazy_preprocess True \
    --report_to wandb \
    --torch_compile True \
    --torch_compile_backend "inductor" \
    --dataloader_drop_last True \
    --frames_upbound 32 \
    --attn_implementation flash_attention_2

here is my pip3 list pip3 list Package Version Editable project location


absl-py 2.1.0 accelerate 0.34.2 aiofiles 22.1.0 aiohappyeyeballs 2.4.0 aiohttp 3.10.5 aiosignal 1.3.1 aiosqlite 0.20.0 altair 5.4.1 annotated-types 0.7.0 anyio 4.4.0 appdirs 1.4.4 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 async-timeout 4.0.3 attrs 24.2.0 audioread 3.0.1 av 13.0.0 babel 2.16.0 beartype 0.14.1 beautifulsoup4 4.12.3 better-abc 0.0.3 bidict 0.23.1 bitsandbytes 0.43.3 black 24.1.0 bleach 6.1.0 Brotli 1.1.0 cachetools 5.5.0 certifi 2024.8.30 cffi 1.17.1 cfgv 3.4.0 chardet 5.2.0 charset-normalizer 3.3.2 click 8.1.7 cmake 3.30.2 colorama 0.4.6 comm 0.2.2 contourpy 1.3.0 crcmod 1.7 cryptography 43.0.1 cuda-python 12.4.0 /home/myles/cuda-python-12.4.0 cycler 0.12.1 Cython 3.0.11 DataProperty 1.0.1 datasets 3.0.0 debugpy 1.8.5 decorator 5.1.1 decord 0.6.0 deepspeed 0.15.2+fc22d960 deepspeed-kernels 0.0.1.dev1698255861 defusedxml 0.7.1 Deprecated 1.2.14 dill 0.3.7 distlib 0.3.8 distro 1.9.0 dnspython 2.6.1 docker-pycreds 0.4.0 docopt 0.6.2 docstring_parser 0.16 e 1.4.5 einops 0.8.0 einops-exts 0.0.4 entrypoints 0.4 et-xmlfile 1.1.0 eval_type_backport 0.2.0 evaluate 0.4.2 exceptiongroup 1.2.2 executing 2.1.0 fancy-einsum 0.0.3 fastapi 0.112.4 fastjsonschema 2.20.0 ffmpeg-python 0.2.0 ffmpy 0.4.0 filelock 3.16.0 flash_attn 2.6.3 flatbuffers 24.3.25 fonttools 4.53.1 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.10.0 ftfy 6.2.3 future 1.0.0 gast 0.6.0 gitdb 4.0.11 GitPython 3.1.43 google-pasta 0.2.0 gradio 4.43.0 gradio_client 1.3.0 graphviz 0.20.3 grpcio 1.66.1 h11 0.14.0 h5py 3.11.0 hf_transfer 0.1.8 hjson 3.1.0 httpcore 1.0.5 httpx 0.27.2 huggingface-hub 0.24.6 identify 2.6.0 idna 3.8 importlib_metadata 8.4.0 importlib_resources 6.4.4 iniconfig 2.0.0 ipaddress 1.0.23 ipykernel 6.29.5 ipython 8.27.0 ipython-genutils 0.2.0 ipywidgets 8.1.5 isoduration 20.11.0 isort 5.13.2 jaxtyping 0.2.34 jedi 0.19.1 Jinja2 3.1.4 jiter 0.5.0 joblib 1.4.2 json5 0.9.25 jsonlines 4.0.0 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2023.12.1 jupyter 1.1.1 jupyter_client 8.6.2 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.2 jupyter_server_fileid 0.9.3 jupyter_server_terminals 0.5.3 jupyter_server_ydoc 0.8.0 jupyter-ydoc 0.3.4 jupyterlab 4.2.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.13 keras 3.5.0 kiwisolver 1.4.7 latex2mathml 3.77.0 lazy_loader 0.4 Levenshtein 0.25.1 libclang 18.1.1 librosa 0.10.2.post1 linkify-it-py 2.0.3 llava 1.7.0.dev0 /home/myles/LLaVA-NeXT llvmlite 0.43.0 lmms_eval 0.2.3 /home/myles/lmms-eval loguru 0.7.2 lxml 5.3.0 Markdown 3.7 markdown-it-py 3.0.0 markdown2 2.5.0 MarkupSafe 2.1.5 matplotlib 3.9.2 matplotlib-inline 0.1.7 mbstrdecoder 1.1.3 mdit-py-plugins 0.4.1 mdurl 0.1.2 mistune 3.0.2 ml-dtypes 0.4.0 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiprocess 0.70.15 mutagen 1.47.0 mypy-extensions 1.0.0 namex 0.0.8 narwhals 1.6.2 nbclassic 1.1.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.3 ninja 1.11.1.1 nltk 3.9.1 nodeenv 1.9.1 notebook 7.2.2 notebook_shim 0.2.4 num2words 0.5.13 numba 0.60.0 numexpr 2.10.1 numpy 1.26.4 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-cutlass 3.5.1.0 /home/myles/cutlass nvidia-ml-py 12.560.30 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 nvidia-pyindex 1.0.9 open_clip_torch 2.26.1 openai 1.44.0 opencv-python 4.10.0.84 opencv-python-headless 4.10.0.84 openpyxl 3.1.5 opt-einsum 3.3.0 optree 0.12.1 orjson 3.10.7 overrides 7.7.0 packaging 24.1 pandas 2.2.2 pandocfilters 1.5.1 parso 0.8.4 pathlib2 2.3.7.post1 pathspec 0.12.1 pathvalidate 3.2.1 peft 0.12.0 pexpect 4.9.0 Pillow 10.1.0 pip 24.2 platformdirs 4.3.1 pluggy 1.5.0 ply 3.11 pooch 1.8.2 portalocker 2.10.1 pre-commit 3.8.0 prometheus_client 0.20.0 promise 2.3 prompt_toolkit 3.0.47 protobuf 4.25.4 psutil 6.0.0 ptyprocess 0.7.0 pure_eval 0.2.3 py 1.11.0 py-cpuinfo 9.0.0 py-spy 0.3.14 pyarrow 17.0.0 pyarrow-hotfix 0.6 pybind11 2.13.5 pycocoevalcap 1.2 pycocotools 2.0.8 pycparser 2.22 pycryptodomex 3.20.0 pydantic 2.9.0 pydantic_core 2.23.2 pydot 3.0.1 pydub 0.25.1 Pygments 2.18.0 PyJWT 2.9.0 pynndescent 0.5.13 pynvml 11.5.3 pyOpenSSL 24.2.1 pyparsing 3.1.4 pyproject-api 1.7.1 pytablewriter 1.2.0 pytest 8.3.2 python-consul 1.1.0 python-dateutil 2.9.0.post0 python-engineio 4.9.1 python-etcd 0.4.5 python-json-logger 2.0.7 python-multipart 0.0.9 python-socketio 5.11.4 pytorch-triton 3.0.0+757b6a61e7 pytz 2024.1 PyYAML 6.0.2 pyzmq 26.2.0 qtconsole 5.6.0 QtPy 2.4.1 rapidfuzz 3.9.7 referencing 0.35.1 regex 2024.7.24 requests 2.32.3 responses 0.25.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.8.0 ring_flash_attn 0.1 /home/myles/ring-flash-attention rouge_score 0.1.2 rpds-py 0.20.0 ruff 0.6.4 sacrebleu 2.4.3 safetensors 0.4.5 schedule 1.2.2 scikit-learn 1.5.1 scipy 1.14.1 seaborn 0.13.2 semantic-version 2.10.0 Send2Trash 1.8.3 sentencepiece 0.2.0 sentry-sdk 2.13.0 setproctitle 1.3.3 setuptools 70.2.0 shellingham 1.5.4 shortuuid 1.0.13 shtab 1.7.1 simple-websocket 1.0.0 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 sounddevice 0.5.0 soundfile 0.12.1 soupsieve 2.6 soxr 0.5.0.post1 sqlitedict 2.1.0 stack-data 0.6.3 starlette 0.38.4 svgwrite 1.4.3 sympy 1.13.1 tabledata 1.3.3 tabulate 0.9.0 tcolorpy 0.1.6 tenacity 9.0.0 tensorboard 2.17.1 tensorboard-data-server 0.7.2 termcolor 2.4.0 terminado 0.18.1 tf_keras 2.17.0 threadpoolctl 3.5.0 thriftpy2 0.5.2 tiktoken 0.7.0 timm 1.0.9 tinycss2 1.3.0 tokenizers 0.19.1 toml 0.10.2 tomli 2.0.1 tomlkit 0.12.0 toolz 0.12.1 torch 2.5.0.dev20240907+cu124 torchaudio 2.5.0.dev20240907+cu124 torchvision 0.20.0.dev20240907+cu124 tornado 6.4.1 tox 4.18.1 tqdm 4.66.5 tqdm-multiprocess 0.0.11 traitlets 5.14.3 transformer-lens 2.4.1 transformers 4.45.0.dev0 /home/myles/transformers transformers-stream-generator 0.0.5 treelib 1.7.0 triton 3.0.0 typeguard 2.13.3 typepy 1.3.2 typer 0.12.5 types-python-dateutil 2.9.0.20240906 typing_extensions 4.12.2 tyro 0.8.10 tzdata 2024.1 uc-micro-py 1.0.3 umap-learn 0.5.6 Unidecode 1.3.8 uri-template 1.3.0 urllib3 2.2.2 uvicorn 0.30.6 virtualenv 20.26.4 wandb 0.17.9 watchdog 5.0.2 wavedrom 2.0.3.post3 wcwidth 0.2.13 webcolors 24.8.0 webencodings 0.5.1 websocket-client 1.8.0 websockets 12.0 Werkzeug 3.0.4 wheel 0.44.0 widgetsnbextension 4.0.13 wrapt 1.16.0 wsproto 1.2.0 xxhash 3.5.0 y-py 0.6.2 yarl 1.10.0 ypy-websocket 0.8.4 yt-dlp 2024.8.6 zipp 3.20.1 zss 1.2.0 zstandard 0.23.0

Downloads last month
4
Safetensors
Model size
8.45B params
Tensor type
BF16
·
Inference API
Unable to determine this model's library. Check the docs .