diff --git a/TinyLLaVA_Factory/.gitignore b/TinyLLaVA_Factory/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d015f47231eb5fc0758f08fb3fa041e025a4a042 --- /dev/null +++ b/TinyLLaVA_Factory/.gitignore @@ -0,0 +1,63 @@ +# These are some examples of commonly ignored file patterns. +# You should customize this list as applicable to your project. +# Learn more about .gitignore: +# https://www.atlassian.com/git/tutorials/saving-changes/gitignore + +# Node artifact files +node_modules/ +dist/ + +# Compiled Java class files +*.class + +# Compiled Python bytecode +*.py[cod] + +# Log files +*.log + +# Package files +*.jar + +# Maven +target/ +dist/ + +# JetBrains IDE +.idea/ + +# Unit test reports +TEST*.xml + +# Generated by MacOS +.DS_Store + +# Generated by Windows +Thumbs.db + +# Applications +*.app +*.exe +*.war + +# Large media files +*.mp4 +*.tiff +*.avi +*.flv +*.mov +*.wmv + +# +.ipynb_checkpoints +__pycache__ +*.egg-info +.vscode/* +.idea/* +playground/ +wandb/* +checkpoints/* +.ipynb_checkpoints/* +scripts/.ipynb_checkpoints/* +test/* +output/* diff --git a/TinyLLaVA_Factory/CUSTOM_FINETUNE.md b/TinyLLaVA_Factory/CUSTOM_FINETUNE.md new file mode 100644 index 0000000000000000000000000000000000000000..86349d8c9e8de77cbd83b068d53e040d28eade69 --- /dev/null +++ b/TinyLLaVA_Factory/CUSTOM_FINETUNE.md @@ -0,0 +1,96 @@ +# Finetune TinyLLaVA with Custom Datasets + +Given the needs of finetuning with custom datasets, we provide a tutorial on how to custom finetune on our trained model, e.g. tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B (HF path). + +## Dataset Format + +Convert your data to a JSON file of a List of all samples. Sample metadata should contain `id` (a unique identifier), `image` (the path to the image), and `conversations` (the conversation data between human and AI). + +Here's an example of the [pokemon dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) turned into the data format: + +```json +[ + { + "id": "meiKqU2auAVK2vrtLhKGoJ", + "image": "pokemon/image/meiKqU2auAVK2vrtLhKGoJ.jpg", + "conversations": [ + { + "from": "human", + "value": "\nProvide a brief description of the given image." + }, + { + "from": "gpt", + "value": "a drawing of a green pokemon with red eyes" + } + ] + } +] +``` + +
+You can use the following scripts to convert the Pokemon dataset to the above data format. +converting data format + +```python +import shortuuid +from datasets import load_dataset +from PIL import Image +import random +import json +import tqdm +import os + +ds = load_dataset('lambdalabs/pokemon-blip-captions') +pokemon_data = [] + +pokemon_image_path = '/path/to/your/data/pokemon/image' +pokemon_data_path = '/path/to/your/pokemon_blip_captions.json' + +description_list = [ + "Describe the image concisely.", + "Provide a brief description of the given image.", + "Offer a succinct explanation of the picture presented.", + "Summarize the visual content of the image.", + "Give a short and clear explanation of the subsequent image.", + "Share a concise interpretation of the image provided.", + "Present a compact description of the photo's key features.", + "Relay a brief, clear account of the picture shown.", + "Render a clear and concise summary of the photo.", + "Write a terse but informative summary of the picture.", + "Create a compact narrative representing the image presented." +] + +for sample in tqdm.tqdm(ds['train']): + uuid = shortuuid.uuid() + sample_dict = dict() + sample_dict['id'] = uuid + sample_dict['image'] = 'pokemon/image/' + uuid + '.jpg' + sample['image'].save(os.path.join(pokemon_image_path, uuid + '.jpg')) + conversations = [ + {"from": "human", "value": "\n" + random.choice(description_list)}, + {"from": "gpt", "value": sample['text']} + ] + sample_dict['conversations'] = conversations + pokemon_data.append(sample_dict) + +with open(pokemon_data_path, 'w') as f: + json.dump(pokemon_data, f, indent=4) +``` + +
+ +## Custom Finetune +After acquiring the dataset following the above data format, you can finetune our trained model TinyLLaVA-Phi-2-SigLIP-3.1B checkpoint by using lora. + +- Replace data paths and `output_dir` with yours in `scripts/train/custom_finetune.sh` +- Adjust your GPU ids (localhost) and `per_device_train_batch_size` in `scripts/train/custom_finetune.sh`. + +```bash +bash scripts/train/custom_finetune.sh +``` + +## Evaluation with Custom Finetuned Model +All of the models trained by TinyLLaVA Factory have the same evaluation procedure, no matter it is trained through custom finetune or through normal training. Please see the [Evaluation](https://tinyllava-factory.readthedocs.io/en/latest/Evaluation.html) section in our Doc. + + + diff --git a/TinyLLaVA_Factory/LICENSE b/TinyLLaVA_Factory/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ddb59dfd4c7db6c3fff68e9fc46196d57021e0bc --- /dev/null +++ b/TinyLLaVA_Factory/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [TinyLLaVA] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/TinyLLaVA_Factory/README.md b/TinyLLaVA_Factory/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e40f13a38897da2f31a217b90a6beb4b52f87c5f --- /dev/null +++ b/TinyLLaVA_Factory/README.md @@ -0,0 +1,428 @@ +

TinyLLaVA Factory

+ +[![hf_space](https://img.shields.io/badge/🤗-%20Open%20In%20HF-blue.svg)](https://huggingface.co/tinyllava) [![arXiv](https://img.shields.io/badge/Arxiv-2402.14289-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2402.14289) [![arXiv](https://img.shields.io/badge/Arxiv-2405.11788-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2405.11788)[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/TinyLLaVA/TinyLLaVA_Factory/blob/main/LICENSE) [![Doc](https://img.shields.io/badge/Doc-Document-logo=read%20the%20docs&logoColor=white&label=Doc)](https://tinyllava-factory.readthedocs.io/en/latest/) [![Demo](https://img.shields.io/badge/Demo-Demo-red.svg)](http://8843843nmph5.vicp.fun/#/) + +![architecture](./assets/architecture.jpg) + +## 🎉 News +* **[2024.08.13]** A simple [visualizaiton tool](https://github.com/TinyLLaVA/TinyLLaVA_Factory/tree/main/tinyllava_visualizer) for interpreting the prediction of TinyLLaVA is added. +* **[2024.05.21]** Our paper: [TinyLLaVA Factory: A Modularized Codebase for Small-scale Large Multimodal Models](https://arxiv.org/abs/2405.11788) is released! +* **[2024.05.15]** [TinyLLaVA Factory](https://github.com/TinyLLaVA/TinyLLaVA_Factory), our new codebase, is released! **Note that the old codebase, TinyLLaVABench, is moved to the [tinyllava_bench](https://github.com/TinyLLaVA/TinyLLaVA_Factory/tree/tinyllava_bench) branch.** +* **[2024.05.04]** [TinyLLaVA Demo](http://8843843nmph5.vicp.fun/#/) is released! +* **[2024.02.21]** Our paper: [TinyLLaVA: A Framework of Small-scale Large Multimodal Models](https://arxiv.org/abs/2402.14289) is released! + +## 🔥 Takeaways +- Our best model, [TinyLLaVA-Phi-2-SigLIP-3.1B](https://huggingface.co/tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B), achieves better overall performance against existing 7B models such as LLaVA-1.5 and Qwen-VL. + +- TinyLLaVA Factory is an open-source modular codebase for small-scale large multimodal models (LMMs), implemented in PyTorch and HuggingFace, with a focus on simplicity of code implementations, extensibility of new features, and reproducibility of training results. + +- With TinyLLaVA Factory, you can customize your own large multimodal models with less coding effort and less coding mistakes. + +- TinyLLaVA Factory integrates a suite of cutting-edge models and methods. + + - LLM currently supports **OpenELM**, **TinyLlama**, **StableLM**, **Qwen**, **Gemma**, and **Phi**. + + - Vision tower currently supports **CLIP,** **SigLIP**, **Dino**, and **combination of CLIP and Dino**. + + - Connector currently supports **MLP**, **Qformer**, and **Resampler**. + + - Training Recipe currently supports **Frozen/Fully/Partially tuning** and **LoRA/QLoRA tuning**. + +- The password to access our demo is '1234'. + +## Contents + +- [🎉 News](#-news) +- [🔥 Takeaways](#-takeaways) +- [Contents](#contents) +- [Installation and Requirements](#installation-and-requirements) + - [Upgrade to the latest code base](#upgrade-to-the-latest-code-base) +- [Get Started](#get-started) + - [1. Data Preparation](#1-data-preparation) + - [2. Train](#2-train) + - [3. Evaluation](#3-evaluation) +- [Model Zoo](#model-zoo) + - [Trained Models](#trained-models) + - [Model Performance](#model-performance) + - [Legacy Models](#legacy-models) +- [Launch Demo Locally](#launch-demo-locally) + - [Gradio Web Demo](#gradio-web-demo) + - [CLI Inference](#cli-inference) + - [Quick Inference Scripts](#quick-inference-scripts) +- [Custom Finetune](#custom-finetune) +- [Customize Your Own Large Multimodel Models](#customize-your-own-large-multimodel-models) + - [LLM](#llm) + - [Vision Tower](#vision-tower) + - [Connector](#connector) +- [Acknowledgement](#acknowledgement) +- [Contact](#contact) +- [✏ Citation](#-citation) +- [❤️ Community efforts](#️-community-efforts) + + +## Installation and Requirements + +Please note that our environment requirements are different from LLaVA's environment requirements. We strongly recommend you create the environment from scratch as follows. + +1. Clone this repository and navigate to the folder +```bash +git clone https://github.com/TinyLLaVA/TinyLLaVA_Factory.git +cd TinyLLaVA_Factory +``` + +2. Create a conda environment, activate it and install Packages +```Shell +conda create -n tinyllava_factory python=3.10 -y +conda activate tinyllava_factory +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +3. Install additional packages +```Shell +pip install flash-attn --no-build-isolation +``` +#### Upgrade to the latest code base + +```Shell +git pull +pip install -e . +``` + +## Get Started + +#### 1. Data Preparation + +Please refer to the [Data Preparation](https://tinyllava-factory.readthedocs.io/en/latest/Prepare%20Datasets.html) section in our [Documenation](https://tinyllava-factory.readthedocs.io/en/latest/). + +#### 2. Train + +Here's an example for training a LMM using Phi-2. + +- Replace data paths with yours in `scripts/train/train_phi.sh` +- Replace `output_dir` with yours in `scripts/train/pretrain.sh` +- Replace `pretrained_model_path` and `output_dir` with yours in `scripts/train/finetune.sh` +- Adjust your GPU ids (localhost) and `per_device_train_batch_size` in `scripts/train/pretrain.sh` and `scripts/train/finetune.sh` + +```bash +bash scripts/train/train_phi.sh +``` + +Important hyperparameters used in pretraining and finetuning are provided below. + +| Training Stage | Global Batch Size | Learning rate | conv_version | +| -------------- | :---------------: | :-----------: | :----------: | +| Pretraining | 256 | 1e-3 | pretrain | +| Finetuning | 128 | 2e-5 | phi | + +**Tips:** + +Global Batch Size = num of GPUs * `per_device_train_batch_size` * `gradient_accumulation_steps`, we recommand you always keep global batch size and learning rate as above except for lora tuning your model. + +`conv_version` is a hyperparameter used for choosing different chat templates for different LLMs. In the pretraining stage, `conv_version` is the same for all LLMs, using `pretrain`. In the finetuning stage, we use + +`phi` for Phi-2, StableLM, Qwen-1.5 + +`llama` for TinyLlama, OpenELM + +`gemma` for Gemma + +#### 3. Evaluation + +Please refer to the [Evaluation](https://tinyllava-factory.readthedocs.io/en/latest/Evaluation.html) section in our [Documenation](https://tinyllava-factory.readthedocs.io/en/latest/Evaluation.html). + +## Model Zoo + +### Trained Models + +which are trained using TinyLLaVA Factory. + +- [TinyLLaVA-Phi-2-SigLIP-3.1B](https://huggingface.co/tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B) +- [TinyLLaVA-Gemma-SigLIP-2.4B](https://huggingface.co/tinyllava/TinyLLaVA-Gemma-SigLIP-2.4B) +- [TinyLLaVA-OpenELM-450M-SigLIP-0.89B](https://huggingface.co/jiajunlong/TinyLLaVA-0.89B) +- [TinyLLaVA-Qwen2-0.5B-SigLIP](https://huggingface.co/Zhang199/TinyLLaVA-Qwen2-0.5B-SigLIP) +- [TinyLLaVA-Qwen2.5-3B-SigLIP](https://huggingface.co/Zhang199/TinyLLaVA-Qwen2.5-3B-SigLIP) + +#### Model Performance + +| VT (HF Path) | LLM (HF Path) | Recipe | VQA-v2 | GQA | SQA-image | TextVQA | MM-Vet | POPE | MME | MMMU-val | +| --------------------------------- | ---------------------------------- | --------- | :----: | :--: | :-------: | :-----: | :----: | :--: | :----: | :------: | +| openai/clip-vit-large-patch14-336 | apple/OpenELM-450M-Instruct | base | 69.5 | 52.1 | 50.6 | 40.4 | 20.0 | 83.6 | 1052.9 | 23.9 | +| google/siglip-so400m-patch14-384 | apple/OpenELM-450M-Instruct | base | 71.7 | 53.9 | 54.1 | 44.0 | 20.0 | 85.4 | 1118.8 | 24.0 | +| google/siglip-so400m-patch14-384 | Qwen/Qwen2-0.5B | base | 72.3 | 55.8 | 60.1 | 45.2 | 19.5 | 86.6 | 1153.0 | 29.7 | +| google/siglip-so400m-patch14-384 | Qwen/Qwen2.5-0.5B | base | 75.3 | 59.5 | 60.3 | 48.3 | 23.9 | 86.1 | 1253.0 | 33.3 | +| google/siglip-so400m-patch14-384 | Qwen/Qwen2.5-3B | base | 79.4 | 62.5 | 74.1 | 58.3 | 34.8 | 87.4 | 1438.7 | 39.9 | +| openai/clip-vit-large-patch14-336 | TinyLlama/TinyLlama-1.1B-Chat-v1.0 | base | 73.7 | 58.0 | 59.9 | 46.3 | 23.2 | 85.5 | 1284.6 | 27.9 | +| google/siglip-so400m-patch14-384 | TinyLlama/TinyLlama-1.1B-Chat-v1.0 | base | 75.5 | 58.6 | 64.0 | 49.6 | 23.5 | 86.3 | 1256.5 | 28.3 | +| openai/clip-vit-large-patch14-336 | stabilityai/stablelm-2-zephyr-1_6b | base | 75.9 | 59.5 | 64.6 | 50.5 | 27.3 | 86.1 | 1368.1 | 31.8 | +| google/siglip-so400m-patch14-384 | stabilityai/stablelm-2-zephyr-1_6b | base | 78.2 | 60.7 | 66.7 | 56.0 | 29.4 | 86.3 | 1319.3 | 32.6 | +| google/siglip-so400m-patch14-384 | google/gemma-2b-it | base | 78.4 | 61.6 | 64.4 | 53.6 | 26.9 | 86.4 | 1339.0 | 31.7 | +| openai/clip-vit-large-patch14-336 | microsoft/phi-2 | base | 76.8 | 59.4 | 71.2 | 53.4 | 31.7 | 86.8 | 1448.6 | 36.3 | +| google/siglip-so400m-patch14-384 | microsoft/phi-2 | base | 79.2 | 61.6 | 71.9 | 57.4 | 35.0 | 87.2 | 1462.4 | 38.2 | +| google/siglip-so400m-patch14-384 | microsoft/phi-2 | base&lora | 77.6 | 59.7 | 71.6 | 53.8 | 33.3 | 87.9 | 1413.2 | 35.6 | +| google/siglip-so400m-patch14-384 | microsoft/phi-2 | share | 80.1 | 62.1 | 73.0 | 60.3 | 37.5 | 87.2 | 1466.4 | 38.4 | + +### Legacy Models + +which are trained using the old codebase TinyLLaVABench. + +- [TinyLLaVA-3.1B](https://huggingface.co/bczhou/TinyLLaVA-3.1B) +- [TinyLLaVA-2.0B](https://huggingface.co/bczhou/TinyLLaVA-2.0B) +- [TinyLLaVA-1.5B](https://huggingface.co/bczhou/TinyLLaVA-1.5B) +- [tiny-llava-hf](https://huggingface.co/bczhou/tiny-llava-v1-hf) + +If you have models trained by our old codebase TinyLLaVABench and you still want to use them, we provide an example of [TinyLLaVA-3.1B](https://huggingface.co/bczhou/TinyLLaVA-3.1B) for how to use legacy models. + +
+Example of using legacy models + + +```Python +from tinyllava.eval.run_tiny_llava import eval_model +from tinyllava.model.convert_legecy_weights_to_tinyllavafactory import * + +model = convert_legecy_weights_to_tinyllavafactory('bczhou/TinyLLaVA-3.1B') + +prompt = "What are the things I should be cautious about when I visit here?" +image_file = "https://llava-vl.github.io/static/images/view.jpg" + +args = type('Args', (), { + "model_path": None, + "model": model, + "query": prompt, + "conv_mode": "phi", # the same as conv_version in the training stage. Different LLMs have different conv_mode/conv_version, please replace it + "image_file": image_file, + "sep": ",", + "temperature": 0, + "top_p": None, + "num_beams": 1, + "max_new_tokens": 512 +})() + +eval_model(args) + +""" +Output: +When visiting this serene lakeside location with a wooden dock, there are a few things to be cautious about. First, ensure that the dock is stable and secure before stepping onto it, as it might be slippery or wet, especially if it's a wooden structure. Second, be mindful of the surrounding water, as it can be deep or have hidden obstacles, such as rocks or debris, that could pose a risk. Additionally, be aware of the weather conditions, as sudden changes in weather can make the area more dangerous. Lastly, respect the natural environment and wildlife, and avoid littering or disturbing the ecosystem. +""" +``` + +
+ + + +## Launch Demo Locally + +### Gradio Web Demo +Launch a local web demo by running: +```bash +python tinyllava/serve/app.py --model-path tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B +``` +### CLI Inference +We also support running inference with CLI. To use our model, run: +```bash +python -m tinyllava.serve.cli \ + --model-path tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B \ + --image-file "./tinyllava/serve/examples/extreme_ironing.jpg" +``` +### Quick Inference Scripts +If you want to launch the model trained by yourself or us locally, here's an example. +
+Run inference with the model trained by yourself + +```Python +from tinyllava.eval.run_tiny_llava import eval_model + +model_path = "/absolute/path/to/your/model/" +prompt = "What are the things I should be cautious about when I visit here?" +image_file = "https://llava-vl.github.io/static/images/view.jpg" +conv_mode = "phi" # or llama, gemma, etc + +args = type('Args', (), { + "model_path": model_path, + "model": None, + "query": prompt, + "conv_mode": conv_mode, + "image_file": image_file, + "sep": ",", + "temperature": 0, + "top_p": None, + "num_beams": 1, + "max_new_tokens": 512 +})() + +eval_model(args) + +""" +Output: +XXXXXXXXXXXXXXXXX +""" +``` +
+ +
+Run inference with the model trained by us using huggingface transformers + +```Python +from transformers import AutoTokenizer, AutoModelForCausalLM + +hf_path = 'tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B' +model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True) +model.cuda() +config = model.config +tokenizer = AutoTokenizer.from_pretrained(hf_path, use_fast=False, model_max_length = config.tokenizer_model_max_length,padding_side = config.tokenizer_padding_side) +prompt="What are these?" +image_url="http://images.cocodataset.org/val2017/000000039769.jpg" +output_text, genertaion_time = model.chat(prompt=prompt, image=image_url, tokenizer=tokenizer) + +print('model output:', output_text) +print('runing time:', genertaion_time) +``` +
+ +## Custom Finetune +If you want to finetune TinyLLaVA with your custom datasets, please refer to [here](https://github.com/TinyLLaVA/TinyLLaVA_Factory/blob/main/CUSTOM_FINETUNE.md). + +## Customize Your Own Large Multimodel Models + +### LLM + +If you want to add a new LLM by yourself, you need to create two files: one for chat template and the other for language model, under the folders `tinyllava/data/template/` and `tinyllava/model/llm/`. + +Here is an example of adding the Gemma model. + +Firstly, create `tinyllava/data/template/gemma_template.py`, which will be used for the finetuning stage. + +```python +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from packaging import version + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from . import register_template +from ...utils.constants import * + +from transformers import PreTrainedTokenizer +import torch +import tokenizers + + +system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + +@register_template('gemma') # Enable the TemplateFactory to obtain the added template by this string ('gemma'). +@dataclass +class GemmaTemplate(Template): + format_image_token: "Formatter" = StringFormatter(slot="\n{{content}}") + format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ") + format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "") # to be modified according to the tokenizer you choose + system: "Formatter" = EmptyFormatter(slot=system+" ") + separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '']) # to be modified according to the tokenizer you choose + + def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds): + # your code here + return labels, cur_len +``` +**Tips:** + +Please ensure that the `labels` (returned by the `_make_masks` function) follows this format: answers and the eos token id are not masked, and the other tokens are masked with `-100`. + +Secondly, create `tinyllava/model/llm/gemma.py`. + +```python +from transformers import GemmaForCausalLM, AutoTokenizer +# The LLM you want to add along with its corresponding tokenizer. + +from . import register_llm + +# Add GemmaForCausalLM along with its corresponding tokenizer and handle special tokens. +@register_llm('gemma') # Enable the LLMFactory to obtain the added LLM by this string ('gemma'). +def return_gemmaclass(): + def tokenizer_and_post_load(tokenizer): + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + return (GemmaForCausalLM, (AutoTokenizer, tokenizer_and_post_load)) +``` + +Finally, create `scripts/train/train_gemma.sh` with the corresponding `LLM_VERSION` and `CONV_VERSION`. + +### Vision Tower + +If you want to add a new vision tower, you need to implement a new vision tower class that should be inherited from the base class `VisionTower`. Here's an example of the MoF vision tower. + +First, create `tinyllava/model/vision_tower/mof.py` + +```python +@register_vision_tower('mof') +class MoFVisionTower(VisionTower): + def __init__(self, cfg): + super().__init__(cfg) + + self._vision_tower = MoF(cfg) + self._image_processor = # your image processor + + def _load_model(self, vision_tower_name, **kwargs): + # your code here, make sure your model can be correctly loaded from pretrained parameters either by huggingface or pytorch loading + + def forward(self, x, **kwargs): + # your code here +``` + +Then, modify your training scripts with the corresponding `VT_VERSION`. + +### Connector + +If you want to add a new connector, you need to implement a new connector class that should be inherited from the base class `Connector`. Here's an example of the Linear connector. + +First, create `tinyllava/model/connector/linear.py` + + +```python +import torch.nn as nn + +from . import register_connector +from .base import Connector + +@register_connector('linear') #Enable the ConnectorMFactory to obtain the added connector by this string ('linear'). +class LinearConnector(Connector): + def __init__(self, config): + super().__init__() + self._connector = nn.Linear(config.vision_hidden_size, config.hidden_size) # define your connector model +``` + +Then, modify your training scripts with the corresponding `CN_VERSION`. + +## Acknowledgement +We give special thanks to Lei Zhao, Luche Wang, Kaijun Luo, and Junchen Wang for building the [Demo](http://8843843nmph5.vicp.fun/#/). + +## Contact +If you have any questions, feel free to either initiate an *Issue* or contact us by WeChat (WeChatID: *TinyLLaVA*). + +## ✏ Citation + +If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:. + +```BibTeX +@misc{zhou2024tinyllava, + title={TinyLLaVA: A Framework of Small-scale Large Multimodal Models}, + author={Baichuan Zhou and Ying Hu and Xi Weng and Junlong Jia and Jie Luo and Xien Liu and Ji Wu and Lei Huang}, + year={2024}, + eprint={2402.14289}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` +```BibTeX +@article{jia2024tinyllava, + title={TinyLLaVA Factory: A Modularized Codebase for Small-scale Large Multimodal Models}, + author={Jia, Junlong and Hu, Ying and Weng, Xi and Shi, Yiming and Li, Miao and Zhang, Xingjian and Zhou, Baichuan and Liu, Ziyu and Luo, Jie and Huang, Lei and Wu, Ji}, + journal={arXiv preprint arXiv:2405.11788}, + year={2024} +} +``` + + +## ❤️ Community efforts +* Our codebase is built upon the [LLaVA](https://github.com/haotian-liu/LLaVA) project. Great work! +* Our project uses data from the [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V) project. Great work! diff --git a/TinyLLaVA_Factory/assets/architecture.jpg b/TinyLLaVA_Factory/assets/architecture.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eb241ea593289f0a2490524175286e37445eb253 Binary files /dev/null and b/TinyLLaVA_Factory/assets/architecture.jpg differ diff --git a/TinyLLaVA_Factory/pyproject.toml b/TinyLLaVA_Factory/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..8099e195ffc8d767116bbdb7858739fe781a7604 --- /dev/null +++ b/TinyLLaVA_Factory/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "tinyllava" +version = "1.0.0" +description = "A Framework of Small-scale Large Multimodal Models." +readme = "README.md" +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", +] +dependencies = [ + "torch==2.0.1", "torchvision==0.15.2", "tiktoken", "openpyxl", "tensorboardX", + "transformers==4.40.1", "tokenizers==0.19.0", "sentencepiece==0.1.99", "shortuuid", + "accelerate==0.27.2", "bitsandbytes==0.41.0", "peft==0.10.0", + "pydantic<2,>=1", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2", + "gradio==3.35.2", "gradio_client==0.2.9", + "requests", "httpx==0.24.0", "uvicorn", "fastapi", + "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", + "deepspeed==0.14.0", "ninja", "wandb", +] + +[project.optional-dependencies] +train = ["deepspeed==0.14.0", "ninja", "wandb"] + +[project.urls] +"Homepage" = "https://github.com/DLCV-BUAA/TinyLLaVABench" +"Bug Tracker" = "https://github.com/DLCV-BUAA/TinyLLaVABench/issues" + +[tool.setuptools.packages.find] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] + +[tool.wheel] +exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] + diff --git a/TinyLLaVA_Factory/scripts/convert_answer_to_mmmu.py b/TinyLLaVA_Factory/scripts/convert_answer_to_mmmu.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b07e76958a0691db8dabf611db1009ce610771 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/convert_answer_to_mmmu.py @@ -0,0 +1,31 @@ +import argparse +import json +import os + + +def eval_model(args): + answers = [json.loads(q) for q in open(os.path.expanduser(args.answers_file), "r")] + answers_dict = {} + for answer in answers: + answers_dict[answer["question_id"]] = answer["text"] + # print(answer) + + with open(args.answers_output, "w") as f: + json.dump(answers_dict, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--answers-file", + type=str, + required=True + ) + parser.add_argument( + "--answers-output", + type=str, + required=True + ) + args = parser.parse_args() + + eval_model(args) diff --git a/TinyLLaVA_Factory/scripts/convert_gqa_for_eval.py b/TinyLLaVA_Factory/scripts/convert_gqa_for_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4d46c8b876df618faac548e9b369109d541f4f23 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/convert_gqa_for_eval.py @@ -0,0 +1,18 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--src", type=str) +parser.add_argument("--dst", type=str) +args = parser.parse_args() + +all_answers = [] +for line_idx, line in enumerate(open(args.src)): + res = json.loads(line) + question_id = res['question_id'] + text = res['text'].rstrip('.').lower() + all_answers.append({"questionId": question_id, "prediction": text}) + +with open(args.dst, 'w') as f: + json.dump(all_answers, f) diff --git a/TinyLLaVA_Factory/scripts/convert_mmvet_for_eval.py b/TinyLLaVA_Factory/scripts/convert_mmvet_for_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..97f5cfb7fb7691ef3921e3e6afc6d82ec54d4c6c --- /dev/null +++ b/TinyLLaVA_Factory/scripts/convert_mmvet_for_eval.py @@ -0,0 +1,18 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--src", type=str) +parser.add_argument("--dst", type=str) +args = parser.parse_args() + +cur_result = {} + +for line in open(args.src): + data = json.loads(line) + qid = data['question_id'] + cur_result[f'v1_{qid}'] = data['text'] + +with open(args.dst, 'w') as f: + json.dump(cur_result, f, indent=2) diff --git a/TinyLLaVA_Factory/scripts/convert_vqav2_for_submission.py b/TinyLLaVA_Factory/scripts/convert_vqav2_for_submission.py new file mode 100644 index 0000000000000000000000000000000000000000..cedd291d0ba0c9c8ca4f10b1862184b076221466 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/convert_vqav2_for_submission.py @@ -0,0 +1,56 @@ +import os +import argparse +import json + +from tinyllava.eval.m4c_evaluator import EvalAIAnswerProcessor + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2") + parser.add_argument('--ckpt', type=str, required=True) + parser.add_argument('--split', type=str, required=True) + return parser.parse_args() + + +if __name__ == '__main__': + + args = parse_args() + + src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl') + test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl') + dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json') + os.makedirs(os.path.dirname(dst), exist_ok=True) + + results = [] + error_line = 0 + for line_idx, line in enumerate(open(src)): + try: + results.append(json.loads(line)) + except: + error_line += 1 + + results = {x['question_id']: x['text'] for x in results} + test_split = [json.loads(line) for line in open(test_split)] + split_ids = set([x['question_id'] for x in test_split]) + + print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') + + all_answers = [] + + answer_processor = EvalAIAnswerProcessor() + + for x in test_split: + if x['question_id'] not in results: + all_answers.append({ + 'question_id': x['question_id'], + 'answer': '' + }) + else: + all_answers.append({ + 'question_id': x['question_id'], + 'answer': answer_processor(results[x['question_id']]) + }) + + with open(dst, 'w') as f: + json.dump(all_answers, open(dst, 'w')) diff --git a/TinyLLaVA_Factory/scripts/eval/gqa.sh b/TinyLLaVA_Factory/scripts/eval/gqa.sh new file mode 100644 index 0000000000000000000000000000000000000000..24e6ae52b6f8e68c989c185915d800894e44a269 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/gqa.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +gpu_list="${CUDA_VISIBLE_DEVICES:-0}" +IFS=',' read -ra GPULIST <<< "$gpu_list" + +CHUNKS=${#GPULIST[@]} + +SPLIT="llava_gqa_testdev_balanced" +GQADIR="/home/ai/data/llava/dataset/eval/gqa" + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune/" +MODEL_NAME="tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune" +EVAL_DIR="/home/ai/data/llava/dataset/eval" + +for IDX in $(seq 0 $((CHUNKS-1))); do + CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m tinyllava.eval.model_vqa_loader \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/gqa/$SPLIT.jsonl \ + --image-folder $EVAL_DIR/gqa/images \ + --answers-file $EVAL_DIR/gqa/answers/$SPLIT/$MODEL_NAME/${CHUNKS}_${IDX}.jsonl \ + --num-chunks $CHUNKS \ + --chunk-idx $IDX \ + --temperature 0 \ + --conv-mode phi & +done + +wait + +output_file=$EVAL_DIR/gqa/answers/$SPLIT/$MODEL_NAME/merge.jsonl + +# Clear out the output file if it exists. +> "$output_file" + +# Loop through the indices and concatenate each file. +for IDX in $(seq 0 $((CHUNKS-1))); do + cat $EVAL_DIR/gqa/answers/$SPLIT/$MODEL_NAME/${CHUNKS}_${IDX}.jsonl >> "$output_file" +done + +python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json + +cd $GQADIR +python eval/eval.py --tier testdev_balanced + diff --git a/TinyLLaVA_Factory/scripts/eval/mme.sh b/TinyLLaVA_Factory/scripts/eval/mme.sh new file mode 100644 index 0000000000000000000000000000000000000000..cb51edafe586102c160401cb6265f1d5133c0dd4 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/mme.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +MODEL_PATH="/home/jiajunlong/LLaVA/ying/checkpoints/tiny-llava-TinyLlama-1.1B-Chat-v1.0-clip-vit-large-patch14-336-tinyllama-llava-finetune" +MODEL_NAME="tiny-llava-TinyLlama-1.1B-Chat-v1.0-clip-vit-large-patch14-336-tinyllama-llava-finetune" +EVAL_DIR="/home/jiajunlong/llava_data/eval" + +python -m tinyllava.eval.model_vqa_loader \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/MME/llava_mme.jsonl \ + --image-folder $EVAL_DIR/MME/MME_Benchmark_release_version \ + --answers-file $EVAL_DIR/MME/answers/$MODEL_NAME.jsonl \ + --temperature 0 \ + --conv-mode llama + +cd $EVAL_DIR/MME + +python convert_answer_to_mme.py --experiment $MODEL_NAME + +cd eval_tool + +python calculation.py --results_dir answers/$MODEL_NAME + diff --git a/TinyLLaVA_Factory/scripts/eval/mmmu.sh b/TinyLLaVA_Factory/scripts/eval/mmmu.sh new file mode 100644 index 0000000000000000000000000000000000000000..2cf10394483c80f105b5a73a21cdaf2f6c3e40ad --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/mmmu.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune-final" +MODEL_NAME="tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune-final" +EVAL_DIR="/home/ai/data/llava/dataset/eval" + +python -m tinyllava.eval.model_vqa_mmmu \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/MMMU/anns_for_eval.json \ + --image-folder $EVAL_DIR/MMMU/all_images \ + --answers-file $EVAL_DIR/MMMU/answers/$MODEL_NAME.jsonl \ + --temperature 0 \ + --conv-mode phi + +python scripts/convert_answer_to_mmmu.py \ + --answers-file $EVAL_DIR/MMMU/answers/$MODEL_NAME.jsonl \ + --answers-output $EVAL_DIR/MMMU/answers/"$MODEL_NAME"_output.json + +cd $EVAL_DIR/MMMU/eval + +python main_eval_only.py --output_path $EVAL_DIR/MMMU/answers/"$MODEL_NAME"_output.json diff --git a/TinyLLaVA_Factory/scripts/eval/mmvet.sh b/TinyLLaVA_Factory/scripts/eval/mmvet.sh new file mode 100644 index 0000000000000000000000000000000000000000..9c55142c4625fb96015e1c146d90cf4a95fb4588 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/mmvet.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-phi-2-clip-vit-large-patch14-336-baseline-finetune/" +MODEL_NAME="tiny-llava-phi-2-clip-vit-large-patch14-336-baseline-finetune2" +EVAL_DIR="/home/ai/data/llava/dataset/eval" + +python -m tinyllava.eval.model_vqa \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/mm-vet/llava-mm-vet.jsonl \ + --image-folder $EVAL_DIR/mm-vet/images \ + --answers-file $EVAL_DIR/mm-vet/answers/$MODEL_NAME.jsonl \ + --temperature 0 \ + --conv-mode phi + +mkdir -p $EVAL_DIR/mm-vet/results + +python scripts/convert_mmvet_for_eval.py \ + --src $EVAL_DIR/mm-vet/answers/$MODEL_NAME.jsonl \ + --dst $EVAL_DIR/mm-vet/results/$MODEL_NAME.json diff --git a/TinyLLaVA_Factory/scripts/eval/pope.sh b/TinyLLaVA_Factory/scripts/eval/pope.sh new file mode 100644 index 0000000000000000000000000000000000000000..96a0e4fefc2cee7a9b96890a0dfca09de9b584ec --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/pope.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-stablelm-2-zephyr-1_6b-siglip-so400m-patch14-384-base-finetune/" +MODEL_NAME="tiny-llava-stablelm-2-zephyr-1_6b-siglip-so400m-patch14-384-base-finetune" +EVAL_DIR="/home/ai/data/llava/dataset/eval" + +python -m tinyllava.eval.model_vqa_pope \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/pope/llava_pope_test.jsonl \ + --image-folder $EVAL_DIR/pope/val2014 \ + --answers-file $EVAL_DIR/pope/answers/$MODEL_NAME.jsonl \ + --temperature 0 \ + --conv-mode phi + +python tinyllava/eval/eval_pope.py \ + --annotation-dir $EVAL_DIR/pope/coco \ + --question-file $EVAL_DIR/pope/llava_pope_test.jsonl \ + --result-file $EVAL_DIR/pope/answers/$MODEL_NAME.jsonl diff --git a/TinyLLaVA_Factory/scripts/eval/sqa.sh b/TinyLLaVA_Factory/scripts/eval/sqa.sh new file mode 100644 index 0000000000000000000000000000000000000000..cf21005dee7972debd35cb418a34af586f65062f --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/sqa.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune" +MODEL_NAME="tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune" +EVAL_DIR="/home/ai/data/llava/dataset/eval" +python -m tinyllava.eval.model_vqa_science \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/scienceqa/llava_test_CQM-A.json \ + --image-folder $EVAL_DIR/scienceqa/images/test \ + --answers-file $EVAL_DIR/scienceqa/answers/$MODEL_NAME.jsonl \ + --single-pred-prompt \ + --temperature 0 \ + --conv-mode phi + +python tinyllava/eval/eval_science_qa.py \ + --base-dir $EVAL_DIR/scienceqa \ + --result-file $EVAL_DIR/scienceqa/answers/$MODEL_NAME.jsonl \ + --output-file $EVAL_DIR/scienceqa/answers/"$MODEL_NAME"_output.jsonl \ + --output-result $EVAL_DIR/scienceqa/answers/"$MODEL_NAME"_result.json + diff --git a/TinyLLaVA_Factory/scripts/eval/textvqa.sh b/TinyLLaVA_Factory/scripts/eval/textvqa.sh new file mode 100644 index 0000000000000000000000000000000000000000..cdff2232f111ea2b5c97e17024b890f91ca8aa26 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/textvqa.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune/" +MODEL_NAME="tiny-llava-phi-2-siglip-so400m-patch14-384-base-finetune" +EVAL_DIR="/home/ai/data/llava/dataset/eval" + +python -m tinyllava.eval.model_vqa_loader \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/textvqa/llava_textvqa_val_v051_ocr.jsonl \ + --image-folder $EVAL_DIR/textvqa/train_images \ + --answers-file $EVAL_DIR/textvqa/answers/$MODEL_NAME.jsonl \ + --temperature 0 \ + --conv-mode phi + +python -m tinyllava.eval.eval_textvqa \ + --annotation-file $EVAL_DIR/textvqa/TextVQA_0.5.1_val.json \ + --result-file $EVAL_DIR/textvqa/answers/$MODEL_NAME.jsonl + diff --git a/TinyLLaVA_Factory/scripts/eval/vqav2.sh b/TinyLLaVA_Factory/scripts/eval/vqav2.sh new file mode 100644 index 0000000000000000000000000000000000000000..a58b35b17967b4371414dad0ac91a46c8da2ba65 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/eval/vqav2.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +gpu_list="${CUDA_VISIBLE_DEVICES:-0}" +IFS=',' read -ra GPULIST <<< "$gpu_list" + +CHUNKS=${#GPULIST[@]} + +SPLIT="llava_vqav2_mscoco_test-dev2015" + +MODEL_PATH="/mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-phi-2-clip-vit-large-patch14-336-baseline-finetune/" +MODEL_NAME="tiny-llava-phi-2-clip-vit-large-patch14-336-baseline-finetune2" +EVAL_DIR="/home/ai/data/llava/dataset/eval" + +for IDX in $(seq 0 $((CHUNKS-1))); do + CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m tinyllava.eval.model_vqa_loader \ + --model-path $MODEL_PATH \ + --question-file $EVAL_DIR/vqav2/$SPLIT.jsonl \ + --image-folder $EVAL_DIR/vqav2/test2015 \ + --answers-file $EVAL_DIR/vqav2/answers/$SPLIT/$MODEL_NAME/${CHUNKS}_${IDX}.jsonl \ + --num-chunks $CHUNKS \ + --chunk-idx $IDX \ + --temperature 0 \ + --conv-mode phi & +done + +wait + +output_file=$EVAL_DIR/vqav2/answers/$SPLIT/$MODEL_NAME/merge.jsonl + +# Clear out the output file if it exists. +> "$output_file" + +# Loop through the indices and concatenate each file. +for IDX in $(seq 0 $((CHUNKS-1))); do + cat $EVAL_DIR/vqav2/answers/$SPLIT/$MODEL_NAME/${CHUNKS}_${IDX}.jsonl >> "$output_file" +done + +python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $MODEL_NAME --dir $EVAL_DIR/vqav2 diff --git a/TinyLLaVA_Factory/scripts/train/custom_finetune.sh b/TinyLLaVA_Factory/scripts/train/custom_finetune.sh new file mode 100644 index 0000000000000000000000000000000000000000..63b036da2dd3180689a3dc014d86a6c99fb9b2cd --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/custom_finetune.sh @@ -0,0 +1,45 @@ +DATA_PATH="/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json" +IMAGE_PATH="/home/ai/data/llava/dataset" +MODEL_MAX_LENGTH=3072 +OUTPUT_DIR="/mnt/data/sata/yinghu/checkpoints/llava_factory/custom-finetune-TinyLLaVA-Phi-2-SigLIP-3.1B-lora" + +deepspeed --include localhost:0,1,2,3 --master_port 29501 tinyllava/train/custom_finetune.py \ + --deepspeed ./scripts/zero2.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version phi \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --fp16 True \ + --training_recipe lora \ + --tune_type_llm lora \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --lora_r 128 \ + --lora_alpha 256 \ + --group_by_modality_length False \ + --pretrained_model_path "tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B" \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 1e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name custom-finetune-TinyLLaVA-Phi-2-SigLIP-3.1B-lora diff --git a/TinyLLaVA_Factory/scripts/train/finetune.sh b/TinyLLaVA_Factory/scripts/train/finetune.sh new file mode 100644 index 0000000000000000000000000000000000000000..297b798f624807ae58af0ee2b04184af0620e018 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/finetune.sh @@ -0,0 +1,64 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm full \ + --tune_type_vision_tower frozen\ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --group_by_modality_length True \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune diff --git a/TinyLLaVA_Factory/scripts/train/gemma/finetune_gemma.sh b/TinyLLaVA_Factory/scripts/train/gemma/finetune_gemma.sh new file mode 100644 index 0000000000000000000000000000000000000000..95defd9a5eb24f9ff1234e3a4e6827c0982560ed --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/gemma/finetune_gemma.sh @@ -0,0 +1,64 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm full \ + --tune_type_vision_tower frozen\ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --group_by_modality_length True \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune diff --git a/TinyLLaVA_Factory/scripts/train/gemma/pretrain_gemma.sh b/TinyLLaVA_Factory/scripts/train/gemma/pretrain_gemma.sh new file mode 100644 index 0000000000000000000000000000000000000000..25c6c8f603816f397d194a5c0f25313e9a5172d6 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/gemma/pretrain_gemma.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +if [ $# -ne 9 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +VERSION="$7" +TRAIN_RECIPE="$8" +MODEL_MAX_LENGTH="$9" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH\ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version pretrain \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm frozen \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain diff --git a/TinyLLaVA_Factory/scripts/train/gemma/train_gemma.sh b/TinyLLaVA_Factory/scripts/train/gemma/train_gemma.sh new file mode 100644 index 0000000000000000000000000000000000000000..f55647ab95a8640df35446068556f3163fbeedd9 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/gemma/train_gemma.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset + +LLM_VERSION=google/gemma-2b-it +VT_VERSION=google/siglip-so400m-patch14-384 +VT_VERSION2="" +CN_VERSION=mlp2x_gelu +CONV_VERSION=gemma +VERSION=base +TRAIN_RECIPE=common +MODEL_MAX_LENGTH=2048 + + +bash scripts/train/gemma/pretrain_gemma.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/gemma/finetune_gemma.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/lora/finetune_lora.sh b/TinyLLaVA_Factory/scripts/train/lora/finetune_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..633a0ae7a19f660c9793ad1d94a1fd1015cf2928 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/lora/finetune_lora.sh @@ -0,0 +1,66 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:2,3,4,5 --master_port 29502 tinyllava/train/train.py \ + --deepspeed ./scripts/zero2.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm lora \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --lora_r 128 \ + --lora_alpha 256 \ + --group_by_modality_length False \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune diff --git a/TinyLLaVA_Factory/scripts/train/lora/finetune_qlora.sh b/TinyLLaVA_Factory/scripts/train/lora/finetune_qlora.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f344e98e388135e3b07010e01301210dc9c16a4 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/lora/finetune_qlora.sh @@ -0,0 +1,66 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29502 tinyllava/train/train.py \ + --deepspeed ./scripts/zero2.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm qlora \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --lora_r 128 \ + --lora_alpha 256 \ + --group_by_modality_length False \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-4 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune diff --git a/TinyLLaVA_Factory/scripts/train/lora/train_phi_lora.sh b/TinyLLaVA_Factory/scripts/train/lora/train_phi_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..488f29160bee1ecdaf8da65bcc682b0887e02978 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/lora/train_phi_lora.sh @@ -0,0 +1,18 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset + +LLM_VERSION=microsoft/phi-2 +VT_VERSION=google/siglip-so400m-patch14-384 +VT_VERSION2="" +CN_VERSION=mlp2x_gelu +CONV_VERSION=phi +VERSION=base-lora-zero2-r128 +PRETRAIN_TRAIN_RECIPE=common +FINETUNE_TRAIN_RECIPE=lora +MODEL_MAX_LENGTH=3072 + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$PRETRAIN_TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/lora/finetune_lora.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$FINETUNE_TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/lora/train_phi_qlora.sh b/TinyLLaVA_Factory/scripts/train/lora/train_phi_qlora.sh new file mode 100644 index 0000000000000000000000000000000000000000..eb4d4f4203c6efb7f2b5857cb8e618436c7f4f33 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/lora/train_phi_qlora.sh @@ -0,0 +1,18 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset + +LLM_VERSION=microsoft/phi-2 +VT_VERSION=google/siglip-so400m-patch14-384 +VT_VERSION2="" +CN_VERSION=mlp2x_gelu +CONV_VERSION=phi +VERSION=base-qlora +PRETRAIN_TRAIN_RECIPE=common +FINETUNE_TRAIN_RECIPE=qlora_int8 +MODEL_MAX_LENGTH=3072 + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$PRETRAIN_TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/lora/finetune_qlora.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$FINETUNE_TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/openelm/finetune_openelm.sh b/TinyLLaVA_Factory/scripts/train/openelm/finetune_openelm.sh new file mode 100644 index 0000000000000000000000000000000000000000..c71314bc96669326e521b5d4a9299a67b57083fb --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/openelm/finetune_openelm.sh @@ -0,0 +1,64 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:2,3,4,5 --master_port 29503 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --fp16 True \ + --tokenizer_name_or_path meta-llama/Llama-2-7b-hf \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm full \ + --tune_type_vision_tower frozen\ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --group_by_modality_length True \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune diff --git a/TinyLLaVA_Factory/scripts/train/openelm/pretrain_openelm.sh b/TinyLLaVA_Factory/scripts/train/openelm/pretrain_openelm.sh new file mode 100644 index 0000000000000000000000000000000000000000..216da7d6ff5c89266696a04337ff3ea041936c3e --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/openelm/pretrain_openelm.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +if [ $# -ne 9 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +VERSION="$7" +TRAIN_RECIPE="$8" +MODEL_MAX_LENGTH="$9" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:0,1 --master_port 29503 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH\ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version pretrain \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --fp16 True \ + --tokenizer_name_or_path meta-llama/Llama-2-7b-hf \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm frozen \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 64 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain diff --git a/TinyLLaVA_Factory/scripts/train/openelm/train_openelm.sh b/TinyLLaVA_Factory/scripts/train/openelm/train_openelm.sh new file mode 100644 index 0000000000000000000000000000000000000000..b81e22d0701c5dd948d04b786c7b688b936e0a2d --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/openelm/train_openelm.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset + +LLM_VERSION=apple/OpenELM-270M-Instruct +VT_VERSION=google/siglip-so400m-patch14-384 +VT_VERSION2="" +CN_VERSION=mlp2x_gelu +CONV_VERSION=llama +VERSION=elm_base +TRAIN_RECIPE=common +MODEL_MAX_LENGTH=2048 + + +bash scripts/train/openelm/pretrain_openelm.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/openelm/finetune_openelm.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/pretrain.sh b/TinyLLaVA_Factory/scripts/train/pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3cf9df7b5df075bdabce53844c1806cb58b0bfc --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/pretrain.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +if [ $# -ne 9 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +VERSION="$7" +TRAIN_RECIPE="$8" +MODEL_MAX_LENGTH="$9" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH\ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version pretrain \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm frozen \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain diff --git a/TinyLLaVA_Factory/scripts/train/qwen2/finetune_qwen2.sh b/TinyLLaVA_Factory/scripts/train/qwen2/finetune_qwen2.sh new file mode 100644 index 0000000000000000000000000000000000000000..8e635f2978a0f4ba79307bfdd0c250fc72ccb49c --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/qwen2/finetune_qwen2.sh @@ -0,0 +1,64 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --bf16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm full \ + --tune_type_vision_tower frozen\ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --group_by_modality_length True \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-finetune diff --git a/TinyLLaVA_Factory/scripts/train/qwen2/pretrain_qwen2.sh b/TinyLLaVA_Factory/scripts/train/qwen2/pretrain_qwen2.sh new file mode 100644 index 0000000000000000000000000000000000000000..d2f27e6f77d1f252018043be465382c4712b5aee --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/qwen2/pretrain_qwen2.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +if [ $# -ne 9 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +VERSION="$7" +TRAIN_RECIPE="$8" +MODEL_MAX_LENGTH="$9" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH\ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version pretrain \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --bf16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm frozen \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain diff --git a/TinyLLaVA_Factory/scripts/train/qwen2/readme.md b/TinyLLaVA_Factory/scripts/train/qwen2/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..499f31e8c3dd444671cddd3d419e40b3691b34f1 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/qwen2/readme.md @@ -0,0 +1 @@ +These codes work for Qwen/Qwen2-0.5B and Qwen/Qwen2-0.5B-Instruct. However there is bug causing from deepspeed when you use other Qwen2 models like qwen2-1.5B. If you want to use tinyllava to train other qwen2 models, please feel free to contact our team. diff --git a/TinyLLaVA_Factory/scripts/train/qwen2/train_qwen2_base.sh b/TinyLLaVA_Factory/scripts/train/qwen2/train_qwen2_base.sh new file mode 100644 index 0000000000000000000000000000000000000000..b3c926c780c9a8fdc070f35f5993fccc2c5f70df --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/qwen2/train_qwen2_base.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json #pretrain annotation file path +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json #finetune annotation file path +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images #pretrain image dir +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset #finetune image dir + +LLM_VERSION=Qwen/Qwen2-0.5B # llm path in huggingface +VT_VERSION=google/siglip-so400m-patch14-384 #vision tower path in huggingface +VT_VERSION2="" #if you are not using mof vision tower, keep it empty +CN_VERSION=mlp2x_gelu #connector type, other options are: qformer, resampler, etc +CONV_VERSION=qwen2_base #chat template, other options are: phi, llama, gemmma, etc +VERSION=qwen2-0_5b_base #experiment name for recording different runnings +TRAIN_RECIPE=common #training recipes, other options are: lora, qlora +MODEL_MAX_LENGTH=2048 #max model length for llm + + +bash scripts/train/qwen2/pretrain_qwen2.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/qwen2/finetune_qwen2.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/qwen2/train_qwen2_instruct.sh b/TinyLLaVA_Factory/scripts/train/qwen2/train_qwen2_instruct.sh new file mode 100644 index 0000000000000000000000000000000000000000..768eea754e7360c94549d17f68cb139f5ce1c450 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/qwen2/train_qwen2_instruct.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json #pretrain annotation file path +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json #finetune annotation file path +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images #pretrain image dir +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset #finetune image dir + +LLM_VERSION=Qwen/Qwen2-0.5B-Instruct # llm path in huggingface +VT_VERSION=google/siglip-so400m-patch14-384 #vision tower path in huggingface +VT_VERSION2="" #if you are not using mof vision tower, keep it empty +CN_VERSION=mlp2x_gelu #connector type, other options are: qformer, resampler, etc +CONV_VERSION=qwen2_instruct #chat template, other options are: phi, llama, gemmma, etc +VERSION=qwen2-0_5b_instruct #experiment name for recording different runnings +TRAIN_RECIPE=common #training recipes, other options are: lora, qlora +MODEL_MAX_LENGTH=2048 #max model length for llm + + +bash scripts/train/qwen2/pretrain_qwen2.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/qwen2/finetune_qwen2.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/share/finetune_share.sh b/TinyLLaVA_Factory/scripts/train/share/finetune_share.sh new file mode 100644 index 0000000000000000000000000000000000000000..92f050ddba0679d157732c610e1f13768014ec67 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/share/finetune_share.sh @@ -0,0 +1,64 @@ +#!/bin/bash +if [ $# -ne 10 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +CONV_VERSION="$7" +VERSION="$8" +TRAIN_RECIPE="$9" +MODEL_MAX_LENGTH="${10}" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH \ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version $CONV_VERSION \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm full \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --group_by_modality_length True \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-share-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-share-finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-share-finetune diff --git a/TinyLLaVA_Factory/scripts/train/share/pretrain_share.sh b/TinyLLaVA_Factory/scripts/train/share/pretrain_share.sh new file mode 100644 index 0000000000000000000000000000000000000000..17d89164059f0ea7adf86dbfe1e2ff788bcf3354 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/share/pretrain_share.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +if [ $# -ne 9 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the arguments to variables +DATA_PATH="$1" +IMAGE_PATH="$2" +LLM_VERSION="$3" +VT_VERSION="$4" +VT_VERSION2="$5" +CN_VERSION="$6" +VERSION="$7" +TRAIN_RECIPE="$8" +MODEL_MAX_LENGTH="$9" + +VT_VARIANT="${VT_VERSION#*/}" +LLM_VARIANT="${LLM_VERSION#*/}" + +deepspeed --include localhost:4,5,6,7 --master_port 29501 tinyllava/train/train.py \ + --deepspeed ./scripts/zero3.json \ + --data_path $DATA_PATH\ + --image_folder $IMAGE_PATH \ + --is_multimodal True \ + --conv_version pretrain \ + --model_name_or_path $LLM_VERSION \ + --vision_tower $VT_VERSION \ + --vision_tower2 "$VT_VERSION2" \ + --connector_type $CN_VERSION \ + --mm_vision_select_layer -2 \ + --image_aspect_ratio square \ + --attn_implementation flash_attention_2 \ + --fp16 True \ + --training_recipe $TRAIN_RECIPE \ + --tune_type_llm full \ + --tune_type_vision_tower frozen \ + --tune_vision_tower_from_layer 0 \ + --tune_type_connector full \ + --pretrained_model_path /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-pretrain \ + --output_dir /mnt/data/sata/yinghu/checkpoints/llava_factory/tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-share-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 False \ + --model_max_length $MODEL_MAX_LENGTH \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to tensorboard \ + --tokenizer_use_fast False \ + --run_name tiny-llava-${LLM_VARIANT}-${VT_VARIANT}-${VERSION}-share-pretrain diff --git a/TinyLLaVA_Factory/scripts/train/share/train_phi_share.sh b/TinyLLaVA_Factory/scripts/train/share/train_phi_share.sh new file mode 100644 index 0000000000000000000000000000000000000000..f208ff876418589274b6bdcf912b6edcba1c48a3 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/share/train_phi_share.sh @@ -0,0 +1,21 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json +SHARE_PRETRAIN_DATA_PATH=/mnt/data/sata/ssd/dataset/text_files/really_cleaned_share-captioner_coco_lcs_sam_1246k_1107.json +SHARE_FINETUNE_DATA_PATH=/mnt/data/sata/ssd/dataset/text_files/cleaned_sharegpt4v_mix665k_cap23k_coco-ap9k_lcs3k_sam9k_div2k.json +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images +SHARE_PRETRAIN_IMAGE_PATH=/home/ai/data/llava/dataset +SHARE_FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset + +LLM_VERSION=microsoft/phi-2 +VT_VERSION=google/siglip-so400m-patch14-384 +VT_VERSION2="" +CN_VERSION=mlp2x_gelu +CONV_VERSION=phi +VERSION=share +TRAIN_RECIPE=common +MODEL_MAX_LENGTH=3072 + + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/share/pretrain_share.sh "$SHARE_PRETRAIN_DATA_PATH" "$SHARE_PRETRAIN_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/share/finetune_share.sh "$SHARE_FINETUNE_DATA_PATH" "$SHARE_FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/train_mof.sh b/TinyLLaVA_Factory/scripts/train/train_mof.sh new file mode 100644 index 0000000000000000000000000000000000000000..c169539632ee06196b7c760c923958cf5786934d --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/train_mof.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset + +LLM_VERSION=TinyLlama/TinyLlama-1.1B-Chat-v1.0 +VT_VERSION=mof:openai/clip-vit-large-patch14 +VT_VERSION2=mof:facebook/dinov2-large +CN_VERSION=mof_mlp +CONV_VERSION=llama +VERSION=llama-mof-base +TRAIN_RECIPE=common +MODEL_MAX_LENGTH=2048 + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/finetune.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/train_phi.sh b/TinyLLaVA_Factory/scripts/train/train_phi.sh new file mode 100644 index 0000000000000000000000000000000000000000..4dbb9867a98a18838cf048390162dd1810456d2a --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/train_phi.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json #pretrain annotation file path +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json #finetune annotation file path +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images #pretrain image dir +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset #finetune image dir + +LLM_VERSION=microsoft/phi-2 # llm path in huggingface +VT_VERSION=google/siglip-so400m-patch14-384 #vision tower path in huggingface +VT_VERSION2="" #if you are not using mof vision tower, keep it empty +CN_VERSION=mlp2x_gelu #connector type, other options are: qformer, resampler, etc +CONV_VERSION=phi #chat template, other options are: phi, llama, gemmma, etc +VERSION=base #experiment name for recording different runnings +TRAIN_RECIPE=common #training recipes, other options are: lora, qlora +MODEL_MAX_LENGTH=3072 #max model length for llm + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/finetune.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/train_stablelm.sh b/TinyLLaVA_Factory/scripts/train/train_stablelm.sh new file mode 100644 index 0000000000000000000000000000000000000000..36fa77328b4c49033c449958d2b064a59337515d --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/train_stablelm.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json #pretrain annotation file path +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json #finetune annotation file path +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images #pretrain image dir +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset #finetune image dir + +LLM_VERSION=stabilityai/stablelm-2-zephyr-1_6b # llm path in huggingface +VT_VERSION=openai/clip-vit-large-patch14-336 #vision tower path in huggingface +VT_VERSION2="" #if you are not using mof vision tower, keep it empty +CN_VERSION=mlp2x_gelu #connector type, other options are: qformer, resampler, etc +CONV_VERSION=phi #chat template for stablelm is the same as that for phi +VERSION=base #experiment name for recording different runnings +TRAIN_RECIPE=common #training recipes, other options are: lora, qlora +MODEL_MAX_LENGTH=2048 #max model length for llm + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/finetune.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/train/train_tinyllama.sh b/TinyLLaVA_Factory/scripts/train/train_tinyllama.sh new file mode 100644 index 0000000000000000000000000000000000000000..a4e2fdca1155e478285802c9b4c9e97720875d40 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/train/train_tinyllama.sh @@ -0,0 +1,17 @@ +DATA_PATH=/home/ai/data/llava/dataset/text_files/blip_laion_cc_sbu_558k.json #pretrain annotation file path +FINETUNE_DATA_PATH=/home/ai/data/llava/dataset/text_files/llava_v1_5_mix665k.json #finetune annotation file path +IMAGE_PATH=/home/ai/data/llava/dataset/llava/llava_pretrain/images #pretrain image dir +FINETUNE_IMAGE_PATH=/home/ai/data/llava/dataset #finetune image dir + +LLM_VERSION=TinyLlama/TinyLlama-1.1B-Chat-v1.0 # llm path in huggingface +VT_VERSION=google/siglip-so400m-patch14-384 #vision tower path in huggingface +VT_VERSION2="" #if you are not using mof vision tower, keep it empty +CN_VERSION=mlp2x_gelu #connector type, other options are: qformer, resampler, etc +CONV_VERSION=llama #chat template, other options are: phi, llama, gemmma, etc +VERSION=base #experiment name for recording different runnings +TRAIN_RECIPE=common #training recipes, other options are: lora, qlora +MODEL_MAX_LENGTH=2048 #max model length for llm + + +bash scripts/train/pretrain.sh "$DATA_PATH" "$IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" +bash scripts/train/finetune.sh "$FINETUNE_DATA_PATH" "$FINETUNE_IMAGE_PATH" "$LLM_VERSION" "$VT_VERSION" "$VT_VERSION2" "$CN_VERSION" "$CONV_VERSION" "$VERSION" "$TRAIN_RECIPE" "$MODEL_MAX_LENGTH" diff --git a/TinyLLaVA_Factory/scripts/zero2.json b/TinyLLaVA_Factory/scripts/zero2.json new file mode 100644 index 0000000000000000000000000000000000000000..c95ebefe07b7d8d9fd0936a014679d07102cc270 --- /dev/null +++ b/TinyLLaVA_Factory/scripts/zero2.json @@ -0,0 +1,23 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 2, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + } +} \ No newline at end of file diff --git a/TinyLLaVA_Factory/scripts/zero3.json b/TinyLLaVA_Factory/scripts/zero3.json new file mode 100644 index 0000000000000000000000000000000000000000..14b7b3e92421662a8724ddc61026236acb7a2a3d --- /dev/null +++ b/TinyLLaVA_Factory/scripts/zero3.json @@ -0,0 +1,28 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + } +} diff --git a/TinyLLaVA_Factory/tinyllava/__init__.py b/TinyLLaVA_Factory/tinyllava/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TinyLLaVA_Factory/tinyllava/data/__init__.py b/TinyLLaVA_Factory/tinyllava/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..944291131479262a436b4ba2aaf43051262bba0c --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/__init__.py @@ -0,0 +1,4 @@ +from .template import * +from .image_preprocess import * +from .text_preprocess import * +from .dataset import * \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/data/dataset.py b/TinyLLaVA_Factory/tinyllava/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e44f66464a7f2405c128eaec883c7feacd42c36b --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/dataset.py @@ -0,0 +1,128 @@ +import copy +from dataclasses import dataclass +import json +from typing import Dict, Sequence, TYPE_CHECKING +from PIL import Image, ImageFile +import os + +from .text_preprocess import TextPreprocess +from .image_preprocess import ImagePreprocess +from ..utils.arguments import DataArguments +from ..utils.constants import * + + +import transformers +import torch +from torch.utils.data import Dataset + + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = json.load(open(data_path, "r")) + + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + self.text_preprocess = TextPreprocess(tokenizer, data_args.conv_version) + self.image_preprocess = ImagePreprocess(data_args.image_processor, data_args) + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + data_dict = self.text_preprocess(copy.deepcopy(sources["conversations"])) + if 'image' in sources: + image_file = self.list_data_dict[i]['image'] + image_folder = self.data_args.image_folder + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + image = self.image_preprocess(image) + data_dict['image'] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + # print(f'{i}:{sources}') + crop_size = getattr(self.data_args.image_processor, 'crop_size', getattr(self.data_args.image_processor, 'size')) + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + for input_id in input_ids: + input_id[input_id == self.tokenizer.eos_token_id] = -300 + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + attention_mask = input_ids.ne(self.tokenizer.pad_token_id) + labels = labels[:, :self.tokenizer.model_max_length] + # FIXME: This is a hack for handling phi and stablelm, as they have the same eos, pad and unk. We want the model + # FIXME: to predict the eos in the input ids, but we also use the id of eos to pad sequence, so we use a temp + # FIXME: eos id first, and convert them back. + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + for input_id in input_ids: + input_id[input_id == -300] = self.tokenizer.eos_token_id + + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) diff --git a/TinyLLaVA_Factory/tinyllava/data/image_preprocess.py b/TinyLLaVA_Factory/tinyllava/data/image_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ac1b6a0566ad192d5cba34d6e2c8cbae2a65bc --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/image_preprocess.py @@ -0,0 +1,70 @@ +import os + +from PIL import Image, ImageFile +import torch +import ast + +from ..utils.data_utils import * + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +# 可能imagepreprocess需要继承一个huggingface的图像处理类?提供from_pretrained方法 + +class ImagePreprocess: + def __init__(self, image_processor, data_args={}): + self.image_aspect_ratio = getattr(data_args, 'image_aspect_ratio', None) + self.image_processor = image_processor + self.image_grid_pinpoints = getattr(data_args, 'image_grid_pinpoints', None) + + def __call__(self, image): + if self.image_aspect_ratio == 'pad': + image = self.expand2square(image, tuple(int(x * 255) for x in self.image_processor.image_mean)) + elif self.image_aspect_ratio == "anyres": + image = self.process_anyres_image(image, self.image_processor, self.image_grid_pinpoints) + return image + image = self.image_processor(image, return_tensors='pt')['pixel_values'][0] + return image + + @classmethod + def expand2square(cls, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + @classmethod + def process_anyres_image(cls, image, processor, grid_pinpoints): + """ + Process an image with variable resolutions. + + Args: + image (PIL.Image.Image): The input image to be processed. + processor: The image processor object. + grid_pinpoints (str): A string representation of a list of possible resolutions. + + Returns: + torch.Tensor: A tensor containing the processed image patches. + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + best_resolution = select_best_resolution(image.size, possible_resolutions) + image_padded = resize_and_pad_image(image, best_resolution) + + patches = divide_to_patches(image_padded, processor.crop_size['height']) + + image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) + + image_patches = [image_original_resize] + patches + image_patches = [processor(image_patch, return_tensors='pt')['pixel_values'][0] + for image_patch in image_patches] + return torch.stack(image_patches, dim=0) + diff --git a/TinyLLaVA_Factory/tinyllava/data/template/__init__.py b/TinyLLaVA_Factory/tinyllava/data/template/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08f48ed65e8e90f2abda8b62792a93b790160bc --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/__init__.py @@ -0,0 +1,29 @@ +import os +from typing import Dict + +from .base import * +from ...utils import import_modules + + +TEMPlATE_FACTORY: Dict[str, Template] = {} + +def TemplateFactory(version): + template = TEMPlATE_FACTORY.get(version, None) + assert template, f"{version} is not implmentation" + return template + + +def register_template(name): + def register_template_cls(cls): + if name in TEMPlATE_FACTORY: + return TEMPlATE_FACTORY[name] + + TEMPlATE_FACTORY[name] = cls + return cls + + return register_template_cls + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +import_modules(models_dir, "tinyllava.data.template") diff --git a/TinyLLaVA_Factory/tinyllava/data/template/base.py b/TinyLLaVA_Factory/tinyllava/data/template/base.py new file mode 100644 index 0000000000000000000000000000000000000000..65f475aacf148710d198d00b67ac72dd80ea0531 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/base.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +import copy + +from .formatter import EmptyFormatter, StringFormatter +from .formatter import Formatter +from ...utils.constants import * + +from transformers import PreTrainedTokenizer +import torch + + + +@dataclass +class Template: + format_image_token: "Formatter" + format_user: "Formatter" + format_assistant: "Formatter" + system: "Formatter" + separator: "Formatter" + + def encode(self, messages, tokenizer, mode='train'): + """ + 1. get list form messages(conversations:[{from:human, value:message}, {from:gpt, value:message}]) + ===> human_list, value_list + 2. prompt two list + 3. tokenize prompt + 4. make target + """ + question_list, answer_list = self.get_list_from_message(messages) + prompt = self.prompt(question_list, answer_list) + input_ids = self.tokenizer_image_token(prompt, tokenizer, return_tensors='pt') + if mode == 'train': + labels = self.make_labels(input_ids, prompt, tokenizer) + return dict( + input_ids=input_ids, + labels=labels + ) + else: + return dict(input_ids=input_ids, prompt=prompt) + + + def get_list_from_message(self, messages): + return self._get_list_from_message(messages) + + def _get_list_from_message(self, messages): + """ + messages ====> [{from:human, value:message}, {from:gpt, value:message}] + """ + question_list = [] + answer_list = [] + first_is_not_question = 0 + for i, message in enumerate(messages): + if i == 0 and message['from'] != 'human': + first_is_not_question = 1 + continue + if i % 2 == first_is_not_question: + question_list.append(message['value']) + else: + answer_list.append(message['value']) + + assert len(question_list) == len(answer_list) , \ + f"qa is not match : length_q:{len(question_list)} vs length_a:{len(answer_list)}" + return question_list, answer_list + + + def prompt( + self, + question_list, answer_list + ): + if type(question_list) is str: + question_list = [question_list] + if type(answer_list) is str: + answer_list = [answer_list] + msg = self._prompt(question_list, answer_list) + return msg + + def _prompt( + self, + question_list, answer_list, + ): + msg = "" + for i, (question, answer) in enumerate(zip(question_list, answer_list)): + if i == 0: + msg += self.system.apply() + if DEFAULT_IMAGE_TOKEN in question: + question = question.replace(DEFAULT_IMAGE_TOKEN, '').strip() + question = self.format_image_token.apply(content=question).strip() + msg += self.format_user.apply(content=question) + msg += self.format_assistant.apply(content=answer) + return msg + + def make_labels(self, input_ids, prompt, tokenizer): + labels = copy.deepcopy(input_ids) + sep, eos_token = self.separator.apply() + total_len = int(labels.ne(tokenizer.pad_token_id).sum()) + if tokenizer.pad_token_id == tokenizer.eos_token_id: + total_len += prompt.count(eos_token) + rounds = prompt.split(eos_token) + eos_token_length = len(tokenizer.encode(eos_token)) + labels, cur_len = self._make_masks(labels, tokenizer, sep, eos_token_length, rounds) + if cur_len < tokenizer.model_max_length: + import time + if cur_len != total_len: + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + print("number of rounds: ", len(rounds) - 1) + print("rounds: ", rounds[:-1]) + print("prompt: ", prompt) + print(labels) + print(input_ids) + time.sleep(5) + labels[:] = IGNORE_INDEX + return labels + + + + def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds): + cur_len = 0 + for rou in rounds: + if rou == "": + break + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(self.tokenizer_image_token(rou, tokenizer)) + eos_token_length + instruction_len = len(self.tokenizer_image_token(parts[0], tokenizer)) - 1 + labels[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + labels[cur_len:] = IGNORE_INDEX + return labels, cur_len + + @classmethod + def tokenizer_image_token(cls, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + def _insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in _insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/data/template/formatter.py b/TinyLLaVA_Factory/tinyllava/data/template/formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..14b595f4b9dc4ee9c8f37b8dde413e2670d3a6d9 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/formatter.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Union, List + + +SLOT = Union[str, List[str], Dict[str, str]] + +@dataclass +class Formatter(ABC): + slot: SLOT = "" + + @abstractmethod + def apply(self, **kwargs) -> SLOT: ... + + + +@dataclass +class EmptyFormatter(Formatter): + def apply(self, **kwargs) -> SLOT: + return self.slot + + +@dataclass +class StringFormatter(Formatter): + def apply(self, **kwargs) -> SLOT: + msg = "" + for name, value in kwargs.items(): + if value is None: + msg = self.slot.split(':')[0] + ":" + return msg + if not isinstance(value, str): + raise RuntimeError("Expected a string, got {}".format(value)) + msg = self.slot.replace("{{" + name + "}}", value, 1) + return msg diff --git a/TinyLLaVA_Factory/tinyllava/data/template/gemma_template.py b/TinyLLaVA_Factory/tinyllava/data/template/gemma_template.py new file mode 100644 index 0000000000000000000000000000000000000000..3e062ebf29f2f4bc915ecd768c837f625744b066 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/gemma_template.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from packaging import version + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from . import register_template +from ...utils.constants import * + +from transformers import PreTrainedTokenizer +import torch +import tokenizers + +system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + +@register_template('gemma') +@dataclass +class GemmaTemplate(Template): + format_image_token: "Formatter" = StringFormatter(slot="\n{{content}}") + format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ") + format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "") + system: "Formatter" = EmptyFormatter(slot=system+" ") + separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '']) + + def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds): + cur_len = 1 # bos + eos_token_length = 1 + bos_token_length = 1 + labels[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(self.tokenizer_image_token(rou, tokenizer)) + eos_token_length - bos_token_length + instruction_len = len(self.tokenizer_image_token(parts[0], tokenizer)) - 1 - bos_token_length + labels[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + + labels[cur_len:] = IGNORE_INDEX + return labels, cur_len diff --git a/TinyLLaVA_Factory/tinyllava/data/template/llama_template.py b/TinyLLaVA_Factory/tinyllava/data/template/llama_template.py new file mode 100644 index 0000000000000000000000000000000000000000..6705add0e540f3737ab057591b0e66a4103d96c9 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/llama_template.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from packaging import version + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from . import register_template +from ...utils.constants import * + +from transformers import PreTrainedTokenizer +import torch +import tokenizers + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + +system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + +@register_template('llama') +@dataclass +class LlamaTemplate(Template): + format_image_token: "Formatter" = StringFormatter(slot="\n{{content}}") + format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ") + format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "") + system: "Formatter" = EmptyFormatter(slot=system+" ") + separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '']) + + def _make_masks(self, labels, tokenizer, sep, eos_token_length, rounds): + cur_len = 1 # bos + eos_token_length = 1 + bos_token_length = 1 + labels[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(self.tokenizer_image_token(rou, tokenizer)) + eos_token_length - bos_token_length + instruction_len = len(self.tokenizer_image_token(parts[0], tokenizer)) - 1 - bos_token_length + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + labels[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + + labels[cur_len:] = IGNORE_INDEX + return labels, cur_len + + + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/data/template/phi_template.py b/TinyLLaVA_Factory/tinyllava/data/template/phi_template.py new file mode 100644 index 0000000000000000000000000000000000000000..e8aa45f3d6acd17714ddd66947674479e4164f92 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/phi_template.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from . import register_template + +from transformers import PreTrainedTokenizer +import torch + +system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + +@register_template('phi') +@dataclass +class PhiTemplate(Template): + format_image_token: "Formatter" = StringFormatter(slot="\n{{content}}") + format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ") + format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "<|endoftext|>") + system: "Formatter" = EmptyFormatter(slot=system+" ") + separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '<|endoftext|>']) + + + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/data/template/pretrain_template.py b/TinyLLaVA_Factory/tinyllava/data/template/pretrain_template.py new file mode 100644 index 0000000000000000000000000000000000000000..47139361e1499bcad31b2a2d99497e5f5f21ade9 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/pretrain_template.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +import copy + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from ...utils.constants import * +from . import register_template + +from transformers import PreTrainedTokenizer +import torch + + + +@register_template('pretrain') +@dataclass +class PretrainTemplate(Template): + format_image_token: "Formatter" = EmptyFormatter(slot="") + format_user: "Formatter" = EmptyFormatter(slot="") + format_assistant: "Formatter" = StringFormatter(slot="{{content}}\n") + system: "Formatter" = EmptyFormatter(slot="") + separator: "Formatter" = EmptyFormatter(slot=['', '']) + + def make_labels(self, input_ids, prompt, tokenizer): + labels = copy.deepcopy(input_ids) + mask_len = len(self.tokenizer_image_token("", tokenizer)) + labels[:mask_len] = IGNORE_INDEX + return labels + + + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/data/template/qwen2_base_template.py b/TinyLLaVA_Factory/tinyllava/data/template/qwen2_base_template.py new file mode 100644 index 0000000000000000000000000000000000000000..adfff5eb42a73da2cc28c7b02cb9e5d28d00be2a --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/qwen2_base_template.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from . import register_template + +from transformers import PreTrainedTokenizer +import torch + +system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + +@register_template('qwen2_base') +@dataclass +class Qwen2BaseTemplate(Template): + format_image_token: "Formatter" = StringFormatter(slot="\n{{content}}") + format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ") + format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "<|endoftext|>") + system: "Formatter" = EmptyFormatter(slot=system+" ") + separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '<|endoftext|>']) + + + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/data/template/qwen2_instruct_template.py b/TinyLLaVA_Factory/tinyllava/data/template/qwen2_instruct_template.py new file mode 100644 index 0000000000000000000000000000000000000000..1cdb687b0ac366f931a33f544b2c0980d14dea26 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/template/qwen2_instruct_template.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from .formatter import EmptyFormatter, StringFormatter +from .base import Template +from .formatter import Formatter +from . import register_template + +from transformers import PreTrainedTokenizer +import torch + +system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + +@register_template('qwen2_instruct') +@dataclass +class Qwen2InstructTemplate(Template): + format_image_token: "Formatter" = StringFormatter(slot="\n{{content}}") + format_user: "Formatter" = StringFormatter(slot="USER" + ": " + "{{content}}" + " ") + format_assistant: "Formatter" = StringFormatter(slot="ASSISTANT" + ": " + "{{content}}" + "<|im_end|>") + system: "Formatter" = EmptyFormatter(slot=system+" ") + separator: "Formatter" = EmptyFormatter(slot=[' ASSISTANT: ', '<|im_end|>']) + + + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/data/text_preprocess.py b/TinyLLaVA_Factory/tinyllava/data/text_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..a50c722a80a32819a3ef900e7a948bb1084c3fa1 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/data/text_preprocess.py @@ -0,0 +1,12 @@ +from typing import Any + +from .template import TemplateFactory + + +class TextPreprocess: + def __init__(self, tokenizer, version): + self.tokenizer = tokenizer + self.template = TemplateFactory(version)() + + def __call__(self, messages, mode='train'): + return self.template.encode(messages, self.tokenizer, mode) \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/eval/__init__.py b/TinyLLaVA_Factory/tinyllava/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TinyLLaVA_Factory/tinyllava/eval/eval_pope.py b/TinyLLaVA_Factory/tinyllava/eval/eval_pope.py new file mode 100644 index 0000000000000000000000000000000000000000..b115b8f2327ea9d972f9e41bcbb03c68be6b3508 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/eval_pope.py @@ -0,0 +1,81 @@ +import os +import json +import argparse + +def eval_pope(answers, label_file): + label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] + + for answer in answers: + text = answer['text'] + + # Only keep the first sentence + if text.find('.') != -1: + text = text.split('.')[0] + + text = text.replace(',', '') + words = text.split(' ') + if 'No' in words or 'not' in words or 'no' in words: + answer['text'] = 'no' + else: + answer['text'] = 'yes' + + for i in range(len(label_list)): + if label_list[i] == 'no': + label_list[i] = 0 + else: + label_list[i] = 1 + + pred_list = [] + for answer in answers: + if answer['text'] == 'no': + pred_list.append(0) + else: + pred_list.append(1) + + pos = 1 + neg = 0 + yes_ratio = pred_list.count(1) / len(pred_list) + + TP, TN, FP, FN = 0, 0, 0, 0 + for pred, label in zip(pred_list, label_list): + if pred == pos and label == pos: + TP += 1 + elif pred == pos and label == neg: + FP += 1 + elif pred == neg and label == neg: + TN += 1 + elif pred == neg and label == pos: + FN += 1 + + print('TP\tFP\tTN\tFN\t') + print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) + + precision = float(TP) / float(TP + FP) + recall = float(TP) / float(TP + FN) + f1 = 2*precision*recall / (precision + recall) + acc = (TP + TN) / (TP + TN + FP + FN) + print('Accuracy: {}'.format(acc)) + print('Precision: {}'.format(precision)) + print('Recall: {}'.format(recall)) + print('F1 score: {}'.format(f1)) + print('Yes ratio: {}'.format(yes_ratio)) + print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--annotation-dir", type=str) + parser.add_argument("--question-file", type=str) + parser.add_argument("--result-file", type=str) + args = parser.parse_args() + + questions = [json.loads(line) for line in open(args.question_file)] + questions = {question['question_id']: question for question in questions} + answers = [json.loads(q) for q in open(args.result_file)] + for file in os.listdir(args.annotation_dir): + assert file.startswith('coco_pope_') + assert file.endswith('.json') + category = file[10:-5] + cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] + print('Category: {}, # samples: {}'.format(category, len(cur_answers))) + eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) + print("====================================") diff --git a/TinyLLaVA_Factory/tinyllava/eval/eval_science_qa.py b/TinyLLaVA_Factory/tinyllava/eval/eval_science_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf206bbd7a5d6376eef82d61b3ef8bbe0f71c6c --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/eval_science_qa.py @@ -0,0 +1,114 @@ +import argparse +import json +import os +import re +import random + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--base-dir', type=str) + parser.add_argument('--result-file', type=str) + parser.add_argument('--output-file', type=str) + parser.add_argument('--output-result', type=str) + parser.add_argument('--split', type=str, default='test') + parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) + return parser.parse_args() + + +def convert_caps(results): + fakecaps = [] + for result in results: + image_id = result['question_id'] + caption = result['text'] + fakecaps.append({"image_id": int(image_id), "caption": caption}) + return fakecaps + + +def get_pred_idx(prediction, choices, options): + """ + Get the index (e.g. 2) from the prediction (e.g. 'C') + """ + if prediction in options[:len(choices)]: + return options.index(prediction) + else: + return -1 + return random.choice(range(len(choices))) + + +if __name__ == "__main__": + args = get_args() + + base_dir = args.base_dir + split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] + problems = json.load(open(os.path.join(base_dir, "problems.json"))) + predictions = [json.loads(line) for line in open(args.result_file)] + predictions = {pred['question_id']: pred for pred in predictions} + split_problems = {idx: problems[idx] for idx in split_indices} + + results = {'correct': [], 'incorrect': []} + sqa_results = {} + sqa_results['acc'] = None + sqa_results['correct'] = None + sqa_results['count'] = None + sqa_results['results'] = {} + sqa_results['outputs'] = {} + + for prob_id, prob in split_problems.items(): + if prob_id not in predictions: + pred = {'text': 'FAILED', 'prompt': 'Unknown'} + pred_text = 'FAILED' + else: + pred = predictions[prob_id] + pred_text = pred['text'] + + if pred_text in args.options: + answer = pred_text + elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": + answer = pred_text[0] + else: + pattern = re.compile(r'The answer is ([A-Z]).') + res = pattern.findall(pred_text) + if len(res) == 1: + answer = res[0] # 'A', 'B', ... + else: + answer = "FAILED" + + pred_idx = get_pred_idx(answer, prob['choices'], args.options) + + analysis = { + 'question_id': prob_id, + 'parsed_ans': answer, + 'ground_truth': args.options[prob['answer']], + 'question': pred['prompt'], + 'pred': pred_text, + 'is_multimodal': '' in pred['prompt'], + } + + sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) + sqa_results['outputs'][prob_id] = pred_text + + if pred_idx == prob['answer']: + results['correct'].append(analysis) + else: + results['incorrect'].append(analysis) + + correct = len(results['correct']) + total = len(results['correct']) + len(results['incorrect']) + + ###### IMG ###### + multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) + multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) + multimodal_total = multimodal_correct + multimodal_incorrect + ###### IMG ###### + + print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') + + sqa_results['acc'] = correct / total * 100 + sqa_results['correct'] = correct + sqa_results['count'] = total + + with open(args.output_file, 'w') as f: + json.dump(results, f, indent=2) + with open(args.output_result, 'w') as f: + json.dump(sqa_results, f, indent=2) diff --git a/TinyLLaVA_Factory/tinyllava/eval/eval_textvqa.py b/TinyLLaVA_Factory/tinyllava/eval/eval_textvqa.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b67c04e53268744c95cbb307a5221d459e9a4b --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/eval_textvqa.py @@ -0,0 +1,65 @@ +import os +import argparse +import json +import re + +from tinyllava.eval.m4c_evaluator import TextVQAAccuracyEvaluator + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--annotation-file', type=str) + parser.add_argument('--result-file', type=str) + parser.add_argument('--result-dir', type=str) + return parser.parse_args() + + +def prompt_processor(prompt): + if prompt.startswith('OCR tokens: '): + pattern = r"Question: (.*?) Short answer:" + match = re.search(pattern, prompt, re.DOTALL) + question = match.group(1) + elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: + if prompt.startswith('Reference OCR token:'): + question = prompt.split('\n')[1] + else: + question = prompt.split('\n')[0] + elif len(prompt.split('\n')) == 2: + question = prompt.split('\n')[0] + else: + assert False + + return question.lower() + + +def eval_single(annotation_file, result_file): + experiment_name = os.path.splitext(os.path.basename(result_file))[0] + print(experiment_name) + annotations = json.load(open(annotation_file))['data'] + annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} + results = [json.loads(line) for line in open(result_file)] + + pred_list = [] + for result in results: + annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] + pred_list.append({ + "pred_answer": result['text'], + "gt_answers": annotation['answers'], + }) + + evaluator = TextVQAAccuracyEvaluator() + print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) + + +if __name__ == "__main__": + args = get_args() + + if args.result_file is not None: + eval_single(args.annotation_file, args.result_file) + + if args.result_dir is not None: + for result_file in sorted(os.listdir(args.result_dir)): + if not result_file.endswith('.jsonl'): + print(f'Skipping {result_file}') + continue + eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) diff --git a/TinyLLaVA_Factory/tinyllava/eval/m4c_evaluator.py b/TinyLLaVA_Factory/tinyllava/eval/m4c_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e30e958da061a4f0a0bfe34b12d2fcaeba7ff2f4 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/m4c_evaluator.py @@ -0,0 +1,334 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import re + +from tqdm import tqdm + + +class EvalAIAnswerProcessor: + """ + Processes an answer similar to Eval AI + copied from + https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897 + """ + + CONTRACTIONS = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + + NUMBER_MAP = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + ARTICLES = ["a", "an", "the"] + PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)") + COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)") + PUNCTUATIONS = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def __init__(self, *args, **kwargs): + pass + + def word_tokenize(self, word): + word = word.lower() + word = word.replace(",", "").replace("?", "").replace("'s", " 's") + return word.strip() + + def process_punctuation(self, in_text): + out_text = in_text + for p in self.PUNCTUATIONS: + if (p + " " in in_text or " " + p in in_text) or ( + re.search(self.COMMA_STRIP, in_text) is not None + ): + out_text = out_text.replace(p, "") + else: + out_text = out_text.replace(p, " ") + out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE) + return out_text + + def process_digit_article(self, in_text): + out_text = [] + temp_text = in_text.lower().split() + for word in temp_text: + word = self.NUMBER_MAP.setdefault(word, word) + if word not in self.ARTICLES: + out_text.append(word) + else: + pass + for word_id, word in enumerate(out_text): + if word in self.CONTRACTIONS: + out_text[word_id] = self.CONTRACTIONS[word] + out_text = " ".join(out_text) + return out_text + + def __call__(self, item): + item = self.word_tokenize(item) + item = item.replace("\n", " ").replace("\t", " ").strip() + item = self.process_punctuation(item) + item = self.process_digit_article(item) + return item + + +class TextVQAAccuracyEvaluator: + def __init__(self): + self.answer_processor = EvalAIAnswerProcessor() + + def _compute_answer_scores(self, raw_answers): + """ + compute the accuracy (soft score) of human answers + """ + answers = [self.answer_processor(a) for a in raw_answers] + assert len(answers) == 10 + gt_answers = list(enumerate(answers)) + unique_answers = set(answers) + unique_answer_scores = {} + + for unique_answer in unique_answers: + accs = [] + for gt_answer in gt_answers: + other_answers = [item for item in gt_answers if item != gt_answer] + matching_answers = [ + item for item in other_answers if item[1] == unique_answer + ] + acc = min(1, float(len(matching_answers)) / 3) + accs.append(acc) + unique_answer_scores[unique_answer] = sum(accs) / len(accs) + + return unique_answer_scores + + def eval_pred_list(self, pred_list): + pred_scores = [] + for entry in tqdm(pred_list): + pred_answer = self.answer_processor(entry["pred_answer"]) + unique_answer_scores = self._compute_answer_scores(entry["gt_answers"]) + score = unique_answer_scores.get(pred_answer, 0.0) + pred_scores.append(score) + + accuracy = sum(pred_scores) / len(pred_scores) + return accuracy + + +class STVQAAccuracyEvaluator: + def __init__(self): + self.answer_processor = EvalAIAnswerProcessor() + + def eval_pred_list(self, pred_list): + pred_scores = [] + for entry in pred_list: + pred_answer = self.answer_processor(entry["pred_answer"]) + gts = [self.answer_processor(a) for a in entry["gt_answers"]] + score = 1.0 if pred_answer in gts else 0.0 + pred_scores.append(score) + + accuracy = sum(pred_scores) / len(pred_scores) + return accuracy + + +class STVQAANLSEvaluator: + def __init__(self): + import editdistance # install with `pip install editdistance` + + self.get_edit_distance = editdistance.eval + + def get_anls(self, s1, s2): + s1 = s1.lower().strip() + s2 = s2.lower().strip() + iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2)) + anls = iou if iou >= 0.5 else 0.0 + return anls + + def eval_pred_list(self, pred_list): + pred_scores = [] + for entry in pred_list: + anls = max( + self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"] + ) + pred_scores.append(anls) + + accuracy = sum(pred_scores) / len(pred_scores) + return accuracy + + +class TextCapsBleu4Evaluator: + def __init__(self): + # The following script requires Java 1.8.0 and pycocotools installed. + # The pycocoevalcap can be installed with pip as + # pip install git+https://github.com/ronghanghu/coco-caption.git@python23 + # Original pycocoevalcap code is at https://github.com/tylin/coco-caption + # but has no python3 support yet. + try: + from pycocoevalcap.bleu.bleu import Bleu + from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer + except ModuleNotFoundError: + print( + "Please install pycocoevalcap module using " + "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa + ) + raise + + self.tokenizer = PTBTokenizer() + self.scorer = Bleu(4) + + def eval_pred_list(self, pred_list): + # Create reference and hypotheses captions. + gts = {} + res = {} + for idx, entry in enumerate(pred_list): + gts[idx] = [{"caption": a} for a in entry["gt_answers"]] + res[idx] = [{"caption": entry["pred_answer"]}] + + gts = self.tokenizer.tokenize(gts) + res = self.tokenizer.tokenize(res) + score, _ = self.scorer.compute_score(gts, res) + + bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4) + return bleu4 diff --git a/TinyLLaVA_Factory/tinyllava/eval/model_vqa.py b/TinyLLaVA_Factory/tinyllava/eval/model_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..91db09c826271e9b1388ef72736300071657d0d1 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/model_vqa.py @@ -0,0 +1,103 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + + model, tokenizer, image_processor, context_len = load_pretrained_model(model_path) + model.to(device='cuda') + text_processor = TextPreprocess(tokenizer, args.conv_mode) + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + + questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + for line in tqdm(questions): + idx = line["question_id"] + image_file = line["image"] + qs = line["text"] + cur_prompt = qs + + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + + msg = Message() + msg.add_message(qs) + + result = text_processor(msg.messages, mode='eval') + input_ids = result['input_ids'] + prompt = result['prompt'] + input_ids = input_ids.unsqueeze(0).cuda() + + image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') + image_tensor = image_processor(image) + image_tensors = image_tensor.unsqueeze(0).half().cuda() + image_sizes = [image.size] + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensors, + image_sizes=image_sizes, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + # no_repeat_ngram_size=3, + max_new_tokens=1024, + use_cache=True) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": args.model_base, + "metadata": {}}) + "\n") + ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + args = parser.parse_args() + + eval_model(args) diff --git a/TinyLLaVA_Factory/tinyllava/eval/model_vqa_loader.py b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..599292a0ee5954a2a1f18bebb1fc96e1a6cfe815 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_loader.py @@ -0,0 +1,146 @@ +import argparse +import time + +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + +from torch.utils.data import Dataset, DataLoader + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +# Custom dataset class +class CustomDataset(Dataset): + def __init__(self, questions, image_folder, text_processor, image_processor): + self.questions = questions + self.image_folder = image_folder + self.text_processor = text_processor + self.image_processor = image_processor + + def __getitem__(self, index): + line = self.questions[index] + image_file = line["image"] + qs = line["text"] + + image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') + image_tensor = self.image_processor(image) + + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + msg = Message() + msg.add_message(qs) + #print(prompt) + result = self.text_processor(msg.messages, mode='eval') + input_ids = result['input_ids'] + + return input_ids, image_tensor, image.size + + def __len__(self): + return len(self.questions) + + +def collate_fn(batch): + input_ids, image_tensors, image_sizes = zip(*batch) + input_ids = torch.stack(input_ids, dim=0) + image_tensors = torch.stack(image_tensors, dim=0) + return input_ids, image_tensors, image_sizes + + +# DataLoader +def create_data_loader(questions, image_folder, text_processor, image_processor, batch_size=1, num_workers=4): + assert batch_size == 1, "batch_size must be 1" + dataset = CustomDataset(questions, image_folder, text_processor, image_processor) + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) + return data_loader + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model, tokenizer, image_processor, context_len = load_pretrained_model(model_path) + + text_processor = TextPreprocess(tokenizer, args.conv_mode) + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + + questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + + data_loader = create_data_loader(questions, args.image_folder, text_processor, image_processor) + # print("Tokenizer's eos token: ", tokenizer.eos_token) + model.to(device='cuda') + for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)): + idx = line["question_id"] + cur_prompt = line["text"] + # keywords = [tokenizer.eos_token] + # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + input_ids = input_ids.to(device='cuda', non_blocking=True) + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + pad_token_id=tokenizer.pad_token_id, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + max_new_tokens=args.max_new_tokens, + # stopping_criteria=[stopping_criteria], + image_sizes=image_sizes, + use_cache=True) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + # print("Printing outputs") + # print(outputs) + # time.sleep(5) + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": args.model_base, + "metadata": {}}) + "\n") + # ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llama") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--max_new_tokens", type=int, default=128) + parser.add_argument("--image_aspect_ratio", type=str, default="pad") + args = parser.parse_args() + + eval_model(args) diff --git a/TinyLLaVA_Factory/tinyllava/eval/model_vqa_mmmu.py b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_mmmu.py new file mode 100644 index 0000000000000000000000000000000000000000..c96baece6a38cb06a7b9e9faf7dcc4bb036c45e3 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_mmmu.py @@ -0,0 +1,181 @@ +import argparse +import torch +import os +import json +import random +import numpy as np +from tqdm import tqdm +import shortuuid + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def parse_multi_choice_response(response, all_choices, index2ans): + """ + Parse the prediction from the generated response. + Return the predicted index e.g., A, B, C, D. + """ + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + candidates = [] + for choice in all_choices: # e.g., (A) (B) (C) (D) + if f"({choice})" in response: + candidates.append(choice) + ans_with_brack = True + + if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if f" {choice} " in response: + candidates.append(choice) + + # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, randomly choose one. + pred_index = random.choice(all_choices) + elif len(candidates) > 1: + start_indexes = [] + if index_ans: + if ans_with_brack: + for can in candidates: + index = response.rfind(f"({can})") + start_indexes.append(index) # -1 will be ignored anyway + # start_indexes = [generated_response.index(f'({can})') for can in candidates] + else: + for can in candidates: + index = response.rfind(f" {can} ") + start_indexes.append(index) + else: + for can in candidates: + index = response.lower().rfind(index2ans[can].lower()) + start_indexes.append(index) + # get the last one + pred_index = candidates[np.argmax(start_indexes)] + else: # if only one candidate, use it. + pred_index = candidates[0] + + return pred_index + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model, tokenizer, image_processor, context_len = load_pretrained_model(model_path) + + text_processor = TextPreprocess(tokenizer, args.conv_mode) + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + + questions = json.load(open(os.path.expanduser(args.question_file), "r")) + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + model.to(device="cuda") + for i, line in enumerate(tqdm(questions)): + idx = line["id"] + question = line["prompt"] + + if "image" in line: + image_file = line["image"] + # image = Image.open(image_file).convert("RGB") + image = Image.open(os.path.join(args.image_folder, image_file)).convert("RGB") + image_sizes = [image.size] + image = image_processor(image) + images = image.unsqueeze(0).half().cuda() + question = "" + "\n" + question + else: + images = None + image_sizes = None + + msg = Message() + msg.add_message(question) + # print(msg.messages) + + result = text_processor(msg.messages, mode='eval') + # print(result["prompt"]) + input_ids = result['input_ids'] + input_ids = input_ids.unsqueeze(0).cuda() + + with torch.inference_mode(): + if images is not None: + output_ids = model.generate( + input_ids, + images=images, + image_sizes=image_sizes, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + max_new_tokens=1024, + use_cache=True, + pad_token_id=tokenizer.pad_token_id, + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + else: + if line["question_type"] == "multiple-choice": + all_choices = line["all_choices"] + outputs = random.choice(all_choices) + else: + outputs = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" + + if line["question_type"] == "multiple-choice": + pred_ans = parse_multi_choice_response( + outputs, line["all_choices"], line["index2ans"] + ) + else: # open question + pred_ans = outputs + + # print(outputs, pred_ans) + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": questions, + "text": pred_ans, + "answer_id": ans_id, + "model_id": args.model_path.split("/")[-1], + "metadata": {}}) + "\n") + ans_file.flush() + ans_file.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.json") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llama") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--answer-prompter", action="store_true") + parser.add_argument("--image_aspect_ratio", type=str, default="pad") + args = parser.parse_args() + + eval_model(args) + diff --git a/TinyLLaVA_Factory/tinyllava/eval/model_vqa_pope.py b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_pope.py new file mode 100644 index 0000000000000000000000000000000000000000..242e90aaa90de4c88b42382df9850869a7197df4 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_pope.py @@ -0,0 +1,147 @@ +import argparse +import time + +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + +from torch.utils.data import Dataset, DataLoader + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +# Custom dataset class +class CustomDataset(Dataset): + def __init__(self, questions, image_folder, text_processor, image_processor): + self.questions = questions + self.image_folder = image_folder + self.text_processor = text_processor + self.image_processor = image_processor + + def __getitem__(self, index): + line = self.questions[index] + image_file = line["image"] + qs = line["text"] + + image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') + image_tensor = self.image_processor(image) + + qs = DEFAULT_IMAGE_TOKEN + '\n' + qs + msg = Message() + msg.add_message(qs) + #print(prompt) + result = self.text_processor(msg.messages, mode='eval') + input_ids = result['input_ids'] + + return input_ids, image_tensor, image.size + + def __len__(self): + return len(self.questions) + + +def collate_fn(batch): + input_ids, image_tensors, image_sizes = zip(*batch) + input_ids = torch.stack(input_ids, dim=0) + image_tensors = torch.stack(image_tensors, dim=0) + return input_ids, image_tensors, image_sizes + + +# DataLoader +def create_data_loader(questions, image_folder, text_processor, image_processor, batch_size=1, num_workers=4): + assert batch_size == 1, "batch_size must be 1" + dataset = CustomDataset(questions, image_folder, text_processor, image_processor) + data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) + return data_loader + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model, tokenizer, image_processor, context_len = load_pretrained_model(model_path) + + text_processor = TextPreprocess(tokenizer, args.conv_mode) + #model.config.image_aspect_ratio = 'pad' + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + + questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + + data_loader = create_data_loader(questions, args.image_folder, text_processor, image_processor) + # print("Tokenizer's eos token: ", tokenizer.eos_token) + model.to(device='cuda') + for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)): + idx = line["question_id"] + cur_prompt = line["text"] + # keywords = [tokenizer.eos_token] + # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + input_ids = input_ids.to(device='cuda', non_blocking=True) + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), + pad_token_id=tokenizer.pad_token_id, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + max_new_tokens=args.max_new_tokens, + # stopping_criteria=[stopping_criteria], + image_sizes=image_sizes, + use_cache=True) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + # print("Printing outputs") + # print(outputs) + # time.sleep(5) + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": args.model_base, + "metadata": {}}) + "\n") + # ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llama") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--max_new_tokens", type=int, default=128) + parser.add_argument("--image_aspect_ratio", type=str, default="pad") + args = parser.parse_args() + + eval_model(args) diff --git a/TinyLLaVA_Factory/tinyllava/eval/model_vqa_science.py b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_science.py new file mode 100644 index 0000000000000000000000000000000000000000..bce3e191d2c35ee6cead9937fa0e33ab84f07ad0 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/model_vqa_science.py @@ -0,0 +1,110 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def eval_model(args): + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model, tokenizer, image_processor, context_len = load_pretrained_model(model_path) + + text_processor = TextPreprocess(tokenizer, args.conv_mode) + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + + questions = json.load(open(os.path.expanduser(args.question_file), "r")) + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + model.to(device='cuda') + for i, line in enumerate(tqdm(questions)): + idx = line["id"] + question = line['conversations'][0] + question = question['value'].replace('', '').strip() + if 'image' in line: + image_file = line["image"] + image = Image.open(os.path.join(args.image_folder, image_file)) + image_sizes = [image.size] + image = image_processor(image) + images = image.unsqueeze(0).half().cuda() + question = '' + '\n' + question + else: + images = None + image_sizes = None + + if args.single_pred_prompt: + question = question + '\n' + "Answer with the option's letter from the given choices directly." + msg = Message() + msg.add_message(question) + + result = text_processor(msg.messages, mode='eval') + input_ids = result['input_ids'] + prompt = result['prompt'] + input_ids = input_ids.unsqueeze(0).cuda() + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=images, + image_sizes=image_sizes, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + max_new_tokens=1024, + use_cache=True, + pad_token_id=tokenizer.pad_token_id + + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({"question_id": idx, + "prompt": prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": args.model_path.split('/')[-1], + "metadata": {}}) + "\n") + ans_file.flush() + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.json") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llama") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--answer-prompter", action="store_true") + parser.add_argument("--single-pred-prompt", action="store_true") + parser.add_argument("--image_aspect_ratio", type=str, default='pad') + args = parser.parse_args() + + eval_model(args) + + diff --git a/TinyLLaVA_Factory/tinyllava/eval/run_tiny_llava.py b/TinyLLaVA_Factory/tinyllava/eval/run_tiny_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..78c1b3744db0bc3fd34e05a08f08314360bb2bb4 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/eval/run_tiny_llava.py @@ -0,0 +1,118 @@ +import argparse +import re +import requests +from PIL import Image +from io import BytesIO + +import torch +from transformers import PreTrainedModel + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + + +def image_parser(args): + out = args.image_file.split(args.sep) + return out + + +def load_image(image_file): + if image_file.startswith("http") or image_file.startswith("https"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + return image + + +def load_images(image_files): + out = [] + for image_file in image_files: + image = load_image(image_file) + out.append(image) + return out + + +def eval_model(args): + # Model + disable_torch_init() + + if args.model_path is not None: + model, tokenizer, image_processor, context_len = load_pretrained_model(args.model_path) + else: + assert args.model is not None, 'model_path or model must be provided' + model = args.model + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + tokenizer = model.tokenizer + image_processor = model.vision_tower._image_processor + qs = args.query + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + text_processor = TextPreprocess(tokenizer, args.conv_mode) + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + model.cuda() + + msg = Message() + msg.add_message(qs) + + result = text_processor(msg.messages, mode='eval') + input_ids = result['input_ids'] + prompt = result['prompt'] + input_ids = input_ids.unsqueeze(0).cuda() + + + image_files = image_parser(args) + images = load_images(image_files)[0] + images_tensor = image_processor(images) + images_tensor = images_tensor.unsqueeze(0).half().cuda() + + + + stop_str = text_processor.template.separator.apply()[1] + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=images_tensor, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=args.max_new_tokens, + use_cache=True, + stopping_criteria=[stopping_criteria], + ) + + outputs = tokenizer.batch_decode( + output_ids, skip_special_tokens=True + )[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + outputs = outputs.strip() + print(outputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--model", type=PreTrainedModel, default=None) + parser.add_argument("--image-file", type=str, required=True) + parser.add_argument("--query", type=str, required=True) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--sep", type=str, default=",") + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--max_new_tokens", type=int, default=512) + args = parser.parse_args() + + eval_model(args) \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/model/__init__.py b/TinyLLaVA_Factory/tinyllava/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1539774ad4e5e90109d9a6b22d8c75b75aec4f38 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/__init__.py @@ -0,0 +1,7 @@ +from .llm import * +from .connector import * +from .vision_tower import * +from .configuration_tinyllava import * +from .modeling_tinyllava import * +from .convert_legecy_weights_to_tinyllavafactory import * +from .load_model import * diff --git a/TinyLLaVA_Factory/tinyllava/model/configuration_tinyllava.py b/TinyLLaVA_Factory/tinyllava/model/configuration_tinyllava.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea904cd1acb54dbf508b0d6267ce4b97e86a486 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/configuration_tinyllava.py @@ -0,0 +1,133 @@ +from transformers import PretrainedConfig, LlavaConfig +from transformers import CONFIG_MAPPING +from transformers import AutoConfig +from tinyllava.utils.constants import * + +class TinyLlavaConfig(PretrainedConfig): + + model_type = "tinyllava" + def __init__( + self, + llm_model_name_or_path = '', + tokenizer_name_or_path = None, + vision_model_name_or_path = '', + vision_model_name_or_path2 = '', + connector_type = None, + text_config=None, + hidden_size=2048, + vocab_size=32000, + ignore_index=-100, + image_token_index=32000, + pad_token = None, + pad_token_id = None, + tokenizer_padding_side = 'right', + tokenizer_model_max_length = 2048, + vision_config = None, + vision_hidden_size = None, + vision_feature_layer = -2, + vision_feature_select_strategy = 'patch', + image_aspect_ratio = 'square', + resampler_hidden_size = None, + num_queries = None, + num_resampler_layers = None, + use_cache = False, + cache_dir = None, + tokenizer_use_fast = False, + tune_type_llm = 'frozen', + tune_type_connector = 'frozen', + tune_type_vision_tower = 'frozen', + tune_vision_tower_from_layer = -1, + + **kwargs + + ): + self.llm_model_name_or_path = llm_model_name_or_path + self.tokenizer_name_or_path = tokenizer_name_or_path or self.llm_model_name_or_path + self.vision_model_name_or_path = vision_model_name_or_path + self.vision_model_name_or_path2 = vision_model_name_or_path2 + self.connector_type = connector_type + self.tune_type_llm = tune_type_llm + self.tune_type_connector = tune_type_connector + self.tune_type_vision_tower = tune_type_vision_tower + self.tune_vision_tower_from_layer = tune_vision_tower_from_layer + + self.ignore_index = IGNORE_INDEX + self.image_token_index = IMAGE_TOKEN_INDEX + self.pad_token = pad_token + self.pad_token_id = pad_token_id + self.tokenizer_padding_side = tokenizer_padding_side + self.tokenizer_model_max_length = tokenizer_model_max_length + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + self.image_aspect_ratio = image_aspect_ratio + self.resampler_hidden_size = resampler_hidden_size + self.num_queries = num_queries + self.num_resampler_layers = num_resampler_layers + self.use_cache = use_cache + self.cache_dir = cache_dir + self.tokenizer_use_fast = tokenizer_use_fast + self._load_text_config(text_config) + self._load_vision_config(vision_config) + + super().__init__(**kwargs) + + def load_from_config(self, config): + self.llm_model_name_or_path = getattr(config, 'model_name_or_path', '') + self.tokenizer_name_or_path = getattr(config, 'tokenizer_name_or_path', None) or self.llm_model_name_or_path + self.vision_model_name_or_path = getattr(config, 'vision_tower', '') + self.vision_model_name_or_path2 = getattr(config, 'vision_tower2', '') + self.connector_type = getattr(config, 'connector_type', None) + self.vision_feature_layer = getattr(config, 'mm_vision_select_layer', -2) + self.vision_feature_select_strategy = getattr(config, 'mm_vision_select_feature', "patch") + self.image_aspect_ratio = getattr(config, 'image_aspect_ratio', "pad") + self.resampler_hidden_size = getattr(config, 'resampler_hidden_size', None) + self.num_queries = getattr(config, 'num_queries', None) + self.num_resampler_layers = getattr(config, 'num_resampler_layers', None) + + self.cache_dir = getattr(config, 'cache_dir', None) + self.tokenizer_use_fast = getattr(config, 'tokenizer_use_fast', False) + self.tokenizer_model_max_length = getattr(config, 'model_max_length', 2048) + self.tokenizer_padding_side = getattr(config, 'tokenizer_padding_side', 'right') + + self._load_text_config() + self._load_vision_config() + + + def _load_text_config(self, text_config=None): + if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '': + self.text_config = CONFIG_MAPPING['llama']() + + else: + self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True) + if text_config is not None: + self.text_config = self.text_config.from_dict(text_config) + + self.hidden_size = getattr(self.text_config, 'hidden_size', getattr(self.text_config, 'model_dim', None)) + self.vocab_size = getattr(self.text_config, 'vocab_size', None) + + + + def _load_vision_config(self, vision_config=None): + if self.vision_model_name_or_path is None or self.vision_model_name_or_path == '': + self.vision_config = CONFIG_MAPPING['clip_vision_model']( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + else: + self.vision_config = AutoConfig.from_pretrained(self.vision_model_name_or_path.split(':')[-1]) + self.vision_config = getattr(self.vision_config, 'vision_config', self.vision_config) + if vision_config is not None: + self.vision_config = self.vision_config.from_dict(vision_config) + + self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1] + self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1] + self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None) + + diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/__init__.py b/TinyLLaVA_Factory/tinyllava/model/connector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71b569594e3cd352485a451bfd61f2d0b1252eef --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/__init__.py @@ -0,0 +1,28 @@ +import os + +from ...utils import import_modules + + +CONNECTOR_FACTORY = {} + +def ConnectorFactory(connector_name): + model = None + for name in CONNECTOR_FACTORY.keys(): + if name.lower() in connector_name.lower(): + model = CONNECTOR_FACTORY[name] + assert model, f"{connector_name} is not registered" + return model + + +def register_connector(name): + def register_connector_cls(cls): + if name in CONNECTOR_FACTORY: + return CONNECTOR_FACTORY[name] + CONNECTOR_FACTORY[name] = cls + return cls + return register_connector_cls + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +import_modules(models_dir, "tinyllava.model.connector") diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/base.py b/TinyLLaVA_Factory/tinyllava/model/connector/base.py new file mode 100644 index 0000000000000000000000000000000000000000..44d17be95b43113c3ce6bd31e1944b69b22c93fd --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/base.py @@ -0,0 +1,32 @@ +import os + +import torch +import torch.nn as nn + + +class Connector(nn.Module): + def __init__(self, config=None): + super().__init__() + self._connector = None + + def load_model(self, **kwargs): + pretrained_connector_path = kwargs.get('pretrained_connector_path', None) + if pretrained_connector_path is not None: + pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin') + connector_weights = torch.load(pretrained_connector_path, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + self._connector.load_state_dict(get_w(connector_weights, '_connector')) + print(f'Loading connector from {pretrained_connector_path}...') + + for p in self._connector.parameters(): + p.requires_grad = False + + + + + def forward(self, x): + return self._connector(x) + + + diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/identity.py b/TinyLLaVA_Factory/tinyllava/model/connector/identity.py new file mode 100644 index 0000000000000000000000000000000000000000..1ffbff4aedd26c4c9cc9c9643a1237eb19ca69cf --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/identity.py @@ -0,0 +1,15 @@ +import torch.nn as nn + +from . import register_connector +from .base import Connector + + + +@register_connector('identity') +class IdentityConnector(Connector): + def __init__(self, config=None): + super().__init__() + self._connector = nn.Identity() + + + diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/linear.py b/TinyLLaVA_Factory/tinyllava/model/connector/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..2effa1a1aaf183ddbb17aeeecee908d366f7a5b3 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/linear.py @@ -0,0 +1,22 @@ +import torch.nn as nn + +from . import register_connector +from .base import Connector + + + + + +@register_connector('linear') +class LinearConnector(Connector): + def __init__(self, config): + super().__init__() + self._connector = nn.Linear(config.vision_hidden_size, config.hidden_size) + + + # @property + # def config(self): + # return {"connector_type": 'linear', + # "in_hidden_size": self.in_hidden_size, + # "out_hidden_size": self.out_hidden_size + # } diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/mlp.py b/TinyLLaVA_Factory/tinyllava/model/connector/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..15032a0f4d2553019b7fdacd88a921a4f0efb1eb --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/mlp.py @@ -0,0 +1,40 @@ +import re + +import torch.nn as nn + +from . import register_connector +from .base import Connector + + +ACT_TYPE = { + 'relu': nn.ReLU, + 'gelu': nn.GELU +} + + + + +@register_connector('mlp') +class MLPConnector(Connector): + def __init__(self, config): + super().__init__() + + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.connector_type) + act_type = config.connector_type.split('_')[-1] + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.vision_hidden_size, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(ACT_TYPE[act_type]()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + + self._connector = nn.Sequential(*modules) + + + +# @property +# def config(self): +# return {"connector_type": 'mlp', +# "in_hidden_size": self.in_hidden_size, +# "out_hidden_size": self.out_hidden_size +# } + diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/mof_mlp.py b/TinyLLaVA_Factory/tinyllava/model/connector/mof_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..6c099a728952956b32d81a49139df9ec53e10deb --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/mof_mlp.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn + +from . import register_connector +from .base import Connector + + + + +class MoFMLP(nn.Module): + def __init__(self, config): + super().__init__() + + modules_clip = [nn.Linear(config.vision_hidden_size, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size) + ] + + modules_dinov2 = [nn.Linear(config.vision_hidden_size, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size) + ] + + self.clip = nn.Sequential(*modules_clip) + self.dinov2 = nn.Sequential(*modules_dinov2) + + + + def forward(self, x): + + image_features_clip = self.clip(x[0]) + image_features_dinov2 = self.dinov2(x[1]) + + bs = image_features_clip.size(0) + total_len = image_features_clip.size(1)+image_features_dinov2.size(1) + dim = image_features_clip.size(-1) + + merged_features = torch.empty(bs, total_len, dim).to(device=x[0].device, dtype=x[0].dtype) + merged_features[:,0::2] = image_features_clip + merged_features[:,1::2] = image_features_dinov2 + + return merged_features + + + + +@register_connector('mof_mlp') +class MoFMLPConnector(Connector): + def __init__(self, config): + super().__init__() + + self._connector = MoFMLP(config) diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/qformer.py b/TinyLLaVA_Factory/tinyllava/model/connector/qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..91dd542a0bd170fcae9d1a7f6dfef224abef538f --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/qformer.py @@ -0,0 +1,1283 @@ + +import torch +import torch.nn as nn + +from transformers.models.bert.configuration_bert import BertConfig + +from . import register_connector +from .base import Connector + + + +class QFormer(nn.Module): + def __init__(self, config): + super().__init__() + + bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased') + bert_config.encoder_width = config.vision_hidden_size + # insert cross-attention layer every other block + bert_config.add_cross_attention = True + bert_config.cross_attention_freq = 2 + bert_config.query_length = config.num_queries + self.bert = BertModel(config=bert_config, add_pooling_layer=False) + self.bert.embeddings.word_embeddings = None + self.bert.embeddings.position_embeddings = None + self.bert.embeddings.LayerNorm.weight = None # added by ying + self.bert.embeddings.LayerNorm.bias = None # added by ying + for layer in self.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.query_tokens = nn.Parameter( + torch.zeros(1, config.num_queries, bert_config.hidden_size) + ) + self.query_tokens.data.normal_(mean=0.0, std=bert_config.initializer_range) + + self.projector = nn.Linear(bert_config.hidden_size, config.hidden_size) + + + + def forward(self, x): + device = x.device + image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(device) + query_tokens = self.query_tokens.expand(x.shape[0], -1, -1).to(device) + query_output = self.bert( + query_embeds=query_tokens, + encoder_hidden_states=x, + encoder_attention_mask=image_atts, + return_dict=True, + ) + image_embeds = query_output.last_hidden_state + image_embeds = self.projector(image_embeds) + return image_embeds + + + +@register_connector('qformer') +class QFormerConnector(Connector): + def __init__(self, config): + super().__init__() + self._connector = QFormer(config) + + + def load_model(self, **kwargs): + pretrained_connector_path = kwargs.get('pretrained_connector_path', None) + if pretrained_connector_path is not None: + pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin') + connector_weights = torch.load(pretrained_connector_path, map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + self._connector.load_state_dict(get_w(connector_weights, '_connector'), strict=False) + print(f'Loading connector from {pretrained_connector_path}...') + + for p in self._connector.parameters(): + p.requires_grad = False + +# =================================qformer bert related ================================= +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + + +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/TinyLLaVA_Factory/tinyllava/model/connector/resampler.py b/TinyLLaVA_Factory/tinyllava/model/connector/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dcc2ce6ee7893bddc01cd6d95f26275d2ddef2 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/connector/resampler.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from . import register_connector +from .base import Connector +import torch +from einops import rearrange, repeat +from einops_exts import rearrange_many +from torch import einsum + + + +class PerceiverResampler(nn.Module): + def __init__(self, config): + super().__init__() + dim = config.hidden_size + depth=config.num_resampler_layers + num_latents=config.num_queries + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + self.layers = nn.ModuleList([]) + self.linear = nn.Linear(config.vision_hidden_size, config.hidden_size) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=64, heads=8), + FeedForward(dim=dim, mult=4), + ] + ) + ) + + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + b, v = x.shape[:2] + x = self.linear(x) + # blocks + latents = repeat(self.latents, "n d -> b T n d", b=b, T=1) + x = x.unsqueeze(1) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + return self.norm(latents).squeeze(1) + + +@register_connector('resampler') +class ResamplerConnector(Connector): + def __init__(self, config): + super().__init__() + + self._connector = PerceiverResampler(config) + + +# =================================resampler related ================================= +def exists(val): + return val is not None + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.norm_media = nn.LayerNorm(dim) + self.norm_latents = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, T, n1, D) + latent (torch.Tensor): latent features + shape (b, T, n2, D) + """ + x = self.norm_media(x) + latents = self.norm_latents(latents) + + h = self.heads + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) + q = q * self.scale + + # attention + sim = einsum("... i d, ... j d -> ... i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("... i j, ... j d -> ... i d", attn, v) + out = rearrange(out, "b h t n d -> b t n (h d)", h=h) + return self.to_out(out) + diff --git a/TinyLLaVA_Factory/tinyllava/model/convert_legecy_weights_to_tinyllavafactory.py b/TinyLLaVA_Factory/tinyllava/model/convert_legecy_weights_to_tinyllavafactory.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e2070ec771a977019298f7c976821b2be689f --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/convert_legecy_weights_to_tinyllavafactory.py @@ -0,0 +1,103 @@ +import os +import json + +from huggingface_hub import hf_hub_download +import torch + +from safetensors import safe_open +from .modeling_tinyllava import TinyLlavaForConditionalGeneration +from .configuration_tinyllava import TinyLlavaConfig + +KEYS_TO_MODIFY_MAPPING = { + "model.vision_tower.vision_tower": "vision_tower._vision_tower", + "model.mm_projector": "connector._connector", + "model.embed_tokens": "language_model.model.embed_tokens", + "model.layers": "language_model.model.layers", + "model.norm": "language_model.model.norm", + "lm_head": "language_model.lm_head", + "model.final_layernorm": "language_model.model.final_layernorm" +} +KEYS_TO_MODELNAME_MAPPING = { + "TinyLlavaLlamaForCausalLM": 'TinyLlama/TinyLlama-1.1B-chat-v1.0', + "TinyLlavaStablelmForCausalLM": 'stabilityai/stablelm-2-zephyr-1_6b', + "TinyLlavaPhiForCausalLM": 'microsoft/phi-2', + "bczhou/TinyLLaVA-3.1B-SigLIP": 'google/siglip-so400m-patch14-384', + "bczhou/TinyLLaVA-2.0B-SigLIP": 'google/siglip-so400m-patch14-384', + "bczhou/TinyLLaVA-1.5B-SigLIP": 'google/siglip-so400m-patch14-384', +} + +def convert_legecy_config_to_tinyllavaconfig(old_config_path): + if os.path.exists(old_config_path): + config_path = os.path.join(old_config_path, 'config.json') + else: + config_path = hf_hub_download(old_config_path, "config.json") + + with open(config_path, 'r') as f: + old_config = json.load(f) + llm_model_name_or_path = KEYS_TO_MODELNAME_MAPPING[old_config['architectures'][0]] + vision_model_name_or_path = KEYS_TO_MODELNAME_MAPPING[old_config['mm_vision_tower']] + model_config = TinyLlavaConfig( + llm_model_name_or_path = llm_model_name_or_path, + vision_model_name_or_path = vision_model_name_or_path, + connector_type = old_config['mm_projector_type'], + hidden_size = old_config['hidden_size'], + vocab_size = old_config['vocab_size'], + pad_token = old_config['pad_token'], + tokenizer_padding_side = old_config['tokenizer_padding_side'], + tokenizer_model_max_length = old_config['tokenizer_model_max_length'], + vision_feature_layer = old_config['mm_vision_select_layer'], + vision_feature_select_strategy = old_config['mm_vision_select_feature'], + image_aspect_ratio = old_config['image_aspect_ratio'], + use_cache = old_config['use_cache'] + ) + return model_config + + +def convert_state_dict_to_tinyllavafactory(old_state_dict_path): + old_state_dict = [] + if os.path.exists(old_state_dict_path): + meta_file_name = os.path.join(old_state_dict_path, 'model.safetensors.index.json') + if os.path.exists(meta_file_name): + with open(meta_file_name, 'r') as f: + meta_file = json.load(f) + meta_file = list(set(meta_file['weight_map'].values())) + for name in meta_file: + old_state_dict.append(os.path.join(old_state_dict_path, name)) + else: + old_state_dict.append(os.path.join(old_state_dict_path, 'model.safetensors')) + else: + try: + meta_file_name = hf_hub_download(old_state_dict_path, 'model.safetensors.index.json') + with open(meta_file_name, 'r') as f: + meta_file = json.load(f) + meta_file = list(set(meta_file['weight_map'].values())) + for name in meta_file: + old_state_dict.append(hf_hub_download(old_state_dict_path, name)) + except: + old_state_dict.append(hf_hub_download(old_state_dict_path, 'model.safetensors')) + state_dict = {} + for osd in old_state_dict: + with safe_open(osd, framework="pt",device=0) as f: + for k in f.keys(): + state_dict[k]= f.get_tensor(k) + + new_state_dict={} + for key, value in state_dict.items(): + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + new_state_dict[key] = value + return new_state_dict + +def convert_legecy_weights_to_tinyllavafactory(old_state_dict_path, new_state_dict_path=None): + model_config = convert_legecy_config_to_tinyllavaconfig(old_state_dict_path) + model = TinyLlavaForConditionalGeneration(model_config) + # For the checkpoints saved as '*.safetensors. + + state_dict = convert_state_dict_to_tinyllavafactory(old_state_dict_path) + model.load_state_dict(state_dict, False) + if new_state_dict_path is not None: + model.config.save_pretained(new_state_dict_path) + model.tokenizer.save_pretrained(new_state_dict_path) + model.save_pretrained(new_state_dict_path) + return model diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/__init__.py b/TinyLLaVA_Factory/tinyllava/model/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afb80e93909032c240345498dec5fea6408d137b --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/__init__.py @@ -0,0 +1,28 @@ +import os + +from ...utils import import_modules + + +LLM_FACTORY = {} + +def LLMFactory(model_name_or_path): + model, tokenizer_and_post_load = None, None + for name in LLM_FACTORY.keys(): + if name in model_name_or_path.lower(): + model, tokenizer_and_post_load = LLM_FACTORY[name]() + assert model, f"{model_name_or_path} is not registered" + return model, tokenizer_and_post_load + + +def register_llm(name): + def register_llm_cls(cls): + if name in LLM_FACTORY: + return LLM_FACTORY[name] + LLM_FACTORY[name] = cls + return cls + return register_llm_cls + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +import_modules(models_dir, "tinyllava.model.llm") diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/gemma.py b/TinyLLaVA_Factory/tinyllava/model/llm/gemma.py new file mode 100644 index 0000000000000000000000000000000000000000..a4482711da6e89758cad994ef6a0a4fa7564d5ba --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/gemma.py @@ -0,0 +1,11 @@ +from transformers import GemmaForCausalLM, AutoTokenizer + +from . import register_llm + +@register_llm('gemma') +def return_gemmaclass(): + def tokenizer_and_post_load(tokenizer): + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + return (GemmaForCausalLM, (AutoTokenizer, tokenizer_and_post_load)) + diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/openelm.py b/TinyLLaVA_Factory/tinyllava/model/llm/openelm.py new file mode 100644 index 0000000000000000000000000000000000000000..42f3c39ce1282631038414ed29dff9c9cacd3a8a --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/openelm.py @@ -0,0 +1,1297 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +# this import has to be relative, otherwise, when setting trust_remote_code=True +# huggingface transformers won't be able to load the module correctly +from numbers import Number +from typing import List, Optional, Union + +import numpy as np +from transformers import PretrainedConfig, AutoTokenizer + +from . import register_llm + +@register_llm('openelm') +def return_openelmclass(): + def tokenizer_and_post_load(tokenizer): + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + return OpenELMForCausalLM, (AutoTokenizer, tokenizer_and_post_load) + + +def make_divisible( + v: Union[float, int], + divisor: Optional[int] = 8, + min_value: Optional[Union[float, int]] = None, +) -> Union[float, int]: + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by the divisor + It can be seen at: + https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62 + Args: + v: input value + divisor: default to 8 + min_value: minimum divisor value + Returns: + new_v: new divisible value + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def compute_heads(model_dim: int, head_dim: int) -> int: + """Compute the number of heads. + Args: + model_dim: Model dimension. + head_dim: Head dimension. + Returns: + An integer denoting number of heads in multi-head attention is returned. + Raises: + ValueError: if model dimension is not divisible by head dimension. + """ + if model_dim % head_dim == 0: + return model_dim // head_dim + else: + raise ValueError( + f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}." + ) + + +OpenELM_CONFIGS = { + "OpenELM-270M": dict( + num_transformer_layers=16, + model_dim=1280, + head_dim=64, + num_gqa_groups=4, + normalize_qk_projections=True, + share_input_output_layers=True, + # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively. + ffn_multipliers=(0.5, 4.0), + qkv_multipliers=(0.5, 1.0), + ), + "OpenELM-450M": dict( + num_transformer_layers=20, + model_dim=1536, + head_dim=64, + num_gqa_groups=4, + normalize_qk_projections=True, + share_input_output_layers=True, + # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively. + ffn_multipliers=(0.5, 4.0), + qkv_multipliers=(0.5, 1.0), + ), + "OpenELM-1_1B": dict( + num_transformer_layers=28, + model_dim=2048, + head_dim=64, + num_gqa_groups=4, + normalize_qk_projections=True, + share_input_output_layers=True, + # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively. + ffn_multipliers=(0.5, 4.0), + qkv_multipliers=(0.5, 1.0), + ), + "OpenELM-3B": dict( + num_transformer_layers=36, + model_dim=3072, + head_dim=128, + num_gqa_groups=4, + normalize_qk_projections=True, + share_input_output_layers=True, + # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively. + ffn_multipliers=(0.5, 4.0), + qkv_multipliers=(0.5, 1.0), + ), +} + + +class OpenELMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the OpenELM model. + max_context_length (`int`, *optional*, defaults to 2048): + Maximum number of input tokens. + num_transformer_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer decoder. + model_dim (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + head_dim (`int`, *optional*, defaults to 128): + The attention head dimension. + qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0): + If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions, + resulting in uniform allocation of parameters. + If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions + assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer. + This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623 + num_query_heads (`Union[int, None]`, *optional*, defaults to None): + The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`. + num_gqa_groups (`int`, *optional*, defaults to 1): + This variable allows to switch between multi-head attention, group query attention, and multi-query attention. + When num_gqa_groups == 1, then it is multi-head attention. + When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention + When num_gqa_groups == num_heads, then it is multi-query attention + ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0): + Feed-forward network (FFN) multipliers. + If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions, + resulting in uniform allocation of parameters. + If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions + assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer. + This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623 + ffn_with_glu (`bool`, *optional*, defaults to True): + Whether to use FFN with Gated Linear Unit (GLU) + ffn_dim_divisor (`int`, *optional*, defaults to 256): + The ffn layer dimension divisor. + activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the decoder. + normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`): + Type of normalization layer. + normalize_qk_projections (`bool`, *optional*, defaults to False): + Whether to normalize queries and keys after projections + share_input_output_layers (`bool`, *optional*, defaults to False): + Whether to share the embedding between input and output linear layer + rope_freq_constant (`int`, *optional*, defaults to 10000): + The base period of the RoPE embeddings. + rope_max_length (`int`, *optional*, defaults to 4096): + That rope_max_length is set to twice of max_context_length. + This allows flexibility in token lengths during training or fine-tuning. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + """ + + model_type = "openelm" + + def __init__( + self, + vocab_size: int = 32000, + max_context_length: int = 2048, + num_transformer_layers: int = 12, + model_dim: int = 2048, + head_dim: int = 128, + qkv_multipliers: Union[Number, List[Number]] = 1.0, + num_query_heads: Union[int, None] = None, + num_gqa_groups: int = 1, + ffn_multipliers: Union[Number, List[Number]] = 4.0, + ffn_with_glu: bool = True, + ffn_dim_divisor: int = 256, + activation_fn_name: str = "swish", + normalization_layer_name: str = "rms_norm", + normalize_qk_projections: bool = False, + share_input_output_layers: bool = False, + rope_freq_constant: int = 10000, + rope_max_length: int = 4096, + initializer_range: float = 0.02, + use_cache: bool = True, + bos_token_id: int = 1, + eos_token_id: int = 2, + **kwargs, + ) -> None: + self.vocab_size = vocab_size + self.max_context_length = max_context_length + self.num_transformer_layers = num_transformer_layers + self.model_dim = model_dim + self.head_dim = head_dim + self.qkv_multipliers = qkv_multipliers + self.num_query_heads = num_query_heads + self.num_gqa_groups = num_gqa_groups + self.ffn_multipliers = ffn_multipliers + self.ffn_with_glu = ffn_with_glu + self.ffn_dim_divisor = ffn_dim_divisor + self.activation_fn_name = activation_fn_name + self.normalization_layer_name = normalization_layer_name + self.normalize_qk_projections = normalize_qk_projections + self.share_input_output_layers = share_input_output_layers + self.rope_freq_constant = rope_freq_constant + self.rope_max_length = rope_max_length + self.num_query_heads = ( + compute_heads(model_dim=model_dim, head_dim=head_dim) + if num_query_heads is None + else num_query_heads + ) + self.initializer_range = initializer_range + + self.__post_init__() + super().__init__( + use_cache=use_cache, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + def __post_init__(self) -> None: + if self.num_gqa_groups is not None: + head_multiple_of = self.num_gqa_groups + else: + head_multiple_of = 2 + + if isinstance(self.qkv_multipliers, Number): + # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters. + qkv_dim = make_divisible( + self.model_dim * self.qkv_multipliers, + divisor=self.head_dim * head_multiple_of, + ) + query_dims = [int(qkv_dim)] * self.num_transformer_layers + + elif ( + isinstance(self.qkv_multipliers, (tuple, list)) + and len(self.qkv_multipliers) == 2 + ): + # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1]. + # This results in variable allocation of parameters in attention layer. + # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623 + qkv_multipliers = [ + round(v, 2) + for v in np.linspace( + self.qkv_multipliers[0], + self.qkv_multipliers[1], + num=self.num_transformer_layers, + dtype=float, + ) + ] + # Make sure that scaled model dimension is divisible by scaled head dimension. + query_dims = [ + int( + make_divisible( + self.model_dim * m, divisor=self.head_dim * head_multiple_of + ) + ) + for m in qkv_multipliers + ] + else: + raise NotImplementedError( + f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}." + ) + + # compute the number of query, key, and value heads + # For multi-head and multi-query attention, the number of heads for query, key, and value are the same. + # For group query attention, the number of key and value heads are the same. + self.num_query_heads = [ + int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims + ] + self.num_kv_heads = [ + q_heads // self.num_gqa_groups for q_heads in self.num_query_heads + ] + + # Feed-forward network (FFN) multipliers + if isinstance(self.ffn_multipliers, Number): + # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters. + self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers + elif isinstance(self.ffn_multipliers, (tuple, list)): + # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1]. + # This results in variable allocation of parameters in FFN layer. + # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623 + if len(self.ffn_multipliers) == 2: + self.ffn_multipliers = [ + round(v, 2) + for v in np.linspace( + self.ffn_multipliers[0], + self.ffn_multipliers[1], + num=self.num_transformer_layers, + dtype=float, + ) + ] + else: + assert ( + len(self.ffn_multipliers) == self.num_transformer_layers + ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}" + else: + raise NotImplementedError( + f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}." + ) + + # check num_query_heads divisible by num_kv_heads for every layer + for layer_idx in range(len(query_dims)): + assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0 + +class OpenELMRMSNorm(nn.Module): + def __init__(self, num_features: int, eps: float = 1e-6): + """ + Initialize the OpenELMRMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.num_features = num_features + + def _norm(self, x: Tensor) -> Tensor: + """ + Apply the OpenELMRMSNorm normalization to the input tensor. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the OpenELMRMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying OpenELMRMSNorm. + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def extra_repr(self) -> str: + return ( + super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}" + ) + + +class OpenELMPreTrainedModel(PreTrainedModel): + config_class = OpenELMConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["OpenELMDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def __init__(self, *inputs, **kwargs) -> None: + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, OpenELMRMSNorm): + module.weight.data.fill_(1.0) + + +def _rotate_half(x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor: + return (x * pos_cos) + (_rotate_half(x) * pos_sin) + + +class OpenELMRotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings (aka RoPE) from `RoFormer `_. + RoPE encodes the position information of tokens using a rotation matrix, and is able to capture + explicit relative positional dependencies. + Args: + model_dim: The dimensionality of the model's hidden state. + max_seq_length: Maximum sequence length. + freq_constant: A constant used for computing frequencies. + """ + + def __init__( + self, model_dim: int, max_seq_length: int, freq_constant: int = 10000 + ) -> None: + inv_freq = 1.0 / ( + freq_constant + ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim) + ) + super().__init__() + + self.model_dim = model_dim + self.freq_constant = freq_constant + self.max_seq_length = max_seq_length + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._cached_cos = None + self._cached_sin = None + self._cached_seq_length = max_seq_length + self._compute_sin_cos_embeddings(max_seq_length) + + def extra_repr(self) -> str: + return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}" + + def _compute_sin_cos_embeddings( + self, + key_len: int, + key_device: torch.device = torch.device("cpu"), + key_dtype: torch.dtype = torch.float32, + ) -> None: + """ + Compute sine and cos embeddings. + Args: + key_len: Number of tokens in the key embeddings in the transformer model. + device: Device where the key embeddings are stored. + key_dtype: Data type of the key embeddings. + Returns: + None + ...note: + We recalculate the sine and cosine embeddings if any of the following conditions are met: + 1. The number of tokens in key embeddings are greater than the cached sequence length. + 2. Sine and cosine caches are empty. + 3. The device and data type of sine and cosine embeddings does not match with the key embeddings. + """ + if ( + key_len > self._cached_seq_length + or self._cached_cos is None + or (self._cached_cos is not None and self._cached_cos.device != key_device) + or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype) + or self._cached_sin is None + or (self._cached_sin is not None and self._cached_sin.device != key_device) + or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype) + ): + self._cached_seq_length = max(key_len, self._cached_seq_length) + + # The shape of 'pos_index' is [number of key tokens] + pos_index = torch.arange( + self._cached_seq_length, + dtype=torch.float32, + device=self.inv_freq.device, + ) + # The shape of 'pos_index_theta' is [number of key tokens, model dimension] + pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq) + # The shape of 'emb' is [number of key tokens, model dimension] + emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1) + + # the shape of cos and sin embeddings is [number of key tokens, model_dim] + cos_emb = emb.cos().to(dtype=key_dtype, device=key_device) + sin_emb = emb.sin().to(dtype=key_dtype, device=key_device) + + # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim] + self._cached_cos = cos_emb[None, None, :, :] + self._cached_sin = sin_emb[None, None, :, :] + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + The forward function of RoPE embeddings. + Args: + query: Query embeddings in the transformer model. The shape of query embeddings is + [Batch, number of query heads, number of query tokens, model dimension]. + key: Key embeddings in the transformer model. The shape of key embeddings is + [Batch, number of key heads, number of key tokens, model dimension]. + Returns: + A tuple containing the query and key embeddings with positional information. The shape of the returned query + and key embeddings is the same as the input query and key embeddings respectively. + ...note: + The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors + are casted to original input datatype. + """ + dim = key.shape[-1] + key_len = key.shape[2] + query_len = query.shape[2] + + assert dim == self.model_dim + assert key.device == query.device + assert key.dtype == query.dtype + + # In the context of self-attention, the lengths of keys and queries are equal. + # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries + # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys + # represent embeddings of previous tokens and the current token, while the query corresponds + # to the embedding of the current token only. + assert ( + key_len >= query_len + ), "Number of keys has to be greater than or equal to number of queries." + + query_float = query.float() + key_float = key.float() + + self._compute_sin_cos_embeddings( + key_len, key_device=key_float.device, key_dtype=key_float.dtype + ) + query_float = _apply_rotary_pos_emb( + x=query_float, + pos_sin=self._cached_sin[..., key_len - query_len : key_len, :], + pos_cos=self._cached_cos[..., key_len - query_len : key_len, :], + ) + key_float = _apply_rotary_pos_emb( + x=key_float, + pos_sin=self._cached_sin[..., :key_len, :], + pos_cos=self._cached_cos[..., :key_len, :], + ) + + return query_float.type_as(query), key_float.type_as(key) + + +class OpenELMMultiHeadCausalAttention(nn.Module): + def __init__(self, config: OpenELMConfig, layer_idx: int) -> None: + super().__init__() + self.layer_idx = layer_idx + head_dim = config.head_dim + q_heads = config.num_query_heads[layer_idx] + k_heads = config.num_kv_heads[layer_idx] + v_heads = config.num_kv_heads[layer_idx] + + self.qkv_proj = nn.Linear( + in_features=config.model_dim, + out_features=(q_heads + k_heads + v_heads) * head_dim, + bias=False, + ) + + self.pos_embedding = OpenELMRotaryEmbedding( + model_dim=config.head_dim, + max_seq_length=config.rope_max_length, + freq_constant=config.rope_freq_constant, + ) + + if config.normalize_qk_projections: + self.q_norm = OpenELMRMSNorm( + num_features=config.head_dim, + ) + self.k_norm = OpenELMRMSNorm( + num_features=config.head_dim, + ) + else: + self.q_norm = None + self.k_norm = None + + self.out_proj = nn.Linear( + in_features=q_heads * head_dim, + out_features=config.model_dim, + bias=False, + ) + + self.head_dim = config.head_dim + self.num_q_heads = q_heads + self.num_k_heads = k_heads + self.num_v_heads = v_heads + self.transformer_dim = config.model_dim + self.num_groups = self.num_q_heads // self.num_k_heads + + def extra_repr(self) -> str: + return ( + super().extra_repr() + + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}" + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass of multi-head self-attention. + Args: + hidden_states: Input tensor of the shape [batch size, sequence length, model dimension]. + past_key_value: Tensor storing the cached keys and values. + output_attentions: output attention weights. + use_cache: Specifies whether to use kv-cache for generation. + cache_position: used for updating the kv-cache. + Returns: + The output of the same shape as the input, optionally with a tensor containing cached keys and values. + """ + + # scaled_dot_product_attention does not return attention weights, set output_attentions to False + output_attentions = False + batch_size, seq_length, d_model = hidden_states.size() + + # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h] + qkv = self.qkv_proj(hidden_states) + # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h] + qkv = qkv.reshape( + batch_size, + seq_length, + self.num_q_heads + self.num_k_heads + self.num_v_heads, + self.head_dim, + ) + # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h] + qkv = qkv.transpose(1, 2) + # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h] + queries, keys, values = qkv.split( + [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1 + ) + + if self.q_norm is not None: + queries = self.q_norm(queries) + + if self.k_norm is not None: + keys = self.k_norm(keys) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + cache_kwargs = {"cache_position": cache_position} + keys, values = past_key_value.update( + keys, values, self.layer_idx, cache_kwargs + ) + + # Add positional embedding + queries, keys = self.pos_embedding(queries, keys) + + if self.num_groups != 1: + # GQA + # [B, k_h, S, h] --> [B, q_h, S, h] + keys = keys.repeat_interleave(self.num_groups, dim=1) + # [B, v_h, S, h] --> [B, q_h, S, h] + values = values.repeat_interleave(self.num_groups, dim=1) + + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]] + + attn_output = F.scaled_dot_product_attention( + queries, + keys, + values, + attn_mask=causal_mask, + dropout_p=0, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape( + batch_size, seq_length, self.num_q_heads * self.head_dim + ) + attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + +class OpenELMFeedForwardNetwork(nn.Module): + def __init__(self, config: OpenELMConfig, layer_idx: int) -> None: + super().__init__() + ffn_multiplier = config.ffn_multipliers[layer_idx] + intermediate_dim = int( + make_divisible( + ffn_multiplier * config.model_dim, + divisor=config.ffn_dim_divisor, + ) + ) + if config.ffn_with_glu: + # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1. + self.proj_1 = nn.Linear( + in_features=config.model_dim, + out_features=2 * intermediate_dim, + bias=False, + ) + self.proj_2 = nn.Linear( + in_features=intermediate_dim, + out_features=config.model_dim, + bias=False, + ) + self.ffn_with_glu = True + else: + # Standard FFN, as described in https://arxiv.org/abs/1706.03762 + self.proj_1 = nn.Linear( + in_features=config.model_dim, + out_features=intermediate_dim, + bias=False, + ) + self.proj_2 = nn.Linear( + in_features=intermediate_dim, + out_features=config.model_dim, + bias=False, + ) + self.ffn_with_glu = False + + self.act = ACT2FN[config.activation_fn_name] + + def extra_repr(self) -> str: + return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}" + + def forward(self, x: Tensor) -> Tensor: + """Forward function of FFN layer. + Args: + x: Input tensor of the shape [batch size, sequence length, model dimension]. + Returns: + A tensor of the same shape as the input. + """ + if self.ffn_with_glu: + y_12 = self.proj_1(x) + y_1, y_2 = y_12.chunk(2, dim=-1) + y = self.act(y_1) * y_2 + return self.proj_2(y) + else: + return self.proj_2(self.act(self.proj_1(x))) + + +class OpenELMDecoderLayer(nn.Module): + def __init__(self, config: OpenELMConfig, layer_idx: int) -> None: + super().__init__() + self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx) + self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx) + self.ffn_norm = OpenELMRMSNorm( + num_features=config.model_dim, + ) + self.attn_norm = OpenELMRMSNorm( + num_features=config.model_dim, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class OpenELMModel(OpenELMPreTrainedModel): + config_class = OpenELMConfig + + def __init__(self, config: OpenELMConfig): + super().__init__(config) + self.config = config + + self.token_embeddings = nn.Embedding( + embedding_dim=config.model_dim, + num_embeddings=config.vocab_size, + ) + + self.layers = nn.ModuleList( + OpenELMDecoderLayer(config=config, layer_idx=layer_idx) + for layer_idx in range(config.num_transformer_layers) + ) + self.norm = OpenELMRMSNorm(num_features=config.model_dim) + if config.share_input_output_layers: + self.classifier = None + else: + self.classifier = nn.Linear( + in_features=config.model_dim, + out_features=config.vocab_size, + bias=False, + ) + self.num_transformer_layers = config.num_transformer_layers + self.gradient_checkpointing = False + + # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class. + # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`. + causal_mask = torch.full( + (config.max_context_length, config.max_context_length), + fill_value=True, + dtype=torch.bool, + ) + self.register_buffer( + "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False + ) + + # Initialize weights and apply final processing + self.post_init() + self.reset_parameters(config=config) + + def get_input_embeddings(self): + return self.token_embeddings + + def set_input_embeddings(self, new_embeddings: torch.Tensor): + self.token_embeddings = new_embeddings + + def reset_parameters(self, config: OpenELMConfig) -> None: + """Initialize the layers in Language Model + The initialization scheme is followed, following `OPT `_. + Args: + use_megatron_std: Use standard deviation as described in Megatron-LM. + Returns: + None + """ + for module in self.modules(): + if isinstance(module, nn.Linear): + std = module.in_features**-0.5 + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + std = module.embedding_dim**-0.5 + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + elif isinstance(module, OpenELMRMSNorm): + if module.weight is not None: + torch.nn.init.ones_(module.weight) + if hasattr(module, "bias") and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + model_dim = config.model_dim + n_layers = config.num_transformer_layers + std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5) + for param_name, param in self.named_parameters(): + if param_name.endswith("out_proj.weight") or param_name.endswith( + "ffn.proj_2.weight" + ): + torch.nn.init.normal_(param, mean=0.0, std=std) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.token_embeddings(input_ids) + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, Cache) + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask(self, attention_mask, input_tensor): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + batch_size, seq_length = input_tensor.shape[:2] + dtype = input_tensor.dtype + device = input_tensor.device + + # support going beyond cached `max_position_embedding` + if seq_length > self.causal_mask.shape[-1]: + causal_mask = torch.full( + (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), + fill_value=1, + ) + self.register_buffer( + "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False + ) + + # We use the current dtype to avoid any overflows + min_dtype = torch.finfo(dtype).min + causal_mask = ( + self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) + * min_dtype + ) + + causal_mask = causal_mask.to(dtype=dtype, device=device) + if attention_mask is not None and attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[ + :, None, None, : + ].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( + padding_mask, min_dtype + ) + + if self.config._attn_implementation == "sdpa" and attention_mask is not None: + # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). + is_tracing = ( + torch.jit.is_tracing() + or isinstance(input_tensor, torch.fx.Proxy) + or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + ) + if not is_tracing and torch.any(attention_mask != 1): + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = causal_mask.mul( + ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True) + ).to(dtype) + + return causal_mask + + +class OpenELMForCausalLM(OpenELMPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: OpenELMConfig): + super().__init__(config) + self.transformer = OpenELMModel(config) + self.vocab_size = config.vocab_size + if config.share_input_output_layers: + self.lm_head = None + else: + self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.token_embeddings + + def set_input_embeddings(self, value): + self.transformer.token_embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.transformer = decoder + + def get_decoder(self): + return self.transformer + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.lm_head is None: + # shared + logits = F.linear( + hidden_states, weight=self.transformer.token_embeddings.weight + ) + else: + logits = self.lm_head(hidden_states) + logits = logits[:, : self.config.vocab_size] + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + if self.generation_config.cache_implementation == "static": + # generation with static cache + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + past_length = 0 + else: + past_length = cache_position[-1] + 1 + input_ids = input_ids[:, past_length:] + position_ids = position_ids[:, past_length:] + + # we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + cache_position = torch.arange( + past_length, + past_length + position_ids.shape[-1], + device=position_ids.device, + ) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # We could use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "position_ids": position_ids.contiguous(), + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/phi.py b/TinyLLaVA_Factory/tinyllava/model/llm/phi.py new file mode 100644 index 0000000000000000000000000000000000000000..98ec6d368d6748bc35c5aae8c0b9f447dc2ed3cf --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/phi.py @@ -0,0 +1,10 @@ +from transformers import PhiForCausalLM, AutoTokenizer + +from . import register_llm + +@register_llm('phi') +def return_phiclass(): + def tokenizer_and_post_load(tokenizer): + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + return (PhiForCausalLM, (AutoTokenizer, tokenizer_and_post_load)) diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/qwen2.py b/TinyLLaVA_Factory/tinyllava/model/llm/qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..919bbdc47cf156ed1fab9da0f8e9650b35574b12 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/qwen2.py @@ -0,0 +1,11 @@ +from transformers import Qwen2ForCausalLM, AutoTokenizer + +from . import register_llm + +@register_llm('qwen2') +def return_qwen2class(): + def tokenizer_and_post_load(tokenizer): + tokenizer.unk_token = tokenizer.pad_token +# tokenizer.pad_token = tokenizer.unk_token + return tokenizer + return Qwen2ForCausalLM, (AutoTokenizer, tokenizer_and_post_load) diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/stablelm.py b/TinyLLaVA_Factory/tinyllava/model/llm/stablelm.py new file mode 100644 index 0000000000000000000000000000000000000000..11c55a3990ba0a9564a6b516fffeec6161c16fae --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/stablelm.py @@ -0,0 +1,9 @@ +from transformers import StableLmForCausalLM, AutoTokenizer + +from . import register_llm + +@register_llm('stablelm') +def return_phiclass(): + def tokenizer_and_post_load(tokenizer): + return tokenizer + return (StableLmForCausalLM, (AutoTokenizer, tokenizer_and_post_load)) diff --git a/TinyLLaVA_Factory/tinyllava/model/llm/tinyllama.py b/TinyLLaVA_Factory/tinyllava/model/llm/tinyllama.py new file mode 100644 index 0000000000000000000000000000000000000000..29b58e14aacd4ebc243437d5ea58a673d4308576 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/llm/tinyllama.py @@ -0,0 +1,10 @@ +from transformers import LlamaForCausalLM, AutoTokenizer + +from . import register_llm + +@register_llm('tinyllama') +def return_tinyllamaclass(): + def tokenizer_and_post_load(tokenizer): + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + return LlamaForCausalLM, (AutoTokenizer, tokenizer_and_post_load) diff --git a/TinyLLaVA_Factory/tinyllava/model/load_model.py b/TinyLLaVA_Factory/tinyllava/model/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c8a18a85a6288f33dddccecadbf7bf55568418 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/load_model.py @@ -0,0 +1,65 @@ +import os +import torch +from collections import OrderedDict +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig + +from .modeling_tinyllava import TinyLlavaForConditionalGeneration +from .configuration_tinyllava import TinyLlavaConfig + +def load_base_ckp_for_lora(ckp_path): + ckp = torch.load(ckp_path, map_location=torch.device('cpu')) + new_ckp = OrderedDict() + for k, v in ckp.items(): + new_k = k.replace('.base_layer', '') + new_ckp[new_k] = v + return new_ckp + + +def load_pretrained_model(model_name_or_path, load_type='hf', load_8bit=False, load_4bit=False, device_map="auto", + device="cuda", **kwargs): + kwargs = {"device_map": device_map, **kwargs} + if device != "cuda": + kwargs['device_map'] = {"": device} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + if model_name_or_path is not None and 'lora' not in model_name_or_path: + model = TinyLlavaForConditionalGeneration.from_pretrained(model_name_or_path,low_cpu_mem_usage=True) + + elif model_name_or_path is not None and 'lora' in model_name_or_path: + if os.path.exists(os.path.join(model_name_or_path, 'adapter_config.json')): + model_config = TinyLlavaConfig.from_pretrained(model_name_or_path) + model = TinyLlavaForConditionalGeneration(model_config) + language_model_ckp_path = os.path.join(model_name_or_path, 'language_model/pytorch_model.bin') + language_model_ckp = load_base_ckp_for_lora(language_model_ckp_path) + model.language_model.load_state_dict(language_model_ckp) + vision_tower_ckp_path = os.path.join(model_name_or_path, 'vision_tower/pytorch_model.bin') + vision_tower_ckp = load_base_ckp_for_lora(vision_tower_ckp_path) + model.vision_tower._vision_tower.load_state_dict(vision_tower_ckp) + connector_ckp_path = os.path.join(model_name_or_path, 'connector/pytorch_model.bin') + connector_ckp = load_base_ckp_for_lora(connector_ckp_path) + model.connector.load_state_dict(connector_ckp, strict=False) + model.to(torch.float16) + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_name_or_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + + image_processor = model.vision_tower._image_processor + context_len = getattr(model.config, 'max_sequence_length', 2048) + # tokenizer = AutoTokenizer.from_pretrained(model.config.llm_model_name_or_path, use_fast=False, padding_side="right") + tokenizer = model.tokenizer + #tokenizer.pad_token = tokenizer.eos_token + return model, tokenizer, image_processor, context_len diff --git a/TinyLLaVA_Factory/tinyllava/model/modeling_tinyllava.py b/TinyLLaVA_Factory/tinyllava/model/modeling_tinyllava.py new file mode 100644 index 0000000000000000000000000000000000000000..b674a0be88a7e81db9cfba261dae35a239392ce0 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/modeling_tinyllava.py @@ -0,0 +1,384 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import ast + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.generation.utils import GenerateOutput + +from . import LLMFactory, ConnectorFactory, VisionTowerFactory +from .configuration_tinyllava import TinyLlavaConfig +from ..utils.constants import * +# from tinyllava.utils.data_utils import get_value_from_kwargs + +def get_value_from_kwargs(kwargs, name): + if name in kwargs: + return kwargs.pop(name) + else: + return None + + + +class TinyLlavaPreTrainedModel(PreTrainedModel): + config_class = TinyLlavaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlavaVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self): + return self.language_model._supports_sdpa + + +class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel): + def __init__(self, config: TinyLlavaConfig): + + super().__init__(config) + + self.language_model = LLMFactory(config.llm_model_name_or_path)[0](config.text_config) + self.vision_tower = VisionTowerFactory(config.vision_model_name_or_path)(config.vision_config) + self.connector = ConnectorFactory(config.connector_type)(config) + + (Tokenizer, post_load) = LLMFactory(config.llm_model_name_or_path)[1] + self.tokenizer = post_load(Tokenizer.from_pretrained( + config.tokenizer_name_or_path, + cache_dir = config.cache_dir, + model_max_length = config.tokenizer_model_max_length, + padding_side = config.tokenizer_padding_side, + use_fast = config.tokenizer_use_fast, + )) + self.post_init() + + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + image_sizes: Optional[List[List[int]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images, + image_sizes + ) + return self.language_model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_sizes: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("`inputs_embeds` is not supported") + + if images is not None: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _ + ) = self.prepare_inputs_labels_for_multimodal( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes + ) + else: + inputs_embeds = self.language_model.get_input_embeddings()(inputs) + + return self.language_model.generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs + ) + + def encode_images(self, images): + kwargs = {} + kwargs['vision_feature_layer'] = self.config.vision_feature_layer + kwargs['vision_feature_select_strategy'] = self.config.vision_feature_select_strategy + images = images.to(device=self.device, dtype=self.dtype) + image_features = self.vision_tower(images, **kwargs) + image_features = self.connector(image_features) + return image_features + + + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, + inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + inputs = self.language_model.prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + inputs['images'] = images + if image_sizes is not None: + inputs['image_sizes'] = image_sizes + return inputs + + def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, + images, image_sizes=None + ): + vision_tower = self.vision_tower + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, 'tune_mm_mlp_adapter', False): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + _input_ids = input_ids + input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": + new_input_embeds_padded.append(torch.cat(( + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), + cur_new_embed + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + else: + new_input_embeds_padded.append(torch.cat(( + cur_new_embed, + torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) + ), dim=0)) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + + + + def load_llm(self, **kwargs): + language_model_name = get_value_from_kwargs(kwargs, 'model_name_or_path') + pretrained_llm_path = get_value_from_kwargs(kwargs, 'pretrained_llm_path') + if pretrained_llm_path is not None: + language_model_name = pretrained_llm_path + if language_model_name is not None: + self.language_model = self.language_model.from_pretrained( + language_model_name, **kwargs + ) + print('loading language model from ', language_model_name) + self.language_model.requires_grad_(False) + + self.config.text_config.torch_dtype = kwargs.get('torch_dtype', None) + self.config.pad_token = getattr(self.tokenizer, 'pad_token', None) + self.config.pad_token_id = getattr(self.tokenizer, 'pad_token_id', None) + #self.config.tokenizer_padding_side = getattr(self.tokenizer, 'padding_side', None) + #self.config.tokenizer_model_max_length = getattr(self.tokenizer, 'model_max_length', None) + + + def load_vision_tower(self, **kwargs): + vision_tower_name = get_value_from_kwargs(kwargs, 'model_name_or_path') + self.vision_tower.load_model(vision_tower_name, **kwargs) + + + def load_connector(self, **kwargs): + self.connector.load_model(**kwargs) + + + + + diff --git a/TinyLLaVA_Factory/tinyllava/model/vision_tower/__init__.py b/TinyLLaVA_Factory/tinyllava/model/vision_tower/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1e0e25b9ec9f97ae880b3e813ac5c8cdfd19ae --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/vision_tower/__init__.py @@ -0,0 +1,29 @@ +import os + +from ...utils import import_modules + + +VISION_TOWER_FACTORY = {} + +def VisionTowerFactory(vision_tower_name): + vision_tower_name = vision_tower_name.split(':')[0] + model = None + for name in VISION_TOWER_FACTORY.keys(): + if name.lower() in vision_tower_name.lower(): + model = VISION_TOWER_FACTORY[name] + assert model, f"{vision_tower_name} is not registered" + return model + + +def register_vision_tower(name): + def register_vision_tower_cls(cls): + if name in VISION_TOWER_FACTORY: + return VISION_TOWER_FACTORY[name] + VISION_TOWER_FACTORY[name] = cls + return cls + return register_vision_tower_cls + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +import_modules(models_dir, "tinyllava.model.vision_tower") diff --git a/TinyLLaVA_Factory/tinyllava/model/vision_tower/base.py b/TinyLLaVA_Factory/tinyllava/model/vision_tower/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c6414c9bf8e43492bbe26a0153c4dca0ecbee072 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/vision_tower/base.py @@ -0,0 +1,70 @@ +import os + +import torch +import torch.nn as nn + +from transformers import PreTrainedModel +# from tinyllava.utils.data_utils import get_value_from_kwargs + +def get_value_from_kwargs(kwargs, name): + if name in kwargs: + return kwargs.pop(name) + else: + return None + +class VisionTower(nn.Module): + def __init__(self, cfg): + super().__init__() + self._vision_tower = None + self._image_processor = None + self.config = cfg + + + def load_model(self, vision_tower_name, **kwargs): + self._load_model(vision_tower_name, **kwargs) + self._vision_tower.requires_grad_(False) + + + + + def _load_model(self, vision_tower_name, **kwargs): + pretrained_vision_tower_path = get_value_from_kwargs(kwargs, 'pretrained_vision_tower_path') + if isinstance(self._vision_tower, PreTrainedModel): # hf model + if pretrained_vision_tower_path is not None: + vision_tower_name = pretrained_vision_tower_path + self._vision_tower = self._vision_tower.from_pretrained(vision_tower_name, **kwargs) + else: # nn.Module + if pretrained_vision_tower_path is not None: + vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + self._vision_tower.load_state_dict(vision_tower_weights) + + print("Loading vision tower from ", vision_tower_name) + + + + def forward(self, x, **kwargs): + image_features = self._vision_tower(x, output_hidden_states=True) + image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)] + + if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch': + image_features = image_features[:, 1:] + elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch': + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}") + + return image_features + + + + @property + def vision_tower(self): + return self._vision_tower + + @vision_tower.setter + def vision_tower(self, vision_tower): + self._vision_tower = vision_tower + + diff --git a/TinyLLaVA_Factory/tinyllava/model/vision_tower/clip.py b/TinyLLaVA_Factory/tinyllava/model/vision_tower/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f5098975d8155d028e9113092602f99810a81425 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/vision_tower/clip.py @@ -0,0 +1,14 @@ +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig + +from . import register_vision_tower +from .base import VisionTower + + +@register_vision_tower('clip') +class CLIPVisionTower(VisionTower): + def __init__(self, cfg): + super().__init__(cfg) + self._vision_tower = CLIPVisionModel(cfg) + self._image_processor = CLIPImageProcessor.from_pretrained(cfg.model_name_or_path) + + diff --git a/TinyLLaVA_Factory/tinyllava/model/vision_tower/dinov2.py b/TinyLLaVA_Factory/tinyllava/model/vision_tower/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..28ae5b2a3c519428eb5f331adf459dbeaa6a7f3e --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/vision_tower/dinov2.py @@ -0,0 +1,13 @@ +from transformers import Dinov2Model, AutoImageProcessor + +from . import register_vision_tower +from .base import VisionTower + + +@register_vision_tower('dinov2') +class DINOv2VisionTower(VisionTower): + def __init__(self, cfg): + super().__init__(cfg) + self._vision_tower = Dinov2Model(cfg) + self._image_processor = AutoImageProcessor.from_pretrained(cfg.model_name_or_path) + diff --git a/TinyLLaVA_Factory/tinyllava/model/vision_tower/mof.py b/TinyLLaVA_Factory/tinyllava/model/vision_tower/mof.py new file mode 100644 index 0000000000000000000000000000000000000000..05e1d35661bffeb60d4c9304b40b7546178f0096 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/vision_tower/mof.py @@ -0,0 +1,96 @@ +import os +import torch +import torch.nn as nn +from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig, Dinov2Model, AutoConfig + +from . import register_vision_tower +from .base import VisionTower + + + + + +class MoF(nn.Module): + def __init__(self, cfg): + super().__init__() + self.clip = CLIPVisionModel(cfg) + + cfg_dinov2 = AutoConfig.from_pretrained(cfg.model_name_or_path2) + self.dinov2 = Dinov2Model(cfg_dinov2) + + +# def enable_input_require_grads(self): +# def make_inputs_require_grad(module, input, output): +# output.requires_grads() + +# if hasattr(self.clip, 'enable_input_require_grads'): +# self.clip.enable_input_require_grads() +# else: +# self.clip.get_input_embeddings(make_inputs_require_grad) + +# if hasattr(self.dinov2, 'enable_input_require_grads'): +# self.dinov2.enable_input_require_grads() +# else: +# self.dinov2.get_input_embeddings(make_inputs_require_grad) + + + def forward(self, x, **kwargs): + + image_features_clip = self.clip(x, output_hidden_states=True) + image_features_clip = image_features_clip.hidden_states[kwargs.get('vision_feature_layer', -2)] + + image_features_dinov2 = self.dinov2(x, output_hidden_states=True) + image_features_dinov2 = image_features_dinov2.hidden_states[kwargs.get('vision_feature_layer', -2)] + + if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch': + image_features_clip = image_features_clip[:, 1:] + image_features_dinov2 = image_features_dinov2[:, 1:] + elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch': + image_features_clip = image_features_clip + image_features_dinov2 = image_features_dinov2 + else: + raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}") + + + image_features = image_features_clip, image_features_dinov2 + + return image_features + + + + + + +@register_vision_tower('mof') +class MoFVisionTower(VisionTower): + def __init__(self, cfg): + super().__init__(cfg) + + self._vision_tower = MoF(cfg) + + self._image_processor = CLIPImageProcessor.from_pretrained(cfg.model_name_or_path) + + + def _load_model(self, vision_tower_name, **kwargs): + pretrained_vision_tower_path = kwargs.pop('pretrained_vision_tower_path', None) + if pretrained_vision_tower_path is None: + model_name_or_path_dinov2 = kwargs.pop('model_name_or_path2') + self._vision_tower.clip = self._vision_tower.clip.from_pretrained(vision_tower_name, **kwargs) + self._vision_tower.dinov2 = self._vision_tower.dinov2.from_pretrained(model_name_or_path_dinov2, **kwargs) + print("Loading vision tower1 from ", vision_tower_name) + print("Loading vision tower2 from ", model_name_or_path_dinov2) + else: # nn.Module + if pretrained_vision_tower_path is not None: + vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu') + def get_w(weights, keyword): + return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} + self._vision_tower.load_state_dict(vision_tower_weights) + print("Loading vision tower from ", pretrained_vision_tower_path) + + + def forward(self, x, **kwargs): + device = x.data.device + self.to(device) + return self._vision_tower(x, **kwargs) + + diff --git a/TinyLLaVA_Factory/tinyllava/model/vision_tower/siglip.py b/TinyLLaVA_Factory/tinyllava/model/vision_tower/siglip.py new file mode 100644 index 0000000000000000000000000000000000000000..b640d27d0cc9453776f537b2242a297c656482e0 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/model/vision_tower/siglip.py @@ -0,0 +1,20 @@ +from transformers import SiglipVisionModel, SiglipVisionConfig, SiglipImageProcessor + +from . import register_vision_tower +from .base import VisionTower + + +@register_vision_tower('siglip') +class SIGLIPVisionTower(VisionTower): + def __init__(self, cfg): + super().__init__(cfg) + self._vision_tower = SiglipVisionModel(cfg) + self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path) + + +# def forward(self, x, **kwargs): +# image_features = self._vision_tower(x, output_hidden_states=True) +# image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)] + + +# return image_features diff --git a/TinyLLaVA_Factory/tinyllava/serve/__init__.py b/TinyLLaVA_Factory/tinyllava/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TinyLLaVA_Factory/tinyllava/serve/app.py b/TinyLLaVA_Factory/tinyllava/serve/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d1de68ce8b97478191158e3efddb09acb974e7b3 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/serve/app.py @@ -0,0 +1,358 @@ +''' +@Description: +@Author: jiajunlong +@Date: 2024-06-19 19:30:17 +@LastEditTime: 2024-06-19 19:32:47 +@LastEditors: jiajunlong +''' +import argparse +import hashlib +import json +from pathlib import Path +import time +from threading import Thread +import logging + +import gradio as gr +import torch +from transformers import TextIteratorStreamer + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + +DEFAULT_MODEL_PATH = "tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B" + + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} +""" +title_markdown = """ +# TinyLLaVA: A Framework of Small-scale Large Multimodal Models +[[Code](https://github.com/DLCV-BUAA/TinyLLaVABench)] | 📚 [[Paper](https://arxiv.org/pdf/2402.14289.pdf)] +""" +tos_markdown = """ +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""" +learn_more_markdown = """ +### License +The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""" +ack_markdown = """ +### Acknowledgement +The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community! +""" + + +def regenerate(state, image_process_mode): + state.messages[-1]['value'] = None + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + + +def clear_history(): + state = Message() + return (state, state.to_gradio_chatbot(), "", None) + + +def add_text(state, text, image, image_process_mode): + if len(text) <= 0 and image is None: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + + text = text[:1536] # Hard cut-off + if image is not None: + text = text[:1200] # Hard cut-off for images + if "" not in text: + # text = '' + text + text = text + "\n" + if len(state.images) > 0: + state = Message() + state.add_image(image, len(state.messages)) + state.add_message(text, None) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + + +def load_demo(): + state = Message() + return state + + +@torch.inference_mode() +def get_response(params): + input_ids = params["input_ids"] + prompt = params["prompt"] + images = params.get("images", None) + num_image_tokens = 0 + if images is not None and len(images) > 0: + if len(images) > 0: + # image = [load_image_from_base64(img) for img in images][0] + image = images[0][0] + image = image_processor(image) + image = image.unsqueeze(0).to(model.device, dtype=torch.float16) + num_image_tokens = getattr(model.vision_tower._vision_tower, "num_patches", 336) + else: + image = None + image_args = {"images": image} + else: + image = None + image_args = {} + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, "max_position_embeddings", 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + do_sample = True if temperature > 0.001 else False + logger.info(prompt) + input_ids = input_ids.unsqueeze(0).to(model.device) + # keywords = [stop_str] + + # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextIteratorStreamer( + tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15 + ) + + max_new_tokens = min( + max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens + ) + + if max_new_tokens < 1: + yield json.dumps( + { + "text": prompt + + "Exceeds max token length. Please start a new conversation, thanks.", + "error_code": 0, + } + ).encode() + b"\0" + return + + generate_kwargs = dict( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + streamer=streamer, + use_cache=True, + pad_token_id = tokenizer.eos_token_id, + **image_args, + ) + thread = Thread(target=model.generate, kwargs=generate_kwargs) + thread.start() + logger.debug(prompt) + logger.debug(generate_kwargs) + generated_text = prompt + for new_text in streamer: + generated_text += new_text + # print(f"new_text:{new_text}") + if generated_text.endswith(stop_str): + generated_text = generated_text[: -len(stop_str)] + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + + +def http_bot(state, temperature, top_p, max_new_tokens): + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + return + + + images = state.images + result = text_processor(state.messages, mode='eval') + prompt = result['prompt'] + input_ids = result['input_ids'] + pload = { + "model": model_name, + "prompt": prompt, + "input_ids": input_ids, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1536), + "stop": ( + text_processor.template.separator.apply()[1] + ), "images": images} + + state.messages[-1]['value'] = "▌" + yield (state, state.to_gradio_chatbot()) + + # for stream + output = get_response(pload) + for chunk in output: + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt) :].strip() + state.messages[-1]['value'] = output + "▌" + yield (state, state.to_gradio_chatbot()) + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1]['value'] = output + yield (state, state.to_gradio_chatbot()) + return + time.sleep(0.03) + + state.messages[-1]['value'] = state.messages[-1]['value'][:-1] + yield (state, state.to_gradio_chatbot()) + + +def build_demo(): + textbox = gr.Textbox( + show_label=False, placeholder="Enter text and press ENTER", container=False + ) + with gr.Blocks(title="TinyLLaVA", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + gr.Markdown(title_markdown) + + with gr.Row(): + with gr.Column(scale=5): + with gr.Row(elem_id="Model ID"): + gr.Dropdown( + choices=[DEFAULT_MODEL_PATH.split('/')[-1]], + value=DEFAULT_MODEL_PATH.split('/')[-1], + interactive=True, + label="Model ID", + container=False, + ) + imagebox = gr.Image(type="pil") + image_process_mode = gr.Radio( + ["Crop", "Resize", "Pad", "Default"], + value="Default", + label="Preprocess for non-square image", + visible=False, + ) + + # cur_dir = os.path.dirname(os.path.abspath(__file__)) + cur_dir = Path(__file__).parent + gr.Examples( + examples=[ + [ + f"{cur_dir}/examples/extreme_ironing.jpg", + "What is unusual about this image?", + ], + [ + f"{cur_dir}/examples/waterview.jpg", + "What are the things I should be cautious about when I visit here?", + ], + ], + inputs=[imagebox, textbox], + ) + + with gr.Accordion("Parameters", open=False) as _: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + with gr.Column(scale=8): + chatbot = gr.Chatbot(elem_id="chatbot", label="Chatbot", height=550) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button(value="Send", variant="primary") + with gr.Row(elem_id="buttons") as _: + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) + clear_btn = gr.Button(value="🗑️ Clear", interactive=True) + + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + gr.Markdown(ack_markdown) + + regenerate_btn.click( + regenerate, + [state, image_process_mode], + [state, chatbot, textbox, imagebox], + queue=False, + ).then( + http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot] + ) + + clear_btn.click( + clear_history, None, [state, chatbot, textbox, imagebox], queue=False + ) + + textbox.submit( + add_text, + [state, textbox, imagebox, image_process_mode], + [state, chatbot, textbox, imagebox], + queue=False, + ).then( + http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot] + ) + + submit_btn.click( + add_text, + [state, textbox, imagebox, image_process_mode], + [state, chatbot, textbox, imagebox], + queue=False, + ).then( + http_bot, [state, temperature, top_p, max_output_tokens], [state, chatbot] + ) + + demo.load(load_demo, None, [state], queue=False) + return demo + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=None) + parser.add_argument("--share", default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--conv-mode", type=str, default="phi") + parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH) + parser.add_argument("--model-name", type=str, default=DEFAULT_MODEL_PATH.split('/')[-1]) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logger = logging.getLogger(__name__) + logger.info(gr.__version__) + args = parse_args() + model_name = args.model_name + model, tokenizer, image_processor, context_len = load_pretrained_model( + args.model_path, + load_4bit=args.load_4bit, + load_8bit=args.load_8bit + ) + model.to(args.device) + image_processor = ImagePreprocess(image_processor, model.config) + text_processor = TextPreprocess(tokenizer, args.conv_mode) + demo = build_demo() + demo.queue() + demo.launch(server_name=args.host, server_port=args.port, share=args.share) diff --git a/TinyLLaVA_Factory/tinyllava/serve/cli.py b/TinyLLaVA_Factory/tinyllava/serve/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..fed66e376d85cd3342b00391246b8c9b58df81e7 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/serve/cli.py @@ -0,0 +1,120 @@ +''' +@Description: +@Author: jiajunlong +@Date: 2024-06-19 19:30:17 +@LastEditTime: 2024-06-19 19:32:47 +@LastEditors: jiajunlong +''' +import argparse +import requests +from PIL import Image +from io import BytesIO + +import torch +from transformers import TextStreamer + +from tinyllava.utils import * +from tinyllava.data import * +from tinyllava.model import * + + +def load_image(image_file): + if image_file.startswith('http://') or image_file.startswith('https://'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + if args.model_path is not None: + model, tokenizer, image_processor, context_len = load_pretrained_model(model_name_or_path=args.model_path, load_8bit=args.load_8bit, load_4bit=args.load_4bit, device=args.device) + else: + assert args.model is not None, 'model_path or model must be provided' + model = args.model + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + tokenizer = model.tokenizer + image_processor = model.vision_tower._image_processor + + text_processor = TextPreprocess(tokenizer, args.conv_mode) + data_args = model.config + image_processor = ImagePreprocess(image_processor, data_args) + model.to(args.device) + if getattr(text_processor.template, 'role', None) is None: + roles = ['USER', 'ASSISTANT'] + else: + roles = text_processor.template.role.apply() + msg = Message() + image = load_image(args.image_file) + # Similar operation in model_worker.py + image_tensor = image_processor(image) + image_tensor = image_tensor.unsqueeze(0).to(model.device, dtype=torch.float16) + + while True: + try: + inp = input(f"{roles[0]}: ") + except EOFError: + inp = "" + if not inp: + print("exit...") + break + + print(f"{roles[1]}: ", end="") + + if image is not None: + # first message + inp = DEFAULT_IMAGE_TOKEN + '\n' + inp + msg.add_message(inp) + image = None + else: + # later messages + msg.add_message(inp) + result = text_processor(msg.messages, mode='eval') + prompt = result['prompt'] + input_ids = result['input_ids'].unsqueeze(0).to(model.device) + + # stop_str = text_processor.template.separator.apply()[1] + # keywords = [stop_str] + # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensor, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + max_new_tokens=args.max_new_tokens, + streamer=streamer, + use_cache=True, + pad_token_id = tokenizer.eos_token_id, + # stopping_criteria=[stopping_criteria] + ) + + outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() + msg.messages[-1]['value'] = outputs + + if args.debug: + print("\n", {"prompt": prompt, "outputs": outputs}, "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B") + parser.add_argument("--model", type=str, default=None) + parser.add_argument("--image-file", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--conv-mode", type=str, default='phi') + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + main(args) diff --git a/TinyLLaVA_Factory/tinyllava/serve/examples/extreme_ironing.jpg b/TinyLLaVA_Factory/tinyllava/serve/examples/extreme_ironing.jpg new file mode 100644 index 0000000000000000000000000000000000000000..638b078837f175039b2db49a63821288d9681daa Binary files /dev/null and b/TinyLLaVA_Factory/tinyllava/serve/examples/extreme_ironing.jpg differ diff --git a/TinyLLaVA_Factory/tinyllava/serve/examples/waterview.jpg b/TinyLLaVA_Factory/tinyllava/serve/examples/waterview.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f44ebaba1aa493b8bab3baa4e827b76752b1869 Binary files /dev/null and b/TinyLLaVA_Factory/tinyllava/serve/examples/waterview.jpg differ diff --git a/TinyLLaVA_Factory/tinyllava/train/__init__.py b/TinyLLaVA_Factory/tinyllava/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16b20bdc57500a00978e4e8bb104e2bbfea65084 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/train/__init__.py @@ -0,0 +1,2 @@ +from .train import * +from .tinyllava_trainer import * \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/train/custom_finetune.py b/TinyLLaVA_Factory/tinyllava/train/custom_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4700f7cd7c6c5367663fedf7d7e60ef4f4a588 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/train/custom_finetune.py @@ -0,0 +1,52 @@ +import tokenizers +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoImageProcessor + +from tinyllava.train.tinyllava_trainer import LLaVATrainer +from tinyllava.training_recipe import TrainingRecipeFactory +from tinyllava.utils import * +from tinyllava.model import * +from tinyllava.data.dataset import make_supervised_data_module + +def load_settings(model_arguments, data_arguments, training_arguments): + model_arguments.tune_type_connector = training_arguments.tune_type_connector + model_arguments.tune_type_llm = training_arguments.tune_type_llm + model_arguments.tune_type_vision_tower = training_arguments.tune_type_vision_tower + model_arguments.image_aspect_ratio = data_arguments.image_aspect_ratio + + + + + +def train(): + + # load argument + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_arguments, data_arguments, training_arguments = parser.parse_args_into_dataclasses() + logger_setting(getattr(training_arguments, 'output_dir', None)) + training_recipe = TrainingRecipeFactory(training_arguments.training_recipe)(training_arguments) + load_settings(model_arguments, data_arguments, training_arguments) + # load pretrained checkpoint + model = AutoModelForCausalLM.from_pretrained(training_arguments.pretrained_model_path, trust_remote_code=True) + config = model.config + tokenizer = AutoTokenizer.from_pretrained(training_arguments.pretrained_model_path, use_fast=False, model_max_length = config.tokenizer_model_max_length,padding_side = config.tokenizer_padding_side) + model.tokenizer = tokenizer + model = training_recipe(model) + model.config.use_cache = False + model.config.image_aspect_ratio = data_arguments.image_aspect_ratio + data_arguments.image_processor = AutoImageProcessor.from_pretrained(config.vision_model_name_or_path) + data_arguments.is_multimodal = True + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_arguments) + log_trainable_params(model) # not work well with zero3 + trainer = LLaVATrainer(model=model, #does not require model.to(device), huggingface/deepspeed does it for you? + tokenizer=tokenizer, + args=training_arguments, + **data_module) + + trainer.train() + + training_recipe.save(model, trainer) + +if __name__ == "__main__": + train() diff --git a/TinyLLaVA_Factory/tinyllava/train/tinyllava_trainer.py b/TinyLLaVA_Factory/tinyllava/train/tinyllava_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca26c59c01b069fd02330cbf2a3452b753249c9a --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/train/tinyllava_trainer.py @@ -0,0 +1,231 @@ +import os +import torch +from torch import nn + +from torch.utils.data import Sampler + +from transformers import Trainer +from transformers.trainer import ( + is_sagemaker_mp_enabled, + get_parameter_names, + has_length, + ALL_LAYERNORM_LAYERS, + # ShardedDDPOption, + logger, +) +from typing import List, Optional + +from ..utils.train_utils import * + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + assert len(mm_indices) > 0, "Should have at least one multimodal sample." + assert len(lang_indices) > 0, "Should have at least one language sample." + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) >= megabatch_size: + megabatches = [additional_batch[:megabatch_size]] + megabatches + additional_batch = additional_batch[megabatch_size:] + + if len(additional_batch) > 0: + megabatches.append(additional_batch) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + group_by_modality: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.group_by_modality = group_by_modality + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.group_by_modality: + indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +class LLaVATrainer(Trainer): + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + world_size=self.args.world_size, + lengths=lengths, + group_by_modality=True, + ) + else: + return super()._get_train_sampler() + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + # if self.sharded_ddp == ShardedDDPOption.SIMPLE: + # return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + if self.args.mm_projector_lr is not None: + connector_parameters = [name for name, _ in opt_model.named_parameters() if "connector" in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in connector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "name": "decay_no_connector_parameters" + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in connector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "name": "no_decay_no_connector_parameters" + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in connector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "lr": self.args.mm_projector_lr, + "name": "decay_connector_parameters" + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in connector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "lr": self.args.mm_projector_lr, + "name": "no_decay_proj_parameters" + }, + ] + else: + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "name": "decay_parameters" + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "name": "no_decay_parameters" + }, + ] + + if getattr(self.args, "moe_enable", False): + from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer + optimizer_grouped_parameters = split_params_into_different_moe_groups_for_optimizer(optimizer_grouped_parameters) + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + + + diff --git a/TinyLLaVA_Factory/tinyllava/train/train.py b/TinyLLaVA_Factory/tinyllava/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe2ab2bdc4073c55a9987bfaaf56b9b4afed3c0 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/train/train.py @@ -0,0 +1,92 @@ +from packaging import version +import pathlib + +import tokenizers +import transformers + + +from tinyllava.train.tinyllava_trainer import LLaVATrainer +from tinyllava.training_recipe import TrainingRecipeFactory +from tinyllava.utils import * +from tinyllava.model import * +from tinyllava.data.dataset import make_supervised_data_module + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') + + +def load_settings(model_arguments, data_arguments, training_arguments): + model_arguments.tune_type_connector = training_arguments.tune_type_connector + model_arguments.tune_type_llm = training_arguments.tune_type_llm + model_arguments.tune_type_vision_tower = training_arguments.tune_type_vision_tower + model_arguments.image_aspect_ratio = data_arguments.image_aspect_ratio + + model_args = {} + model_args['llm'] = _load_llm_settings(model_arguments) + model_args['vision_tower'] = _load_vision_settings(model_arguments) + model_args['connector'] = _load_connector_settings(model_arguments) + return model_args + +def _load_llm_settings(model_arguments): + llm_args = {} + llm_args['model_name_or_path'] = model_arguments.model_name_or_path + llm_args['cache_dir'] = model_arguments.cache_dir + llm_args['attn_implementation'] = model_arguments.attn_implementation # flash_attention_2 only supports torch.float16 and torch.bfloat16 dtypes + return llm_args + +def _load_vision_settings(model_arguments): + vision_args = {} + vision_args['model_name_or_path'] = model_arguments.vision_tower.split(':')[-1] + if model_arguments.vision_tower2 != '': + vision_args['model_name_or_path2'] = model_arguments.vision_tower2.split(':')[-1] + return vision_args + +def _load_connector_settings(model_arguments): + connector_args = {} + connector_args['connector_type'] = model_arguments.connector_type + return connector_args + + +def train(): + + # load argument + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_arguments, data_arguments, training_arguments = parser.parse_args_into_dataclasses() + + logger_setting(getattr(training_arguments, 'output_dir', None)) + + training_recipe = TrainingRecipeFactory(training_arguments.training_recipe)(training_arguments) + # model_args contain arguements for huggingface model .from_pretrained function + model_args = load_settings(model_arguments, data_arguments, training_arguments) + model_args = training_recipe.add_args(model_args) + model_config = TinyLlavaConfig() + model_config.load_from_config(model_arguments) + model = TinyLlavaForConditionalGeneration(model_config) + # load pretrained checkpoint + if training_arguments.pretrained_model_path is not None: + model = training_recipe.load(model, model_args) + else: + model.load_llm(**model_args['llm']) + model.load_vision_tower(**model_args['vision_tower']) + model.load_connector(**model_args['connector']) + + model = training_recipe(model) + model.config.use_cache = False + model.config.image_aspect_ratio = data_arguments.image_aspect_ratio + tokenizer = model.tokenizer + data_arguments.image_processor = model.vision_tower._image_processor + data_arguments.is_multimodal = True + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_arguments) + log_trainable_params(model) # not work well with zero3 + trainer = LLaVATrainer(model=model, #does not require model.to(device), huggingface/deepspeed does it for you? + tokenizer=tokenizer, + args=training_arguments, + **data_module) + + trainer.train() + + training_recipe.save(model, trainer) + +if __name__ == "__main__": + train() diff --git a/TinyLLaVA_Factory/tinyllava/training_recipe/__init__.py b/TinyLLaVA_Factory/tinyllava/training_recipe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f771097aac5c5f35bb99ff45b84200e634a20f0 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/training_recipe/__init__.py @@ -0,0 +1,27 @@ +import os + +from ..utils import import_modules + + +RECIPE_FACTORY = {} + +def TrainingRecipeFactory(training_recipe): + recipe = None + for name in RECIPE_FACTORY.keys(): + if name.lower() == training_recipe.lower(): + recipe = RECIPE_FACTORY[name] + assert recipe, f"{training_recipe} is not registered" + return recipe + + +def register_training_recipe(name): + def register_training_recipe_cls(cls): + if name in RECIPE_FACTORY: + return RECIPE_FACTORY[name] + RECIPE_FACTORY[name] = cls + return cls + return register_training_recipe_cls + + +models_dir = os.path.dirname(__file__) +import_modules(models_dir, "tinyllava.training_recipe") diff --git a/TinyLLaVA_Factory/tinyllava/training_recipe/base.py b/TinyLLaVA_Factory/tinyllava/training_recipe/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab3c1398a7f3cf9e06712aaf899fd32f94e1f3a --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/training_recipe/base.py @@ -0,0 +1,167 @@ +import os +import torch + +from ..utils import * +from ..model import * + +class BaseTrainingRecipe: + + def __init__(self, training_arguments): + self.training_arguments = training_arguments + + + def __call__(self, model): + model = self.training_model_converse(model) + model = self.tune_type_setting(model) + model.config.tune_type_connector = self.training_arguments.tune_type_connector + model.config.tune_type_vision_tower = self.training_arguments.tune_type_vision_tower + model.config.tune_type_llm = self.training_arguments.tune_type_llm + model.config.tune_vision_tower_from_layer = self.training_arguments.tune_vision_tower_from_layer + return model + + + def add_args(self, model_args): + llm_dtype = (torch.float16 if self.training_arguments.fp16 else (torch.bfloat16 if self.training_arguments.bf16 else torch.float32)) + model_args['llm'].update(dict(torch_dtype=llm_dtype)) + if self.training_arguments.pretrained_model_path is not None: + model_args['llm'].update(dict(pretrained_llm_path=os.path.join(self.training_arguments.pretrained_model_path, 'language_model'))) + model_args['vision_tower'].update(dict(pretrained_vision_tower_path=os.path.join(self.training_arguments.pretrained_model_path, 'vision_tower'))) + model_args['connector'].update(dict(pretrained_connector_path=os.path.join(self.training_arguments.pretrained_model_path, 'connector'))) + return model_args + + def tune_type_setting(self, model): + model = self._llm_tune_type_setting(model) + model = self._vision_tower_tune_type_setting(model) + model = self._connector_tune_type_setting(model) + return model + + + + def _llm_tune_type_setting(self, model): + tune_type = self.training_arguments.tune_type_llm.lower() + assert tune_type in ('frozen', 'full', 'lora', 'qlora'), f'tune_type {tune_type} not supported in this training recipe!' + if tune_type == 'full': + model.language_model.requires_grad_(True) + elif tune_type == 'frozen': + model.language_model.requires_grad_(False) + self.support_gradient_checkpoint(model.language_model, self.training_arguments.gradient_checkpointing) + return model + + def _vision_tower_tune_type_setting(self, model): + tune_type = self.training_arguments.tune_type_vision_tower.lower() + assert tune_type in ('frozen', 'full', 'partially-tune', 'lora', 'qlora'), f'tune_type {tune_type} not supported in this training recipe!' + if tune_type == 'full': + model.vision_tower.requires_grad_(True) + elif tune_type == 'frozen': + model.vision_tower.requires_grad_(False) + elif tune_type == 'partially-tune': + #-------------------------------------------- + #-------------------------------------------- + #TODO gradient checkpointing related??? + #-------------------------------------------- + #-------------------------------------------- + from_layer = self.training_arguments.tune_vision_tower_from_layer + if from_layer > -1: + log(f'Tune the vision tower from layer {from_layer}!') + for n, p in model.vision_tower.named_parameters(): + if 'vision_model.encoder.layers.' in n: #TODO not sure if other visual encoders contain 'vision_model.encoder.layers.' + layer_id = int(n.split('vision_model.encoder.layers.')[-1].split('.')[0]) + if layer_id >= from_layer: + p.requires_grad = True + else: + p.requires_grad = False + else: + p.requires_grad = False + #self.support_gradient_checkpoint(model.vision_tower._vision_tower, self.training_arguments.gradient_checkpointing) + return model + + def _connector_tune_type_setting(self, model): + tune_type = self.training_arguments.tune_type_connector.lower() + assert tune_type in ('frozen', 'full', 'lora', 'qlora'), f'tune_type {tune_type} not supported in this training recipe!' + if tune_type == 'full': + for p in model.connector.parameters(): + p.requires_grad = True + elif tune_type == 'frozen': + for p in model.connector.parameters(): + p.requires_grad = False + return model + + + + def training_model_converse(self, model): + return model + + + def save(self, model, trainer): + model.config.use_cache = True + #save tokenizer + model.tokenizer.save_pretrained(self.training_arguments.output_dir) + #save entire model config + model.config.save_pretrained(self.training_arguments.output_dir, from_pt=True) + #save trainer + trainer.save_state() + + if 'finetune' in self.training_arguments.output_dir and self.training_arguments.pretrained_model_path is not None: # for finetune stage + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(self.training_arguments.output_dir) + return + + #the followings are for pretrain stage + #save language model + language_model_state_dict = get_state_maybe_zero_3(model.language_model.named_parameters(), [''], False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + language_model_output_dir = os.path.join(self.training_arguments.output_dir, 'language_model') + os.makedirs(language_model_output_dir, exist_ok=True) + language_model_output_path = os.path.join(self.training_arguments.output_dir, 'language_model/pytorch_model.bin') + torch.save(language_model_state_dict, language_model_output_path) + model.config.text_config.save_pretrained(language_model_output_dir, from_pt=True) + #save vision tower + vision_tower_state_dict = get_state_maybe_zero_3(model.vision_tower._vision_tower.named_parameters(), [''], False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + vision_tower_output_dir = os.path.join(self.training_arguments.output_dir, 'vision_tower') + os.makedirs(vision_tower_output_dir, exist_ok=True) + vision_tower_output_path = os.path.join(self.training_arguments.output_dir, 'vision_tower/pytorch_model.bin') + torch.save(vision_tower_state_dict, vision_tower_output_path) + if isinstance(model.vision_tower._vision_tower, PreTrainedModel): + model.vision_tower._vision_tower.config.save_pretrained(vision_tower_output_dir, from_pt=True) + #save connector + connector_state_dict = get_state_maybe_zero_3(model.connector.named_parameters(), [''], False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + connector_output_dir = os.path.join(self.training_arguments.output_dir, 'connector') + os.makedirs(connector_output_dir, exist_ok=True) + connector_output_path = os.path.join(self.training_arguments.output_dir, 'connector/pytorch_model.bin') + torch.save(connector_state_dict, connector_output_path) + + + def load(self, model, model_args={}): + if not ('lora' in self.training_arguments.pretrained_model_path and os.path.exists(os.path.join(self.training_arguments.pretrained_model_path, 'adapter_config.json'))): # loading model for non-lora/non-qlora pretraining + model.load_llm(**model_args['llm']) + model.load_vision_tower(**model_args['vision_tower']) + model.load_connector(**model_args['connector']) + else: + model.language_model = model.language_model.from_pretrained(model_args['llm']['model_name_or_path'],attn_implementation='flash_attention_2',torch_dtype=model_args['llm']['torch_dtype']) + model.load_vision_tower(**model_args['vision_tower']) + model.load_connector(**model_args['connector']) + model.to(model_args['llm']['torch_dtype']) + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, self.training_arguments.pretrained_model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + + return model + + + def support_gradient_checkpoint(self, model, gradient_checkpointing=False): + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + if gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + + diff --git a/TinyLLaVA_Factory/tinyllava/training_recipe/common_recipe.py b/TinyLLaVA_Factory/tinyllava/training_recipe/common_recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..a38a0e4097745d9b507f46d6f0d050e15f5ba50e --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/training_recipe/common_recipe.py @@ -0,0 +1,14 @@ +import os + +import torch + +from .base import BaseTrainingRecipe +from . import register_training_recipe +from ..utils import log +from ..utils import get_state_maybe_zero_3 +from ..model import TinyLlavaConfig, TinyLlavaForConditionalGeneration + + +@register_training_recipe('common') +class CommonTrainingRecipe(BaseTrainingRecipe): + ... diff --git a/TinyLLaVA_Factory/tinyllava/training_recipe/lora_recipe.py b/TinyLLaVA_Factory/tinyllava/training_recipe/lora_recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..96509f1847f1b7f2abae353f9a3d5d442b23db37 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/training_recipe/lora_recipe.py @@ -0,0 +1,91 @@ +import os + +from collections import OrderedDict + +import torch +from transformers import BitsAndBytesConfig +from peft import prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model, TaskType, PeftModel +from peft.tuners.lora import LoraLayer + +from .base import BaseTrainingRecipe +from . import register_training_recipe +from ..utils.train_utils import * +from ..utils import log +from ..model import TinyLlavaConfig, TinyLlavaForConditionalGeneration + + +@register_training_recipe('lora') +class LoRATrainingRecipe(BaseTrainingRecipe): + def __init__(self, training_arguments): + super().__init__(training_arguments) + self.training_arguments = training_arguments + self.lora_skip_module = ['connector', 'vision_tower', 'language_model'] + + + def training_model_converse(self, model): + if self.training_arguments.tune_type_connector == 'lora': + self.lora_skip_module.remove('connector') + if self.training_arguments.tune_type_llm == 'lora': + self.lora_skip_module.remove('language_model') + if self.training_arguments.tune_type_vision_tower == 'lora': + self.lora_skip_module.remove('vision_tower') + lora_config = LoraConfig( + r=self.training_arguments.lora_r, + lora_alpha=self.training_arguments.lora_alpha, + target_modules=find_all_linear_names(model, self.lora_skip_module), + lora_dropout=self.training_arguments.lora_dropout, + bias=self.training_arguments.lora_bias, + task_type="CAUSAL_LM", + ) + if self.training_arguments.bits == 16: + if self.training_arguments.bf16: + model.to(torch.bfloat16) + if self.training_arguments.fp16: + model.to(torch.float16) + if model.peft_config is None: + log("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + return model + + + def save(self, model, trainer): + model.config.use_cache = True + #save tokenizer + model.tokenizer.save_pretrained(self.training_arguments.output_dir) + #save entire model config + model.config.save_pretrained(self.training_arguments.output_dir, from_pt=True) + #save trainer + trainer.save_state() + + #save language model base params + language_model_state_dict = get_peft_state_non_lora_maybe_zero_3(model.language_model.named_parameters(), False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + language_model_output_dir = os.path.join(self.training_arguments.output_dir, 'language_model') + os.makedirs(language_model_output_dir, exist_ok=True) + language_model_output_path = os.path.join(self.training_arguments.output_dir, 'language_model/pytorch_model.bin') + torch.save(language_model_state_dict, language_model_output_path) + model.config.text_config.save_pretrained(language_model_output_dir, from_pt=True) + #save vision tower base params + vision_tower_state_dict = get_peft_state_non_lora_maybe_zero_3(model.vision_tower._vision_tower.named_parameters(), False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + vision_tower_output_dir = os.path.join(self.training_arguments.output_dir, 'vision_tower') + os.makedirs(vision_tower_output_dir, exist_ok=True) + vision_tower_output_path = os.path.join(self.training_arguments.output_dir, 'vision_tower/pytorch_model.bin') + torch.save(vision_tower_state_dict, vision_tower_output_path) + model.config.vision_config.save_pretrained(vision_tower_output_dir, from_pt=True) + #save connector base params + connector_state_dict = get_peft_state_non_lora_maybe_zero_3(model.connector.named_parameters(), False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + connector_output_dir = os.path.join(self.training_arguments.output_dir, 'connector') + os.makedirs(connector_output_dir, exist_ok=True) + connector_output_path = os.path.join(self.training_arguments.output_dir, 'connector/pytorch_model.bin') + torch.save(connector_state_dict, connector_output_path) + # save lora params + lora_state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), self.training_arguments.lora_bias + ) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + model.save_pretrained(self.training_arguments.output_dir, state_dict=lora_state_dict) + + diff --git a/TinyLLaVA_Factory/tinyllava/training_recipe/qlora_recipe.py b/TinyLLaVA_Factory/tinyllava/training_recipe/qlora_recipe.py new file mode 100644 index 0000000000000000000000000000000000000000..657ba7edce1332d61e0acd5032d0fbb4f6b13b86 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/training_recipe/qlora_recipe.py @@ -0,0 +1,109 @@ +import os + +from collections import OrderedDict + +import torch +from transformers import BitsAndBytesConfig +from peft import prepare_model_for_kbit_training +from peft import LoraConfig, get_peft_model, TaskType, PeftModel +from peft.tuners.lora import LoraLayer + +from .base import BaseTrainingRecipe +from . import register_training_recipe +from ..utils.train_utils import * +from ..utils import log +from ..model import TinyLlavaConfig, TinyLlavaForConditionalGeneration + + +@register_training_recipe('qlora_int8') +class QLoRAInt8TrainingRecipe(BaseTrainingRecipe): + def __init__(self, training_arguments): + super().__init__(training_arguments) + self.training_arguments = training_arguments + self.lora_skip_module = ['connector', 'vision_tower', 'language_model'] + + + def add_args(self, model_args): + llm_dtype = (torch.float16 if self.training_arguments.fp16 else (torch.bfloat16 if self.training_arguments.bf16 else torch.float32)) + model_args['llm'].update(dict(torch_dtype=llm_dtype)) + model_args['llm'].update(dict(low_cpu_mem_usage=True)) + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=llm_dtype + ) + model_args['llm'].update(dict(quantization_config=quantization_config)) + + if self.training_arguments.pretrained_model_path is not None: + model_args['llm'].update(dict(pretrained_llm_path=os.path.join(self.training_arguments.pretrained_model_path, 'language_model'))) + model_args['vision_tower'].update(dict(pretrained_vision_tower_path=os.path.join(self.training_arguments.pretrained_model_path, 'vision_tower'))) + model_args['connector'].update(dict(pretrained_connector_path=os.path.join(self.training_arguments.pretrained_model_path, 'connector'))) + return model_args + + + def training_model_converse(self, model): + if self.training_arguments.tune_type_connector == 'qlora': + self.lora_skip_module.remove('connector') + if self.training_arguments.tune_type_llm == 'qlora': + self.lora_skip_module.remove('language_model') + if self.training_arguments.tune_type_vision_tower == 'qlora': + self.lora_skip_module.remove('vision_tower') + lora_config = LoraConfig( + r=self.training_arguments.lora_r, + lora_alpha=self.training_arguments.lora_alpha, + target_modules=find_all_linear_names(model, self.lora_skip_module), + lora_dropout=self.training_arguments.lora_dropout, + bias=self.training_arguments.lora_bias, + task_type="CAUSAL_LM", + ) + if self.training_arguments.bits == 16: + if self.training_arguments.bf16: + model.to(torch.bfloat16) + if self.training_arguments.fp16: + model.to(torch.float16) + log("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + return model + + + def save(self, model, trainer): + model.config.use_cache = True + #save tokenizer + model.tokenizer.save_pretrained(self.training_arguments.output_dir) + #save entire model config + model.config.save_pretrained(self.training_arguments.output_dir, from_pt=True) + #save trainer + trainer.save_state() + + #save language model base params + language_model_state_dict = get_peft_state_non_lora_maybe_zero_3(model.language_model.named_parameters(), False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + language_model_output_dir = os.path.join(self.training_arguments.output_dir, 'language_model') + os.makedirs(language_model_output_dir, exist_ok=True) + language_model_output_path = os.path.join(self.training_arguments.output_dir, 'language_model/pytorch_model.bin') + torch.save(language_model_state_dict, language_model_output_path) + model.config.text_config.save_pretrained(language_model_output_dir, from_pt=True) + #save vision tower base params + vision_tower_state_dict = get_peft_state_non_lora_maybe_zero_3(model.vision_tower._vision_tower.named_parameters(), False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + vision_tower_output_dir = os.path.join(self.training_arguments.output_dir, 'vision_tower') + os.makedirs(vision_tower_output_dir, exist_ok=True) + vision_tower_output_path = os.path.join(self.training_arguments.output_dir, 'vision_tower/pytorch_model.bin') + torch.save(vision_tower_state_dict, vision_tower_output_path) + model.config.vision_config.save_pretrained(vision_tower_output_dir, from_pt=True) + #save connector base params + connector_state_dict = get_peft_state_non_lora_maybe_zero_3(model.connector.named_parameters(), False) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + connector_output_dir = os.path.join(self.training_arguments.output_dir, 'connector') + os.makedirs(connector_output_dir, exist_ok=True) + connector_output_path = os.path.join(self.training_arguments.output_dir, 'connector/pytorch_model.bin') + torch.save(connector_state_dict, connector_output_path) + # save lora params + lora_state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), self.training_arguments.lora_bias + ) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + model.save_pretrained(self.training_arguments.output_dir, state_dict=lora_state_dict) + + diff --git a/TinyLLaVA_Factory/tinyllava/utils/__init__.py b/TinyLLaVA_Factory/tinyllava/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28e4afa4defaa665cca8979508dc5e927c102ac6 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/__init__.py @@ -0,0 +1,8 @@ +from .arguments import * +from .constants import * +from .import_module import * +from .logging import * +from .train_utils import * +from .message import * +from .eval_utils import * +from .data_utils import * diff --git a/TinyLLaVA_Factory/tinyllava/utils/arguments.py b/TinyLLaVA_Factory/tinyllava/utils/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c0bf843f24e67d697682cc8e7d8b34fae18c01 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/arguments.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence, TYPE_CHECKING +import transformers + + +if TYPE_CHECKING: + import transformers + +@dataclass +class ModelArguments: + cache_dir: Optional[str] = field(default=None) + + model_name_or_path: Optional[str] = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0") + tokenizer_name_or_path: Optional[str] = field(default=None) + attn_implementation: Optional[str] = field(default=None) + vision_tower: Optional[str] = field(default='') + vision_tower2: Optional[str] = field(default='') + connector_type: str = field(default='linear') + + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + mm_patch_merge_type: Optional[str] = field(default='flat') + mm_vision_select_feature: Optional[str] = field(default="patch") + resampler_hidden_size: Optional[int] = field(default=768) + num_queries: Optional[int] = field(default=128) + num_resampler_layers: Optional[int] = field(default=3) + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + tokenizer_use_fast: bool = field(default=False) + tokenizer_padding_side: str = field(default='right') + + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = True + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + conv_version: str = 'pretrain' + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + training_recipe: str = field(default='common') + tune_type_llm: str = field(default="frozen") # support only: frozen, full, lora, qlora_int4, qlora_int8 + tune_type_vision_tower: str = field(default="frozen") # support only: frozen, full, partially-tune + tune_vision_tower_from_layer: Optional[int] = field(default=10) + tune_type_connector: str = field(default="full") # support only: frozen, full + tune_embed_tokens: Optional[int] = field(default=False) + + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + vision_tower_lr: Optional[float] = None + pretrained_model_path: Optional[str] = field(default=None) + + + + diff --git a/TinyLLaVA_Factory/tinyllava/utils/constants.py b/TinyLLaVA_Factory/tinyllava/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6964799be0cf9b66dde7fca8e09b709711d8edcb --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/constants.py @@ -0,0 +1,13 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "." + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +IMAGE_PLACEHOLDER = "" \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/utils/data_utils.py b/TinyLLaVA_Factory/tinyllava/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e95cb9505a43809a541651a6c3a0a9e3e94e1ab --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/data_utils.py @@ -0,0 +1,114 @@ +import ast +import math +from PIL import Image + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + + Args: + image_size (tuple): The size of the input image in the format (width, height). + grid_pinpoints (str): A string representation of a list of possible resolutions. + patch_size (int): The size of each image patch. + + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + if type(grid_pinpoints) is list: + possible_resolutions = grid_pinpoints + else: + possible_resolutions = ast.literal_eval(grid_pinpoints) + width, height = select_best_resolution(image_size, possible_resolutions) + return width // patch_size, height // patch_size + +def select_best_resolution(original_size, possible_resolutions): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float('inf') + + for width, height in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + return best_fit + +## added by llava-1.6 +def divide_to_patches(image, patch_size): + """ + Divides an image into patches of a specified size. + + Args: + image (PIL.Image.Image): The input image. + patch_size (int): The size of each patch. + + Returns: + list: A list of PIL.Image.Image objects representing the patches. + """ + patches = [] + width, height = image.size + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + box = (j, i, j + patch_size, i + patch_size) + patch = image.crop(box) + patches.append(patch) + + return patches + +## added by llava-1.6 +def resize_and_pad_image(image, target_resolution): + """ + Resize and pad an image to a target resolution while maintaining aspect ratio. + + Args: + image (PIL.Image.Image): The input image. + target_resolution (tuple): The target resolution (width, height) of the image. + + Returns: + PIL.Image.Image: The resized and padded image. + """ + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + # Resize the image + resized_image = image.resize((new_width, new_height)) + + new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + new_image.paste(resized_image, (paste_x, paste_y)) + + return new_image + +def get_value_from_kwargs(kwargs, name): + if name in kwargs: + return kwargs.pop(name) + else: + return None \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/utils/eval_utils.py b/TinyLLaVA_Factory/tinyllava/utils/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61b70fc71dbcc58ef67f5bfbf7866a8c9c4fdd40 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/eval_utils.py @@ -0,0 +1,50 @@ +import os +from PIL import Image +from io import BytesIO +import base64 +from transformers import AutoTokenizer +import torch +from transformers import StoppingCriteria, PhiForCausalLM + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + outputs = [] + for i in range(output_ids.shape[0]): + outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) + return all(outputs) + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) diff --git a/TinyLLaVA_Factory/tinyllava/utils/import_module.py b/TinyLLaVA_Factory/tinyllava/utils/import_module.py new file mode 100644 index 0000000000000000000000000000000000000000..274a3a2d293132cbcb95d939392b2c2133dbe4c5 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/import_module.py @@ -0,0 +1,13 @@ +import importlib +import os + +def import_modules(models_dir, namespace): + for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and file.endswith(".py") + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + model_name) \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/utils/logging.py b/TinyLLaVA_Factory/tinyllava/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..634b6c9ffb99fcef55f43dcccdc2940a3541b483 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/logging.py @@ -0,0 +1,56 @@ +import logging +import os +import sys + +import torch.distributed as dist + + +root_logger = None + +def print_rank0(*args): + local_rank = dist.get_rank() + if local_rank == 0: + print(*args) + +def logger_setting(save_dir=None): + global root_logger + if root_logger is not None: + return root_logger + else: + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s | %(levelname)s: %(message)s") + ch.setFormatter(formatter) + root_logger.addHandler(ch) + + if save_dir: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + save_file = os.path.join(save_dir, 'log.txt') + if not os.path.exists(save_file): + os.system(f"touch {save_file}") + fh = logging.FileHandler(save_file, mode='a') + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + root_logger.addHandler(fh) + return root_logger + +def log(*args): + global root_logger + local_rank = dist.get_rank() + if local_rank == 0: + root_logger.info(*args) + + + + +def log_trainable_params(model): + total_params = sum(p.numel() for p in model.parameters()) + total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + log(f'Total Parameters: {total_params}, Total Trainable Parameters: {total_trainable_params}') + log(f'Trainable Parameters:') + for name, param in model.named_parameters(): + if param.requires_grad: + print_rank0(f"{name}: {param.numel()} parameters") diff --git a/TinyLLaVA_Factory/tinyllava/utils/message.py b/TinyLLaVA_Factory/tinyllava/utils/message.py new file mode 100644 index 0000000000000000000000000000000000000000..518d1641977aac2eba00c4ccba2c340b7f76c4c8 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/message.py @@ -0,0 +1,66 @@ +''' +@Description: +@Author: jiajunlong +@Date: 2024-06-19 19:30:17 +@LastEditTime: 2024-06-19 19:32:47 +@LastEditors: jiajunlong +''' +class Message: + def __init__(self, msg=None): + self._messages = msg if msg else [] + self._images = [] + self.skip_next = False + + def add_message(self, question, answer=None): + quension_msg_dict = {'from': 'human'} + quension_msg_dict['value'] = question + answer_msg_dict = {'from': 'gpt'} + answer_msg_dict['value'] = answer + self._messages.append(quension_msg_dict) + self._messages.append(answer_msg_dict) + + def add_image(self, image, index=0): + self._images.append((image, index)) + + @property + def images(self): + return self._images + + @property + def messages(self): + return self._messages + + def copy(self): + return Message(self._messages) + + def to_gradio_chatbot(self): + ret = [] + for i, msg in enumerate(self.messages): + if i % 2 == 0: + if len(self.images) != 0 and i == self.images[0][1]: + image = self.images[0][0] + import base64 + from io import BytesIO + msg = msg['value'] + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = img_str + msg.replace('', '').strip() + ret.append([msg, None]) + else: + ret.append([msg['value'], None]) + else: + ret[-1][-1] = msg['value'] + return ret \ No newline at end of file diff --git a/TinyLLaVA_Factory/tinyllava/utils/train_utils.py b/TinyLLaVA_Factory/tinyllava/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62ebf2731d7f8ac62a058df592ce9fd935c50f6c --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava/utils/train_utils.py @@ -0,0 +1,95 @@ +import logging +import os + +import torch +from peft.tuners.lora import LoraLayer +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + +def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + +def lora_kbit_setting(model, training_args): + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + + +def maybe_zero_3(param, ignore_status=False, name=None): + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_state_maybe_zero_3(named_params, keys_to_match=[''], require_grad_only=True): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model, skip_keywords=['connector', 'vision_tower']): + cls = torch.nn.Linear + lora_module_names = set() + skip_keywords = skip_keywords + for name, module in model.named_modules(): + if any(skip_keyword in name for skip_keyword in skip_keywords) or 'lm_head' in name or 'output_layer' in name or 'head' in name: + continue + if isinstance(module, cls): + names = name.split('.') + #lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + lora_module_names.add(name) + # if 'lm_head' in lora_module_names: + # lora_module_names.remove('lm_head') + return list(lora_module_names) diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/README.md b/TinyLLaVA_Factory/tinyllava_visualizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cd509a71e1186b20a2b4baa6ec06b9d47f11d089 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava_visualizer/README.md @@ -0,0 +1,93 @@ +# TinyLLaVA Visualizer + +TinyLLaVA Visualizer is a specialized visualization tool designed to work with the TinyLLaVA model, a multimodal large model. This tool enables users to visualize the relationships between generated words, their connections to the input image, and the probability distributions of these words during the model's inference process. + +## Features + +TinyLLaVA Visualizer provides three main visualization functionalities: + +1. **Word Relationships**: Visualize the relationships between each generated word and the words generated before it. This allows users to understand how the model builds up context over time. +2. **Word-Image Relationships**: Visualize the relationship between each generated word and the input image. This feature helps users see how the model links textual output to visual input. +3. **Word Probability Distributions**: Visualize the probability distribution of each word during the generation process, providing insight into the model's confidence for each word choice. + +## Installation + +To use TinyLLaVA Visualizer, simply ensure that you have an environment capable of running TinyLLaVA. If you already have TinyLLaVA set up, you're good to go! No additional installation steps are required for this tool. + +``` +conda create -n tinyllava_factory python=3.10 -y +conda activate tinyllava_factory +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +## Usage + +### Inference and Visualization + +Place the `inference.py` in the root directory of the project where your model's code resides. During the model's inference process, integrate the `Monitor` class from the visualizer to generate visual outputs. Below is an example use case: + +``` +from tinyllava.eval.run_tiny_llava import eval_model +from tinyllava.model.convert_legecy_weights_to_tinyllavafactory import convert_legecy_weights_to_tinyllavafactory +from tinyllava_visualizer.tinyllava_visualizer import Monitor + +def main(): + model = convert_legecy_weights_to_tinyllavafactory('TinyLLaVA-3.1B') + prompt = "What are the things I should be cautious about when I visit here?" + image_file = "image_test/1.jpeg" + + args = type('Args', (), { + "model_path": None, + "model": model, + "query": prompt, + "conv_mode": "phi", # Adjust based on the LLM version + "image_file": image_file, + "sep": ",", + "temperature": 0, + "top_p": None, + "num_beams": 1, + "max_new_tokens": 512 + })() + +monitor = Monitor(args, llm_layers_index=31) +eval_model(args) +monitor.get_output(output_dir='results/') + +if __name__ == "__main__": + main() +``` + +This example demonstrates how to set up and use TinyLLaVA Visualizer in a typical inference workflow. After running this code, the visual outputs will be stored in the `results` directory, categorized by the type of visualization. + +## Project Structure + +- `tinyllava_visualizer/tinyllava_visualizer.py`: The main script for visualization. +- `tinyllava/`: Directory containing core model and data processing code. +- `scripts/`: Contains utility scripts. +- `eval/`: Evaluation scripts and tools. +- `results/`: Storage for visualization results. + +## Example Visualizations + +Here are examples of the types of visual outputs you can expect: + +prompt = "What is it?" + +image: + + + +output: + +The image features a small, fluffy, light brown dog with a pink collar. The dog is wearing a sweater, which adds a touch of warmth and style to its appearance. The dog is standing on a wooden floor, and its gaze is directed straight at the camera, creating a sense of connection between the viewer and the subject. The dog's fur appears soft and fluffy, and its pink collar stands out against its light brown coat. The wooden floor provides a natural and warm background that contrasts with the dog's vibrant colors. The dog's position and the way it looks at the camera give the image a sense of liveliness and personality. The image does not contain any text or other objects. The focus is solely on the dog, making it the central element of the image. The relative position of the dog to the camera and the wooden floor suggests that the photo was taken in a home setting, possibly in the living room or a similar area. The image does not provide any additional context or information about the dog's breed, age, or any other details beyond what can be seen in the image. + +## *the visualization of word 'sweater'* + +| Word Probability Distributions | Word Relationships | Word-Image Relationships | +| ------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| ![](https://raw.githubusercontent.com/lingcco/TinyLLaVA_Factory/tinyllava_visualizer/tinyllava_visualizer/demo/Word%20Probability%20Distributions.png) | ![](https://raw.githubusercontent.com/lingcco/TinyLLaVA_Factory/tinyllava_visualizer/tinyllava_visualizer/demo/Word%20Relationships.png) | | + +--- + +If you encounter any issues or have suggestions, reach out to us at [21376195@buaa.edu.cn]. diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word Probability Distributions.png b/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word Probability Distributions.png new file mode 100644 index 0000000000000000000000000000000000000000..1493c26a37a8532176e58a0c283a67bfffeb9bfb Binary files /dev/null and b/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word Probability Distributions.png differ diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word Relationships.png b/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word Relationships.png new file mode 100644 index 0000000000000000000000000000000000000000..5f01ee6ecaddea53af95f88838436b222bd40833 Binary files /dev/null and b/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word Relationships.png differ diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word-Image Relationships.png b/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word-Image Relationships.png new file mode 100644 index 0000000000000000000000000000000000000000..06e3d15655244435ed7b8c338c367bade9cf3f35 Binary files /dev/null and b/TinyLLaVA_Factory/tinyllava_visualizer/demo/Word-Image Relationships.png differ diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/demo/demo_picture.webp b/TinyLLaVA_Factory/tinyllava_visualizer/demo/demo_picture.webp new file mode 100644 index 0000000000000000000000000000000000000000..d6ab0a51befacbec77c29840088941421412131f Binary files /dev/null and b/TinyLLaVA_Factory/tinyllava_visualizer/demo/demo_picture.webp differ diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/inference_example.py b/TinyLLaVA_Factory/tinyllava_visualizer/inference_example.py new file mode 100644 index 0000000000000000000000000000000000000000..af9fdd4a474a43564d773ffd875afe043dd1fb72 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava_visualizer/inference_example.py @@ -0,0 +1,27 @@ +from tinyllava.eval.run_tiny_llava import eval_model +from transformers import AutoTokenizer, AutoModelForCausalLM +from tinyllava_visualizer.tinyllava_visualizer import * + +prompt = "What are the things I should be cautious about when I visit here?" +image_file = "https://llava-vl.github.io/static/images/view.jpg" + +model = AutoModelForCausalLM.from_pretrained("/mnt/hwfile/opendatalab/wensiwei/checkpoint/TinyLLaVA-Phi-2-SigLIP-3.1B", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("/mnt/hwfile/opendatalab/wensiwei/checkpoint/TinyLLaVA-Phi-2-SigLIP-3.1B", trust_remote_code=True) +model.tokenizer = tokenizer + +args = type('Args', (), { + "model_path": None, + "model": model, + "query": prompt, + "conv_mode": "phi", # the same as conv_version in the training stage. Different LLMs have different conv_mode/conv_version, please replace it + "image_file": image_file, + "sep": ",", + "temperature": 0, + "top_p": None, + "num_beams": 1, + "max_new_tokens": 512 +})() + +monitor = Monitor(args, model, llm_layers_index=31) +eval_model(args) +monitor.get_output(output_dir='results/') diff --git a/TinyLLaVA_Factory/tinyllava_visualizer/tinyllava_visualizer.py b/TinyLLaVA_Factory/tinyllava_visualizer/tinyllava_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4603c66a5f375bd89d3fd038f57956a0549da87 --- /dev/null +++ b/TinyLLaVA_Factory/tinyllava_visualizer/tinyllava_visualizer.py @@ -0,0 +1,215 @@ +from collections import defaultdict +from torch.nn.parallel import DistributedDataParallel +from matplotlib import pyplot as plt +import torch +import requests +from io import BytesIO +from PIL import Image, ImageDraw +from torchvision.transforms import ToPILImage +import torch.nn.functional as F +import numpy as np +import os +import datetime +from tinyllava.data import * +from tinyllava.utils import * +from tinyllava.model import * +import pdb + +def load_image(image_file): + if image_file.startswith("http") or image_file.startswith("https"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + return image + + +def load_images(image_files): + out = [] + for image_file in image_files: + image = load_image(image_file) + out.append(image) + return out + + +def extract_max_values_and_indices(tensor, k): + max_values, max_indices = torch.topk(tensor, k, dim=2) + max_values_with_indices = torch.stack((max_indices, max_values), dim=3) + return max_values_with_indices + + +def visualize_grid_to_grid(i, mask, image, output_dir, grid_size=27, alpha=0.6): + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + mask = mask.detach().cpu().numpy() + mask = Image.fromarray(mask).resize((384, 384)) + fig, ax = plt.subplots(1, 2, figsize=(10, 7)) + fig.tight_layout() + + ax[0].imshow(image) + ax[0].axis('off') + + ax[1].imshow(image) + im = ax[1].imshow(mask / np.max(mask), alpha=alpha, cmap='rainbow') + ax[1].axis('off') + cbar = fig.colorbar(im, ax=ax[1]) + cbar.set_label('Color Temperature') + name = os.path.join(output_dir, "hot_image", f"{i}.png") + plt.savefig(name) + plt.close(fig) + + +def generate_square_subsequent_mask(sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + +def generate_word_images(tokenizer, top_words_tensor, num, input_ids, embed_tokens, output_dir): + num_top_words = top_words_tensor.shape[1] + for i in range(num_top_words - num, num_top_words): + fig, ax = plt.subplots() + word_indices = top_words_tensor[0, i, :, 0].detach().cpu().numpy() + probabilities = top_words_tensor[0, i, :, 1].detach().cpu().numpy() + colors = plt.cm.viridis(probabilities) + + for j, (word_index, color, prob) in enumerate(zip(word_indices, colors, probabilities)): + word = tokenizer.decode([int(word_index)]) + prob_text = f"{word} P: {prob:.2f}" + ax.text(0.5, 0.9 - j * 0.1, prob_text, color=color, ha='center', va='center', transform=ax.transAxes) + ax.axis('off') + ax.set_title('Top Words for Index {}'.format(i - num_top_words + num + 1)) + plt.savefig(os.path.join(output_dir, 'word', f"word_image_{i - num_top_words + num + 1}.png")) + plt.close() + + +def generate_word_images_before(tokenizer, input_ids, tensor, num, top_words_tensor, output_dir): + num_top_words = tensor.shape[2] + result = tensor.mean(dim=1) # [1, len, len] + input_ids_fir = input_ids[input_ids != -200].unsqueeze(0) + for i in range(num_top_words - num, num_top_words - 1): + top1_indices = top_words_tensor[0, i, 0, 0].long() + fig, ax = plt.subplots() + result_1 = result[0, i, 0:input_ids.shape[1]] + result_1 = result_1[input_ids.squeeze() != -200] + if not i == num_top_words - num: + result_2 = result[0, i, num_top_words - num + 1:i + 1] + result_1 = torch.cat((result_1, result_2), dim=0) + + if not i == num_top_words - num: + output_ids = top_words_tensor[0, num_top_words - num:i, 0, 0].unsqueeze(0).long() + input_ids_fir = torch.cat((input_ids_fir, output_ids), dim=1) + + tv, ti = torch.topk(result_1.squeeze(), 8) + tv = tv / torch.max(tv) + probabilities = tv.detach().cpu().numpy() + colors = plt.cm.viridis(probabilities) + for j, (word_index, color, prob) in enumerate(zip(ti, colors, probabilities)): + word = tokenizer.decode(input_ids_fir[0, word_index.item()]) + prob_text = f"{word} P: {prob:.2f}" + ax.text(0.5, 0.9 - j * 0.1, prob_text, color=color, ha='center', va='center', transform=ax.transAxes) + ax.axis('off') + ax.set_title( + 'similarities of output word {}'.format(tokenizer.decode([top1_indices.detach().cpu().numpy()]))) + plt.savefig(os.path.join(output_dir, 'word_before', f"word_image_{i - (num_top_words - num - 1)}.png")) + plt.close() + + +class Monitor: + def __init__(self, args, model, llm_layers_index): + self.model = model + self.args = args + self.input_ids = None + self.image = None + self.params = list(model.parameters()) + self.output = defaultdict(dict) + self.attentions = [] + self.hidden = [] + self.logit = [] + self.image_token = [] + self.llm_layers_index = llm_layers_index + self._register(llm_layers_index) + + def _register(self, llm_layers_index): + def attention_hook(module, input, output): + self.hidden.append(input[0]) + + def output_hook(module, input, output): + self.logit.append(output) + + def image_hook(module, input, output): + self.image_token.append(output) + + mod = self.model + mod.language_model.model.layers[llm_layers_index].register_forward_hook(attention_hook) + mod.language_model.lm_head.register_forward_hook(output_hook) + mod.connector.register_forward_hook(image_hook) + + def prepare_input(self): + # 获得input_ids + qs = self.args.query + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + text_processor = TextPreprocess(self.model.tokenizer, self.args.conv_mode) + msg = Message() + msg.add_message(qs) + result = text_processor(msg.messages, mode='eval') + self.input_ids = result['input_ids'].unsqueeze(0).cuda() + # 获得图片tensor + data_args = self.model.config + image_processor = self.model.vision_tower._image_processor + image_processor = ImagePreprocess(image_processor, data_args) + image_files = self.args.image_file.split(self.args.sep) + images = load_images(image_files)[0] + images_tensor = image_processor(images) + image_tensor = 255 * (images_tensor - images_tensor.min()) / (images_tensor.max() - images_tensor.min()) + image_tensor = image_tensor.clamp(0, 255) + image_tensor = image_tensor.byte() + to_pil = ToPILImage() + self.image = to_pil(image_tensor).convert('RGB') + self.model.cuda() + self.logit = F.softmax(torch.cat(self.logit, dim=1), dim=2) + hidden_tensor = torch.cat(self.hidden, dim=1) + length = hidden_tensor.shape[1] + attention_mask = torch.unsqueeze( + torch.unsqueeze(generate_square_subsequent_mask(length).clone().detach(), dim=0), + dim=0).cuda() + + self.hidden = self.model.language_model.model.layers[self.llm_layers_index](hidden_tensor, + output_attentions=True, + attention_mask=attention_mask) + self.image_token = self.image_token[0].squeeze() + self.image_token = torch.cat((torch.zeros(1, 2560).cuda(), self.image_token), dim=0) + + def get_output(self, output_dir='results/'): + print("Starting visualization...") + self.prepare_input() + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = os.path.join(output_dir, f"run_{timestamp}") + os.makedirs(output_dir, exist_ok=True) + + os.makedirs(os.path.join(output_dir, 'word'), exist_ok=True) + os.makedirs(os.path.join(output_dir, 'word_before'), exist_ok=True) + os.makedirs(os.path.join(output_dir, 'hot_image'), exist_ok=True) + + num = self.logit.shape[1] - 726 - len(self.input_ids[0]) + result = extract_max_values_and_indices(self.logit, 8) + generate_word_images(self.model.tokenizer, result, num, self.input_ids, + self.model.language_model.model.embed_tokens.weight, output_dir) + + generate_word_images_before(self.model.tokenizer, self.input_ids, self.hidden[1], num, result, output_dir) + + result_top1 = result[0, :, 0, 0].squeeze() + for i in range(len(result_top1) - num, len(result_top1)): + word_id = result_top1[i] + word_id_tensor = torch.tensor([word_id]).long().cuda() + word_vector = self.model.language_model.model.embed_tokens(word_id_tensor).squeeze().detach() + vector_expanded = word_vector.unsqueeze(0).expand_as(self.image_token) + vector_norm = F.normalize(vector_expanded, p=2, dim=1) + matrix_norm = F.normalize(self.image_token, p=2, dim=1) + cosine_similarities = torch.sum(vector_norm * matrix_norm, dim=1) + normalized_similarities = F.softmax(cosine_similarities, dim=0) + visualize_grid_to_grid('hot_image_' + str(i - (len(result_top1) - num) + 1), + normalized_similarities.view(27, 27), + self.image, output_dir) + print("Completed visualization.")