Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +18 -0
- Dockerfile +39 -0
- LICENSE +21 -0
- README.md +13 -0
- __init__.py +4 -0
- audio_qa_out_cache.wav +3 -0
- data/figures/framework.jpeg +3 -0
- data/figures/inputids.png +3 -0
- data/figures/samples.png +3 -0
- data/figures/title_new.png +3 -0
- data/figures/training.jpeg +3 -0
- data/omni2-demo.mp4 +3 -0
- data/samples/output1.wav +0 -0
- data/samples/output2.wav +3 -0
- data/samples/output3.wav +0 -0
- data/samples/output4.wav +0 -0
- data/samples/output5.wav +3 -0
- data/samples/vision_qa_audio.wav +3 -0
- hotkey.txt +1 -0
- inference.py +705 -0
- inference_vision.py +259 -0
- litgpt/__init__.py +19 -0
- litgpt/config.py +181 -0
- litgpt/generate/__init__.py +0 -0
- litgpt/generate/base.py +795 -0
- litgpt/model.py +654 -0
- litgpt/tokenizer.py +131 -0
- litgpt/utils.py +641 -0
- models/README.md +142 -0
- models/ViT-B-32.pt +3 -0
- models/data/figures/framework.jpeg +3 -0
- models/data/figures/inputids.png +3 -0
- models/data/figures/samples.png +3 -0
- models/data/figures/title.png +3 -0
- models/data/figures/training.jpeg +3 -0
- models/data/omni2-demo.mp4 +3 -0
- models/hub/.locks/models--hubertsiuzdak--snac_24khz/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40.lock +0 -0
- models/hub/.locks/models--hubertsiuzdak--snac_24khz/a9e7ef62bf7e1eb94d2713721029837aacab3b55.lock +0 -0
- models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 +3 -0
- models/hub/models--hubertsiuzdak--snac_24khz/blobs/a9e7ef62bf7e1eb94d2713721029837aacab3b55 +13 -0
- models/hub/models--hubertsiuzdak--snac_24khz/refs/main +1 -0
- models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/config.json +13 -0
- models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/pytorch_model.bin +3 -0
- models/hub/version.txt +1 -0
- models/lit_model.pth +3 -0
- models/model_config.yaml +43 -0
- models/small.pt +3 -0
- models/tokenizer.json +0 -0
- models/tokenizer_config.json +40 -0
- requirements.txt +20 -0
.gitattributes
CHANGED
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
audio_qa_out_cache.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/figures/framework.jpeg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/figures/inputids.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/figures/samples.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/figures/title_new.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/figures/training.jpeg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
data/omni2-demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
+
data/samples/output2.wav filter=lfs diff=lfs merge=lfs -text
|
44 |
+
data/samples/output5.wav filter=lfs diff=lfs merge=lfs -text
|
45 |
+
data/samples/vision_qa_audio.wav filter=lfs diff=lfs merge=lfs -text
|
46 |
+
models/data/figures/framework.jpeg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
models/data/figures/inputids.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
models/data/figures/samples.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
models/data/figures/title.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
models/data/figures/training.jpeg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
models/data/omni2-demo.mp4 filter=lfs diff=lfs merge=lfs -text
|
52 |
+
models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 filter=lfs diff=lfs merge=lfs -text
|
53 |
+
vision_qa_out_cache.wav filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
|
2 |
+
|
3 |
+
# Set environment variables
|
4 |
+
ENV PYTHONUNBUFFERED=1 \
|
5 |
+
DEBIAN_FRONTEND=noninteractive \
|
6 |
+
CUDA_HOME=/usr/local/cuda \
|
7 |
+
PATH=/usr/local/cuda/bin:$PATH \
|
8 |
+
LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
|
9 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
10 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
|
11 |
+
HF_HOME=/app/models
|
12 |
+
|
13 |
+
# Install system dependencies
|
14 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
15 |
+
python3 \
|
16 |
+
python3-pip \
|
17 |
+
python3-dev \
|
18 |
+
build-essential \
|
19 |
+
git \
|
20 |
+
ffmpeg \
|
21 |
+
libsndfile1 \
|
22 |
+
curl \
|
23 |
+
&& rm -rf /var/lib/apt/lists/*
|
24 |
+
|
25 |
+
# Upgrade pip and install build tools
|
26 |
+
RUN python3 -m pip install --upgrade pip setuptools wheel uv
|
27 |
+
|
28 |
+
WORKDIR /app
|
29 |
+
|
30 |
+
COPY requirements.txt .
|
31 |
+
|
32 |
+
# Install other requirements
|
33 |
+
RUN python3 -m uv pip install --no-cache-dir -r requirements.txt --prerelease=allow
|
34 |
+
|
35 |
+
COPY . .
|
36 |
+
|
37 |
+
EXPOSE 8000
|
38 |
+
|
39 |
+
CMD ["python3", "server.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 gpt-omni
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
tags:
|
4 |
+
- any-to-any
|
5 |
+
- omega
|
6 |
+
- omegalabs
|
7 |
+
- bittensor
|
8 |
+
- agi
|
9 |
+
---
|
10 |
+
|
11 |
+
This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
|
12 |
+
|
13 |
+
Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
|
__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
audio_qa_out_cache.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5be14b744cde16792c9bb0a66e4ce9b899ad39ce076996d8803cb41ce010bb9d
|
3 |
+
size 159788
|
data/figures/framework.jpeg
ADDED
![]() |
Git LFS Details
|
data/figures/inputids.png
ADDED
![]() |
Git LFS Details
|
data/figures/samples.png
ADDED
![]() |
Git LFS Details
|
data/figures/title_new.png
ADDED
![]() |
Git LFS Details
|
data/figures/training.jpeg
ADDED
![]() |
Git LFS Details
|
data/omni2-demo.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c2098124af391dca9c48854f5686143c137cc069f08b5e457675b9ba744bd2f
|
3 |
+
size 11784395
|
data/samples/output1.wav
ADDED
Binary file (62.2 kB). View file
|
|
data/samples/output2.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b50c4df6f508a4367e5a49e90f974f8786c6d9ffb2599a8abcd25e693399735a
|
3 |
+
size 105176
|
data/samples/output3.wav
ADDED
Binary file (70.4 kB). View file
|
|
data/samples/output4.wav
ADDED
Binary file (67.6 kB). View file
|
|
data/samples/output5.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33f58e7cc49a4e4fd4809d20cde2fb22855054cf61558be8ffef347fc35ce8f2
|
3 |
+
size 114732
|
data/samples/vision_qa_audio.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18eba79742ad8074074a113e6df56410bdf66e34a645d619a4ad7b8171f6d7d7
|
3 |
+
size 150572
|
hotkey.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
5CWzdquwP5vjnCphhkPMDo51G7MrWXdTJaEiKZYadbBoiiQ6
|
inference.py
ADDED
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import lightning as L
|
3 |
+
import torch
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
from snac import SNAC
|
7 |
+
from litgpt import Tokenizer
|
8 |
+
from litgpt.utils import (
|
9 |
+
num_parameters,
|
10 |
+
)
|
11 |
+
from litgpt.generate.base import (
|
12 |
+
generate_AA,
|
13 |
+
generate_ASR,
|
14 |
+
generate_TA,
|
15 |
+
generate_TT,
|
16 |
+
generate_AT,
|
17 |
+
generate_TA_BATCH,
|
18 |
+
next_token_image_batch
|
19 |
+
)
|
20 |
+
import soundfile as sf
|
21 |
+
from litgpt.model import GPT, Config
|
22 |
+
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
23 |
+
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
24 |
+
from utils.snac_utils import get_snac, generate_audio_data
|
25 |
+
import whisper
|
26 |
+
from tqdm import tqdm
|
27 |
+
from huggingface_hub import snapshot_download
|
28 |
+
|
29 |
+
|
30 |
+
torch.set_printoptions(sci_mode=False)
|
31 |
+
|
32 |
+
|
33 |
+
# TODO
|
34 |
+
text_vocabsize = 151936
|
35 |
+
text_specialtokens = 64
|
36 |
+
audio_vocabsize = 4096
|
37 |
+
audio_specialtokens = 64
|
38 |
+
|
39 |
+
padded_text_vocabsize = text_vocabsize + text_specialtokens
|
40 |
+
padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
|
41 |
+
|
42 |
+
_eot = text_vocabsize
|
43 |
+
_pad_t = text_vocabsize + 1
|
44 |
+
_input_t = text_vocabsize + 2
|
45 |
+
_answer_t = text_vocabsize + 3
|
46 |
+
_asr = text_vocabsize + 4
|
47 |
+
|
48 |
+
_eoa = audio_vocabsize
|
49 |
+
_pad_a = audio_vocabsize + 1
|
50 |
+
_input_a = audio_vocabsize + 2
|
51 |
+
_answer_a = audio_vocabsize + 3
|
52 |
+
_split = audio_vocabsize + 4
|
53 |
+
_image = audio_vocabsize + 5
|
54 |
+
_eoimage = audio_vocabsize + 6
|
55 |
+
|
56 |
+
|
57 |
+
def get_input_ids_TA(text, text_tokenizer):
|
58 |
+
input_ids_item = [[] for _ in range(8)]
|
59 |
+
text_tokens = text_tokenizer.encode(text)
|
60 |
+
for i in range(7):
|
61 |
+
input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
|
62 |
+
layershift(_answer_a, i)
|
63 |
+
]
|
64 |
+
input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
|
65 |
+
input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
|
66 |
+
input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
|
67 |
+
return input_ids_item
|
68 |
+
|
69 |
+
|
70 |
+
def get_input_ids_TT(text, text_tokenizer):
|
71 |
+
input_ids_item = [[] for i in range(8)]
|
72 |
+
text_tokens = text_tokenizer.encode(text).tolist()
|
73 |
+
|
74 |
+
for i in range(7):
|
75 |
+
input_ids_item[i] = torch.tensor(
|
76 |
+
[layershift(_pad_a, i)] * (len(text_tokens) + 3)
|
77 |
+
).unsqueeze(0)
|
78 |
+
input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
|
79 |
+
input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
|
80 |
+
|
81 |
+
return input_ids_item
|
82 |
+
|
83 |
+
|
84 |
+
def get_input_ids_whisper(
|
85 |
+
mel, leng, whispermodel, device,
|
86 |
+
special_token_a=_answer_a, special_token_t=_answer_t,
|
87 |
+
):
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
mel = mel.unsqueeze(0).to(device)
|
91 |
+
# audio_feature = whisper.decode(whispermodel,mel, options).audio_features
|
92 |
+
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
|
93 |
+
|
94 |
+
T = audio_feature.size(0)
|
95 |
+
input_ids = []
|
96 |
+
for i in range(7):
|
97 |
+
input_ids_item = []
|
98 |
+
input_ids_item.append(layershift(_input_a, i))
|
99 |
+
input_ids_item += [layershift(_pad_a, i)] * T
|
100 |
+
input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
|
101 |
+
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
|
102 |
+
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
|
103 |
+
input_ids.append(input_id_T.unsqueeze(0))
|
104 |
+
return audio_feature.unsqueeze(0), input_ids
|
105 |
+
|
106 |
+
|
107 |
+
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
108 |
+
with torch.no_grad():
|
109 |
+
mel = mel.unsqueeze(0).to(device)
|
110 |
+
# audio_feature = whisper.decode(whispermodel,mel, options).audio_features
|
111 |
+
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
|
112 |
+
T = audio_feature.size(0)
|
113 |
+
input_ids_AA = []
|
114 |
+
for i in range(7):
|
115 |
+
input_ids_item = []
|
116 |
+
input_ids_item.append(layershift(_input_a, i))
|
117 |
+
input_ids_item += [layershift(_pad_a, i)] * T
|
118 |
+
input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
|
119 |
+
input_ids_AA.append(torch.tensor(input_ids_item))
|
120 |
+
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
|
121 |
+
input_ids_AA.append(input_id_T)
|
122 |
+
|
123 |
+
input_ids_AT = []
|
124 |
+
for i in range(7):
|
125 |
+
input_ids_item = []
|
126 |
+
input_ids_item.append(layershift(_input_a, i))
|
127 |
+
input_ids_item += [layershift(_pad_a, i)] * T
|
128 |
+
input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
|
129 |
+
input_ids_AT.append(torch.tensor(input_ids_item))
|
130 |
+
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
|
131 |
+
input_ids_AT.append(input_id_T)
|
132 |
+
|
133 |
+
input_ids = [input_ids_AA, input_ids_AT]
|
134 |
+
stacked_inputids = [[] for _ in range(8)]
|
135 |
+
for i in range(2):
|
136 |
+
for j in range(8):
|
137 |
+
stacked_inputids[j].append(input_ids[i][j])
|
138 |
+
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
|
139 |
+
return torch.stack([audio_feature, audio_feature]), stacked_inputids
|
140 |
+
|
141 |
+
|
142 |
+
def load_audio(path):
|
143 |
+
audio = whisper.load_audio(path)
|
144 |
+
duration_ms = (len(audio) / 16000) * 1000
|
145 |
+
audio = whisper.pad_or_trim(audio)
|
146 |
+
mel = whisper.log_mel_spectrogram(audio)
|
147 |
+
return mel, int(duration_ms / 20) + 1
|
148 |
+
|
149 |
+
|
150 |
+
def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
151 |
+
snacmodel, out_dir=None):
|
152 |
+
with fabric.init_tensor():
|
153 |
+
model.set_kv_cache(batch_size=2)
|
154 |
+
tokenlist = generate_TA_BATCH(
|
155 |
+
model,
|
156 |
+
audio_feature,
|
157 |
+
input_ids,
|
158 |
+
[leng, leng],
|
159 |
+
["A1A2", "A1T2"],
|
160 |
+
max_returned_tokens=2048,
|
161 |
+
temperature=0.9,
|
162 |
+
top_k=1,
|
163 |
+
eos_id_a=_eoa,
|
164 |
+
eos_id_t=_eot,
|
165 |
+
pad_id_t=_pad_t,
|
166 |
+
shift=padded_text_vocabsize,
|
167 |
+
include_prompt=True,
|
168 |
+
generate_text=True,
|
169 |
+
)
|
170 |
+
text_tokenlist = tokenlist[-1]
|
171 |
+
if text_vocabsize in text_tokenlist:
|
172 |
+
text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
|
173 |
+
text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
|
174 |
+
|
175 |
+
audio_tokenlist = tokenlist[:-1]
|
176 |
+
audiolist = reconscruct_snac(audio_tokenlist)
|
177 |
+
audio = reconstruct_tensors(audiolist)
|
178 |
+
if out_dir is None:
|
179 |
+
out_dir = "./output/default/A1-A2-batch"
|
180 |
+
else:
|
181 |
+
out_dir = out_dir + "/A1-A2-batch"
|
182 |
+
if not os.path.exists(out_dir):
|
183 |
+
os.makedirs(out_dir)
|
184 |
+
with torch.inference_mode():
|
185 |
+
audio_hat = snacmodel.decode(audio)
|
186 |
+
sf.write(
|
187 |
+
f"{out_dir}/{step:02d}.wav",
|
188 |
+
audio_hat.squeeze().cpu().numpy(),
|
189 |
+
24000,
|
190 |
+
)
|
191 |
+
model.clear_kv_cache()
|
192 |
+
return text
|
193 |
+
|
194 |
+
|
195 |
+
def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
196 |
+
with fabric.init_tensor():
|
197 |
+
model.set_kv_cache(batch_size=1)
|
198 |
+
tokenlist = generate_AT(
|
199 |
+
model,
|
200 |
+
audio_feature,
|
201 |
+
input_ids,
|
202 |
+
[leng],
|
203 |
+
["AT"],
|
204 |
+
max_returned_tokens=2048,
|
205 |
+
temperature=0.9,
|
206 |
+
top_k=1,
|
207 |
+
eos_id_a=_eoa,
|
208 |
+
eos_id_t=_eot,
|
209 |
+
pad_id_t=_pad_t,
|
210 |
+
shift=padded_text_vocabsize,
|
211 |
+
include_prompt=True,
|
212 |
+
generate_text=True,
|
213 |
+
)
|
214 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
215 |
+
|
216 |
+
|
217 |
+
def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
218 |
+
snacmodel, out_dir=None):
|
219 |
+
with fabric.init_tensor():
|
220 |
+
model.set_kv_cache(batch_size=1)
|
221 |
+
tokenlist = generate_AA(
|
222 |
+
model,
|
223 |
+
audio_feature,
|
224 |
+
input_ids,
|
225 |
+
[leng],
|
226 |
+
["A1T2"],
|
227 |
+
max_returned_tokens=2048,
|
228 |
+
temperature=0.9,
|
229 |
+
top_k=1,
|
230 |
+
eos_id_a=_eoa,
|
231 |
+
eos_id_t=_eot,
|
232 |
+
pad_id_t=_pad_t,
|
233 |
+
shift=padded_text_vocabsize,
|
234 |
+
include_prompt=True,
|
235 |
+
generate_text=True,
|
236 |
+
)
|
237 |
+
audiolist = reconscruct_snac(tokenlist)
|
238 |
+
tokenlist = tokenlist[-1]
|
239 |
+
if text_vocabsize in tokenlist:
|
240 |
+
tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
|
241 |
+
if out_dir is None:
|
242 |
+
out_dir = "./output/default/A1-A2"
|
243 |
+
else:
|
244 |
+
out_dir = out_dir + "/A1-A2"
|
245 |
+
if not os.path.exists(out_dir):
|
246 |
+
os.makedirs(out_dir)
|
247 |
+
|
248 |
+
audio = reconstruct_tensors(audiolist)
|
249 |
+
with torch.inference_mode():
|
250 |
+
audio_hat = snacmodel.decode(audio)
|
251 |
+
sf.write(
|
252 |
+
f"{out_dir}/{step:02d}.wav",
|
253 |
+
audio_hat.squeeze().cpu().numpy(),
|
254 |
+
24000,
|
255 |
+
)
|
256 |
+
model.clear_kv_cache()
|
257 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
258 |
+
|
259 |
+
|
260 |
+
def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
|
261 |
+
with fabric.init_tensor():
|
262 |
+
model.set_kv_cache(batch_size=1)
|
263 |
+
tokenlist = generate_ASR(
|
264 |
+
model,
|
265 |
+
audio_feature,
|
266 |
+
input_ids,
|
267 |
+
[leng],
|
268 |
+
["A1T1"],
|
269 |
+
max_returned_tokens=2048,
|
270 |
+
temperature=0.9,
|
271 |
+
top_k=1,
|
272 |
+
eos_id_a=_eoa,
|
273 |
+
eos_id_t=_eot,
|
274 |
+
pad_id_t=_pad_t,
|
275 |
+
shift=padded_text_vocabsize,
|
276 |
+
include_prompt=True,
|
277 |
+
generate_text=True,
|
278 |
+
)
|
279 |
+
model.clear_kv_cache()
|
280 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
281 |
+
|
282 |
+
|
283 |
+
def T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
284 |
+
snacmodel, out_dir=None):
|
285 |
+
with fabric.init_tensor():
|
286 |
+
model.set_kv_cache(batch_size=1)
|
287 |
+
tokenlist = generate_TA(
|
288 |
+
model,
|
289 |
+
None,
|
290 |
+
input_ids,
|
291 |
+
None,
|
292 |
+
["T1A2"],
|
293 |
+
max_returned_tokens=2048,
|
294 |
+
temperature=0.9,
|
295 |
+
top_k=1,
|
296 |
+
eos_id_a=_eoa,
|
297 |
+
eos_id_t=_eot,
|
298 |
+
pad_id_t=_pad_t,
|
299 |
+
shift=padded_text_vocabsize,
|
300 |
+
include_prompt=True,
|
301 |
+
generate_text=True,
|
302 |
+
)
|
303 |
+
|
304 |
+
audiolist = reconscruct_snac(tokenlist)
|
305 |
+
tokenlist = tokenlist[-1]
|
306 |
+
|
307 |
+
if text_vocabsize in tokenlist:
|
308 |
+
tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
|
309 |
+
audio = reconstruct_tensors(audiolist)
|
310 |
+
if out_dir is None:
|
311 |
+
out_dir = "./output/default/T1-A2"
|
312 |
+
else:
|
313 |
+
out_dir = out_dir + "/T1-A2"
|
314 |
+
if not os.path.exists(out_dir):
|
315 |
+
os.makedirs(out_dir)
|
316 |
+
|
317 |
+
with torch.inference_mode():
|
318 |
+
audio_hat = snacmodel.decode(audio)
|
319 |
+
sf.write(
|
320 |
+
f"{out_dir}/{step:02d}.wav",
|
321 |
+
audio_hat.squeeze().cpu().numpy(),
|
322 |
+
24000,
|
323 |
+
)
|
324 |
+
model.clear_kv_cache()
|
325 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
326 |
+
|
327 |
+
|
328 |
+
def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
329 |
+
|
330 |
+
with fabric.init_tensor():
|
331 |
+
model.set_kv_cache(batch_size=1)
|
332 |
+
tokenlist = generate_TT(
|
333 |
+
model,
|
334 |
+
None,
|
335 |
+
input_ids,
|
336 |
+
None,
|
337 |
+
["T1T2"],
|
338 |
+
max_returned_tokens=2048,
|
339 |
+
temperature=0.9,
|
340 |
+
top_k=1,
|
341 |
+
eos_id_a=_eoa,
|
342 |
+
eos_id_t=_eot,
|
343 |
+
pad_id_t=_pad_t,
|
344 |
+
shift=padded_text_vocabsize,
|
345 |
+
include_prompt=True,
|
346 |
+
generate_text=True,
|
347 |
+
)
|
348 |
+
model.clear_kv_cache()
|
349 |
+
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
350 |
+
|
351 |
+
|
352 |
+
def load_model(ckpt_dir, device):
|
353 |
+
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
354 |
+
whisper_model_path = ckpt_dir + "/small.pt"
|
355 |
+
if not os.path.exists(whisper_model_path):
|
356 |
+
whisper_model_path = "small"
|
357 |
+
whispermodel = whisper.load_model(whisper_model_path).to(device)
|
358 |
+
text_tokenizer = Tokenizer(ckpt_dir)
|
359 |
+
fabric = L.Fabric(devices=1, strategy="auto")
|
360 |
+
config = Config.from_file(ckpt_dir + "/model_config.yaml")
|
361 |
+
config.post_adapter = False
|
362 |
+
|
363 |
+
with fabric.init_module(empty_init=False):
|
364 |
+
model = GPT(config)
|
365 |
+
|
366 |
+
model = fabric.setup(model)
|
367 |
+
state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
|
368 |
+
model.load_state_dict(state_dict, strict=True)
|
369 |
+
model.to(device).eval()
|
370 |
+
|
371 |
+
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
372 |
+
|
373 |
+
|
374 |
+
def download_model(ckpt_dir):
|
375 |
+
repo_id = "gpt-omni/mini-omni2"
|
376 |
+
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
377 |
+
|
378 |
+
|
379 |
+
def get_text_stream(list_output, index, text_tokenizer):
|
380 |
+
text_tokens = list_output[-1][index:]
|
381 |
+
index += len(text_tokens)
|
382 |
+
is_text_end = False
|
383 |
+
if text_vocabsize in text_tokens:
|
384 |
+
text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
|
385 |
+
is_text_end = True
|
386 |
+
if len(text_tokens) == 0:
|
387 |
+
return "", index, is_text_end
|
388 |
+
res_text = text_tokenizer.decode(torch.tensor(text_tokens))
|
389 |
+
return res_text, index, is_text_end
|
390 |
+
|
391 |
+
|
392 |
+
class OmniInference:
|
393 |
+
|
394 |
+
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
|
395 |
+
self.device = device
|
396 |
+
if not os.path.exists(ckpt_dir):
|
397 |
+
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
398 |
+
download_model(ckpt_dir)
|
399 |
+
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
|
400 |
+
|
401 |
+
def warm_up(self, sample='./data/samples/output1.wav'):
|
402 |
+
for _ in self.run_AT_batch_stream(sample):
|
403 |
+
pass
|
404 |
+
|
405 |
+
@torch.inference_mode()
|
406 |
+
def run_AT_batch_stream(self,
|
407 |
+
audio_path,
|
408 |
+
stream_stride=4,
|
409 |
+
max_returned_tokens=2048,
|
410 |
+
temperature=0.9,
|
411 |
+
top_k=1,
|
412 |
+
top_p=1.0,
|
413 |
+
eos_id_a=_eoa,
|
414 |
+
eos_id_t=_eot,
|
415 |
+
save_path=None,
|
416 |
+
sample_rate=24000,
|
417 |
+
):
|
418 |
+
|
419 |
+
assert os.path.exists(audio_path), f"audio file {audio_path} not found"
|
420 |
+
model = self.model
|
421 |
+
|
422 |
+
with self.fabric.init_tensor():
|
423 |
+
model.set_kv_cache(batch_size=2,device=self.device)
|
424 |
+
|
425 |
+
mel, leng = load_audio(audio_path)
|
426 |
+
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
427 |
+
T = input_ids[0].size(1)
|
428 |
+
device = input_ids[0].device
|
429 |
+
|
430 |
+
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
|
431 |
+
|
432 |
+
if model.max_seq_length < max_returned_tokens - 1:
|
433 |
+
raise NotImplementedError(
|
434 |
+
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
435 |
+
)
|
436 |
+
|
437 |
+
input_pos = torch.tensor([T], device=device)
|
438 |
+
list_output = [[] for i in range(8)]
|
439 |
+
tokens_A, token_T = next_token_image_batch(
|
440 |
+
model,
|
441 |
+
audio_feature.to(torch.float32).to(model.device),
|
442 |
+
None,
|
443 |
+
input_ids,
|
444 |
+
[T - 3, T - 3],
|
445 |
+
["A1T2", "A1T2"],
|
446 |
+
input_pos=torch.arange(0, T, device=device),
|
447 |
+
temperature=temperature,
|
448 |
+
top_k=top_k,
|
449 |
+
top_p=top_p,
|
450 |
+
)
|
451 |
+
|
452 |
+
for i in range(7):
|
453 |
+
list_output[i].append(tokens_A[i].tolist()[0])
|
454 |
+
list_output[7].append(token_T.tolist()[0])
|
455 |
+
|
456 |
+
model_input_ids = [[] for i in range(8)]
|
457 |
+
for i in range(7):
|
458 |
+
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
|
459 |
+
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
460 |
+
model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
|
461 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
462 |
+
|
463 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
464 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
465 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
466 |
+
|
467 |
+
text_end = False
|
468 |
+
index = 1
|
469 |
+
nums_generate = stream_stride
|
470 |
+
begin_generate = False
|
471 |
+
current_index = 0
|
472 |
+
|
473 |
+
text_index = 0
|
474 |
+
is_text_end = False
|
475 |
+
|
476 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
477 |
+
tokens_A, token_T = next_token_image_batch(
|
478 |
+
model,
|
479 |
+
None,
|
480 |
+
None,
|
481 |
+
model_input_ids,
|
482 |
+
None,
|
483 |
+
None,
|
484 |
+
input_pos=input_pos,
|
485 |
+
temperature=temperature,
|
486 |
+
top_k=top_k,
|
487 |
+
top_p=top_p,
|
488 |
+
)
|
489 |
+
|
490 |
+
if text_end:
|
491 |
+
token_T = torch.tensor([_pad_t], device=device)
|
492 |
+
|
493 |
+
if tokens_A[-1] == eos_id_a:
|
494 |
+
break
|
495 |
+
|
496 |
+
if token_T == eos_id_t:
|
497 |
+
text_end = True
|
498 |
+
|
499 |
+
for i in range(7):
|
500 |
+
list_output[i].append(tokens_A[i].tolist()[0])
|
501 |
+
list_output[7].append(token_T.tolist()[0])
|
502 |
+
|
503 |
+
model_input_ids = [[] for i in range(8)]
|
504 |
+
for i in range(7):
|
505 |
+
tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
|
506 |
+
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
507 |
+
model_input_ids[i].append(
|
508 |
+
torch.tensor([layershift(4097, i)], device=device)
|
509 |
+
)
|
510 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
511 |
+
|
512 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
513 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
514 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
515 |
+
|
516 |
+
if index == 7:
|
517 |
+
begin_generate = True
|
518 |
+
|
519 |
+
if begin_generate:
|
520 |
+
current_index += 1
|
521 |
+
if current_index == nums_generate:
|
522 |
+
current_index = 0
|
523 |
+
snac = get_snac(list_output, index, nums_generate)
|
524 |
+
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
|
525 |
+
if is_text_end:
|
526 |
+
text_stream = ""
|
527 |
+
else:
|
528 |
+
text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
|
529 |
+
|
530 |
+
yield (audio_stream, text_stream)
|
531 |
+
|
532 |
+
input_pos = input_pos.add_(1)
|
533 |
+
index += 1
|
534 |
+
text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
|
535 |
+
print(f"text output: {text}")
|
536 |
+
|
537 |
+
if save_path is not None:
|
538 |
+
audiolist = reconscruct_snac(list_output)
|
539 |
+
audio = reconstruct_tensors(audiolist)
|
540 |
+
with torch.inference_mode():
|
541 |
+
audio_hat = self.snacmodel.decode(audio)
|
542 |
+
sf.write(save_path, audio_hat.squeeze().cpu().numpy(), sample_rate)
|
543 |
+
|
544 |
+
model.clear_kv_cache()
|
545 |
+
return list_output
|
546 |
+
|
547 |
+
|
548 |
+
def test_infer():
|
549 |
+
device = "cuda:0"
|
550 |
+
out_dir = f"./output/{get_time_str()}"
|
551 |
+
ckpt_dir = f"./checkpoint"
|
552 |
+
if not os.path.exists(ckpt_dir):
|
553 |
+
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
554 |
+
download_model(ckpt_dir)
|
555 |
+
|
556 |
+
fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
|
557 |
+
|
558 |
+
task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
|
559 |
+
|
560 |
+
# prepare test data
|
561 |
+
# TODO
|
562 |
+
test_audio_list = sorted(glob.glob('./data/samples/output*.wav'))
|
563 |
+
test_audio_transcripts = [
|
564 |
+
"What is your name?",
|
565 |
+
"what are your hobbies?",
|
566 |
+
"Do you like beijing",
|
567 |
+
"How are you feeling today?",
|
568 |
+
"what is the weather like today?",
|
569 |
+
]
|
570 |
+
test_text_list = [
|
571 |
+
"What is your name?",
|
572 |
+
"How are you feeling today?",
|
573 |
+
"Can you describe your surroundings?",
|
574 |
+
"What did you do yesterday?",
|
575 |
+
"What is your favorite book and why?",
|
576 |
+
"How do you make a cup of tea?",
|
577 |
+
"What is the weather like today?",
|
578 |
+
"Can you explain the concept of time?",
|
579 |
+
"Can you tell me a joke?",
|
580 |
+
]
|
581 |
+
|
582 |
+
# LOAD MODEL
|
583 |
+
with torch.no_grad():
|
584 |
+
if "A1A2" in task:
|
585 |
+
print("===============================================================")
|
586 |
+
print(" testing A1A2")
|
587 |
+
print("===============================================================")
|
588 |
+
step = 0
|
589 |
+
for path in test_audio_list:
|
590 |
+
try:
|
591 |
+
mel, leng = load_audio(path)
|
592 |
+
audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
|
593 |
+
text = A1_A2(
|
594 |
+
fabric,
|
595 |
+
audio_feature,
|
596 |
+
input_ids,
|
597 |
+
leng,
|
598 |
+
model,
|
599 |
+
text_tokenizer,
|
600 |
+
step,
|
601 |
+
snacmodel,
|
602 |
+
out_dir=out_dir,
|
603 |
+
)
|
604 |
+
print(f"input: {test_audio_transcripts[step]}")
|
605 |
+
print(f"output: {text}")
|
606 |
+
step += 1
|
607 |
+
print(
|
608 |
+
"+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
|
609 |
+
)
|
610 |
+
except:
|
611 |
+
print(f"[error] failed to process {path}")
|
612 |
+
print("===============================================================")
|
613 |
+
|
614 |
+
if 'asr' in task:
|
615 |
+
print("===============================================================")
|
616 |
+
print(" testing asr")
|
617 |
+
print("===============================================================")
|
618 |
+
|
619 |
+
index = 0
|
620 |
+
step = 0
|
621 |
+
for path in test_audio_list:
|
622 |
+
mel, leng = load_audio(path)
|
623 |
+
audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
|
624 |
+
output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
|
625 |
+
print(f"audio_path: {path}")
|
626 |
+
print(f"audio transcript: {test_audio_transcripts[index]}")
|
627 |
+
print(f"asr output: {output}")
|
628 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
629 |
+
index += 1
|
630 |
+
|
631 |
+
if "T1A2" in task:
|
632 |
+
step = 0
|
633 |
+
print("\n")
|
634 |
+
print("===============================================================")
|
635 |
+
print(" testing T1A2")
|
636 |
+
print("===============================================================")
|
637 |
+
for text in test_text_list:
|
638 |
+
input_ids = get_input_ids_TA(text, text_tokenizer)
|
639 |
+
text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
|
640 |
+
snacmodel, out_dir=out_dir)
|
641 |
+
print(f"input: {text}")
|
642 |
+
print(f"output: {text_output}")
|
643 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
644 |
+
step += 1
|
645 |
+
print("===============================================================")
|
646 |
+
|
647 |
+
if "T1T2" in task:
|
648 |
+
step = 0
|
649 |
+
print("\n")
|
650 |
+
print("===============================================================")
|
651 |
+
print(" testing T1T2")
|
652 |
+
print("===============================================================")
|
653 |
+
|
654 |
+
for text in test_text_list:
|
655 |
+
input_ids = get_input_ids_TT(text, text_tokenizer)
|
656 |
+
text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
|
657 |
+
print(f" Input: {text}")
|
658 |
+
print(f"Output: {text_output}")
|
659 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
660 |
+
print("===============================================================")
|
661 |
+
|
662 |
+
if "AT" in task:
|
663 |
+
print("===============================================================")
|
664 |
+
print(" testing A1T2")
|
665 |
+
print("===============================================================")
|
666 |
+
step = 0
|
667 |
+
for path in test_audio_list:
|
668 |
+
mel, leng = load_audio(path)
|
669 |
+
audio_feature, input_ids = get_input_ids_whisper(
|
670 |
+
mel, leng, whispermodel, device,
|
671 |
+
special_token_a=_pad_a, special_token_t=_answer_t
|
672 |
+
)
|
673 |
+
text = A1_T2(
|
674 |
+
fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
|
675 |
+
)
|
676 |
+
print(f"input: {test_audio_transcripts[step]}")
|
677 |
+
print(f"output: {text}")
|
678 |
+
step += 1
|
679 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
680 |
+
print("===============================================================")
|
681 |
+
|
682 |
+
if "AA-BATCH" in task:
|
683 |
+
print("===============================================================")
|
684 |
+
print(" testing A1A2-BATCH")
|
685 |
+
print("===============================================================")
|
686 |
+
step = 0
|
687 |
+
for path in test_audio_list:
|
688 |
+
mel, leng = load_audio(path)
|
689 |
+
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
|
690 |
+
text = A1_A2_batch(
|
691 |
+
fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
692 |
+
snacmodel, out_dir=out_dir
|
693 |
+
)
|
694 |
+
print(f"input: {test_audio_transcripts[step]}")
|
695 |
+
print(f"output: {text}")
|
696 |
+
step += 1
|
697 |
+
print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
698 |
+
print("===============================================================")
|
699 |
+
|
700 |
+
print("*********************** test end *****************************")
|
701 |
+
|
702 |
+
|
703 |
+
|
704 |
+
if __name__ == "__main__":
|
705 |
+
test_infer()
|
inference_vision.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from litgpt.generate.base import next_token_image_batch
|
4 |
+
import soundfile as sf
|
5 |
+
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
|
6 |
+
from utils.snac_utils import get_snac, generate_audio_data
|
7 |
+
import clip
|
8 |
+
import inference
|
9 |
+
from tqdm import tqdm
|
10 |
+
from inference import OmniInference, load_model, load_audio, download_model
|
11 |
+
from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
|
15 |
+
torch.set_printoptions(sci_mode=False)
|
16 |
+
|
17 |
+
_image = inference._image
|
18 |
+
_eoimage = inference._eoimage
|
19 |
+
_pad_t = inference._pad_t
|
20 |
+
_input_t = inference._input_t
|
21 |
+
_answer_t = inference._answer_t
|
22 |
+
_eot = inference._eot
|
23 |
+
_eoa = inference._eoa
|
24 |
+
_pad_a = inference._pad_a
|
25 |
+
_input_a = inference._input_a
|
26 |
+
_answer_a = inference._answer_a
|
27 |
+
|
28 |
+
|
29 |
+
def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
mel = mel.unsqueeze(0).to(device)
|
33 |
+
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
|
34 |
+
|
35 |
+
audio_len = audio_feature.size(0)
|
36 |
+
|
37 |
+
input_ids = []
|
38 |
+
input_ids_item = [[] for i in range(8)]
|
39 |
+
for i in range(7):
|
40 |
+
input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
|
41 |
+
input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
|
42 |
+
input_ids_item[i] += [layershift(_answer_a,i)]
|
43 |
+
|
44 |
+
input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
|
45 |
+
input_ids_item = [torch.tensor(item) for item in input_ids_item]
|
46 |
+
|
47 |
+
input_ids.append(input_ids_item)
|
48 |
+
|
49 |
+
input_ids_item = [[] for i in range(8)]
|
50 |
+
for i in range(7):
|
51 |
+
input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
|
52 |
+
input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
|
53 |
+
|
54 |
+
input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
|
55 |
+
|
56 |
+
input_ids_item = [torch.tensor(item) for item in input_ids_item]
|
57 |
+
input_ids.append(input_ids_item)
|
58 |
+
|
59 |
+
stacked_inputids = [[] for _ in range(8)]
|
60 |
+
for i in range(2):
|
61 |
+
for j in range(8):
|
62 |
+
stacked_inputids[j].append(input_ids[i][j])
|
63 |
+
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
|
64 |
+
|
65 |
+
return torch.stack([audio_feature,audio_feature]), stacked_inputids
|
66 |
+
|
67 |
+
|
68 |
+
def load_clip_model(ckpt_dir, device):
|
69 |
+
clip_model_path = ckpt_dir + "/ViT-B-32.pt"
|
70 |
+
if not os.path.exists(clip_model_path):
|
71 |
+
clip_model_path = "ViT-B/32"
|
72 |
+
clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
|
73 |
+
return clipmodel, clippreprocess
|
74 |
+
|
75 |
+
|
76 |
+
class OmniVisionInference(OmniInference):
|
77 |
+
|
78 |
+
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
|
79 |
+
self.device = device
|
80 |
+
if not os.path.exists(ckpt_dir):
|
81 |
+
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
|
82 |
+
download_model(ckpt_dir)
|
83 |
+
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
|
84 |
+
self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
|
85 |
+
|
86 |
+
def warm_up(self,
|
87 |
+
audio_sample='./data/samples/vision_qa_audio.wav',
|
88 |
+
image_sample='./data/samples/vision_qa_image.jpg'
|
89 |
+
):
|
90 |
+
for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
|
91 |
+
save_path="./data/samples/vision_qa_output.wav",
|
92 |
+
warm_up=True):
|
93 |
+
pass
|
94 |
+
|
95 |
+
@torch.inference_mode()
|
96 |
+
def run_vision_AA_batch_stream(self, audio_path, image_path,
|
97 |
+
stream_stride=4,
|
98 |
+
max_returned_tokens=2048,
|
99 |
+
temperature=0.9,
|
100 |
+
top_k=1,
|
101 |
+
top_p=1.0,
|
102 |
+
eos_id_a=_eoa,
|
103 |
+
eos_id_t=_eot,
|
104 |
+
pad_id=_pad_t,
|
105 |
+
save_path=None,
|
106 |
+
warm_up=False
|
107 |
+
):
|
108 |
+
with self.fabric.init_tensor():
|
109 |
+
self.model.set_kv_cache(batch_size=2)
|
110 |
+
|
111 |
+
model = self.model
|
112 |
+
|
113 |
+
mel, leng = load_audio(audio_path)
|
114 |
+
img = Image.open(image_path)
|
115 |
+
|
116 |
+
audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
|
117 |
+
ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
|
118 |
+
ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
|
119 |
+
|
120 |
+
ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
|
121 |
+
leng = [leng,leng]
|
122 |
+
task = ['ImageQA_A','ImageQA_AT']
|
123 |
+
|
124 |
+
T = input_ids[0].size(1)
|
125 |
+
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
|
126 |
+
|
127 |
+
if model.max_seq_length < max_returned_tokens - 1:
|
128 |
+
raise NotImplementedError(
|
129 |
+
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
130 |
+
)
|
131 |
+
|
132 |
+
list_output = [[] for i in range(8)]
|
133 |
+
|
134 |
+
tokens_A , token_T = next_token_image_batch(
|
135 |
+
model,
|
136 |
+
audio_feature.to(torch.float32).to(self.device),
|
137 |
+
ima_feature.to(torch.float32).to(self.device) ,
|
138 |
+
input_ids ,
|
139 |
+
whisper_lens = leng ,
|
140 |
+
task = task,
|
141 |
+
input_pos = torch.arange(0, T, device=self.device),
|
142 |
+
temperature=temperature,
|
143 |
+
top_k=top_k,
|
144 |
+
top_p=top_p
|
145 |
+
)
|
146 |
+
for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
|
147 |
+
list_output[7].append(token_T.tolist()[0])
|
148 |
+
|
149 |
+
text_end = False
|
150 |
+
index = 1
|
151 |
+
nums_generate = stream_stride
|
152 |
+
begin_generate = False
|
153 |
+
current_index = 0
|
154 |
+
input_pos = torch.tensor([T], device=self.device)
|
155 |
+
|
156 |
+
model_input_ids = [[] for i in range(8)]
|
157 |
+
for i in range(7):
|
158 |
+
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
|
159 |
+
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
|
160 |
+
model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
|
161 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
162 |
+
|
163 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
164 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
165 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
166 |
+
|
167 |
+
text_index = 0
|
168 |
+
is_text_end = False
|
169 |
+
|
170 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
171 |
+
|
172 |
+
tokens_A , token_T = next_token_image_batch(model, None , None ,
|
173 |
+
input_ids = model_input_ids,
|
174 |
+
whisper_lens= None,
|
175 |
+
task = None,
|
176 |
+
input_pos = input_pos,
|
177 |
+
temperature=temperature,
|
178 |
+
top_k=top_k,
|
179 |
+
top_p=top_p)
|
180 |
+
|
181 |
+
if text_end:
|
182 |
+
token_T = torch.tensor([_pad_t], device=self.device)
|
183 |
+
|
184 |
+
if tokens_A[-1] == eos_id_a:
|
185 |
+
break
|
186 |
+
if token_T == eos_id_t:
|
187 |
+
text_end = True
|
188 |
+
|
189 |
+
for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
|
190 |
+
list_output[7].append(token_T.tolist()[0])
|
191 |
+
|
192 |
+
|
193 |
+
if index == 7:
|
194 |
+
begin_generate = True
|
195 |
+
|
196 |
+
if begin_generate:
|
197 |
+
current_index += 1
|
198 |
+
if current_index == nums_generate:
|
199 |
+
current_index = 0
|
200 |
+
snac = get_snac(list_output,index,nums_generate)
|
201 |
+
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
|
202 |
+
if is_text_end:
|
203 |
+
text_stream = ""
|
204 |
+
else:
|
205 |
+
text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
|
206 |
+
|
207 |
+
yield (audio_stream, text_stream)
|
208 |
+
|
209 |
+
if warm_up:
|
210 |
+
break
|
211 |
+
|
212 |
+
input_pos = input_pos.add_(1)
|
213 |
+
model_input_ids = [[] for i in range(8)]
|
214 |
+
for i in range(7):
|
215 |
+
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
|
216 |
+
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
|
217 |
+
model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
|
218 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
219 |
+
|
220 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
221 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
222 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
223 |
+
|
224 |
+
index += 1
|
225 |
+
|
226 |
+
text_tokens = list_output[-1]
|
227 |
+
if text_vocabsize in text_tokens:
|
228 |
+
text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
|
229 |
+
res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
|
230 |
+
print(f"text output: {res_text}")
|
231 |
+
|
232 |
+
if save_path is not None:
|
233 |
+
audiolist = reconscruct_snac(list_output)
|
234 |
+
audio = reconstruct_tensors(audiolist)
|
235 |
+
with torch.inference_mode():
|
236 |
+
audio_hat = self.snacmodel.decode(audio)
|
237 |
+
sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
|
238 |
+
|
239 |
+
model.clear_kv_cache()
|
240 |
+
|
241 |
+
|
242 |
+
def test_vision_infer():
|
243 |
+
client = OmniVisionInference()
|
244 |
+
client.warm_up()
|
245 |
+
input_audio_path = './data/samples/vision_qa_audio.wav'
|
246 |
+
input_image_path = './data/samples/vision_qa_image.jpg'
|
247 |
+
|
248 |
+
res_text = ""
|
249 |
+
for audio_stream, text_stream in client.run_vision_AA_batch_stream(
|
250 |
+
input_audio_path,
|
251 |
+
input_image_path,
|
252 |
+
save_path="./vision_qa_output.wav"
|
253 |
+
):
|
254 |
+
res_text += text_stream
|
255 |
+
print(f"text_output: {res_text}")
|
256 |
+
|
257 |
+
|
258 |
+
if __name__ == "__main__":
|
259 |
+
test_vision_infer()
|
litgpt/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from litgpt.model import GPT # needs to be imported before config
|
6 |
+
from litgpt.config import Config
|
7 |
+
from litgpt.tokenizer import Tokenizer
|
8 |
+
|
9 |
+
# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
|
10 |
+
pattern = re.compile(".*Profiler function .* will be ignored")
|
11 |
+
logging.getLogger("torch._dynamo.variables.torch").addFilter(
|
12 |
+
lambda record: not pattern.search(record.getMessage())
|
13 |
+
)
|
14 |
+
|
15 |
+
# Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
|
16 |
+
logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
|
17 |
+
logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
|
18 |
+
|
19 |
+
__all__ = ["GPT", "Config", "Tokenizer"]
|
litgpt/config.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Literal, Optional, Type, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
from typing_extensions import Self
|
11 |
+
|
12 |
+
import litgpt.model
|
13 |
+
from litgpt.utils import find_multiple
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class Config:
|
18 |
+
name: str = ""
|
19 |
+
hf_config: dict = field(default_factory=dict)
|
20 |
+
scale_embeddings: bool = False
|
21 |
+
block_size: int = 4096
|
22 |
+
vocab_size: int = 50254
|
23 |
+
padding_multiple: int = 512
|
24 |
+
padded_vocab_size: Optional[int] = None
|
25 |
+
n_layer: int = 16
|
26 |
+
n_head: int = 32
|
27 |
+
head_size: Optional[int] = None
|
28 |
+
n_embd: int = 4096
|
29 |
+
rotary_percentage: float = 0.25
|
30 |
+
parallel_residual: bool = True
|
31 |
+
bias: bool = True
|
32 |
+
lm_head_bias: bool = False
|
33 |
+
# to use multi-head attention (MHA), set this to `n_head` (default)
|
34 |
+
# to use multi-query attention (MQA), set this to 1
|
35 |
+
# to use grouped-query attention (GQA), set this to a value in between
|
36 |
+
# Example with `n_head=4`
|
37 |
+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
38 |
+
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
|
39 |
+
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
40 |
+
# │ │ │ │ │ │ │
|
41 |
+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
|
42 |
+
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
|
43 |
+
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
|
44 |
+
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
|
45 |
+
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
|
46 |
+
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
|
47 |
+
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
|
48 |
+
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
|
49 |
+
# MHA GQA MQA
|
50 |
+
# n_query_groups=4 n_query_groups=2 n_query_groups=1
|
51 |
+
#
|
52 |
+
# credit https://arxiv.org/pdf/2305.13245.pdf
|
53 |
+
n_query_groups: Optional[int] = None
|
54 |
+
shared_attention_norm: bool = False
|
55 |
+
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
|
56 |
+
norm_eps: float = 1e-5
|
57 |
+
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
|
58 |
+
"GptNeoxMLP"
|
59 |
+
)
|
60 |
+
gelu_approximate: str = "none"
|
61 |
+
intermediate_size: Optional[int] = None
|
62 |
+
rope_condense_ratio: int = 1
|
63 |
+
rope_base: int = 10000
|
64 |
+
n_expert: int = 0
|
65 |
+
n_expert_per_token: int = 0
|
66 |
+
|
67 |
+
add_qkv_bias: Optional[bool] = None
|
68 |
+
prompt_vocab_size: Optional[int] = None
|
69 |
+
attn_dropout: float = 0.0
|
70 |
+
pos_type: str = "rope"
|
71 |
+
force_align: bool = False
|
72 |
+
use_pretrain_phoneme_emb: bool = False
|
73 |
+
tie_word_embeddings: bool = False
|
74 |
+
|
75 |
+
# setting for mini-omni
|
76 |
+
text_vocab_size:int = 152000
|
77 |
+
cat_audio_vocab_size: int = 29120
|
78 |
+
audio_vocab_size: int = 4160
|
79 |
+
whisper_adapter_dim: int = 768
|
80 |
+
vision_adapter_dim: int = 512
|
81 |
+
|
82 |
+
post_adapter: bool = False
|
83 |
+
post_adapter_layers: int = 6
|
84 |
+
asr_adapter: str = "llamamlp"
|
85 |
+
|
86 |
+
def __post_init__(self):
|
87 |
+
if not self.name:
|
88 |
+
self.name = self.hf_config.get("name", self.name)
|
89 |
+
|
90 |
+
if self.head_size is None:
|
91 |
+
assert self.n_embd % self.n_head == 0
|
92 |
+
self.head_size = self.n_embd // self.n_head
|
93 |
+
|
94 |
+
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
|
95 |
+
if self.padded_vocab_size is None:
|
96 |
+
self.padded_vocab_size = find_multiple(
|
97 |
+
self.vocab_size, self.padding_multiple
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
# vocab size shouldn't be larger than padded vocab size
|
101 |
+
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
|
102 |
+
|
103 |
+
# compute the number of query groups
|
104 |
+
if self.n_query_groups is not None:
|
105 |
+
assert self.n_head % self.n_query_groups == 0
|
106 |
+
else:
|
107 |
+
self.n_query_groups = self.n_head
|
108 |
+
|
109 |
+
# compute the intermediate size for MLP if not set
|
110 |
+
if self.intermediate_size is None:
|
111 |
+
if self.mlp_class_name == "LLaMAMLP":
|
112 |
+
raise ValueError(
|
113 |
+
f"The config {self.name!r}, needs to set the `intermediate_size`"
|
114 |
+
)
|
115 |
+
self.intermediate_size = 4 * self.n_embd
|
116 |
+
|
117 |
+
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
|
118 |
+
|
119 |
+
if self.add_qkv_bias is None:
|
120 |
+
self.add_qkv_bias = self.bias
|
121 |
+
|
122 |
+
@classmethod
|
123 |
+
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
|
124 |
+
if name not in name_to_config:
|
125 |
+
# search through all `config['hf_config']['name']`
|
126 |
+
try:
|
127 |
+
conf_dict = next(
|
128 |
+
config
|
129 |
+
for config in configs
|
130 |
+
if name == config["hf_config"]["name"]
|
131 |
+
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
|
132 |
+
== name
|
133 |
+
)
|
134 |
+
except StopIteration:
|
135 |
+
raise ValueError(f"{name!r} is not a supported config name")
|
136 |
+
else:
|
137 |
+
conf_dict = name_to_config[name]
|
138 |
+
|
139 |
+
conf_dict = conf_dict.copy()
|
140 |
+
conf_dict.update(kwargs)
|
141 |
+
return cls(**conf_dict)
|
142 |
+
|
143 |
+
@classmethod
|
144 |
+
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
|
145 |
+
with open(path, encoding="utf-8") as fp:
|
146 |
+
file_kwargs = yaml.safe_load(fp)
|
147 |
+
if file_kwargs is None:
|
148 |
+
raise ValueError(f"{path} is empty which is likely unexpected.")
|
149 |
+
file_kwargs.update(kwargs)
|
150 |
+
return cls(**file_kwargs)
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
|
154 |
+
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
|
155 |
+
if (config_path := path / "model_config.yaml").is_file():
|
156 |
+
return cls.from_file(config_path, **kwargs)
|
157 |
+
if (model_name := path.name) in name_to_config:
|
158 |
+
return cls.from_name(model_name, **kwargs)
|
159 |
+
raise FileNotFoundError(
|
160 |
+
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
|
161 |
+
)
|
162 |
+
|
163 |
+
@property
|
164 |
+
def mlp_class(self) -> Type:
|
165 |
+
# `self.mlp_class_name` cannot be the type to keep the config serializable
|
166 |
+
return getattr(litgpt.model, self.mlp_class_name)
|
167 |
+
|
168 |
+
@property
|
169 |
+
def norm_class(self) -> Type:
|
170 |
+
# `self.norm_class_name` cannot be the type to keep the config serializable
|
171 |
+
if self.norm_class_name == "RMSNorm":
|
172 |
+
from functools import partial
|
173 |
+
|
174 |
+
from litgpt.model import RMSNorm
|
175 |
+
|
176 |
+
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
|
177 |
+
return getattr(torch.nn, self.norm_class_name)
|
178 |
+
|
179 |
+
|
180 |
+
configs = []
|
181 |
+
name_to_config = {config["name"]: config for config in configs}
|
litgpt/generate/__init__.py
ADDED
File without changes
|
litgpt/generate/base.py
ADDED
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 |
+
|
3 |
+
from typing import Any, Literal, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
# import torch._dynamo.config
|
7 |
+
# import torch._inductor.config
|
8 |
+
|
9 |
+
from litgpt.model import GPT
|
10 |
+
from utils.snac_utils import layershift, snac_config
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
|
15 |
+
if torch._dynamo.is_compiling():
|
16 |
+
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
|
17 |
+
distribution = torch.empty_like(probs).exponential_(1)
|
18 |
+
return torch.argmax(probs / distribution, dim=-1, keepdim=True)
|
19 |
+
return torch.multinomial(probs, num_samples=1)
|
20 |
+
|
21 |
+
|
22 |
+
def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
|
23 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
24 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
25 |
+
# Example:
|
26 |
+
# sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
|
27 |
+
# sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
|
28 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
29 |
+
# Keep at least 1 token always to prevent the case where no token is selected
|
30 |
+
# In this case the most probable one is always kept
|
31 |
+
sorted_indices_to_remove[-1:] = 0
|
32 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
33 |
+
0, sorted_indices, sorted_indices_to_remove
|
34 |
+
)
|
35 |
+
logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
36 |
+
return logits
|
37 |
+
|
38 |
+
|
39 |
+
def sample(
|
40 |
+
logits: torch.Tensor,
|
41 |
+
temperature: float = 1.0,
|
42 |
+
top_k: Optional[int] = None,
|
43 |
+
top_p: float = 1.0,
|
44 |
+
) -> torch.Tensor:
|
45 |
+
if top_p < 0.0 or top_p > 1.0:
|
46 |
+
raise ValueError(f"top_p must be in [0, 1], got {top_p}")
|
47 |
+
logits = logits[0, -1]
|
48 |
+
# optionally crop the logits to only the top k options
|
49 |
+
if top_k is not None:
|
50 |
+
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
|
51 |
+
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
|
52 |
+
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
|
53 |
+
# optionally scale the logits and sample from a probability distribution
|
54 |
+
if temperature > 0.0 or top_p > 0.0:
|
55 |
+
if temperature > 0.0:
|
56 |
+
logits = logits / temperature
|
57 |
+
# optionally crop the logits to smallest set of logits with a cumulative probability above top_p
|
58 |
+
if top_p < 1.0:
|
59 |
+
logits = sample_top_p(logits, top_p)
|
60 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
61 |
+
return multinomial_num_samples_1(probs)
|
62 |
+
return torch.argmax(logits, dim=-1, keepdim=True)
|
63 |
+
|
64 |
+
|
65 |
+
def next_token(
|
66 |
+
model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
|
67 |
+
) -> torch.Tensor:
|
68 |
+
input_pos = input_pos.to(model.device)
|
69 |
+
logits_a, logit_t = model(None, x, None, input_pos)
|
70 |
+
|
71 |
+
next_audio_tokens = []
|
72 |
+
for logit_a in logits_a:
|
73 |
+
next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
|
74 |
+
next_audio_tokens.append(next_a)
|
75 |
+
next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
|
76 |
+
return next_audio_tokens, next_t
|
77 |
+
|
78 |
+
|
79 |
+
def next_token_asr(
|
80 |
+
model: GPT,
|
81 |
+
input_pos: torch.Tensor,
|
82 |
+
audio_features: torch.tensor,
|
83 |
+
lens: int,
|
84 |
+
input_ids: list,
|
85 |
+
**kwargs: Any,
|
86 |
+
) -> torch.Tensor:
|
87 |
+
input_pos = input_pos.to(model.device)
|
88 |
+
input_ids = [input_id.to(model.device) for input_id in input_ids]
|
89 |
+
logits_a, logit_t = model(audio_features, input_ids, None, input_pos, whisper_lens=lens)
|
90 |
+
|
91 |
+
next_audio_tokens = []
|
92 |
+
for logit_a in logits_a:
|
93 |
+
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
|
94 |
+
next_audio_tokens.append(next_a)
|
95 |
+
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
|
96 |
+
return next_audio_tokens, next_t
|
97 |
+
|
98 |
+
|
99 |
+
def next_token_A1T2(
|
100 |
+
model: GPT,
|
101 |
+
audio_features: torch.tensor,
|
102 |
+
input_ids: list,
|
103 |
+
whisper_lens: int,
|
104 |
+
task: list,
|
105 |
+
input_pos: torch.Tensor,
|
106 |
+
**kwargs: Any,
|
107 |
+
) -> torch.Tensor:
|
108 |
+
input_pos = input_pos.to(model.device)
|
109 |
+
input_ids = [input_id.to(model.device) for input_id in input_ids]
|
110 |
+
logits_a, logit_t = model(
|
111 |
+
audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
|
112 |
+
)
|
113 |
+
|
114 |
+
next_audio_tokens = []
|
115 |
+
for logit_a in logits_a:
|
116 |
+
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
|
117 |
+
next_audio_tokens.append(next_a)
|
118 |
+
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
|
119 |
+
return next_audio_tokens, next_t
|
120 |
+
|
121 |
+
|
122 |
+
def next_token_A1T1(
|
123 |
+
model: GPT,
|
124 |
+
audio_features: torch.tensor,
|
125 |
+
input_ids: list,
|
126 |
+
whisper_lens: int,
|
127 |
+
task: list,
|
128 |
+
input_pos: torch.Tensor,
|
129 |
+
**kwargs: Any,
|
130 |
+
) -> torch.Tensor:
|
131 |
+
input_pos = input_pos.to(model.device)
|
132 |
+
input_ids = [input_id.to(model.device) for input_id in input_ids]
|
133 |
+
logits_a, logit_t = model(
|
134 |
+
audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
|
135 |
+
)
|
136 |
+
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
|
137 |
+
return next_t
|
138 |
+
|
139 |
+
|
140 |
+
def next_token_image_batch(model: GPT,
|
141 |
+
audio_features: torch.tensor,
|
142 |
+
clip_features: torch.tensor,
|
143 |
+
input_ids: list,
|
144 |
+
whisper_lens: int,
|
145 |
+
task: list,
|
146 |
+
input_pos: torch.Tensor,
|
147 |
+
**kwargs: Any) -> torch.Tensor:
|
148 |
+
input_pos = input_pos.to(model.device)
|
149 |
+
input_ids = [input_id.to(model.device) for input_id in input_ids]
|
150 |
+
logits_a,logit_t = model(audio_features, input_ids, clip_features,
|
151 |
+
input_pos, whisper_lens=whisper_lens, task=task)
|
152 |
+
|
153 |
+
for i in range(7):
|
154 |
+
logits_a[i] = logits_a[i][0].unsqueeze(0)
|
155 |
+
logit_t = logit_t[1].unsqueeze(0)
|
156 |
+
|
157 |
+
next_audio_tokens = []
|
158 |
+
for logit_a in logits_a:
|
159 |
+
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
|
160 |
+
next_audio_tokens.append(next_a)
|
161 |
+
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
|
162 |
+
return next_audio_tokens, next_t
|
163 |
+
|
164 |
+
|
165 |
+
# torch._dynamo.config.automatic_dynamic_shapes = True
|
166 |
+
# torch._inductor.config.triton.unique_kernel_names = True
|
167 |
+
# torch._inductor.config.coordinate_descent_tuning = True
|
168 |
+
# next_token = torch.compile(next_token, mode="reduce-overhead")
|
169 |
+
|
170 |
+
|
171 |
+
@torch.inference_mode()
|
172 |
+
def generate(
|
173 |
+
model: GPT,
|
174 |
+
input_ids: list,
|
175 |
+
max_returned_tokens: int,
|
176 |
+
*,
|
177 |
+
temperature: float = 1.0,
|
178 |
+
top_k: Optional[int] = None,
|
179 |
+
top_p: float = 1.0,
|
180 |
+
eos_id_a: Optional[int] = None,
|
181 |
+
eos_id_t: Optional[int] = None,
|
182 |
+
pad_id: Optional[int] = None,
|
183 |
+
shift: Optional[int] = None,
|
184 |
+
include_prompt: bool = True,
|
185 |
+
generate_text=False,
|
186 |
+
) -> torch.Tensor:
|
187 |
+
# print("eos_id_a:", eos_id_a)
|
188 |
+
# print("eos_id_t:", eos_id_t)
|
189 |
+
# print("pad_id:", pad_id)
|
190 |
+
"""
|
191 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
192 |
+
The implementation of this function is modified from A. Karpathy's nanoGPT.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
model: The model to use.
|
196 |
+
prompt: Tensor of shape (T) with indices of the prompt sequence.
|
197 |
+
max_returned_tokens: The maximum number of tokens to return (given plus generated).
|
198 |
+
temperature: Scales the predicted logits by 1 / temperature.
|
199 |
+
top_k: If specified, only sample among the tokens with the k highest probabilities.
|
200 |
+
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
|
201 |
+
In top-p sampling, the next token is sampled from the highest probability tokens
|
202 |
+
whose cumulative probability exceeds the threshold `top_p`. When specified,
|
203 |
+
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
|
204 |
+
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
|
205 |
+
It can be used in conjunction with `top_k` and `temperature` with the following order
|
206 |
+
of application:
|
207 |
+
|
208 |
+
1. `top_k` sampling
|
209 |
+
2. `temperature` scaling
|
210 |
+
3. `top_p` sampling
|
211 |
+
|
212 |
+
For more details, see https://arxiv.org/abs/1904.09751
|
213 |
+
or https://huyenchip.com/2024/01/16/sampling.html#top_p
|
214 |
+
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
|
215 |
+
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
|
216 |
+
"""
|
217 |
+
T = input_ids[0].size(0)
|
218 |
+
device = input_ids[0].device
|
219 |
+
assert max_returned_tokens > T
|
220 |
+
if model.max_seq_length < max_returned_tokens - 1:
|
221 |
+
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
|
222 |
+
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
|
223 |
+
# not support it to avoid negatively impacting the overall speed
|
224 |
+
raise NotImplementedError(
|
225 |
+
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
226 |
+
)
|
227 |
+
|
228 |
+
for input_id in input_ids:
|
229 |
+
input_id = [input_id]
|
230 |
+
(
|
231 |
+
tokens_A1,
|
232 |
+
tokens_A2,
|
233 |
+
tokens_A3,
|
234 |
+
tokens_A4,
|
235 |
+
tokens_A5,
|
236 |
+
tokens_A6,
|
237 |
+
tokens_A7,
|
238 |
+
tokens_T,
|
239 |
+
) = input_ids
|
240 |
+
|
241 |
+
tokens_A1_output = [tokens_A1]
|
242 |
+
tokens_A2_output = [tokens_A2]
|
243 |
+
tokens_A3_output = [tokens_A3]
|
244 |
+
tokens_A4_output = [tokens_A4]
|
245 |
+
tokens_A5_output = [tokens_A5]
|
246 |
+
tokens_A6_output = [tokens_A6]
|
247 |
+
tokens_A7_output = [tokens_A7]
|
248 |
+
tokens_T_output = [tokens_T]
|
249 |
+
|
250 |
+
list_output = [
|
251 |
+
tokens_A1_output,
|
252 |
+
tokens_A2_output,
|
253 |
+
tokens_A3_output,
|
254 |
+
tokens_A4_output,
|
255 |
+
tokens_A5_output,
|
256 |
+
tokens_A6_output,
|
257 |
+
tokens_A7_output,
|
258 |
+
tokens_T_output,
|
259 |
+
]
|
260 |
+
|
261 |
+
input_pos = torch.tensor([T], device=device)
|
262 |
+
model_input_ids = [
|
263 |
+
tokens_A1.view(1, -1),
|
264 |
+
tokens_A2.view(1, -1),
|
265 |
+
tokens_A3.view(1, -1),
|
266 |
+
tokens_A4.view(1, -1),
|
267 |
+
tokens_A5.view(1, -1),
|
268 |
+
tokens_A6.view(1, -1),
|
269 |
+
tokens_A7.view(1, -1),
|
270 |
+
tokens_T.view(1, -1),
|
271 |
+
]
|
272 |
+
|
273 |
+
tokens_A, token_T = next_token(
|
274 |
+
model,
|
275 |
+
torch.arange(0, T, device=device),
|
276 |
+
model_input_ids,
|
277 |
+
temperature=temperature,
|
278 |
+
top_k=top_k,
|
279 |
+
top_p=top_p,
|
280 |
+
)
|
281 |
+
for i in range(7):
|
282 |
+
list_output[i].append(tokens_A[i].clone())
|
283 |
+
list_output[7].append(token_T.clone())
|
284 |
+
|
285 |
+
# prepare the input for the next iteration
|
286 |
+
for i in range(7):
|
287 |
+
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
|
288 |
+
token_T = token_T.clone()
|
289 |
+
|
290 |
+
text_end = False
|
291 |
+
max_returned_tokens = 1000
|
292 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
293 |
+
model_input_ids = [
|
294 |
+
token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
|
295 |
+
] + [token_T.view(1, -1).to(torch.int32)]
|
296 |
+
tokens_A, token_T = next_token(
|
297 |
+
model,
|
298 |
+
input_pos,
|
299 |
+
model_input_ids,
|
300 |
+
temperature=temperature,
|
301 |
+
top_k=top_k,
|
302 |
+
top_p=top_p,
|
303 |
+
)
|
304 |
+
if text_end:
|
305 |
+
token_T = torch.tensor([pad_id], device=device)
|
306 |
+
|
307 |
+
for i in range(7):
|
308 |
+
list_output[i].append(tokens_A[i].clone())
|
309 |
+
list_output[7].append(token_T.clone())
|
310 |
+
|
311 |
+
if tokens_A[-1] == eos_id_a:
|
312 |
+
break
|
313 |
+
if token_T == eos_id_t:
|
314 |
+
if generate_text:
|
315 |
+
break
|
316 |
+
text_end = True
|
317 |
+
|
318 |
+
for i in range(7):
|
319 |
+
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
|
320 |
+
token_T = token_T.clone()
|
321 |
+
input_pos = input_pos.add_(1)
|
322 |
+
|
323 |
+
for i in range(len(list_output)):
|
324 |
+
list_output[i] = torch.cat(list_output[i])
|
325 |
+
return list_output
|
326 |
+
|
327 |
+
|
328 |
+
@torch.inference_mode()
|
329 |
+
def generate_TA_BATCH(
|
330 |
+
model: GPT,
|
331 |
+
audio_features: torch.Tensor,
|
332 |
+
input_ids: list,
|
333 |
+
leng,
|
334 |
+
task,
|
335 |
+
max_returned_tokens: int = 1000,
|
336 |
+
*,
|
337 |
+
temperature: float = 1.0,
|
338 |
+
top_k: Optional[int] = None,
|
339 |
+
top_p: float = 1.0,
|
340 |
+
eos_id_a: Optional[int] = None,
|
341 |
+
eos_id_t: Optional[int] = None,
|
342 |
+
pad_id_t: Optional[int] = None,
|
343 |
+
shift: Optional[int] = None,
|
344 |
+
include_prompt: bool = True,
|
345 |
+
generate_text=False,
|
346 |
+
) -> torch.Tensor:
|
347 |
+
|
348 |
+
T = input_ids[0].size(1)
|
349 |
+
device = input_ids[0].device
|
350 |
+
assert max_returned_tokens > T
|
351 |
+
if model.max_seq_length < max_returned_tokens - 1:
|
352 |
+
raise NotImplementedError(
|
353 |
+
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
|
354 |
+
)
|
355 |
+
|
356 |
+
input_pos = torch.tensor([T], device=device)
|
357 |
+
model_input_ids = input_ids
|
358 |
+
|
359 |
+
list_output = [[] for i in range(8)]
|
360 |
+
|
361 |
+
tokens_A, token_T = next_token_image_batch(
|
362 |
+
model,
|
363 |
+
audio_features.to(torch.float32).to(model.device),
|
364 |
+
None,
|
365 |
+
input_ids,
|
366 |
+
[T - 3, T - 3],
|
367 |
+
["A1T2", "A1T2"],
|
368 |
+
input_pos=torch.arange(0, T, device=device),
|
369 |
+
temperature=temperature,
|
370 |
+
top_k=top_k,
|
371 |
+
top_p=top_p,
|
372 |
+
)
|
373 |
+
|
374 |
+
for i in range(7):
|
375 |
+
list_output[i].append(tokens_A[i].tolist()[0])
|
376 |
+
list_output[7].append(token_T.tolist()[0])
|
377 |
+
|
378 |
+
model_input_ids = [[] for i in range(8)]
|
379 |
+
for i in range(7):
|
380 |
+
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
|
381 |
+
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
382 |
+
model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
|
383 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
384 |
+
|
385 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
386 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
387 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
388 |
+
|
389 |
+
text_end = False
|
390 |
+
|
391 |
+
for _ in range(2, max_returned_tokens - T + 1):
|
392 |
+
tokens_A, token_T = next_token_image_batch(
|
393 |
+
model,
|
394 |
+
None,
|
395 |
+
None,
|
396 |
+
model_input_ids,
|
397 |
+
None,
|
398 |
+
None,
|
399 |
+
input_pos=input_pos,
|
400 |
+
temperature=temperature,
|
401 |
+
top_k=top_k,
|
402 |
+
top_p=top_p,
|
403 |
+
)
|
404 |
+
|
405 |
+
if text_end:
|
406 |
+
token_T = torch.tensor([pad_id_t], device=device)
|
407 |
+
|
408 |
+
if tokens_A[-1] == eos_id_a:
|
409 |
+
break
|
410 |
+
if token_T == eos_id_t:
|
411 |
+
text_end = True
|
412 |
+
|
413 |
+
for i in range(7):
|
414 |
+
list_output[i].append(tokens_A[i].tolist()[0])
|
415 |
+
list_output[7].append(token_T.tolist()[0])
|
416 |
+
|
417 |
+
model_input_ids = [[] for i in range(8)]
|
418 |
+
for i in range(7):
|
419 |
+
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
|
420 |
+
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
|
421 |
+
model_input_ids[i].append(
|
422 |
+
torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
|
423 |
+
)
|
424 |
+
model_input_ids[i] = torch.stack(model_input_ids[i])
|
425 |
+
|
426 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
427 |
+
model_input_ids[-1].append(token_T.clone().to(torch.int32))
|
428 |
+
model_input_ids[-1] = torch.stack(model_input_ids[-1])
|
429 |
+
|
430 |
+
input_pos = input_pos.add_(1)
|
431 |
+
|
432 |
+
return list_output
|
433 |
+
|
434 |
+
|
435 |
+
@torch.inference_mode()
|
436 |
+
def generate_TT(
|
437 |
+
model: GPT,
|
438 |
+
audio_features: torch.Tensor,
|
439 |
+
input_ids: list,
|
440 |
+
leng,
|
441 |
+
task,
|
442 |
+
max_returned_tokens: int = 2048,
|
443 |
+
*,
|
444 |
+
temperature: float = 1.0,
|
445 |
+
top_k: Optional[int] = None,
|
446 |
+
top_p: float = 1.0,
|
447 |
+
eos_id_a: Optional[int] = None,
|
448 |
+
eos_id_t: Optional[int] = None,
|
449 |
+
pad_id_t: Optional[int] = None,
|
450 |
+
shift: Optional[int] = None,
|
451 |
+
include_prompt: bool = True,
|
452 |
+
generate_text=False,
|
453 |
+
) -> torch.Tensor:
|
454 |
+
|
455 |
+
T = input_ids[0].size(1)
|
456 |
+
device = input_ids[0].device
|
457 |
+
|
458 |
+
output = []
|
459 |
+
token_T = next_token_A1T1(
|
460 |
+
model,
|
461 |
+
None,
|
462 |
+
input_ids,
|
463 |
+
None,
|
464 |
+
None,
|
465 |
+
input_pos=torch.arange(0, T, device=device),
|
466 |
+
temperature=temperature,
|
467 |
+
top_k=top_k,
|
468 |
+
top_p=top_p,
|
469 |
+
)
|
470 |
+
|
471 |
+
output.append(token_T.clone().tolist()[0])
|
472 |
+
input_pos = torch.tensor([T], device=device)
|
473 |
+
|
474 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
475 |
+
model_input_ids = []
|
476 |
+
for i in range(7):
|
477 |
+
model_input_ids.append(
|
478 |
+
torch.tensor([layershift(snac_config.end_of_audio, i)])
|
479 |
+
.view(1, -1)
|
480 |
+
.to(torch.int32)
|
481 |
+
.to(device)
|
482 |
+
)
|
483 |
+
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
|
484 |
+
token_T = next_token_A1T1(
|
485 |
+
model,
|
486 |
+
None,
|
487 |
+
model_input_ids,
|
488 |
+
None,
|
489 |
+
None,
|
490 |
+
input_pos=input_pos,
|
491 |
+
temperature=temperature,
|
492 |
+
top_k=top_k,
|
493 |
+
top_p=top_p,
|
494 |
+
)
|
495 |
+
if token_T == eos_id_t:
|
496 |
+
break
|
497 |
+
output.append(token_T.clone().tolist()[0])
|
498 |
+
input_pos = input_pos.add_(1)
|
499 |
+
return output
|
500 |
+
|
501 |
+
|
502 |
+
@torch.inference_mode()
|
503 |
+
def generate_AT(
|
504 |
+
model: GPT,
|
505 |
+
audio_features: torch.Tensor,
|
506 |
+
input_ids: list,
|
507 |
+
leng,
|
508 |
+
task,
|
509 |
+
max_returned_tokens: int = 2048,
|
510 |
+
*,
|
511 |
+
temperature: float = 1.0,
|
512 |
+
top_k: Optional[int] = None,
|
513 |
+
top_p: float = 1.0,
|
514 |
+
eos_id_a: Optional[int] = None,
|
515 |
+
eos_id_t: Optional[int] = None,
|
516 |
+
pad_id_t: Optional[int] = None,
|
517 |
+
shift: Optional[int] = None,
|
518 |
+
include_prompt: bool = True,
|
519 |
+
generate_text=False,
|
520 |
+
) -> torch.Tensor:
|
521 |
+
|
522 |
+
T = input_ids[0].size(1)
|
523 |
+
device = input_ids[0].device
|
524 |
+
|
525 |
+
output = []
|
526 |
+
token_T = next_token_A1T1(
|
527 |
+
model,
|
528 |
+
audio_features.to(torch.float32).to(model.device),
|
529 |
+
input_ids,
|
530 |
+
[T - 3],
|
531 |
+
["AT"],
|
532 |
+
input_pos=torch.arange(0, T, device=device),
|
533 |
+
temperature=temperature,
|
534 |
+
top_k=top_k,
|
535 |
+
top_p=top_p,
|
536 |
+
)
|
537 |
+
output.append(token_T.clone().tolist()[0])
|
538 |
+
input_pos = torch.tensor([T], device=device)
|
539 |
+
text_end = False
|
540 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
541 |
+
model_input_ids = []
|
542 |
+
for i in range(7):
|
543 |
+
model_input_ids.append(
|
544 |
+
torch.tensor([layershift(snac_config.end_of_audio, i)])
|
545 |
+
.view(1, -1)
|
546 |
+
.to(torch.int32)
|
547 |
+
.to(device)
|
548 |
+
)
|
549 |
+
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
|
550 |
+
token_T = next_token_A1T1(
|
551 |
+
model,
|
552 |
+
None,
|
553 |
+
model_input_ids,
|
554 |
+
None,
|
555 |
+
None,
|
556 |
+
input_pos=input_pos,
|
557 |
+
temperature=temperature,
|
558 |
+
top_k=top_k,
|
559 |
+
top_p=top_p,
|
560 |
+
)
|
561 |
+
if token_T == eos_id_t:
|
562 |
+
break
|
563 |
+
output.append(token_T.clone().tolist()[0])
|
564 |
+
input_pos = input_pos.add_(1)
|
565 |
+
return output
|
566 |
+
|
567 |
+
|
568 |
+
@torch.inference_mode()
|
569 |
+
def generate_TA(
|
570 |
+
model: GPT,
|
571 |
+
audio_features: torch.Tensor,
|
572 |
+
input_ids: list,
|
573 |
+
leng,
|
574 |
+
task,
|
575 |
+
max_returned_tokens: int = 2048,
|
576 |
+
*,
|
577 |
+
temperature: float = 1.0,
|
578 |
+
top_k: Optional[int] = None,
|
579 |
+
top_p: float = 1.0,
|
580 |
+
eos_id_a: Optional[int] = None,
|
581 |
+
eos_id_t: Optional[int] = None,
|
582 |
+
pad_id_t: Optional[int] = None,
|
583 |
+
shift: Optional[int] = None,
|
584 |
+
include_prompt: bool = True,
|
585 |
+
generate_text=False,
|
586 |
+
) -> torch.Tensor:
|
587 |
+
|
588 |
+
T = input_ids[0].size(1)
|
589 |
+
device = input_ids[0].device
|
590 |
+
|
591 |
+
output = [[] for _ in range(8)]
|
592 |
+
tokens_A, token_T = next_token_A1T2(
|
593 |
+
model,
|
594 |
+
None,
|
595 |
+
input_ids,
|
596 |
+
None,
|
597 |
+
None,
|
598 |
+
input_pos=torch.arange(0, T, device=device),
|
599 |
+
temperature=temperature,
|
600 |
+
top_k=top_k,
|
601 |
+
top_p=top_p,
|
602 |
+
)
|
603 |
+
for i in range(7):
|
604 |
+
output[i].append(tokens_A[i].clone().tolist()[0])
|
605 |
+
output[7].append(token_T.clone().tolist()[0])
|
606 |
+
|
607 |
+
input_pos = torch.tensor([T], device=device)
|
608 |
+
text_end = False
|
609 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
610 |
+
|
611 |
+
model_input_ids = []
|
612 |
+
for i in range(7):
|
613 |
+
model_input_ids.append(
|
614 |
+
layershift(tokens_A[i].clone(), i)
|
615 |
+
.view(1, -1)
|
616 |
+
.to(torch.int32)
|
617 |
+
.to(device)
|
618 |
+
)
|
619 |
+
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
|
620 |
+
|
621 |
+
tokens_A, token_T = next_token_A1T2(
|
622 |
+
model,
|
623 |
+
None,
|
624 |
+
model_input_ids,
|
625 |
+
None,
|
626 |
+
None,
|
627 |
+
input_pos=input_pos,
|
628 |
+
temperature=temperature,
|
629 |
+
top_k=top_k,
|
630 |
+
top_p=top_p,
|
631 |
+
)
|
632 |
+
|
633 |
+
if text_end:
|
634 |
+
token_T = torch.tensor([pad_id_t], device=device)
|
635 |
+
|
636 |
+
if tokens_A[-1] == eos_id_a:
|
637 |
+
break
|
638 |
+
|
639 |
+
if token_T == eos_id_t:
|
640 |
+
text_end = True
|
641 |
+
|
642 |
+
for i in range(7):
|
643 |
+
output[i].append(tokens_A[i].clone().tolist()[0])
|
644 |
+
output[7].append(token_T.clone().tolist()[0])
|
645 |
+
input_pos = input_pos.add_(1)
|
646 |
+
|
647 |
+
return output
|
648 |
+
|
649 |
+
|
650 |
+
@torch.inference_mode()
|
651 |
+
def generate_AA(
|
652 |
+
model: GPT,
|
653 |
+
audio_features: torch.Tensor,
|
654 |
+
input_ids: list,
|
655 |
+
leng,
|
656 |
+
task,
|
657 |
+
max_returned_tokens: int = 2048,
|
658 |
+
*,
|
659 |
+
temperature: float = 1.0,
|
660 |
+
top_k: Optional[int] = None,
|
661 |
+
top_p: float = 1.0,
|
662 |
+
eos_id_a: Optional[int] = None,
|
663 |
+
eos_id_t: Optional[int] = None,
|
664 |
+
pad_id_t: Optional[int] = None,
|
665 |
+
shift: Optional[int] = None,
|
666 |
+
include_prompt: bool = True,
|
667 |
+
generate_text=False,
|
668 |
+
) -> torch.Tensor:
|
669 |
+
|
670 |
+
T = input_ids[0].size(1)
|
671 |
+
device = input_ids[0].device
|
672 |
+
|
673 |
+
output = [[] for _ in range(8)]
|
674 |
+
tokens_A, token_T = next_token_A1T2(
|
675 |
+
model,
|
676 |
+
audio_features.to(torch.float32).to(model.device),
|
677 |
+
input_ids,
|
678 |
+
[T - 3],
|
679 |
+
["A1T2"],
|
680 |
+
input_pos=torch.arange(0, T, device=device),
|
681 |
+
temperature=temperature,
|
682 |
+
top_k=top_k,
|
683 |
+
top_p=top_p,
|
684 |
+
)
|
685 |
+
for i in range(7):
|
686 |
+
output[i].append(tokens_A[i].clone().tolist()[0])
|
687 |
+
output[7].append(token_T.clone().tolist()[0])
|
688 |
+
|
689 |
+
input_pos = torch.tensor([T], device=device)
|
690 |
+
|
691 |
+
text_end = False
|
692 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
693 |
+
|
694 |
+
model_input_ids = []
|
695 |
+
for i in range(7):
|
696 |
+
model_input_ids.append(
|
697 |
+
layershift(tokens_A[i].clone(), i)
|
698 |
+
.view(1, -1)
|
699 |
+
.to(torch.int32)
|
700 |
+
.to(device)
|
701 |
+
)
|
702 |
+
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
|
703 |
+
|
704 |
+
tokens_A, token_T = next_token_A1T2(
|
705 |
+
model,
|
706 |
+
None,
|
707 |
+
model_input_ids,
|
708 |
+
None,
|
709 |
+
None,
|
710 |
+
input_pos=input_pos,
|
711 |
+
temperature=temperature,
|
712 |
+
top_k=top_k,
|
713 |
+
top_p=top_p,
|
714 |
+
)
|
715 |
+
|
716 |
+
if text_end:
|
717 |
+
token_T = torch.tensor([pad_id_t], device=device)
|
718 |
+
|
719 |
+
if tokens_A[-1] == eos_id_a:
|
720 |
+
break
|
721 |
+
if token_T == eos_id_t:
|
722 |
+
# print("text_end")
|
723 |
+
text_end = True
|
724 |
+
|
725 |
+
for i in range(7):
|
726 |
+
output[i].append(tokens_A[i].clone().tolist()[0])
|
727 |
+
output[7].append(token_T.clone().tolist()[0])
|
728 |
+
input_pos = input_pos.add_(1)
|
729 |
+
|
730 |
+
return output
|
731 |
+
|
732 |
+
|
733 |
+
@torch.inference_mode()
|
734 |
+
def generate_ASR(
|
735 |
+
model: GPT,
|
736 |
+
audio_features: torch.Tensor,
|
737 |
+
input_ids: list,
|
738 |
+
leng,
|
739 |
+
task,
|
740 |
+
max_returned_tokens: int = 1200,
|
741 |
+
*,
|
742 |
+
temperature: float = 1.0,
|
743 |
+
top_k: Optional[int] = None,
|
744 |
+
top_p: float = 1.0,
|
745 |
+
eos_id_a: Optional[int] = None,
|
746 |
+
eos_id_t: Optional[int] = None,
|
747 |
+
pad_id_t: Optional[int] = None,
|
748 |
+
shift: Optional[int] = None,
|
749 |
+
include_prompt: bool = True,
|
750 |
+
generate_text=False,
|
751 |
+
) -> torch.Tensor:
|
752 |
+
|
753 |
+
T = input_ids[0].size(1)
|
754 |
+
device = input_ids[0].device
|
755 |
+
output = []
|
756 |
+
token_T = next_token_A1T1(
|
757 |
+
model,
|
758 |
+
audio_features.to(torch.float32).to(model.device),
|
759 |
+
input_ids,
|
760 |
+
[T - 3],
|
761 |
+
["asr"],
|
762 |
+
input_pos=torch.arange(0, T, device=device),
|
763 |
+
temperature=temperature,
|
764 |
+
top_k=top_k,
|
765 |
+
top_p=top_p,
|
766 |
+
)
|
767 |
+
output.append(token_T.clone().tolist()[0])
|
768 |
+
input_pos = torch.tensor([T], device=device)
|
769 |
+
text_end = False
|
770 |
+
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
|
771 |
+
model_input_ids = []
|
772 |
+
for i in range(7):
|
773 |
+
model_input_ids.append(
|
774 |
+
torch.tensor([layershift(snac_config.end_of_audio, i)])
|
775 |
+
.view(1, -1)
|
776 |
+
.to(torch.int32)
|
777 |
+
.to(device)
|
778 |
+
)
|
779 |
+
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
|
780 |
+
token_T = next_token_A1T1(
|
781 |
+
model,
|
782 |
+
None,
|
783 |
+
model_input_ids,
|
784 |
+
None,
|
785 |
+
None,
|
786 |
+
input_pos=input_pos,
|
787 |
+
temperature=temperature,
|
788 |
+
top_k=top_k,
|
789 |
+
top_p=top_p,
|
790 |
+
)
|
791 |
+
if token_T == eos_id_t:
|
792 |
+
break
|
793 |
+
output.append(token_T.clone().tolist()[0])
|
794 |
+
input_pos = input_pos.add_(1)
|
795 |
+
return output
|
litgpt/model.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 |
+
|
3 |
+
"""Full definition of a decoder-only transformer-based language model, all of it in this single file.
|
4 |
+
|
5 |
+
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
|
6 |
+
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import math
|
10 |
+
from typing import Any, Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from typing_extensions import Self
|
15 |
+
from litgpt.config import Config
|
16 |
+
|
17 |
+
|
18 |
+
class GPT(nn.Module):
|
19 |
+
def __init__(self, config: Config) -> None:
|
20 |
+
super().__init__()
|
21 |
+
assert config.padded_vocab_size is not None
|
22 |
+
self.config = config
|
23 |
+
if self.config.asr_adapter == "mlp":
|
24 |
+
print("Using MLP adapter for ASR feature")
|
25 |
+
self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
|
26 |
+
elif self.config.asr_adapter == "llamamlp":
|
27 |
+
print("using LLAMA MLP adapter for ASR feature")
|
28 |
+
self.whisper_adapter = whisperMLP(config=config)
|
29 |
+
else:
|
30 |
+
raise ValueError("asr_adapter should be mlp or llamamlp")
|
31 |
+
self.lm_head = nn.Linear(
|
32 |
+
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
|
33 |
+
)
|
34 |
+
|
35 |
+
self.vision_adapter = visionMLP(config = config)
|
36 |
+
if config.post_adapter:
|
37 |
+
self.transformer = nn.ModuleDict(
|
38 |
+
dict(
|
39 |
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
40 |
+
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
41 |
+
post_adapter=nn.ModuleList(
|
42 |
+
Block(config) for _ in range(config.post_adapter_layers)
|
43 |
+
),
|
44 |
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
45 |
+
post_adapter_audio_ln=config.norm_class(
|
46 |
+
config.n_embd, eps=config.norm_eps
|
47 |
+
),
|
48 |
+
post_adapter_audio_lm_head=nn.Linear(
|
49 |
+
config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
|
50 |
+
),
|
51 |
+
)
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
self.transformer = nn.ModuleDict(
|
55 |
+
dict(
|
56 |
+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
|
57 |
+
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
|
58 |
+
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.max_seq_length = self.config.block_size
|
62 |
+
self.mask_cache: Optional[torch.Tensor] = None
|
63 |
+
if config.tie_word_embeddings:
|
64 |
+
self.lm_head.weight = self.transformer.wte.weight
|
65 |
+
|
66 |
+
@property
|
67 |
+
def max_seq_length(self) -> int:
|
68 |
+
return self._max_seq_length
|
69 |
+
|
70 |
+
@max_seq_length.setter
|
71 |
+
def max_seq_length(self, value: int) -> None:
|
72 |
+
"""
|
73 |
+
When doing inference, the sequences used might be shorter than the model's context length.
|
74 |
+
This allows setting a smaller number to avoid allocating unused memory
|
75 |
+
"""
|
76 |
+
if value > self.config.block_size:
|
77 |
+
raise ValueError(
|
78 |
+
f"Cannot attend to {value}, block size is only {self.config.block_size}"
|
79 |
+
)
|
80 |
+
self._max_seq_length = value
|
81 |
+
if not hasattr(self, "cos"):
|
82 |
+
# first call
|
83 |
+
cos, sin = self.rope_cache()
|
84 |
+
self.register_buffer("cos", cos, persistent=False)
|
85 |
+
self.register_buffer("sin", sin, persistent=False)
|
86 |
+
# override
|
87 |
+
elif value != self.cos.size(0):
|
88 |
+
self.cos, self.sin = self.rope_cache(device=self.cos.device)
|
89 |
+
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
|
90 |
+
# if the kv cache is expected
|
91 |
+
|
92 |
+
def reset_parameters(self) -> None:
|
93 |
+
# Trigger resetting the rope-cache
|
94 |
+
self.cos, self.sin = self.rope_cache(device=self.cos.device)
|
95 |
+
|
96 |
+
def _init_weights(self, module: nn.Module) -> None:
|
97 |
+
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
100 |
+
if module.bias is not None:
|
101 |
+
torch.nn.init.zeros_(module.bias)
|
102 |
+
elif isinstance(module, nn.Embedding):
|
103 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
104 |
+
|
105 |
+
def concat_feat(self, audio_feature, clip_feature, input_ids, T, task):
|
106 |
+
|
107 |
+
for j in range(len(T)):
|
108 |
+
if task[j] != 'T1T2' and task[j] != 'T1A2' and task[j]!='ImageQA_T' and not task[j] == 'ImageCAP' and not task[j] == 'ImageQA_A' and not task[j] == 'ImageQA_AT':
|
109 |
+
for i in range(7):
|
110 |
+
input_ids[i][j,1:T[j]+1,:] = audio_feature[j][:T[j]].clone()
|
111 |
+
assert task[j] != 'ImageQ', "ImageQ should be concat with audio feature"
|
112 |
+
|
113 |
+
elif task[j] == 'ImageQA_A' or task[j] == 'ImageQA_AT':
|
114 |
+
print("concat ImageQA_A feature")
|
115 |
+
for i in range(7):
|
116 |
+
input_ids[i][j,1:51,:] = clip_feature[j].clone()
|
117 |
+
|
118 |
+
input_ids[i][j,52 : 52 + T[j],:] = audio_feature[j][:T[j]].clone()
|
119 |
+
|
120 |
+
elif task[j] == 'ImageQA_T' or task[j] =='ImageCAP':
|
121 |
+
for i in range(7):
|
122 |
+
input_ids[i][j,1:51,:] = clip_feature[j].clone()
|
123 |
+
|
124 |
+
return input_ids
|
125 |
+
|
126 |
+
def forward(
|
127 |
+
self,
|
128 |
+
audio_features: torch.Tensor,
|
129 |
+
input_ids: torch.Tensor,
|
130 |
+
clip_features: torch.Tensor,
|
131 |
+
input_pos: Optional[torch.Tensor] = None,
|
132 |
+
whisper_lens: Optional[list] = None,
|
133 |
+
task: Optional[str] = None,
|
134 |
+
) -> torch.Tensor:
|
135 |
+
|
136 |
+
show = False
|
137 |
+
T = input_ids[0].size(1)
|
138 |
+
if self.max_seq_length < T:
|
139 |
+
raise ValueError(
|
140 |
+
f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
|
141 |
+
)
|
142 |
+
|
143 |
+
if input_pos is not None: # use the kv cache
|
144 |
+
cos = self.cos.index_select(0, input_pos)
|
145 |
+
sin = self.sin.index_select(0, input_pos)
|
146 |
+
if self.mask_cache is None:
|
147 |
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
148 |
+
mask = self.mask_cache.index_select(2, input_pos)
|
149 |
+
else:
|
150 |
+
cos = self.cos[:T]
|
151 |
+
sin = self.sin[:T]
|
152 |
+
mask = None
|
153 |
+
|
154 |
+
if audio_features is not None:
|
155 |
+
# get whisper feature
|
156 |
+
x_a = self.whisper_adapter(audio_features)
|
157 |
+
if clip_features is not None:
|
158 |
+
x_v = self.vision_adapter(clip_features)
|
159 |
+
else:
|
160 |
+
x_v = None
|
161 |
+
# get input_ids embedding
|
162 |
+
x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
|
163 |
+
|
164 |
+
x0 = self.transformer.wte(x0)
|
165 |
+
x1 = self.transformer.wte(x1)
|
166 |
+
x2 = self.transformer.wte(x2)
|
167 |
+
x3 = self.transformer.wte(x3)
|
168 |
+
x4 = self.transformer.wte(x4)
|
169 |
+
x5 = self.transformer.wte(x5)
|
170 |
+
x6 = self.transformer.wte(x6)
|
171 |
+
x7 = self.transformer.wte(x7)
|
172 |
+
|
173 |
+
# concat whisper feature
|
174 |
+
input_emb = self.concat_feat(
|
175 |
+
x_a, x_v, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
|
176 |
+
)
|
177 |
+
x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
|
178 |
+
|
179 |
+
else:
|
180 |
+
x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
|
181 |
+
|
182 |
+
x0 = self.transformer.wte(x0)
|
183 |
+
x1 = self.transformer.wte(x1)
|
184 |
+
x2 = self.transformer.wte(x2)
|
185 |
+
x3 = self.transformer.wte(x3)
|
186 |
+
x4 = self.transformer.wte(x4)
|
187 |
+
x5 = self.transformer.wte(x5)
|
188 |
+
x6 = self.transformer.wte(x6)
|
189 |
+
x7 = self.transformer.wte(x7)
|
190 |
+
|
191 |
+
x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
|
192 |
+
|
193 |
+
if self.config.scale_embeddings:
|
194 |
+
x = x * (self.config.n_embd**0.5)
|
195 |
+
|
196 |
+
for block in self.transformer.h:
|
197 |
+
x = block(x, cos, sin, mask, input_pos)
|
198 |
+
|
199 |
+
|
200 |
+
text_vocab_size = self.config.text_vocab_size
|
201 |
+
audio_vocab_size = self.config.audio_vocab_size
|
202 |
+
|
203 |
+
x_ori = x
|
204 |
+
x_ori = self.transformer.ln_f(x_ori)
|
205 |
+
x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
|
206 |
+
xt = x_ori[..., :text_vocab_size]
|
207 |
+
|
208 |
+
if self.config.post_adapter:
|
209 |
+
for block in self.transformer.post_adapter:
|
210 |
+
x = block(x, cos, sin, mask, input_pos)
|
211 |
+
x = self.transformer.post_adapter_audio_ln(x)
|
212 |
+
x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
|
213 |
+
xa = []
|
214 |
+
for i in range(7):
|
215 |
+
xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
|
216 |
+
else:
|
217 |
+
xa = []
|
218 |
+
for i in range(7):
|
219 |
+
xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
|
220 |
+
|
221 |
+
return xa, xt
|
222 |
+
|
223 |
+
@classmethod
|
224 |
+
def from_name(cls, name: str, **kwargs: Any) -> Self:
|
225 |
+
return cls(Config.from_name(name, **kwargs))
|
226 |
+
|
227 |
+
def rope_cache(
|
228 |
+
self, device: Optional[torch.device] = None
|
229 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
230 |
+
return build_rope_cache(
|
231 |
+
seq_len=self.max_seq_length,
|
232 |
+
n_elem=self.config.rope_n_elem,
|
233 |
+
device=device,
|
234 |
+
condense_ratio=self.config.rope_condense_ratio,
|
235 |
+
base=self.config.rope_base,
|
236 |
+
)
|
237 |
+
|
238 |
+
def set_kv_cache(
|
239 |
+
self,
|
240 |
+
batch_size: int,
|
241 |
+
rope_cache_length: Optional[int] = None,
|
242 |
+
device: Optional[torch.device] = None,
|
243 |
+
dtype: Optional[torch.dtype] = None,
|
244 |
+
) -> None:
|
245 |
+
if rope_cache_length is None:
|
246 |
+
rope_cache_length = self.cos.size(-1)
|
247 |
+
max_seq_length = self.max_seq_length
|
248 |
+
|
249 |
+
# initialize the kv cache for all blocks
|
250 |
+
for block in self.transformer.h:
|
251 |
+
block.attn.kv_cache = block.attn.build_kv_cache(
|
252 |
+
batch_size, max_seq_length, rope_cache_length, device, dtype
|
253 |
+
)
|
254 |
+
if self.config.post_adapter:
|
255 |
+
for block in self.transformer.post_adapter:
|
256 |
+
block.attn.kv_cache = block.attn.build_kv_cache(
|
257 |
+
batch_size, max_seq_length, rope_cache_length, device, dtype
|
258 |
+
)
|
259 |
+
|
260 |
+
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
|
261 |
+
# passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
|
262 |
+
# for the kv-cache support (only during inference), we only create it in that situation
|
263 |
+
self.mask_cache = build_mask_cache(max_seq_length, device)
|
264 |
+
|
265 |
+
def clear_kv_cache(self) -> None:
|
266 |
+
self.mask_cache = None
|
267 |
+
for block in self.transformer.h:
|
268 |
+
block.attn.kv_cache = None
|
269 |
+
|
270 |
+
|
271 |
+
class visionMLP(nn.Module):
|
272 |
+
def __init__(self, config: Config) -> None:
|
273 |
+
super().__init__()
|
274 |
+
vision_adapter_dim = config.vision_adapter_dim
|
275 |
+
self.fc_1 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
|
276 |
+
self.fc_2 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
|
277 |
+
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
278 |
+
|
279 |
+
self.config = config
|
280 |
+
|
281 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
282 |
+
x_fc_1 = self.fc_1(x)
|
283 |
+
x_fc_2 = self.fc_2(x)
|
284 |
+
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
285 |
+
return self.proj(x)
|
286 |
+
|
287 |
+
|
288 |
+
class Block(nn.Module):
|
289 |
+
|
290 |
+
def __init__(self, config: Config) -> None:
|
291 |
+
super().__init__()
|
292 |
+
if not config.parallel_residual and config.shared_attention_norm:
|
293 |
+
raise NotImplementedError(
|
294 |
+
"No checkpoint amongst the ones we support uses this configuration"
|
295 |
+
" (non-parallel residual and shared attention norm)."
|
296 |
+
)
|
297 |
+
|
298 |
+
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
|
299 |
+
self.attn = CausalSelfAttention(config)
|
300 |
+
self.norm_2 = (
|
301 |
+
None
|
302 |
+
if config.shared_attention_norm
|
303 |
+
else config.norm_class(config.n_embd, eps=config.norm_eps)
|
304 |
+
)
|
305 |
+
self.mlp = config.mlp_class(config)
|
306 |
+
|
307 |
+
self.config = config
|
308 |
+
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
x: torch.Tensor,
|
312 |
+
cos: torch.Tensor,
|
313 |
+
sin: torch.Tensor,
|
314 |
+
mask: Optional[torch.Tensor] = None,
|
315 |
+
input_pos: Optional[torch.Tensor] = None,
|
316 |
+
) -> torch.Tensor:
|
317 |
+
"""
|
318 |
+
Non-parallel residual Parallel residual
|
319 |
+
┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
|
320 |
+
│ ↓ │ ↓ ↓ the output from `norm_1` is reused
|
321 |
+
│ norm_1 │ norm_1 ───► norm_2
|
322 |
+
│ ↓ │ ↓ ↓
|
323 |
+
│ attn │ attn mlp
|
324 |
+
│ ↓ │ ↓ │
|
325 |
+
┌─ └► + └► + ◄───────────┘
|
326 |
+
│ norm_2
|
327 |
+
│ ↓
|
328 |
+
│ mlp
|
329 |
+
│ ↓
|
330 |
+
└───► +
|
331 |
+
"""
|
332 |
+
|
333 |
+
x_normed = self.norm_1(x)
|
334 |
+
attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
|
335 |
+
|
336 |
+
if self.config.parallel_residual:
|
337 |
+
x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
|
338 |
+
x = self.mlp(x_normed) + attention_output + x
|
339 |
+
else:
|
340 |
+
x = attention_output + x
|
341 |
+
x = self.mlp(self.norm_2(x)) + x
|
342 |
+
return x
|
343 |
+
|
344 |
+
|
345 |
+
class CausalSelfAttention(nn.Module):
|
346 |
+
def __init__(self, config: Config) -> None:
|
347 |
+
super().__init__()
|
348 |
+
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
|
349 |
+
# key, query, value projections for all heads, but in a batch
|
350 |
+
self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
|
351 |
+
# output projection
|
352 |
+
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
|
353 |
+
self.proj = nn.Linear(
|
354 |
+
config.head_size * config.n_head, config.n_embd, bias=config.bias
|
355 |
+
)
|
356 |
+
# disabled by default
|
357 |
+
self.kv_cache: Optional[KVCache] = None
|
358 |
+
|
359 |
+
self.config = config
|
360 |
+
|
361 |
+
def forward(
|
362 |
+
self,
|
363 |
+
x: torch.Tensor,
|
364 |
+
cos: torch.Tensor,
|
365 |
+
sin: torch.Tensor,
|
366 |
+
mask: Optional[torch.Tensor] = None,
|
367 |
+
input_pos: Optional[torch.Tensor] = None,
|
368 |
+
) -> torch.Tensor:
|
369 |
+
B, T, C = (
|
370 |
+
x.size()
|
371 |
+
) # batch size, sequence length, embedding dimensionality (n_embd)
|
372 |
+
|
373 |
+
qkv = self.attn(x)
|
374 |
+
|
375 |
+
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
|
376 |
+
q_per_kv = self.config.n_head // self.config.n_query_groups
|
377 |
+
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
|
378 |
+
qkv = qkv.view(
|
379 |
+
B, T, self.config.n_query_groups, total_qkv, self.config.head_size
|
380 |
+
)
|
381 |
+
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
|
382 |
+
|
383 |
+
# split batched computation into three
|
384 |
+
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
|
385 |
+
|
386 |
+
# maybe repeat k and v if for the non multi-head attention cases
|
387 |
+
# training: flash attention requires it
|
388 |
+
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
|
389 |
+
if self.config.n_query_groups != self.config.n_head and (
|
390 |
+
input_pos is None or self.config.n_query_groups != 1
|
391 |
+
):
|
392 |
+
k = k.expand(
|
393 |
+
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
|
394 |
+
)
|
395 |
+
v = v.expand(
|
396 |
+
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
|
397 |
+
)
|
398 |
+
|
399 |
+
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
|
400 |
+
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
|
401 |
+
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
|
402 |
+
|
403 |
+
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
|
404 |
+
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
|
405 |
+
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
|
406 |
+
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
|
407 |
+
|
408 |
+
if input_pos is not None:
|
409 |
+
if not isinstance(self.kv_cache, KVCache):
|
410 |
+
raise TypeError("You need to call `gpt.set_kv_cache()`")
|
411 |
+
k, v = self.kv_cache(input_pos, k, v)
|
412 |
+
|
413 |
+
y = self.scaled_dot_product_attention(q, k, v, mask)
|
414 |
+
|
415 |
+
y = y.reshape(
|
416 |
+
B, T, self.config.head_size * self.config.n_head
|
417 |
+
) # re-assemble all head outputs side by side
|
418 |
+
|
419 |
+
# output projection
|
420 |
+
return self.proj(y)
|
421 |
+
|
422 |
+
def scaled_dot_product_attention(
|
423 |
+
self,
|
424 |
+
q: torch.Tensor,
|
425 |
+
k: torch.Tensor,
|
426 |
+
v: torch.Tensor,
|
427 |
+
mask: Optional[torch.Tensor] = None,
|
428 |
+
) -> torch.Tensor:
|
429 |
+
scale = 1.0 / math.sqrt(self.config.head_size)
|
430 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
431 |
+
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
|
432 |
+
)
|
433 |
+
return y.transpose(1, 2)
|
434 |
+
|
435 |
+
def build_kv_cache(
|
436 |
+
self,
|
437 |
+
batch_size: int,
|
438 |
+
max_seq_length: int,
|
439 |
+
rope_cache_length: Optional[int] = None,
|
440 |
+
device: Optional[torch.device] = None,
|
441 |
+
dtype: Optional[torch.dtype] = None,
|
442 |
+
) -> "KVCache":
|
443 |
+
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
|
444 |
+
v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
|
445 |
+
if rope_cache_length is None:
|
446 |
+
if self.config.rotary_percentage != 1.0:
|
447 |
+
raise TypeError(
|
448 |
+
"Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
|
449 |
+
)
|
450 |
+
k_shape = v_shape
|
451 |
+
else:
|
452 |
+
k_shape = (
|
453 |
+
batch_size,
|
454 |
+
heads,
|
455 |
+
max_seq_length,
|
456 |
+
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
|
457 |
+
)
|
458 |
+
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
|
459 |
+
|
460 |
+
|
461 |
+
class GptNeoxMLP(nn.Module):
|
462 |
+
def __init__(self, config: Config) -> None:
|
463 |
+
super().__init__()
|
464 |
+
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
465 |
+
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
466 |
+
|
467 |
+
self.config = config
|
468 |
+
|
469 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
470 |
+
x = self.fc(x)
|
471 |
+
x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
|
472 |
+
return self.proj(x)
|
473 |
+
|
474 |
+
|
475 |
+
class LLaMAMLP(nn.Module):
|
476 |
+
def __init__(self, config: Config) -> None:
|
477 |
+
super().__init__()
|
478 |
+
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
479 |
+
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
|
480 |
+
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
481 |
+
|
482 |
+
self.config = config
|
483 |
+
|
484 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
485 |
+
x_fc_1 = self.fc_1(x)
|
486 |
+
x_fc_2 = self.fc_2(x)
|
487 |
+
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
488 |
+
return self.proj(x)
|
489 |
+
|
490 |
+
|
491 |
+
class whisperMLP(nn.Module):
|
492 |
+
def __init__(self, config: Config) -> None:
|
493 |
+
super().__init__()
|
494 |
+
self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
|
495 |
+
self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
|
496 |
+
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
|
497 |
+
|
498 |
+
self.config = config
|
499 |
+
|
500 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
501 |
+
x_fc_1 = self.fc_1(x)
|
502 |
+
x_fc_2 = self.fc_2(x)
|
503 |
+
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
|
504 |
+
return self.proj(x)
|
505 |
+
|
506 |
+
|
507 |
+
class GemmaMLP(LLaMAMLP):
|
508 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
509 |
+
x_fc_1 = self.fc_1(x)
|
510 |
+
x_fc_2 = self.fc_2(x)
|
511 |
+
x = (
|
512 |
+
torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
|
513 |
+
* x_fc_2
|
514 |
+
)
|
515 |
+
return self.proj(x)
|
516 |
+
|
517 |
+
|
518 |
+
class LLaMAMoE(nn.Module):
|
519 |
+
def __init__(self, config: Config) -> None:
|
520 |
+
super().__init__()
|
521 |
+
self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
|
522 |
+
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
|
523 |
+
|
524 |
+
self.config = config
|
525 |
+
|
526 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
527 |
+
"""
|
528 |
+
Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
529 |
+
See also figure 1 in https://arxiv.org/abs/2211.15841
|
530 |
+
"""
|
531 |
+
B, T, C = (
|
532 |
+
x.size()
|
533 |
+
) # batch size, sequence length, embedding dimensionality (n_embd)
|
534 |
+
x = x.view(-1, C) # (B*T, C)
|
535 |
+
router = self.gate(x) # (B*T, n_expert)
|
536 |
+
probs, indices = torch.topk(
|
537 |
+
router, self.config.n_expert_per_token
|
538 |
+
) # (B*T, n_expert_per_token)
|
539 |
+
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
|
540 |
+
masks = indices.unsqueeze(-1) == torch.arange(
|
541 |
+
self.config.n_expert, device=x.device
|
542 |
+
)
|
543 |
+
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
|
544 |
+
y = torch.zeros_like(x) # (B*T, C)
|
545 |
+
for mask, expert in zip(masks, self.experts):
|
546 |
+
token_idx, expert_idx = torch.where(mask)
|
547 |
+
y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
|
548 |
+
return y.view(B, T, C)
|
549 |
+
|
550 |
+
|
551 |
+
def build_rope_cache(
|
552 |
+
seq_len: int,
|
553 |
+
n_elem: int,
|
554 |
+
device: Optional[torch.device] = None,
|
555 |
+
base: int = 10000,
|
556 |
+
condense_ratio: int = 1,
|
557 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
558 |
+
"""Enhanced Transformer with Rotary Position Embedding.
|
559 |
+
|
560 |
+
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
|
561 |
+
transformers/rope/__init__.py. MIT License:
|
562 |
+
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
|
563 |
+
"""
|
564 |
+
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
565 |
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
|
566 |
+
|
567 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
568 |
+
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
|
569 |
+
|
570 |
+
# Calculate the product of position index and $\theta_i$
|
571 |
+
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
|
572 |
+
|
573 |
+
return torch.cos(idx_theta), torch.sin(idx_theta)
|
574 |
+
|
575 |
+
|
576 |
+
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
577 |
+
head_size = x.size(-1)
|
578 |
+
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
579 |
+
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
|
580 |
+
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
581 |
+
roped = (x * cos) + (rotated * sin)
|
582 |
+
return roped.to(dtype=x.dtype)
|
583 |
+
|
584 |
+
|
585 |
+
class KVCache(nn.Module):
|
586 |
+
def __init__(
|
587 |
+
self,
|
588 |
+
k_shape: Tuple[int, int, int, int],
|
589 |
+
v_shape: Tuple[int, int, int, int],
|
590 |
+
device: Optional[torch.device] = None,
|
591 |
+
dtype: Optional[torch.dtype] = None,
|
592 |
+
) -> None:
|
593 |
+
super().__init__()
|
594 |
+
self.register_buffer(
|
595 |
+
"k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
|
596 |
+
)
|
597 |
+
self.register_buffer(
|
598 |
+
"v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
|
599 |
+
)
|
600 |
+
|
601 |
+
def forward(
|
602 |
+
self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
|
603 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
604 |
+
# move the buffer to the activation dtype for when AMP is used
|
605 |
+
self.k = self.k.to(k.dtype)
|
606 |
+
self.v = self.v.to(v.dtype)
|
607 |
+
# update the cache
|
608 |
+
k = self.k.index_copy_(2, input_pos, k)
|
609 |
+
v = self.v.index_copy_(2, input_pos, v)
|
610 |
+
return k, v
|
611 |
+
|
612 |
+
def reset_parameters(self) -> None:
|
613 |
+
torch.nn.init.zeros_(self.k)
|
614 |
+
torch.nn.init.zeros_(self.v)
|
615 |
+
|
616 |
+
|
617 |
+
def build_mask_cache(
|
618 |
+
max_seq_length: int, device: Optional[torch.device] = None
|
619 |
+
) -> torch.Tensor:
|
620 |
+
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
|
621 |
+
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
|
622 |
+
|
623 |
+
|
624 |
+
class RMSNorm(torch.nn.Module):
|
625 |
+
"""Root Mean Square Layer Normalization.
|
626 |
+
|
627 |
+
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
628 |
+
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
629 |
+
"""
|
630 |
+
|
631 |
+
def __init__(
|
632 |
+
self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
|
633 |
+
) -> None:
|
634 |
+
super().__init__()
|
635 |
+
self.weight = torch.nn.Parameter(torch.ones(size))
|
636 |
+
self.eps = eps
|
637 |
+
self.dim = dim
|
638 |
+
self.add_unit_offset = add_unit_offset
|
639 |
+
|
640 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
641 |
+
dtype = x.dtype
|
642 |
+
x = x.float()
|
643 |
+
# NOTE: the original RMSNorm paper implementation is not equivalent
|
644 |
+
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
645 |
+
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
646 |
+
x_normed = x_normed.to(dtype=dtype)
|
647 |
+
if self.add_unit_offset:
|
648 |
+
# Gemma model requires a unit offset
|
649 |
+
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
|
650 |
+
return x_normed * (1 + self.weight)
|
651 |
+
return x_normed * self.weight
|
652 |
+
|
653 |
+
def reset_parameters(self) -> None:
|
654 |
+
torch.nn.init.ones_(self.weight)
|
litgpt/tokenizer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 |
+
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class Tokenizer:
|
11 |
+
def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
|
12 |
+
checkpoint_dir = Path(checkpoint_dir)
|
13 |
+
if not checkpoint_dir.exists():
|
14 |
+
raise NotADirectoryError(
|
15 |
+
f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
|
16 |
+
)
|
17 |
+
|
18 |
+
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
|
19 |
+
self.bos_id = None
|
20 |
+
self.eos_id = None
|
21 |
+
|
22 |
+
# some checkpoints have both files, `.json` takes precedence
|
23 |
+
if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
|
24 |
+
from tokenizers import Tokenizer as HFTokenizer
|
25 |
+
|
26 |
+
self.processor = HFTokenizer.from_file(str(vocabulary_path))
|
27 |
+
self.backend = "huggingface"
|
28 |
+
|
29 |
+
if (
|
30 |
+
special_tokens_path := checkpoint_dir / "tokenizer_config.json"
|
31 |
+
).is_file():
|
32 |
+
with open(special_tokens_path, encoding="utf-8") as fp:
|
33 |
+
config = json.load(fp)
|
34 |
+
bos_token = config.get("bos_token")
|
35 |
+
eos_token = config.get("eos_token")
|
36 |
+
if bos_token is not None and isinstance(bos_token, dict):
|
37 |
+
bos_token = bos_token.get("content")
|
38 |
+
if eos_token is not None and isinstance(eos_token, dict):
|
39 |
+
eos_token = eos_token.get("content")
|
40 |
+
self.bos_id = (
|
41 |
+
self.token_to_id(bos_token) if bos_token is not None else None
|
42 |
+
)
|
43 |
+
self.eos_id = (
|
44 |
+
self.token_to_id(eos_token) if eos_token is not None else None
|
45 |
+
)
|
46 |
+
if (
|
47 |
+
special_tokens_path := checkpoint_dir / "generation_config.json"
|
48 |
+
).is_file():
|
49 |
+
with open(special_tokens_path, encoding="utf-8") as fp:
|
50 |
+
config = json.load(fp)
|
51 |
+
if self.bos_id is None:
|
52 |
+
self.bos_id = config.get("bos_token_id")
|
53 |
+
if self.eos_id is None:
|
54 |
+
self.eos_id = config.get("eos_token_id")
|
55 |
+
|
56 |
+
elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
|
57 |
+
from sentencepiece import SentencePieceProcessor
|
58 |
+
|
59 |
+
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
|
60 |
+
self.backend = "sentencepiece"
|
61 |
+
self.bos_id = self.processor.bos_id()
|
62 |
+
self.eos_id = self.processor.eos_id()
|
63 |
+
else:
|
64 |
+
raise NotImplementedError
|
65 |
+
|
66 |
+
@property
|
67 |
+
def vocab_size(self) -> int:
|
68 |
+
if self.backend == "huggingface":
|
69 |
+
return self.processor.get_vocab_size(with_added_tokens=False)
|
70 |
+
if self.backend == "sentencepiece":
|
71 |
+
return self.processor.vocab_size()
|
72 |
+
raise RuntimeError
|
73 |
+
|
74 |
+
def token_to_id(self, token: str) -> int:
|
75 |
+
if self.backend == "huggingface":
|
76 |
+
id_ = self.processor.token_to_id(token)
|
77 |
+
elif self.backend == "sentencepiece":
|
78 |
+
id_ = self.processor.piece_to_id(token)
|
79 |
+
else:
|
80 |
+
raise RuntimeError
|
81 |
+
if id_ is None:
|
82 |
+
raise ValueError(f"token {token!r} not found in the collection.")
|
83 |
+
return id_
|
84 |
+
|
85 |
+
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
|
86 |
+
if not (
|
87 |
+
tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
|
88 |
+
).is_file():
|
89 |
+
return False
|
90 |
+
with open(tokenizer_config_path, encoding="utf-8") as fp:
|
91 |
+
config = json.load(fp)
|
92 |
+
if "add_bos_token" in config:
|
93 |
+
return config["add_bos_token"]
|
94 |
+
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
|
95 |
+
# ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
|
96 |
+
return config.get("tokenizer_class") == "LlamaTokenizer"
|
97 |
+
|
98 |
+
def encode(
|
99 |
+
self,
|
100 |
+
string: str,
|
101 |
+
device: Optional[torch.device] = None,
|
102 |
+
bos: Optional[bool] = None,
|
103 |
+
eos: bool = False,
|
104 |
+
max_length: int = -1,
|
105 |
+
) -> torch.Tensor:
|
106 |
+
if self.backend == "huggingface":
|
107 |
+
tokens = self.processor.encode(string).ids
|
108 |
+
elif self.backend == "sentencepiece":
|
109 |
+
tokens = self.processor.encode(string)
|
110 |
+
else:
|
111 |
+
raise RuntimeError
|
112 |
+
if bos or (bos is None and self.use_bos):
|
113 |
+
bos_id = self.bos_id
|
114 |
+
if bos_id is None:
|
115 |
+
raise NotImplementedError(
|
116 |
+
"This tokenizer does not have a defined a bos token"
|
117 |
+
)
|
118 |
+
if tokens[0] != bos_id:
|
119 |
+
tokens = [bos_id] + tokens
|
120 |
+
if tokens is None:
|
121 |
+
raise ValueError("`tokens` is None")
|
122 |
+
|
123 |
+
if eos and (not tokens or tokens[-1] != self.eos_id):
|
124 |
+
tokens = tokens + [self.eos_id]
|
125 |
+
if max_length > 0:
|
126 |
+
tokens = tokens[:max_length]
|
127 |
+
return torch.tensor(tokens, dtype=torch.int, device=device)
|
128 |
+
|
129 |
+
def decode(self, tensor: torch.Tensor) -> str:
|
130 |
+
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
|
131 |
+
return self.processor.decode(tokens)
|
litgpt/utils.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
|
2 |
+
|
3 |
+
"""Utility functions for training and inference."""
|
4 |
+
import inspect
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import shutil
|
9 |
+
import sys
|
10 |
+
from dataclasses import asdict, is_dataclass
|
11 |
+
from io import BytesIO
|
12 |
+
from pathlib import Path
|
13 |
+
from typing import (
|
14 |
+
TYPE_CHECKING,
|
15 |
+
Any,
|
16 |
+
Dict,
|
17 |
+
Iterable,
|
18 |
+
List,
|
19 |
+
Literal,
|
20 |
+
Mapping,
|
21 |
+
Optional,
|
22 |
+
TypeVar,
|
23 |
+
Union,
|
24 |
+
)
|
25 |
+
|
26 |
+
import lightning as L
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
import torch.utils._device
|
30 |
+
import yaml
|
31 |
+
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
|
32 |
+
from lightning.fabric.strategies import FSDPStrategy
|
33 |
+
from lightning.fabric.utilities.load import _lazy_load as lazy_load
|
34 |
+
from lightning.pytorch.loggers import WandbLogger
|
35 |
+
from lightning.pytorch.cli import instantiate_class
|
36 |
+
from torch.serialization import normalize_storage_type
|
37 |
+
from typing_extensions import Self
|
38 |
+
|
39 |
+
if TYPE_CHECKING:
|
40 |
+
from litgpt import GPT, Config
|
41 |
+
|
42 |
+
|
43 |
+
def init_out_dir(out_dir: Path) -> Path:
|
44 |
+
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
|
45 |
+
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
|
46 |
+
return out_dir
|
47 |
+
|
48 |
+
|
49 |
+
def find_resume_path(
|
50 |
+
resume: Union[bool, Literal["auto"], Path], out_dir: Path
|
51 |
+
) -> Optional[Path]:
|
52 |
+
if not resume or isinstance(resume, Path):
|
53 |
+
return resume
|
54 |
+
|
55 |
+
resume_path = max(
|
56 |
+
out_dir.rglob("step-*/*.pth"),
|
57 |
+
key=(lambda p: int(p.parent.name.split("-")[1])),
|
58 |
+
default=None,
|
59 |
+
)
|
60 |
+
if resume == "auto":
|
61 |
+
return resume_path
|
62 |
+
if resume is True and resume_path is None:
|
63 |
+
raise FileNotFoundError(
|
64 |
+
f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
|
65 |
+
)
|
66 |
+
return resume_path
|
67 |
+
|
68 |
+
|
69 |
+
def find_multiple(n: int, k: int) -> int:
|
70 |
+
assert k > 0
|
71 |
+
if n % k == 0:
|
72 |
+
return n
|
73 |
+
return n + k - (n % k)
|
74 |
+
|
75 |
+
|
76 |
+
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
|
77 |
+
total = 0
|
78 |
+
for p in module.parameters():
|
79 |
+
if requires_grad is None or p.requires_grad == requires_grad:
|
80 |
+
if hasattr(p, "quant_state"):
|
81 |
+
# bitsandbytes 4bit layer support
|
82 |
+
total += math.prod(p.quant_state.shape)
|
83 |
+
else:
|
84 |
+
total += p.numel()
|
85 |
+
return total
|
86 |
+
|
87 |
+
|
88 |
+
def reset_parameters(module: nn.Module) -> None:
|
89 |
+
"""Calls `reset_parameters` on the module and all its submodules."""
|
90 |
+
for mod in module.modules():
|
91 |
+
if callable(getattr(mod, "reset_parameters", None)):
|
92 |
+
mod.reset_parameters()
|
93 |
+
|
94 |
+
|
95 |
+
def check_valid_checkpoint_dir(
|
96 |
+
checkpoint_dir: Path,
|
97 |
+
model_filename: str = "lit_model.pth",
|
98 |
+
verbose: bool = True,
|
99 |
+
raise_error: bool = False,
|
100 |
+
) -> None:
|
101 |
+
files = {
|
102 |
+
model_filename: (checkpoint_dir / model_filename).is_file(),
|
103 |
+
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
|
104 |
+
"tokenizer.json OR tokenizer.model": (
|
105 |
+
checkpoint_dir / "tokenizer.json"
|
106 |
+
).is_file()
|
107 |
+
or (checkpoint_dir / "tokenizer.model").is_file(),
|
108 |
+
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
|
109 |
+
}
|
110 |
+
if checkpoint_dir.is_dir():
|
111 |
+
if all(files.values()):
|
112 |
+
# we're good
|
113 |
+
return
|
114 |
+
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
|
115 |
+
else:
|
116 |
+
problem = " is not a checkpoint directory"
|
117 |
+
|
118 |
+
# list locally available checkpoints
|
119 |
+
available = list(Path("checkpoints").glob("*/*"))
|
120 |
+
if available:
|
121 |
+
options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
|
122 |
+
extra = f"\nYou have downloaded locally:{options}\n"
|
123 |
+
else:
|
124 |
+
extra = ""
|
125 |
+
|
126 |
+
if verbose:
|
127 |
+
error_message = (
|
128 |
+
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
|
129 |
+
"\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
|
130 |
+
f"{extra}\nSee all download options by running:\n litgpt download"
|
131 |
+
)
|
132 |
+
print(error_message, file=sys.stderr)
|
133 |
+
|
134 |
+
if raise_error:
|
135 |
+
raise FileNotFoundError(
|
136 |
+
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
raise SystemExit(1)
|
140 |
+
|
141 |
+
|
142 |
+
class SavingProxyForStorage:
|
143 |
+
def __init__(self, obj, saver, protocol_version=5):
|
144 |
+
self.protocol_version = protocol_version
|
145 |
+
self.saver = saver
|
146 |
+
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
|
147 |
+
raise TypeError(f"expected storage, not {type(obj)}")
|
148 |
+
|
149 |
+
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
150 |
+
if isinstance(obj, torch.storage.TypedStorage):
|
151 |
+
# PT upstream wants to deprecate this eventually...
|
152 |
+
storage = obj._untyped_storage
|
153 |
+
storage_type_str = obj._pickle_storage_type()
|
154 |
+
storage_type = getattr(torch, storage_type_str)
|
155 |
+
storage_numel = obj._size()
|
156 |
+
else:
|
157 |
+
storage = obj
|
158 |
+
storage_type = normalize_storage_type(type(obj))
|
159 |
+
storage_numel = storage.nbytes()
|
160 |
+
|
161 |
+
storage_key = saver._write_storage_and_return_key(storage)
|
162 |
+
location = torch.serialization.location_tag(storage)
|
163 |
+
|
164 |
+
self.storage_info = (
|
165 |
+
"storage",
|
166 |
+
storage_type,
|
167 |
+
storage_key,
|
168 |
+
location,
|
169 |
+
storage_numel,
|
170 |
+
)
|
171 |
+
|
172 |
+
def __reduce_ex__(self, protocol_version):
|
173 |
+
assert False, "this should be handled with out of band"
|
174 |
+
|
175 |
+
|
176 |
+
class SavingProxyForTensor:
|
177 |
+
def __init__(self, tensor, saver, protocol_version=5):
|
178 |
+
self.protocol_version = protocol_version
|
179 |
+
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
|
180 |
+
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
|
181 |
+
# for Tensors with Python attributes
|
182 |
+
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
|
183 |
+
assert isinstance(
|
184 |
+
storage, torch.storage.TypedStorage
|
185 |
+
), "Please check for updates"
|
186 |
+
storage_proxy = SavingProxyForStorage(
|
187 |
+
storage, saver, protocol_version=protocol_version
|
188 |
+
)
|
189 |
+
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
|
190 |
+
else:
|
191 |
+
(storage, *other_reduce_args) = reduce_args
|
192 |
+
assert isinstance(
|
193 |
+
storage, torch.storage.TypedStorage
|
194 |
+
), "Please check for updates"
|
195 |
+
storage_proxy = SavingProxyForStorage(
|
196 |
+
storage, saver, protocol_version=protocol_version
|
197 |
+
)
|
198 |
+
self.reduce_args = (storage_proxy, *other_reduce_args)
|
199 |
+
|
200 |
+
def __reduce_ex__(self, protocol_version):
|
201 |
+
if protocol_version != self.protocol_version:
|
202 |
+
raise RuntimeError(
|
203 |
+
f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
|
204 |
+
)
|
205 |
+
return self.reduce_ret_fn, self.reduce_args
|
206 |
+
|
207 |
+
|
208 |
+
class IncrementalPyTorchPickler(pickle.Pickler):
|
209 |
+
def __init__(self, saver, *args, **kwargs):
|
210 |
+
super().__init__(*args, **kwargs)
|
211 |
+
self.storage_dtypes = {}
|
212 |
+
self.saver = saver
|
213 |
+
self.id_map = {}
|
214 |
+
|
215 |
+
# this logic is taken from PyTorch 2.0+ torch/serialization.py
|
216 |
+
def persistent_id(self, obj):
|
217 |
+
# FIXME: the docs say that persistent_id should only return a string
|
218 |
+
# but torch store returns tuples. This works only in the binary protocol
|
219 |
+
# see
|
220 |
+
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
221 |
+
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
222 |
+
if isinstance(obj, SavingProxyForStorage):
|
223 |
+
return obj.storage_info
|
224 |
+
|
225 |
+
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
226 |
+
if isinstance(obj, torch.storage.TypedStorage):
|
227 |
+
# TODO: Once we decide to break serialization FC, this case
|
228 |
+
# can be deleted
|
229 |
+
storage = obj._untyped_storage
|
230 |
+
storage_dtype = obj.dtype
|
231 |
+
storage_type_str = obj._pickle_storage_type()
|
232 |
+
storage_type = getattr(torch, storage_type_str)
|
233 |
+
storage_numel = obj._size()
|
234 |
+
|
235 |
+
else:
|
236 |
+
storage = obj
|
237 |
+
storage_dtype = torch.uint8
|
238 |
+
storage_type = normalize_storage_type(type(obj))
|
239 |
+
storage_numel = storage.nbytes()
|
240 |
+
|
241 |
+
# If storage is allocated, ensure that any other saved storages
|
242 |
+
# pointing to the same data all have the same dtype. If storage is
|
243 |
+
# not allocated, don't perform this check
|
244 |
+
if storage.data_ptr() != 0:
|
245 |
+
if storage.data_ptr() in self.storage_dtypes:
|
246 |
+
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
|
247 |
+
raise RuntimeError(
|
248 |
+
"Cannot save multiple tensors or storages that view the same data as different types"
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
self.storage_dtypes[storage.data_ptr()] = storage_dtype
|
252 |
+
|
253 |
+
storage_key = self.id_map.get(storage._cdata)
|
254 |
+
if storage_key is None:
|
255 |
+
storage_key = self.saver._write_storage_and_return_key(storage)
|
256 |
+
self.id_map[storage._cdata] = storage_key
|
257 |
+
location = torch.serialization.location_tag(storage)
|
258 |
+
|
259 |
+
return ("storage", storage_type, storage_key, location, storage_numel)
|
260 |
+
|
261 |
+
return None
|
262 |
+
|
263 |
+
|
264 |
+
class incremental_save:
|
265 |
+
def __init__(self, name):
|
266 |
+
self.name = name
|
267 |
+
self.zipfile = torch._C.PyTorchFileWriter(str(name))
|
268 |
+
self.has_saved = False
|
269 |
+
self.next_key = 0
|
270 |
+
|
271 |
+
def __enter__(self):
|
272 |
+
return self
|
273 |
+
|
274 |
+
def store_early(self, tensor):
|
275 |
+
if isinstance(tensor, torch.Tensor):
|
276 |
+
return SavingProxyForTensor(tensor, self)
|
277 |
+
raise TypeError(f"can only store tensors early, not {type(tensor)}")
|
278 |
+
|
279 |
+
def save(self, obj):
|
280 |
+
if self.has_saved:
|
281 |
+
raise RuntimeError("have already saved")
|
282 |
+
# Write the pickle data for `obj`
|
283 |
+
data_buf = BytesIO()
|
284 |
+
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
|
285 |
+
pickler.dump(obj)
|
286 |
+
data_value = data_buf.getvalue()
|
287 |
+
self.zipfile.write_record("data.pkl", data_value, len(data_value))
|
288 |
+
self.has_saved = True
|
289 |
+
|
290 |
+
def _write_storage_and_return_key(self, storage):
|
291 |
+
if self.has_saved:
|
292 |
+
raise RuntimeError("have already saved")
|
293 |
+
key = self.next_key
|
294 |
+
self.next_key += 1
|
295 |
+
name = f"data/{key}"
|
296 |
+
if storage.device.type != "cpu":
|
297 |
+
storage = storage.cpu()
|
298 |
+
num_bytes = storage.nbytes()
|
299 |
+
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
|
300 |
+
return key
|
301 |
+
|
302 |
+
def __exit__(self, type, value, traceback):
|
303 |
+
self.zipfile.write_end_of_file()
|
304 |
+
|
305 |
+
|
306 |
+
T = TypeVar("T")
|
307 |
+
|
308 |
+
|
309 |
+
def chunked_cross_entropy(
|
310 |
+
logits: Union[torch.Tensor, List[torch.Tensor]],
|
311 |
+
targets: torch.Tensor,
|
312 |
+
chunk_size: int = 128,
|
313 |
+
ignore_index: int = -100,
|
314 |
+
) -> torch.Tensor:
|
315 |
+
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
|
316 |
+
# the memory usage in fine-tuning settings with low number of parameters.
|
317 |
+
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
|
318 |
+
# the memory spike's magnitude
|
319 |
+
|
320 |
+
# lm_head was chunked (we are fine-tuning)
|
321 |
+
if isinstance(logits, list):
|
322 |
+
# don't want to chunk cross entropy
|
323 |
+
if chunk_size == 0:
|
324 |
+
logits = torch.cat(logits, dim=1)
|
325 |
+
logits = logits.reshape(-1, logits.size(-1))
|
326 |
+
targets = targets.reshape(-1)
|
327 |
+
return torch.nn.functional.cross_entropy(
|
328 |
+
logits, targets, ignore_index=ignore_index
|
329 |
+
)
|
330 |
+
|
331 |
+
# chunk cross entropy
|
332 |
+
logit_chunks = [
|
333 |
+
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
|
334 |
+
]
|
335 |
+
target_chunks = [
|
336 |
+
target_chunk.reshape(-1)
|
337 |
+
for target_chunk in targets.split(logits[0].size(1), dim=1)
|
338 |
+
]
|
339 |
+
loss_chunks = [
|
340 |
+
torch.nn.functional.cross_entropy(
|
341 |
+
logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
|
342 |
+
)
|
343 |
+
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
|
344 |
+
]
|
345 |
+
non_masked_elems = (targets != ignore_index).sum()
|
346 |
+
# See [non_masked_elems div note]
|
347 |
+
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
|
348 |
+
torch.ones_like(non_masked_elems)
|
349 |
+
)
|
350 |
+
|
351 |
+
# no chunking at all
|
352 |
+
logits = logits.reshape(-1, logits.size(-1))
|
353 |
+
targets = targets.reshape(-1)
|
354 |
+
if chunk_size == 0:
|
355 |
+
return torch.nn.functional.cross_entropy(
|
356 |
+
logits, targets, ignore_index=ignore_index
|
357 |
+
)
|
358 |
+
|
359 |
+
# lm_head wasn't chunked, chunk cross entropy
|
360 |
+
logit_chunks = logits.split(chunk_size)
|
361 |
+
target_chunks = targets.split(chunk_size)
|
362 |
+
loss_chunks = [
|
363 |
+
torch.nn.functional.cross_entropy(
|
364 |
+
logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
|
365 |
+
)
|
366 |
+
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
|
367 |
+
]
|
368 |
+
non_masked_elems = (targets != ignore_index).sum()
|
369 |
+
# [non_masked_elems div note]:
|
370 |
+
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
|
371 |
+
# results in a python int which is then passed back to torch division. By using the
|
372 |
+
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
|
373 |
+
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
|
374 |
+
torch.ones_like(non_masked_elems)
|
375 |
+
)
|
376 |
+
|
377 |
+
|
378 |
+
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
|
379 |
+
for checkpoint_name, attribute_name in mapping.items():
|
380 |
+
full_checkpoint_name = prefix + checkpoint_name
|
381 |
+
if full_checkpoint_name in state_dict:
|
382 |
+
full_attribute_name = prefix + attribute_name
|
383 |
+
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
|
384 |
+
return state_dict
|
385 |
+
|
386 |
+
|
387 |
+
def get_default_supported_precision(training: bool) -> str:
|
388 |
+
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
training: `-mixed` or `-true` version of the precision to use
|
392 |
+
|
393 |
+
Returns:
|
394 |
+
default precision that is suitable for the task and is supported by the hardware
|
395 |
+
"""
|
396 |
+
from lightning.fabric.accelerators import MPSAccelerator
|
397 |
+
|
398 |
+
if MPSAccelerator.is_available() or (
|
399 |
+
torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
|
400 |
+
):
|
401 |
+
return "16-mixed" if training else "16-true"
|
402 |
+
return "bf16-mixed" if training else "bf16-true"
|
403 |
+
|
404 |
+
|
405 |
+
def load_checkpoint(
|
406 |
+
fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
|
407 |
+
) -> None:
|
408 |
+
if isinstance(fabric.strategy, FSDPStrategy):
|
409 |
+
fabric.load_raw(checkpoint_path, model, strict=strict)
|
410 |
+
else:
|
411 |
+
state_dict = lazy_load(checkpoint_path)
|
412 |
+
state_dict = state_dict.get("model", state_dict)
|
413 |
+
model.load_state_dict(state_dict, strict=strict)
|
414 |
+
|
415 |
+
|
416 |
+
def flops_per_param(
|
417 |
+
max_seq_length: int, n_layer: int, n_embd: int, n_params: int
|
418 |
+
) -> int:
|
419 |
+
flops_per_token = (
|
420 |
+
2 * n_params
|
421 |
+
) # each parameter is used for a MAC (2 FLOPS) per network operation
|
422 |
+
# this assumes that all samples have a fixed length equal to the block size
|
423 |
+
# which is most likely false during finetuning
|
424 |
+
flops_per_seq = flops_per_token * max_seq_length
|
425 |
+
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
|
426 |
+
return flops_per_seq + attn_flops_per_seq
|
427 |
+
|
428 |
+
|
429 |
+
def estimate_flops(model: "GPT", training: bool) -> int:
|
430 |
+
"""Measures estimated FLOPs for MFU.
|
431 |
+
|
432 |
+
Refs:
|
433 |
+
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
|
434 |
+
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
|
435 |
+
"""
|
436 |
+
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
|
437 |
+
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
|
438 |
+
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
|
439 |
+
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
|
440 |
+
n_trainable_params = num_parameters(model, requires_grad=True)
|
441 |
+
trainable_flops = flops_per_param(
|
442 |
+
model.max_seq_length,
|
443 |
+
model.config.n_layer,
|
444 |
+
model.config.n_embd,
|
445 |
+
n_trainable_params,
|
446 |
+
)
|
447 |
+
# forward + backward + gradients (assumes no gradient accumulation)
|
448 |
+
ops_per_step = 3 if training else 1
|
449 |
+
n_frozen_params = num_parameters(model, requires_grad=False)
|
450 |
+
frozen_flops = flops_per_param(
|
451 |
+
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
|
452 |
+
)
|
453 |
+
# forward + backward
|
454 |
+
frozen_ops_per_step = 2 if training else 1
|
455 |
+
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
|
456 |
+
|
457 |
+
|
458 |
+
class CycleIterator:
|
459 |
+
"""An iterator that cycles through an iterable indefinitely.
|
460 |
+
|
461 |
+
Example:
|
462 |
+
>>> iterator = CycleIterator([1, 2, 3])
|
463 |
+
>>> [next(iterator) for _ in range(5)]
|
464 |
+
[1, 2, 3, 1, 2]
|
465 |
+
|
466 |
+
Note:
|
467 |
+
Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
|
468 |
+
"""
|
469 |
+
|
470 |
+
def __init__(self, iterable: Iterable) -> None:
|
471 |
+
self.iterable = iterable
|
472 |
+
self.epoch = 0
|
473 |
+
self._iterator = None
|
474 |
+
|
475 |
+
def __next__(self) -> Any:
|
476 |
+
if self._iterator is None:
|
477 |
+
self._iterator = iter(self.iterable)
|
478 |
+
try:
|
479 |
+
return next(self._iterator)
|
480 |
+
except StopIteration:
|
481 |
+
self._iterator = iter(self.iterable)
|
482 |
+
self.epoch += 1
|
483 |
+
return next(self._iterator)
|
484 |
+
|
485 |
+
def __iter__(self) -> Self:
|
486 |
+
return self
|
487 |
+
|
488 |
+
|
489 |
+
def copy_config_files(source_dir: Path, out_dir: Path) -> None:
|
490 |
+
"""Copies the specified configuration and tokenizer files into the output directory."""
|
491 |
+
|
492 |
+
config_files = ["config.json", "generation_config.json", "model_config.yaml"]
|
493 |
+
tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
|
494 |
+
|
495 |
+
for file_name in config_files + tokenizer_files:
|
496 |
+
src_path = source_dir / file_name
|
497 |
+
if src_path.exists():
|
498 |
+
shutil.copy(src_path, out_dir)
|
499 |
+
|
500 |
+
|
501 |
+
def CLI(*args: Any, **kwargs: Any) -> Any:
|
502 |
+
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
|
503 |
+
|
504 |
+
set_docstring_parse_options(attribute_docstrings=True)
|
505 |
+
set_config_read_mode(urls_enabled=True)
|
506 |
+
|
507 |
+
return CLI(*args, **kwargs)
|
508 |
+
|
509 |
+
|
510 |
+
def capture_hparams() -> Dict[str, Any]:
|
511 |
+
"""Captures the local variables ('hyperparameters') from where this function gets called."""
|
512 |
+
caller_frame = inspect.currentframe().f_back
|
513 |
+
locals_of_caller = caller_frame.f_locals
|
514 |
+
hparams = {}
|
515 |
+
for name, value in locals_of_caller.items():
|
516 |
+
if value is None or isinstance(value, (int, float, str, bool, Path)):
|
517 |
+
hparams[name] = value
|
518 |
+
elif is_dataclass(value):
|
519 |
+
hparams[name] = asdict(value)
|
520 |
+
else:
|
521 |
+
hparams[name] = str(value)
|
522 |
+
return hparams
|
523 |
+
|
524 |
+
|
525 |
+
def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
|
526 |
+
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
|
527 |
+
from jsonargparse import capture_parser
|
528 |
+
|
529 |
+
# TODO: Make this more robust
|
530 |
+
# This hack strips away the subcommands from the top-level CLI
|
531 |
+
# to parse the file as if it was called as a script
|
532 |
+
known_commands = [
|
533 |
+
("finetune_full",), # For subcommands, use `("finetune", "full")` etc
|
534 |
+
("finetune_lora",),
|
535 |
+
("finetune_adapter",),
|
536 |
+
("finetune_adapter_v2",),
|
537 |
+
("finetune",),
|
538 |
+
("pretrain",),
|
539 |
+
]
|
540 |
+
for known_command in known_commands:
|
541 |
+
unwanted = slice(1, 1 + len(known_command))
|
542 |
+
if tuple(sys.argv[unwanted]) == known_command:
|
543 |
+
sys.argv[unwanted] = []
|
544 |
+
|
545 |
+
parser = capture_parser(lambda: CLI(function))
|
546 |
+
config = parser.parse_args()
|
547 |
+
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
|
548 |
+
|
549 |
+
|
550 |
+
def save_config(config: "Config", checkpoint_dir: Path) -> None:
|
551 |
+
config_dict = asdict(config)
|
552 |
+
with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
|
553 |
+
yaml.dump(config_dict, fp)
|
554 |
+
|
555 |
+
|
556 |
+
def parse_devices(devices: Union[str, int]) -> int:
|
557 |
+
if devices in (-1, "auto"):
|
558 |
+
return torch.cuda.device_count() or 1
|
559 |
+
if isinstance(devices, int) and devices > 0:
|
560 |
+
return devices
|
561 |
+
raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
|
562 |
+
|
563 |
+
|
564 |
+
def choose_logger(
|
565 |
+
logger_name: Literal["csv", "tensorboard", "wandb"],
|
566 |
+
out_dir: Path,
|
567 |
+
name: str,
|
568 |
+
log_interval: int = 1,
|
569 |
+
resume: Optional[bool] = None,
|
570 |
+
**kwargs: Any,
|
571 |
+
):
|
572 |
+
if logger_name == "csv":
|
573 |
+
return CSVLogger(
|
574 |
+
root_dir=(out_dir / "logs"),
|
575 |
+
name="csv",
|
576 |
+
flush_logs_every_n_steps=log_interval,
|
577 |
+
**kwargs,
|
578 |
+
)
|
579 |
+
if logger_name == "tensorboard":
|
580 |
+
return TensorBoardLogger(
|
581 |
+
root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
|
582 |
+
)
|
583 |
+
if logger_name == "wandb":
|
584 |
+
return WandbLogger(project=name, resume=resume, **kwargs)
|
585 |
+
raise ValueError(
|
586 |
+
f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
|
587 |
+
)
|
588 |
+
|
589 |
+
|
590 |
+
def get_argument_names(cls):
|
591 |
+
sig = inspect.signature(cls.__init__)
|
592 |
+
return {
|
593 |
+
name
|
594 |
+
for name, param in sig.parameters.items()
|
595 |
+
if param.kind
|
596 |
+
in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
|
597 |
+
}
|
598 |
+
|
599 |
+
|
600 |
+
def instantiate_bnb_optimizer(optimizer, model_parameters):
|
601 |
+
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
|
602 |
+
isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
|
603 |
+
):
|
604 |
+
raise ValueError(
|
605 |
+
"The chosen quantization format only supports the AdamW optimizer."
|
606 |
+
)
|
607 |
+
|
608 |
+
import bitsandbytes as bnb
|
609 |
+
|
610 |
+
if isinstance(optimizer, str):
|
611 |
+
optimizer = bnb.optim.PagedAdamW(model_parameters)
|
612 |
+
else:
|
613 |
+
optim_args = get_argument_names(bnb.optim.PagedAdamW)
|
614 |
+
allowed_kwargs = {
|
615 |
+
key: optimizer["init_args"][key]
|
616 |
+
for key in optim_args & optimizer["init_args"].keys()
|
617 |
+
}
|
618 |
+
optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
|
619 |
+
return optimizer
|
620 |
+
|
621 |
+
|
622 |
+
def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
|
623 |
+
if isinstance(optimizer, str):
|
624 |
+
optimizer_cls = getattr(torch.optim, optimizer)
|
625 |
+
optimizer = optimizer_cls(model_parameters, **kwargs)
|
626 |
+
else:
|
627 |
+
optimizer = dict(optimizer) # copy
|
628 |
+
optimizer["init_args"].update(kwargs)
|
629 |
+
optimizer = instantiate_class(model_parameters, optimizer)
|
630 |
+
return optimizer
|
631 |
+
|
632 |
+
|
633 |
+
def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
|
634 |
+
new_checkpoint_dir = "checkpoints" / checkpoint_dir
|
635 |
+
should_return_new_dir = (
|
636 |
+
not checkpoint_dir.is_dir()
|
637 |
+
and checkpoint_dir.parts[0] != "checkpoints"
|
638 |
+
and not checkpoint_dir.is_absolute()
|
639 |
+
and new_checkpoint_dir.exists()
|
640 |
+
)
|
641 |
+
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
|
models/README.md
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: any-to-any
|
4 |
+
library_name: mini-omni2
|
5 |
+
---
|
6 |
+
|
7 |
+
# Mini-Omni2
|
8 |
+
|
9 |
+
<!-- <p align="center">
|
10 |
+
<img src="./data/figures/title.png" width="100%"/>
|
11 |
+
</p> -->
|
12 |
+
|
13 |
+
|
14 |
+
<p align="center">
|
15 |
+
🤗 <a href="https://huggingface.co/gpt-omni/mini-omni2">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni2">Github</a>
|
16 |
+
| 📑 <a href="https://arxiv.org/abs/2410.11190">Technical report</a>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
Mini-Omni2 is an **omni-interactive** model. It can **understand image, audio and text inputs and has end-to-end voice conversations with users**. Featuring **real-time voice output**, **omni-capable multimodal understanding** and flexible interaction **ability with interruption mechanism while speaking**.
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<img src="./data/figures/framework.jpeg" width="100%"/>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
|
26 |
+
## Updates
|
27 |
+
|
28 |
+
- **2024.10:** Release the model, technical report, inference and chat demo code.
|
29 |
+
|
30 |
+
## Features
|
31 |
+
✅ **Multimodal interaction**: with the ability to understand images, speech and text, just like GPT-4o.
|
32 |
+
|
33 |
+
✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required, just like [Mini-Omni](https://github.com/gpt-omni/mini-omni).
|
34 |
+
|
35 |
+
<!-- ✅ **Streaming audio output**: with first-chunk latency of audio stream less than 0.3s. -->
|
36 |
+
|
37 |
+
<!-- ✅ **Duplex interaction**: hearing while speaking, it can be interrupted by key words like "stop omni". -->
|
38 |
+
|
39 |
+
|
40 |
+
## Demo
|
41 |
+
|
42 |
+
NOTE: need to unmute first.
|
43 |
+
|
44 |
+
https://github.com/user-attachments/assets/ad97ca7f-f8b4-40c3-a7e8-fa54b4edf155
|
45 |
+
|
46 |
+
|
47 |
+
## ToDo
|
48 |
+
- [ ] update interruption mechanism
|
49 |
+
|
50 |
+
|
51 |
+
## Install
|
52 |
+
|
53 |
+
Create a new conda environment and install the required packages:
|
54 |
+
|
55 |
+
```sh
|
56 |
+
conda create -n omni python=3.10
|
57 |
+
conda activate omni
|
58 |
+
|
59 |
+
git clone https://github.com/gpt-omni/mini-omni2.git
|
60 |
+
cd mini-omni2
|
61 |
+
pip install -r requirements.txt
|
62 |
+
```
|
63 |
+
|
64 |
+
## Quick start
|
65 |
+
|
66 |
+
**Interactive demo**
|
67 |
+
|
68 |
+
- start server
|
69 |
+
|
70 |
+
NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
|
71 |
+
|
72 |
+
```sh
|
73 |
+
sudo apt-get install ffmpeg
|
74 |
+
conda activate omni
|
75 |
+
cd mini-omni2
|
76 |
+
python3 server.py --ip '0.0.0.0' --port 60808
|
77 |
+
```
|
78 |
+
|
79 |
+
|
80 |
+
- run streamlit demo
|
81 |
+
|
82 |
+
NOTE: you need to run streamlit **locally** with PyAudio installed.
|
83 |
+
|
84 |
+
```sh
|
85 |
+
pip install PyAudio==0.2.14
|
86 |
+
API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
|
87 |
+
```
|
88 |
+
|
89 |
+
|
90 |
+
**Local test**
|
91 |
+
|
92 |
+
```sh
|
93 |
+
conda activate omni
|
94 |
+
cd mini-omni2
|
95 |
+
# test run the preset audio samples and questions
|
96 |
+
python inference_vision.py
|
97 |
+
```
|
98 |
+
|
99 |
+
## Mini-Omni2 Overview
|
100 |
+
|
101 |
+
**1. Multimodal Modeling**:
|
102 |
+
We use multiple sequences as the input and output of the model. In the input part, we will concatenate image, audio and text features to perform a series of comprehensive tasks, as shown in the following figures. In the output part, we use text-guided delayed parallel output to generate real-time speech responses.
|
103 |
+
<p align="center">
|
104 |
+
<img src="./data/figures/inputids.png" width="100%"/>
|
105 |
+
</p>
|
106 |
+
|
107 |
+
**2. Multi-stage Training**:
|
108 |
+
We propose an efficient alignment training method and conduct encoder adaptation, modal alignment, and multimodal fine-tuning respectively in the three-stage training.
|
109 |
+
<p align="center">
|
110 |
+
<img src="./data/figures/training.jpeg" width="100%"/>
|
111 |
+
</p>
|
112 |
+
|
113 |
+
<!-- **3. Cases**:
|
114 |
+
Here are more cases of Mini-Omni2:
|
115 |
+
<p align="center">
|
116 |
+
<img src="./data/figures/samples.png" width="100%"/>
|
117 |
+
</p> -->
|
118 |
+
|
119 |
+
## FAQ
|
120 |
+
|
121 |
+
**1. Does the model support other languages?**
|
122 |
+
|
123 |
+
No, the model is only trained on English. However, as we use whisper as the audio encoder, the model can understand other languages which is supported by whisper (like chinese), but the output is only in English.
|
124 |
+
|
125 |
+
**2. Error: can not run streamlit in local browser, with remote streamlit server**
|
126 |
+
|
127 |
+
You need start streamlit **locally** with PyAudio installed.
|
128 |
+
|
129 |
+
|
130 |
+
## Acknowledgements
|
131 |
+
|
132 |
+
- [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
|
133 |
+
- [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
|
134 |
+
- [whisper](https://github.com/openai/whisper/) for audio encoding.
|
135 |
+
- [clip](https://github.com/openai/CLIP) for image encoding.
|
136 |
+
- [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
|
137 |
+
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
|
138 |
+
- [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
|
139 |
+
|
140 |
+
<!-- ## Star History
|
141 |
+
|
142 |
+
[](https://star-history.com/#gpt-omni/mini-omni2&Date)
|
models/ViT-B-32.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
|
3 |
+
size 353976522
|
models/data/figures/framework.jpeg
ADDED
![]() |
Git LFS Details
|
models/data/figures/inputids.png
ADDED
![]() |
Git LFS Details
|
models/data/figures/samples.png
ADDED
![]() |
Git LFS Details
|
models/data/figures/title.png
ADDED
![]() |
Git LFS Details
|
models/data/figures/training.jpeg
ADDED
![]() |
Git LFS Details
|
models/data/omni2-demo.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c2098124af391dca9c48854f5686143c137cc069f08b5e457675b9ba744bd2f
|
3 |
+
size 11784395
|
models/hub/.locks/models--hubertsiuzdak--snac_24khz/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40.lock
ADDED
File without changes
|
models/hub/.locks/models--hubertsiuzdak--snac_24khz/a9e7ef62bf7e1eb94d2713721029837aacab3b55.lock
ADDED
File without changes
|
models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
|
3 |
+
size 79488254
|
models/hub/models--hubertsiuzdak--snac_24khz/blobs/a9e7ef62bf7e1eb94d2713721029837aacab3b55
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sampling_rate": 24000,
|
3 |
+
"encoder_dim": 48,
|
4 |
+
"encoder_rates": [2, 4, 8, 8],
|
5 |
+
"decoder_dim": 1024,
|
6 |
+
"decoder_rates": [8, 8, 4, 2],
|
7 |
+
"attn_window_size": null,
|
8 |
+
"codebook_size": 4096,
|
9 |
+
"codebook_dim": 8,
|
10 |
+
"vq_strides": [4, 2, 1],
|
11 |
+
"noise": true,
|
12 |
+
"depthwise": true
|
13 |
+
}
|
models/hub/models--hubertsiuzdak--snac_24khz/refs/main
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec
|
models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sampling_rate": 24000,
|
3 |
+
"encoder_dim": 48,
|
4 |
+
"encoder_rates": [2, 4, 8, 8],
|
5 |
+
"decoder_dim": 1024,
|
6 |
+
"decoder_rates": [8, 8, 4, 2],
|
7 |
+
"attn_window_size": null,
|
8 |
+
"codebook_size": 4096,
|
9 |
+
"codebook_dim": 8,
|
10 |
+
"vq_strides": [4, 2, 1],
|
11 |
+
"noise": true,
|
12 |
+
"depthwise": true
|
13 |
+
}
|
models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
|
3 |
+
size 79488254
|
models/hub/version.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
models/lit_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98a0e4ad1912b0ee0081a5898b6c260f4b06a37aaef34e49b0719b46808c231c
|
3 |
+
size 2814623738
|
models/model_config.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
add_qkv_bias: true
|
2 |
+
asr_adapter: llamamlp
|
3 |
+
attn_dropout: 0.0
|
4 |
+
bias: false
|
5 |
+
block_size: 2048
|
6 |
+
force_align: false
|
7 |
+
gelu_approximate: none
|
8 |
+
head_size: 64
|
9 |
+
hf_config:
|
10 |
+
name: Qwen2-0.5B
|
11 |
+
org: Qwen
|
12 |
+
intermediate_size: 4864
|
13 |
+
lm_head_bias: false
|
14 |
+
mlp_class_name: LLaMAMLP
|
15 |
+
n_embd: 896
|
16 |
+
n_expert: 0
|
17 |
+
n_expert_per_token: 0
|
18 |
+
n_head: 14
|
19 |
+
n_layer: 24
|
20 |
+
n_query_groups: 2
|
21 |
+
name: Qwen2-0.5B
|
22 |
+
norm_class_name: RMSNorm
|
23 |
+
norm_eps: 1.0e-06
|
24 |
+
padded_vocab_size: 181120
|
25 |
+
padding_multiple: 512
|
26 |
+
parallel_residual: false
|
27 |
+
pos_type: rope
|
28 |
+
post_adapter: false
|
29 |
+
post_adapter_layers: 6
|
30 |
+
prompt_vocab_size: null
|
31 |
+
rope_base: 1000000
|
32 |
+
rope_condense_ratio: 1
|
33 |
+
rotary_percentage: 1
|
34 |
+
scale_embeddings: false
|
35 |
+
shared_attention_norm: false
|
36 |
+
tie_word_embeddings: true
|
37 |
+
use_pretrain_phoneme_emb: false
|
38 |
+
vocab_size: 50254
|
39 |
+
text_vocab_size: 152000
|
40 |
+
cat_audio_vocab_size: 29120
|
41 |
+
audio_vocab_size: 4160
|
42 |
+
whisper_adapter_dim: 768
|
43 |
+
vision_adapter_dim: 512
|
models/small.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fb26a40bfcfbb3d7e41586205d21c90ffc1de552c15367efb4a723ce11f700f
|
3 |
+
size 483586606
|
models/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/tokenizer_config.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"151643": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"151644": {
|
13 |
+
"content": "<|im_start|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
},
|
20 |
+
"151645": {
|
21 |
+
"content": "<|im_end|>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": false,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false,
|
26 |
+
"special": true
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
|
30 |
+
"bos_token": null,
|
31 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
32 |
+
"clean_up_tokenization_spaces": false,
|
33 |
+
"eos_token": "<|endoftext|>",
|
34 |
+
"errors": "replace",
|
35 |
+
"model_max_length": 32768,
|
36 |
+
"pad_token": "<|endoftext|>",
|
37 |
+
"split_special_tokens": false,
|
38 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
39 |
+
"unk_token": null
|
40 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
torchvision==0.18.1
|
3 |
+
torchaudio==2.3.1
|
4 |
+
litgpt==0.4.3
|
5 |
+
snac==1.2.0
|
6 |
+
soundfile==0.12.1
|
7 |
+
openai-whisper
|
8 |
+
tokenizers==0.19.1
|
9 |
+
streamlit==1.37.1
|
10 |
+
streamlit-webrtc
|
11 |
+
# PyAudio==0.2.14
|
12 |
+
pydub==0.25.1
|
13 |
+
onnxruntime==1.19.0
|
14 |
+
# numpy==1.26.3
|
15 |
+
librosa==0.10.2.post1
|
16 |
+
flask==3.0.3
|
17 |
+
fire
|
18 |
+
git+https://github.com/mini-omni/CLIP.git
|
19 |
+
gradio_webrtc[vad]==0.0.11
|
20 |
+
twilio
|