aizhibin commited on
Commit
f3264d3
·
1 Parent(s): 4740c9d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +18 -0
  2. Dockerfile +39 -0
  3. LICENSE +21 -0
  4. README.md +13 -0
  5. __init__.py +4 -0
  6. audio_qa_out_cache.wav +3 -0
  7. data/figures/framework.jpeg +3 -0
  8. data/figures/inputids.png +3 -0
  9. data/figures/samples.png +3 -0
  10. data/figures/title_new.png +3 -0
  11. data/figures/training.jpeg +3 -0
  12. data/omni2-demo.mp4 +3 -0
  13. data/samples/output1.wav +0 -0
  14. data/samples/output2.wav +3 -0
  15. data/samples/output3.wav +0 -0
  16. data/samples/output4.wav +0 -0
  17. data/samples/output5.wav +3 -0
  18. data/samples/vision_qa_audio.wav +3 -0
  19. hotkey.txt +1 -0
  20. inference.py +705 -0
  21. inference_vision.py +259 -0
  22. litgpt/__init__.py +19 -0
  23. litgpt/config.py +181 -0
  24. litgpt/generate/__init__.py +0 -0
  25. litgpt/generate/base.py +795 -0
  26. litgpt/model.py +654 -0
  27. litgpt/tokenizer.py +131 -0
  28. litgpt/utils.py +641 -0
  29. models/README.md +142 -0
  30. models/ViT-B-32.pt +3 -0
  31. models/data/figures/framework.jpeg +3 -0
  32. models/data/figures/inputids.png +3 -0
  33. models/data/figures/samples.png +3 -0
  34. models/data/figures/title.png +3 -0
  35. models/data/figures/training.jpeg +3 -0
  36. models/data/omni2-demo.mp4 +3 -0
  37. models/hub/.locks/models--hubertsiuzdak--snac_24khz/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40.lock +0 -0
  38. models/hub/.locks/models--hubertsiuzdak--snac_24khz/a9e7ef62bf7e1eb94d2713721029837aacab3b55.lock +0 -0
  39. models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 +3 -0
  40. models/hub/models--hubertsiuzdak--snac_24khz/blobs/a9e7ef62bf7e1eb94d2713721029837aacab3b55 +13 -0
  41. models/hub/models--hubertsiuzdak--snac_24khz/refs/main +1 -0
  42. models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/config.json +13 -0
  43. models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/pytorch_model.bin +3 -0
  44. models/hub/version.txt +1 -0
  45. models/lit_model.pth +3 -0
  46. models/model_config.yaml +43 -0
  47. models/small.pt +3 -0
  48. models/tokenizer.json +0 -0
  49. models/tokenizer_config.json +40 -0
  50. requirements.txt +20 -0
.gitattributes CHANGED
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audio_qa_out_cache.wav filter=lfs diff=lfs merge=lfs -text
37
+ data/figures/framework.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ data/figures/inputids.png filter=lfs diff=lfs merge=lfs -text
39
+ data/figures/samples.png filter=lfs diff=lfs merge=lfs -text
40
+ data/figures/title_new.png filter=lfs diff=lfs merge=lfs -text
41
+ data/figures/training.jpeg filter=lfs diff=lfs merge=lfs -text
42
+ data/omni2-demo.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ data/samples/output2.wav filter=lfs diff=lfs merge=lfs -text
44
+ data/samples/output5.wav filter=lfs diff=lfs merge=lfs -text
45
+ data/samples/vision_qa_audio.wav filter=lfs diff=lfs merge=lfs -text
46
+ models/data/figures/framework.jpeg filter=lfs diff=lfs merge=lfs -text
47
+ models/data/figures/inputids.png filter=lfs diff=lfs merge=lfs -text
48
+ models/data/figures/samples.png filter=lfs diff=lfs merge=lfs -text
49
+ models/data/figures/title.png filter=lfs diff=lfs merge=lfs -text
50
+ models/data/figures/training.jpeg filter=lfs diff=lfs merge=lfs -text
51
+ models/data/omni2-demo.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 filter=lfs diff=lfs merge=lfs -text
53
+ vision_qa_out_cache.wav filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ CUDA_HOME=/usr/local/cuda \
7
+ PATH=/usr/local/cuda/bin:$PATH \
8
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
9
+ NVIDIA_VISIBLE_DEVICES=all \
10
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility \
11
+ HF_HOME=/app/models
12
+
13
+ # Install system dependencies
14
+ RUN apt-get update && apt-get install -y --no-install-recommends \
15
+ python3 \
16
+ python3-pip \
17
+ python3-dev \
18
+ build-essential \
19
+ git \
20
+ ffmpeg \
21
+ libsndfile1 \
22
+ curl \
23
+ && rm -rf /var/lib/apt/lists/*
24
+
25
+ # Upgrade pip and install build tools
26
+ RUN python3 -m pip install --upgrade pip setuptools wheel uv
27
+
28
+ WORKDIR /app
29
+
30
+ COPY requirements.txt .
31
+
32
+ # Install other requirements
33
+ RUN python3 -m uv pip install --no-cache-dir -r requirements.txt --prerelease=allow
34
+
35
+ COPY . .
36
+
37
+ EXPOSE 8000
38
+
39
+ CMD ["python3", "server.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 gpt-omni
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
audio_qa_out_cache.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5be14b744cde16792c9bb0a66e4ce9b899ad39ce076996d8803cb41ce010bb9d
3
+ size 159788
data/figures/framework.jpeg ADDED

Git LFS Details

  • SHA256: bc668450030500a62ddbb7cf6ea170f0b53da7e3e5506d01a0dc6f2ec690fd1a
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
data/figures/inputids.png ADDED

Git LFS Details

  • SHA256: ad4cf663684c53f72952b13f52ea93fcbe19e287301b3decfcd917de9e23f312
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
data/figures/samples.png ADDED

Git LFS Details

  • SHA256: e63a8cbc2859304cb9c50b831366ac8804ad0326b6ae4897d08f8ab0e1eb63c6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
data/figures/title_new.png ADDED

Git LFS Details

  • SHA256: fd327145b6368a08a713164af9de7b1f9fc15a9077090586a1e65d915a82b538
  • Pointer size: 131 Bytes
  • Size of remote file: 355 kB
data/figures/training.jpeg ADDED

Git LFS Details

  • SHA256: fd49f75dbe5838a3e28f02c8f853dec34d0aad8573911d52bd827ab6dae8f9a1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
data/omni2-demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2098124af391dca9c48854f5686143c137cc069f08b5e457675b9ba744bd2f
3
+ size 11784395
data/samples/output1.wav ADDED
Binary file (62.2 kB). View file
 
data/samples/output2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b50c4df6f508a4367e5a49e90f974f8786c6d9ffb2599a8abcd25e693399735a
3
+ size 105176
data/samples/output3.wav ADDED
Binary file (70.4 kB). View file
 
data/samples/output4.wav ADDED
Binary file (67.6 kB). View file
 
data/samples/output5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33f58e7cc49a4e4fd4809d20cde2fb22855054cf61558be8ffef347fc35ce8f2
3
+ size 114732
data/samples/vision_qa_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18eba79742ad8074074a113e6df56410bdf66e34a645d619a4ad7b8171f6d7d7
3
+ size 150572
hotkey.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5CWzdquwP5vjnCphhkPMDo51G7MrWXdTJaEiKZYadbBoiiQ6
inference.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import glob
5
+ import time
6
+ from snac import SNAC
7
+ from litgpt import Tokenizer
8
+ from litgpt.utils import (
9
+ num_parameters,
10
+ )
11
+ from litgpt.generate.base import (
12
+ generate_AA,
13
+ generate_ASR,
14
+ generate_TA,
15
+ generate_TT,
16
+ generate_AT,
17
+ generate_TA_BATCH,
18
+ next_token_image_batch
19
+ )
20
+ import soundfile as sf
21
+ from litgpt.model import GPT, Config
22
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
23
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
24
+ from utils.snac_utils import get_snac, generate_audio_data
25
+ import whisper
26
+ from tqdm import tqdm
27
+ from huggingface_hub import snapshot_download
28
+
29
+
30
+ torch.set_printoptions(sci_mode=False)
31
+
32
+
33
+ # TODO
34
+ text_vocabsize = 151936
35
+ text_specialtokens = 64
36
+ audio_vocabsize = 4096
37
+ audio_specialtokens = 64
38
+
39
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
40
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
41
+
42
+ _eot = text_vocabsize
43
+ _pad_t = text_vocabsize + 1
44
+ _input_t = text_vocabsize + 2
45
+ _answer_t = text_vocabsize + 3
46
+ _asr = text_vocabsize + 4
47
+
48
+ _eoa = audio_vocabsize
49
+ _pad_a = audio_vocabsize + 1
50
+ _input_a = audio_vocabsize + 2
51
+ _answer_a = audio_vocabsize + 3
52
+ _split = audio_vocabsize + 4
53
+ _image = audio_vocabsize + 5
54
+ _eoimage = audio_vocabsize + 6
55
+
56
+
57
+ def get_input_ids_TA(text, text_tokenizer):
58
+ input_ids_item = [[] for _ in range(8)]
59
+ text_tokens = text_tokenizer.encode(text)
60
+ for i in range(7):
61
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
62
+ layershift(_answer_a, i)
63
+ ]
64
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
65
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
66
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
67
+ return input_ids_item
68
+
69
+
70
+ def get_input_ids_TT(text, text_tokenizer):
71
+ input_ids_item = [[] for i in range(8)]
72
+ text_tokens = text_tokenizer.encode(text).tolist()
73
+
74
+ for i in range(7):
75
+ input_ids_item[i] = torch.tensor(
76
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
77
+ ).unsqueeze(0)
78
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
79
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
80
+
81
+ return input_ids_item
82
+
83
+
84
+ def get_input_ids_whisper(
85
+ mel, leng, whispermodel, device,
86
+ special_token_a=_answer_a, special_token_t=_answer_t,
87
+ ):
88
+
89
+ with torch.no_grad():
90
+ mel = mel.unsqueeze(0).to(device)
91
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
92
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
93
+
94
+ T = audio_feature.size(0)
95
+ input_ids = []
96
+ for i in range(7):
97
+ input_ids_item = []
98
+ input_ids_item.append(layershift(_input_a, i))
99
+ input_ids_item += [layershift(_pad_a, i)] * T
100
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
101
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
102
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
103
+ input_ids.append(input_id_T.unsqueeze(0))
104
+ return audio_feature.unsqueeze(0), input_ids
105
+
106
+
107
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
108
+ with torch.no_grad():
109
+ mel = mel.unsqueeze(0).to(device)
110
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
111
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
112
+ T = audio_feature.size(0)
113
+ input_ids_AA = []
114
+ for i in range(7):
115
+ input_ids_item = []
116
+ input_ids_item.append(layershift(_input_a, i))
117
+ input_ids_item += [layershift(_pad_a, i)] * T
118
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
119
+ input_ids_AA.append(torch.tensor(input_ids_item))
120
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
121
+ input_ids_AA.append(input_id_T)
122
+
123
+ input_ids_AT = []
124
+ for i in range(7):
125
+ input_ids_item = []
126
+ input_ids_item.append(layershift(_input_a, i))
127
+ input_ids_item += [layershift(_pad_a, i)] * T
128
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
129
+ input_ids_AT.append(torch.tensor(input_ids_item))
130
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
131
+ input_ids_AT.append(input_id_T)
132
+
133
+ input_ids = [input_ids_AA, input_ids_AT]
134
+ stacked_inputids = [[] for _ in range(8)]
135
+ for i in range(2):
136
+ for j in range(8):
137
+ stacked_inputids[j].append(input_ids[i][j])
138
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
139
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
140
+
141
+
142
+ def load_audio(path):
143
+ audio = whisper.load_audio(path)
144
+ duration_ms = (len(audio) / 16000) * 1000
145
+ audio = whisper.pad_or_trim(audio)
146
+ mel = whisper.log_mel_spectrogram(audio)
147
+ return mel, int(duration_ms / 20) + 1
148
+
149
+
150
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
151
+ snacmodel, out_dir=None):
152
+ with fabric.init_tensor():
153
+ model.set_kv_cache(batch_size=2)
154
+ tokenlist = generate_TA_BATCH(
155
+ model,
156
+ audio_feature,
157
+ input_ids,
158
+ [leng, leng],
159
+ ["A1A2", "A1T2"],
160
+ max_returned_tokens=2048,
161
+ temperature=0.9,
162
+ top_k=1,
163
+ eos_id_a=_eoa,
164
+ eos_id_t=_eot,
165
+ pad_id_t=_pad_t,
166
+ shift=padded_text_vocabsize,
167
+ include_prompt=True,
168
+ generate_text=True,
169
+ )
170
+ text_tokenlist = tokenlist[-1]
171
+ if text_vocabsize in text_tokenlist:
172
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
173
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
174
+
175
+ audio_tokenlist = tokenlist[:-1]
176
+ audiolist = reconscruct_snac(audio_tokenlist)
177
+ audio = reconstruct_tensors(audiolist)
178
+ if out_dir is None:
179
+ out_dir = "./output/default/A1-A2-batch"
180
+ else:
181
+ out_dir = out_dir + "/A1-A2-batch"
182
+ if not os.path.exists(out_dir):
183
+ os.makedirs(out_dir)
184
+ with torch.inference_mode():
185
+ audio_hat = snacmodel.decode(audio)
186
+ sf.write(
187
+ f"{out_dir}/{step:02d}.wav",
188
+ audio_hat.squeeze().cpu().numpy(),
189
+ 24000,
190
+ )
191
+ model.clear_kv_cache()
192
+ return text
193
+
194
+
195
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
196
+ with fabric.init_tensor():
197
+ model.set_kv_cache(batch_size=1)
198
+ tokenlist = generate_AT(
199
+ model,
200
+ audio_feature,
201
+ input_ids,
202
+ [leng],
203
+ ["AT"],
204
+ max_returned_tokens=2048,
205
+ temperature=0.9,
206
+ top_k=1,
207
+ eos_id_a=_eoa,
208
+ eos_id_t=_eot,
209
+ pad_id_t=_pad_t,
210
+ shift=padded_text_vocabsize,
211
+ include_prompt=True,
212
+ generate_text=True,
213
+ )
214
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
215
+
216
+
217
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
218
+ snacmodel, out_dir=None):
219
+ with fabric.init_tensor():
220
+ model.set_kv_cache(batch_size=1)
221
+ tokenlist = generate_AA(
222
+ model,
223
+ audio_feature,
224
+ input_ids,
225
+ [leng],
226
+ ["A1T2"],
227
+ max_returned_tokens=2048,
228
+ temperature=0.9,
229
+ top_k=1,
230
+ eos_id_a=_eoa,
231
+ eos_id_t=_eot,
232
+ pad_id_t=_pad_t,
233
+ shift=padded_text_vocabsize,
234
+ include_prompt=True,
235
+ generate_text=True,
236
+ )
237
+ audiolist = reconscruct_snac(tokenlist)
238
+ tokenlist = tokenlist[-1]
239
+ if text_vocabsize in tokenlist:
240
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
241
+ if out_dir is None:
242
+ out_dir = "./output/default/A1-A2"
243
+ else:
244
+ out_dir = out_dir + "/A1-A2"
245
+ if not os.path.exists(out_dir):
246
+ os.makedirs(out_dir)
247
+
248
+ audio = reconstruct_tensors(audiolist)
249
+ with torch.inference_mode():
250
+ audio_hat = snacmodel.decode(audio)
251
+ sf.write(
252
+ f"{out_dir}/{step:02d}.wav",
253
+ audio_hat.squeeze().cpu().numpy(),
254
+ 24000,
255
+ )
256
+ model.clear_kv_cache()
257
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
258
+
259
+
260
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
261
+ with fabric.init_tensor():
262
+ model.set_kv_cache(batch_size=1)
263
+ tokenlist = generate_ASR(
264
+ model,
265
+ audio_feature,
266
+ input_ids,
267
+ [leng],
268
+ ["A1T1"],
269
+ max_returned_tokens=2048,
270
+ temperature=0.9,
271
+ top_k=1,
272
+ eos_id_a=_eoa,
273
+ eos_id_t=_eot,
274
+ pad_id_t=_pad_t,
275
+ shift=padded_text_vocabsize,
276
+ include_prompt=True,
277
+ generate_text=True,
278
+ )
279
+ model.clear_kv_cache()
280
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
281
+
282
+
283
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
284
+ snacmodel, out_dir=None):
285
+ with fabric.init_tensor():
286
+ model.set_kv_cache(batch_size=1)
287
+ tokenlist = generate_TA(
288
+ model,
289
+ None,
290
+ input_ids,
291
+ None,
292
+ ["T1A2"],
293
+ max_returned_tokens=2048,
294
+ temperature=0.9,
295
+ top_k=1,
296
+ eos_id_a=_eoa,
297
+ eos_id_t=_eot,
298
+ pad_id_t=_pad_t,
299
+ shift=padded_text_vocabsize,
300
+ include_prompt=True,
301
+ generate_text=True,
302
+ )
303
+
304
+ audiolist = reconscruct_snac(tokenlist)
305
+ tokenlist = tokenlist[-1]
306
+
307
+ if text_vocabsize in tokenlist:
308
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
309
+ audio = reconstruct_tensors(audiolist)
310
+ if out_dir is None:
311
+ out_dir = "./output/default/T1-A2"
312
+ else:
313
+ out_dir = out_dir + "/T1-A2"
314
+ if not os.path.exists(out_dir):
315
+ os.makedirs(out_dir)
316
+
317
+ with torch.inference_mode():
318
+ audio_hat = snacmodel.decode(audio)
319
+ sf.write(
320
+ f"{out_dir}/{step:02d}.wav",
321
+ audio_hat.squeeze().cpu().numpy(),
322
+ 24000,
323
+ )
324
+ model.clear_kv_cache()
325
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
326
+
327
+
328
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
329
+
330
+ with fabric.init_tensor():
331
+ model.set_kv_cache(batch_size=1)
332
+ tokenlist = generate_TT(
333
+ model,
334
+ None,
335
+ input_ids,
336
+ None,
337
+ ["T1T2"],
338
+ max_returned_tokens=2048,
339
+ temperature=0.9,
340
+ top_k=1,
341
+ eos_id_a=_eoa,
342
+ eos_id_t=_eot,
343
+ pad_id_t=_pad_t,
344
+ shift=padded_text_vocabsize,
345
+ include_prompt=True,
346
+ generate_text=True,
347
+ )
348
+ model.clear_kv_cache()
349
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
350
+
351
+
352
+ def load_model(ckpt_dir, device):
353
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
354
+ whisper_model_path = ckpt_dir + "/small.pt"
355
+ if not os.path.exists(whisper_model_path):
356
+ whisper_model_path = "small"
357
+ whispermodel = whisper.load_model(whisper_model_path).to(device)
358
+ text_tokenizer = Tokenizer(ckpt_dir)
359
+ fabric = L.Fabric(devices=1, strategy="auto")
360
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
361
+ config.post_adapter = False
362
+
363
+ with fabric.init_module(empty_init=False):
364
+ model = GPT(config)
365
+
366
+ model = fabric.setup(model)
367
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
368
+ model.load_state_dict(state_dict, strict=True)
369
+ model.to(device).eval()
370
+
371
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
372
+
373
+
374
+ def download_model(ckpt_dir):
375
+ repo_id = "gpt-omni/mini-omni2"
376
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
377
+
378
+
379
+ def get_text_stream(list_output, index, text_tokenizer):
380
+ text_tokens = list_output[-1][index:]
381
+ index += len(text_tokens)
382
+ is_text_end = False
383
+ if text_vocabsize in text_tokens:
384
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
385
+ is_text_end = True
386
+ if len(text_tokens) == 0:
387
+ return "", index, is_text_end
388
+ res_text = text_tokenizer.decode(torch.tensor(text_tokens))
389
+ return res_text, index, is_text_end
390
+
391
+
392
+ class OmniInference:
393
+
394
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
395
+ self.device = device
396
+ if not os.path.exists(ckpt_dir):
397
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
398
+ download_model(ckpt_dir)
399
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
400
+
401
+ def warm_up(self, sample='./data/samples/output1.wav'):
402
+ for _ in self.run_AT_batch_stream(sample):
403
+ pass
404
+
405
+ @torch.inference_mode()
406
+ def run_AT_batch_stream(self,
407
+ audio_path,
408
+ stream_stride=4,
409
+ max_returned_tokens=2048,
410
+ temperature=0.9,
411
+ top_k=1,
412
+ top_p=1.0,
413
+ eos_id_a=_eoa,
414
+ eos_id_t=_eot,
415
+ save_path=None,
416
+ sample_rate=24000,
417
+ ):
418
+
419
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
420
+ model = self.model
421
+
422
+ with self.fabric.init_tensor():
423
+ model.set_kv_cache(batch_size=2,device=self.device)
424
+
425
+ mel, leng = load_audio(audio_path)
426
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
427
+ T = input_ids[0].size(1)
428
+ device = input_ids[0].device
429
+
430
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
431
+
432
+ if model.max_seq_length < max_returned_tokens - 1:
433
+ raise NotImplementedError(
434
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
435
+ )
436
+
437
+ input_pos = torch.tensor([T], device=device)
438
+ list_output = [[] for i in range(8)]
439
+ tokens_A, token_T = next_token_image_batch(
440
+ model,
441
+ audio_feature.to(torch.float32).to(model.device),
442
+ None,
443
+ input_ids,
444
+ [T - 3, T - 3],
445
+ ["A1T2", "A1T2"],
446
+ input_pos=torch.arange(0, T, device=device),
447
+ temperature=temperature,
448
+ top_k=top_k,
449
+ top_p=top_p,
450
+ )
451
+
452
+ for i in range(7):
453
+ list_output[i].append(tokens_A[i].tolist()[0])
454
+ list_output[7].append(token_T.tolist()[0])
455
+
456
+ model_input_ids = [[] for i in range(8)]
457
+ for i in range(7):
458
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
459
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
460
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
461
+ model_input_ids[i] = torch.stack(model_input_ids[i])
462
+
463
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
464
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
465
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
466
+
467
+ text_end = False
468
+ index = 1
469
+ nums_generate = stream_stride
470
+ begin_generate = False
471
+ current_index = 0
472
+
473
+ text_index = 0
474
+ is_text_end = False
475
+
476
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
477
+ tokens_A, token_T = next_token_image_batch(
478
+ model,
479
+ None,
480
+ None,
481
+ model_input_ids,
482
+ None,
483
+ None,
484
+ input_pos=input_pos,
485
+ temperature=temperature,
486
+ top_k=top_k,
487
+ top_p=top_p,
488
+ )
489
+
490
+ if text_end:
491
+ token_T = torch.tensor([_pad_t], device=device)
492
+
493
+ if tokens_A[-1] == eos_id_a:
494
+ break
495
+
496
+ if token_T == eos_id_t:
497
+ text_end = True
498
+
499
+ for i in range(7):
500
+ list_output[i].append(tokens_A[i].tolist()[0])
501
+ list_output[7].append(token_T.tolist()[0])
502
+
503
+ model_input_ids = [[] for i in range(8)]
504
+ for i in range(7):
505
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
506
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
507
+ model_input_ids[i].append(
508
+ torch.tensor([layershift(4097, i)], device=device)
509
+ )
510
+ model_input_ids[i] = torch.stack(model_input_ids[i])
511
+
512
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
513
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
514
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
515
+
516
+ if index == 7:
517
+ begin_generate = True
518
+
519
+ if begin_generate:
520
+ current_index += 1
521
+ if current_index == nums_generate:
522
+ current_index = 0
523
+ snac = get_snac(list_output, index, nums_generate)
524
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
525
+ if is_text_end:
526
+ text_stream = ""
527
+ else:
528
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
529
+
530
+ yield (audio_stream, text_stream)
531
+
532
+ input_pos = input_pos.add_(1)
533
+ index += 1
534
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
535
+ print(f"text output: {text}")
536
+
537
+ if save_path is not None:
538
+ audiolist = reconscruct_snac(list_output)
539
+ audio = reconstruct_tensors(audiolist)
540
+ with torch.inference_mode():
541
+ audio_hat = self.snacmodel.decode(audio)
542
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), sample_rate)
543
+
544
+ model.clear_kv_cache()
545
+ return list_output
546
+
547
+
548
+ def test_infer():
549
+ device = "cuda:0"
550
+ out_dir = f"./output/{get_time_str()}"
551
+ ckpt_dir = f"./checkpoint"
552
+ if not os.path.exists(ckpt_dir):
553
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
554
+ download_model(ckpt_dir)
555
+
556
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
557
+
558
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
559
+
560
+ # prepare test data
561
+ # TODO
562
+ test_audio_list = sorted(glob.glob('./data/samples/output*.wav'))
563
+ test_audio_transcripts = [
564
+ "What is your name?",
565
+ "what are your hobbies?",
566
+ "Do you like beijing",
567
+ "How are you feeling today?",
568
+ "what is the weather like today?",
569
+ ]
570
+ test_text_list = [
571
+ "What is your name?",
572
+ "How are you feeling today?",
573
+ "Can you describe your surroundings?",
574
+ "What did you do yesterday?",
575
+ "What is your favorite book and why?",
576
+ "How do you make a cup of tea?",
577
+ "What is the weather like today?",
578
+ "Can you explain the concept of time?",
579
+ "Can you tell me a joke?",
580
+ ]
581
+
582
+ # LOAD MODEL
583
+ with torch.no_grad():
584
+ if "A1A2" in task:
585
+ print("===============================================================")
586
+ print(" testing A1A2")
587
+ print("===============================================================")
588
+ step = 0
589
+ for path in test_audio_list:
590
+ try:
591
+ mel, leng = load_audio(path)
592
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
593
+ text = A1_A2(
594
+ fabric,
595
+ audio_feature,
596
+ input_ids,
597
+ leng,
598
+ model,
599
+ text_tokenizer,
600
+ step,
601
+ snacmodel,
602
+ out_dir=out_dir,
603
+ )
604
+ print(f"input: {test_audio_transcripts[step]}")
605
+ print(f"output: {text}")
606
+ step += 1
607
+ print(
608
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
609
+ )
610
+ except:
611
+ print(f"[error] failed to process {path}")
612
+ print("===============================================================")
613
+
614
+ if 'asr' in task:
615
+ print("===============================================================")
616
+ print(" testing asr")
617
+ print("===============================================================")
618
+
619
+ index = 0
620
+ step = 0
621
+ for path in test_audio_list:
622
+ mel, leng = load_audio(path)
623
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
624
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
625
+ print(f"audio_path: {path}")
626
+ print(f"audio transcript: {test_audio_transcripts[index]}")
627
+ print(f"asr output: {output}")
628
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
629
+ index += 1
630
+
631
+ if "T1A2" in task:
632
+ step = 0
633
+ print("\n")
634
+ print("===============================================================")
635
+ print(" testing T1A2")
636
+ print("===============================================================")
637
+ for text in test_text_list:
638
+ input_ids = get_input_ids_TA(text, text_tokenizer)
639
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
640
+ snacmodel, out_dir=out_dir)
641
+ print(f"input: {text}")
642
+ print(f"output: {text_output}")
643
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
644
+ step += 1
645
+ print("===============================================================")
646
+
647
+ if "T1T2" in task:
648
+ step = 0
649
+ print("\n")
650
+ print("===============================================================")
651
+ print(" testing T1T2")
652
+ print("===============================================================")
653
+
654
+ for text in test_text_list:
655
+ input_ids = get_input_ids_TT(text, text_tokenizer)
656
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
657
+ print(f" Input: {text}")
658
+ print(f"Output: {text_output}")
659
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
660
+ print("===============================================================")
661
+
662
+ if "AT" in task:
663
+ print("===============================================================")
664
+ print(" testing A1T2")
665
+ print("===============================================================")
666
+ step = 0
667
+ for path in test_audio_list:
668
+ mel, leng = load_audio(path)
669
+ audio_feature, input_ids = get_input_ids_whisper(
670
+ mel, leng, whispermodel, device,
671
+ special_token_a=_pad_a, special_token_t=_answer_t
672
+ )
673
+ text = A1_T2(
674
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
675
+ )
676
+ print(f"input: {test_audio_transcripts[step]}")
677
+ print(f"output: {text}")
678
+ step += 1
679
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
680
+ print("===============================================================")
681
+
682
+ if "AA-BATCH" in task:
683
+ print("===============================================================")
684
+ print(" testing A1A2-BATCH")
685
+ print("===============================================================")
686
+ step = 0
687
+ for path in test_audio_list:
688
+ mel, leng = load_audio(path)
689
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
690
+ text = A1_A2_batch(
691
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
692
+ snacmodel, out_dir=out_dir
693
+ )
694
+ print(f"input: {test_audio_transcripts[step]}")
695
+ print(f"output: {text}")
696
+ step += 1
697
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
698
+ print("===============================================================")
699
+
700
+ print("*********************** test end *****************************")
701
+
702
+
703
+
704
+ if __name__ == "__main__":
705
+ test_infer()
inference_vision.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from litgpt.generate.base import next_token_image_batch
4
+ import soundfile as sf
5
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
6
+ from utils.snac_utils import get_snac, generate_audio_data
7
+ import clip
8
+ import inference
9
+ from tqdm import tqdm
10
+ from inference import OmniInference, load_model, load_audio, download_model
11
+ from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
12
+ from PIL import Image
13
+
14
+
15
+ torch.set_printoptions(sci_mode=False)
16
+
17
+ _image = inference._image
18
+ _eoimage = inference._eoimage
19
+ _pad_t = inference._pad_t
20
+ _input_t = inference._input_t
21
+ _answer_t = inference._answer_t
22
+ _eot = inference._eot
23
+ _eoa = inference._eoa
24
+ _pad_a = inference._pad_a
25
+ _input_a = inference._input_a
26
+ _answer_a = inference._answer_a
27
+
28
+
29
+ def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
30
+
31
+ with torch.no_grad():
32
+ mel = mel.unsqueeze(0).to(device)
33
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
34
+
35
+ audio_len = audio_feature.size(0)
36
+
37
+ input_ids = []
38
+ input_ids_item = [[] for i in range(8)]
39
+ for i in range(7):
40
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
41
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
42
+ input_ids_item[i] += [layershift(_answer_a,i)]
43
+
44
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
45
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
46
+
47
+ input_ids.append(input_ids_item)
48
+
49
+ input_ids_item = [[] for i in range(8)]
50
+ for i in range(7):
51
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
52
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
53
+
54
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
55
+
56
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
57
+ input_ids.append(input_ids_item)
58
+
59
+ stacked_inputids = [[] for _ in range(8)]
60
+ for i in range(2):
61
+ for j in range(8):
62
+ stacked_inputids[j].append(input_ids[i][j])
63
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
64
+
65
+ return torch.stack([audio_feature,audio_feature]), stacked_inputids
66
+
67
+
68
+ def load_clip_model(ckpt_dir, device):
69
+ clip_model_path = ckpt_dir + "/ViT-B-32.pt"
70
+ if not os.path.exists(clip_model_path):
71
+ clip_model_path = "ViT-B/32"
72
+ clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
73
+ return clipmodel, clippreprocess
74
+
75
+
76
+ class OmniVisionInference(OmniInference):
77
+
78
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
79
+ self.device = device
80
+ if not os.path.exists(ckpt_dir):
81
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
82
+ download_model(ckpt_dir)
83
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
84
+ self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
85
+
86
+ def warm_up(self,
87
+ audio_sample='./data/samples/vision_qa_audio.wav',
88
+ image_sample='./data/samples/vision_qa_image.jpg'
89
+ ):
90
+ for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
91
+ save_path="./data/samples/vision_qa_output.wav",
92
+ warm_up=True):
93
+ pass
94
+
95
+ @torch.inference_mode()
96
+ def run_vision_AA_batch_stream(self, audio_path, image_path,
97
+ stream_stride=4,
98
+ max_returned_tokens=2048,
99
+ temperature=0.9,
100
+ top_k=1,
101
+ top_p=1.0,
102
+ eos_id_a=_eoa,
103
+ eos_id_t=_eot,
104
+ pad_id=_pad_t,
105
+ save_path=None,
106
+ warm_up=False
107
+ ):
108
+ with self.fabric.init_tensor():
109
+ self.model.set_kv_cache(batch_size=2)
110
+
111
+ model = self.model
112
+
113
+ mel, leng = load_audio(audio_path)
114
+ img = Image.open(image_path)
115
+
116
+ audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
117
+ ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
118
+ ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
119
+
120
+ ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
121
+ leng = [leng,leng]
122
+ task = ['ImageQA_A','ImageQA_AT']
123
+
124
+ T = input_ids[0].size(1)
125
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
126
+
127
+ if model.max_seq_length < max_returned_tokens - 1:
128
+ raise NotImplementedError(
129
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
130
+ )
131
+
132
+ list_output = [[] for i in range(8)]
133
+
134
+ tokens_A , token_T = next_token_image_batch(
135
+ model,
136
+ audio_feature.to(torch.float32).to(self.device),
137
+ ima_feature.to(torch.float32).to(self.device) ,
138
+ input_ids ,
139
+ whisper_lens = leng ,
140
+ task = task,
141
+ input_pos = torch.arange(0, T, device=self.device),
142
+ temperature=temperature,
143
+ top_k=top_k,
144
+ top_p=top_p
145
+ )
146
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
147
+ list_output[7].append(token_T.tolist()[0])
148
+
149
+ text_end = False
150
+ index = 1
151
+ nums_generate = stream_stride
152
+ begin_generate = False
153
+ current_index = 0
154
+ input_pos = torch.tensor([T], device=self.device)
155
+
156
+ model_input_ids = [[] for i in range(8)]
157
+ for i in range(7):
158
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
159
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
160
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
161
+ model_input_ids[i] = torch.stack(model_input_ids[i])
162
+
163
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
164
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
165
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
166
+
167
+ text_index = 0
168
+ is_text_end = False
169
+
170
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
171
+
172
+ tokens_A , token_T = next_token_image_batch(model, None , None ,
173
+ input_ids = model_input_ids,
174
+ whisper_lens= None,
175
+ task = None,
176
+ input_pos = input_pos,
177
+ temperature=temperature,
178
+ top_k=top_k,
179
+ top_p=top_p)
180
+
181
+ if text_end:
182
+ token_T = torch.tensor([_pad_t], device=self.device)
183
+
184
+ if tokens_A[-1] == eos_id_a:
185
+ break
186
+ if token_T == eos_id_t:
187
+ text_end = True
188
+
189
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
190
+ list_output[7].append(token_T.tolist()[0])
191
+
192
+
193
+ if index == 7:
194
+ begin_generate = True
195
+
196
+ if begin_generate:
197
+ current_index += 1
198
+ if current_index == nums_generate:
199
+ current_index = 0
200
+ snac = get_snac(list_output,index,nums_generate)
201
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
202
+ if is_text_end:
203
+ text_stream = ""
204
+ else:
205
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
206
+
207
+ yield (audio_stream, text_stream)
208
+
209
+ if warm_up:
210
+ break
211
+
212
+ input_pos = input_pos.add_(1)
213
+ model_input_ids = [[] for i in range(8)]
214
+ for i in range(7):
215
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
216
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
217
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
218
+ model_input_ids[i] = torch.stack(model_input_ids[i])
219
+
220
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
221
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
222
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
223
+
224
+ index += 1
225
+
226
+ text_tokens = list_output[-1]
227
+ if text_vocabsize in text_tokens:
228
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
229
+ res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
230
+ print(f"text output: {res_text}")
231
+
232
+ if save_path is not None:
233
+ audiolist = reconscruct_snac(list_output)
234
+ audio = reconstruct_tensors(audiolist)
235
+ with torch.inference_mode():
236
+ audio_hat = self.snacmodel.decode(audio)
237
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
238
+
239
+ model.clear_kv_cache()
240
+
241
+
242
+ def test_vision_infer():
243
+ client = OmniVisionInference()
244
+ client.warm_up()
245
+ input_audio_path = './data/samples/vision_qa_audio.wav'
246
+ input_image_path = './data/samples/vision_qa_image.jpg'
247
+
248
+ res_text = ""
249
+ for audio_stream, text_stream in client.run_vision_AA_batch_stream(
250
+ input_audio_path,
251
+ input_image_path,
252
+ save_path="./vision_qa_output.wav"
253
+ ):
254
+ res_text += text_stream
255
+ print(f"text_output: {res_text}")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ test_vision_infer()
litgpt/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import logging
4
+ import re
5
+ from litgpt.model import GPT # needs to be imported before config
6
+ from litgpt.config import Config
7
+ from litgpt.tokenizer import Tokenizer
8
+
9
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
10
+ pattern = re.compile(".*Profiler function .* will be ignored")
11
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(
12
+ lambda record: not pattern.search(record.getMessage())
13
+ )
14
+
15
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
16
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
17
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
18
+
19
+ __all__ = ["GPT", "Config", "Tokenizer"]
litgpt/config.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Optional, Type, Union
7
+
8
+ import torch
9
+ import yaml
10
+ from typing_extensions import Self
11
+
12
+ import litgpt.model
13
+ from litgpt.utils import find_multiple
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ name: str = ""
19
+ hf_config: dict = field(default_factory=dict)
20
+ scale_embeddings: bool = False
21
+ block_size: int = 4096
22
+ vocab_size: int = 50254
23
+ padding_multiple: int = 512
24
+ padded_vocab_size: Optional[int] = None
25
+ n_layer: int = 16
26
+ n_head: int = 32
27
+ head_size: Optional[int] = None
28
+ n_embd: int = 4096
29
+ rotary_percentage: float = 0.25
30
+ parallel_residual: bool = True
31
+ bias: bool = True
32
+ lm_head_bias: bool = False
33
+ # to use multi-head attention (MHA), set this to `n_head` (default)
34
+ # to use multi-query attention (MQA), set this to 1
35
+ # to use grouped-query attention (GQA), set this to a value in between
36
+ # Example with `n_head=4`
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ │ │ │
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
+ # MHA GQA MQA
50
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
+ #
52
+ # credit https://arxiv.org/pdf/2305.13245.pdf
53
+ n_query_groups: Optional[int] = None
54
+ shared_attention_norm: bool = False
55
+ norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
+ norm_eps: float = 1e-5
57
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
+ "GptNeoxMLP"
59
+ )
60
+ gelu_approximate: str = "none"
61
+ intermediate_size: Optional[int] = None
62
+ rope_condense_ratio: int = 1
63
+ rope_base: int = 10000
64
+ n_expert: int = 0
65
+ n_expert_per_token: int = 0
66
+
67
+ add_qkv_bias: Optional[bool] = None
68
+ prompt_vocab_size: Optional[int] = None
69
+ attn_dropout: float = 0.0
70
+ pos_type: str = "rope"
71
+ force_align: bool = False
72
+ use_pretrain_phoneme_emb: bool = False
73
+ tie_word_embeddings: bool = False
74
+
75
+ # setting for mini-omni
76
+ text_vocab_size:int = 152000
77
+ cat_audio_vocab_size: int = 29120
78
+ audio_vocab_size: int = 4160
79
+ whisper_adapter_dim: int = 768
80
+ vision_adapter_dim: int = 512
81
+
82
+ post_adapter: bool = False
83
+ post_adapter_layers: int = 6
84
+ asr_adapter: str = "llamamlp"
85
+
86
+ def __post_init__(self):
87
+ if not self.name:
88
+ self.name = self.hf_config.get("name", self.name)
89
+
90
+ if self.head_size is None:
91
+ assert self.n_embd % self.n_head == 0
92
+ self.head_size = self.n_embd // self.n_head
93
+
94
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
95
+ if self.padded_vocab_size is None:
96
+ self.padded_vocab_size = find_multiple(
97
+ self.vocab_size, self.padding_multiple
98
+ )
99
+ else:
100
+ # vocab size shouldn't be larger than padded vocab size
101
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
102
+
103
+ # compute the number of query groups
104
+ if self.n_query_groups is not None:
105
+ assert self.n_head % self.n_query_groups == 0
106
+ else:
107
+ self.n_query_groups = self.n_head
108
+
109
+ # compute the intermediate size for MLP if not set
110
+ if self.intermediate_size is None:
111
+ if self.mlp_class_name == "LLaMAMLP":
112
+ raise ValueError(
113
+ f"The config {self.name!r}, needs to set the `intermediate_size`"
114
+ )
115
+ self.intermediate_size = 4 * self.n_embd
116
+
117
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
118
+
119
+ if self.add_qkv_bias is None:
120
+ self.add_qkv_bias = self.bias
121
+
122
+ @classmethod
123
+ def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
124
+ if name not in name_to_config:
125
+ # search through all `config['hf_config']['name']`
126
+ try:
127
+ conf_dict = next(
128
+ config
129
+ for config in configs
130
+ if name == config["hf_config"]["name"]
131
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
132
+ == name
133
+ )
134
+ except StopIteration:
135
+ raise ValueError(f"{name!r} is not a supported config name")
136
+ else:
137
+ conf_dict = name_to_config[name]
138
+
139
+ conf_dict = conf_dict.copy()
140
+ conf_dict.update(kwargs)
141
+ return cls(**conf_dict)
142
+
143
+ @classmethod
144
+ def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
145
+ with open(path, encoding="utf-8") as fp:
146
+ file_kwargs = yaml.safe_load(fp)
147
+ if file_kwargs is None:
148
+ raise ValueError(f"{path} is empty which is likely unexpected.")
149
+ file_kwargs.update(kwargs)
150
+ return cls(**file_kwargs)
151
+
152
+ @classmethod
153
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
154
+ """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
155
+ if (config_path := path / "model_config.yaml").is_file():
156
+ return cls.from_file(config_path, **kwargs)
157
+ if (model_name := path.name) in name_to_config:
158
+ return cls.from_name(model_name, **kwargs)
159
+ raise FileNotFoundError(
160
+ f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
161
+ )
162
+
163
+ @property
164
+ def mlp_class(self) -> Type:
165
+ # `self.mlp_class_name` cannot be the type to keep the config serializable
166
+ return getattr(litgpt.model, self.mlp_class_name)
167
+
168
+ @property
169
+ def norm_class(self) -> Type:
170
+ # `self.norm_class_name` cannot be the type to keep the config serializable
171
+ if self.norm_class_name == "RMSNorm":
172
+ from functools import partial
173
+
174
+ from litgpt.model import RMSNorm
175
+
176
+ return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
177
+ return getattr(torch.nn, self.norm_class_name)
178
+
179
+
180
+ configs = []
181
+ name_to_config = {config["name"]: config for config in configs}
litgpt/generate/__init__.py ADDED
File without changes
litgpt/generate/base.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ import torch
6
+ # import torch._dynamo.config
7
+ # import torch._inductor.config
8
+
9
+ from litgpt.model import GPT
10
+ from utils.snac_utils import layershift, snac_config
11
+ from tqdm import tqdm
12
+
13
+
14
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
15
+ if torch._dynamo.is_compiling():
16
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
17
+ distribution = torch.empty_like(probs).exponential_(1)
18
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
19
+ return torch.multinomial(probs, num_samples=1)
20
+
21
+
22
+ def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
24
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
25
+ # Example:
26
+ # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
27
+ # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
28
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
29
+ # Keep at least 1 token always to prevent the case where no token is selected
30
+ # In this case the most probable one is always kept
31
+ sorted_indices_to_remove[-1:] = 0
32
+ indices_to_remove = sorted_indices_to_remove.scatter(
33
+ 0, sorted_indices, sorted_indices_to_remove
34
+ )
35
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
36
+ return logits
37
+
38
+
39
+ def sample(
40
+ logits: torch.Tensor,
41
+ temperature: float = 1.0,
42
+ top_k: Optional[int] = None,
43
+ top_p: float = 1.0,
44
+ ) -> torch.Tensor:
45
+ if top_p < 0.0 or top_p > 1.0:
46
+ raise ValueError(f"top_p must be in [0, 1], got {top_p}")
47
+ logits = logits[0, -1]
48
+ # optionally crop the logits to only the top k options
49
+ if top_k is not None:
50
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
52
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
53
+ # optionally scale the logits and sample from a probability distribution
54
+ if temperature > 0.0 or top_p > 0.0:
55
+ if temperature > 0.0:
56
+ logits = logits / temperature
57
+ # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
58
+ if top_p < 1.0:
59
+ logits = sample_top_p(logits, top_p)
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ return multinomial_num_samples_1(probs)
62
+ return torch.argmax(logits, dim=-1, keepdim=True)
63
+
64
+
65
+ def next_token(
66
+ model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
67
+ ) -> torch.Tensor:
68
+ input_pos = input_pos.to(model.device)
69
+ logits_a, logit_t = model(None, x, None, input_pos)
70
+
71
+ next_audio_tokens = []
72
+ for logit_a in logits_a:
73
+ next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
74
+ next_audio_tokens.append(next_a)
75
+ next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
76
+ return next_audio_tokens, next_t
77
+
78
+
79
+ def next_token_asr(
80
+ model: GPT,
81
+ input_pos: torch.Tensor,
82
+ audio_features: torch.tensor,
83
+ lens: int,
84
+ input_ids: list,
85
+ **kwargs: Any,
86
+ ) -> torch.Tensor:
87
+ input_pos = input_pos.to(model.device)
88
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
89
+ logits_a, logit_t = model(audio_features, input_ids, None, input_pos, whisper_lens=lens)
90
+
91
+ next_audio_tokens = []
92
+ for logit_a in logits_a:
93
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
94
+ next_audio_tokens.append(next_a)
95
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
96
+ return next_audio_tokens, next_t
97
+
98
+
99
+ def next_token_A1T2(
100
+ model: GPT,
101
+ audio_features: torch.tensor,
102
+ input_ids: list,
103
+ whisper_lens: int,
104
+ task: list,
105
+ input_pos: torch.Tensor,
106
+ **kwargs: Any,
107
+ ) -> torch.Tensor:
108
+ input_pos = input_pos.to(model.device)
109
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
110
+ logits_a, logit_t = model(
111
+ audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
112
+ )
113
+
114
+ next_audio_tokens = []
115
+ for logit_a in logits_a:
116
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
117
+ next_audio_tokens.append(next_a)
118
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
119
+ return next_audio_tokens, next_t
120
+
121
+
122
+ def next_token_A1T1(
123
+ model: GPT,
124
+ audio_features: torch.tensor,
125
+ input_ids: list,
126
+ whisper_lens: int,
127
+ task: list,
128
+ input_pos: torch.Tensor,
129
+ **kwargs: Any,
130
+ ) -> torch.Tensor:
131
+ input_pos = input_pos.to(model.device)
132
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
133
+ logits_a, logit_t = model(
134
+ audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
135
+ )
136
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
137
+ return next_t
138
+
139
+
140
+ def next_token_image_batch(model: GPT,
141
+ audio_features: torch.tensor,
142
+ clip_features: torch.tensor,
143
+ input_ids: list,
144
+ whisper_lens: int,
145
+ task: list,
146
+ input_pos: torch.Tensor,
147
+ **kwargs: Any) -> torch.Tensor:
148
+ input_pos = input_pos.to(model.device)
149
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
150
+ logits_a,logit_t = model(audio_features, input_ids, clip_features,
151
+ input_pos, whisper_lens=whisper_lens, task=task)
152
+
153
+ for i in range(7):
154
+ logits_a[i] = logits_a[i][0].unsqueeze(0)
155
+ logit_t = logit_t[1].unsqueeze(0)
156
+
157
+ next_audio_tokens = []
158
+ for logit_a in logits_a:
159
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
160
+ next_audio_tokens.append(next_a)
161
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
162
+ return next_audio_tokens, next_t
163
+
164
+
165
+ # torch._dynamo.config.automatic_dynamic_shapes = True
166
+ # torch._inductor.config.triton.unique_kernel_names = True
167
+ # torch._inductor.config.coordinate_descent_tuning = True
168
+ # next_token = torch.compile(next_token, mode="reduce-overhead")
169
+
170
+
171
+ @torch.inference_mode()
172
+ def generate(
173
+ model: GPT,
174
+ input_ids: list,
175
+ max_returned_tokens: int,
176
+ *,
177
+ temperature: float = 1.0,
178
+ top_k: Optional[int] = None,
179
+ top_p: float = 1.0,
180
+ eos_id_a: Optional[int] = None,
181
+ eos_id_t: Optional[int] = None,
182
+ pad_id: Optional[int] = None,
183
+ shift: Optional[int] = None,
184
+ include_prompt: bool = True,
185
+ generate_text=False,
186
+ ) -> torch.Tensor:
187
+ # print("eos_id_a:", eos_id_a)
188
+ # print("eos_id_t:", eos_id_t)
189
+ # print("pad_id:", pad_id)
190
+ """
191
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
192
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
193
+
194
+ Args:
195
+ model: The model to use.
196
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
197
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
198
+ temperature: Scales the predicted logits by 1 / temperature.
199
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
200
+ top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
201
+ In top-p sampling, the next token is sampled from the highest probability tokens
202
+ whose cumulative probability exceeds the threshold `top_p`. When specified,
203
+ it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
204
+ to sampling the most probable token, while `top_p=1` samples from the whole distribution.
205
+ It can be used in conjunction with `top_k` and `temperature` with the following order
206
+ of application:
207
+
208
+ 1. `top_k` sampling
209
+ 2. `temperature` scaling
210
+ 3. `top_p` sampling
211
+
212
+ For more details, see https://arxiv.org/abs/1904.09751
213
+ or https://huyenchip.com/2024/01/16/sampling.html#top_p
214
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
215
+ include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
216
+ """
217
+ T = input_ids[0].size(0)
218
+ device = input_ids[0].device
219
+ assert max_returned_tokens > T
220
+ if model.max_seq_length < max_returned_tokens - 1:
221
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
222
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
223
+ # not support it to avoid negatively impacting the overall speed
224
+ raise NotImplementedError(
225
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
226
+ )
227
+
228
+ for input_id in input_ids:
229
+ input_id = [input_id]
230
+ (
231
+ tokens_A1,
232
+ tokens_A2,
233
+ tokens_A3,
234
+ tokens_A4,
235
+ tokens_A5,
236
+ tokens_A6,
237
+ tokens_A7,
238
+ tokens_T,
239
+ ) = input_ids
240
+
241
+ tokens_A1_output = [tokens_A1]
242
+ tokens_A2_output = [tokens_A2]
243
+ tokens_A3_output = [tokens_A3]
244
+ tokens_A4_output = [tokens_A4]
245
+ tokens_A5_output = [tokens_A5]
246
+ tokens_A6_output = [tokens_A6]
247
+ tokens_A7_output = [tokens_A7]
248
+ tokens_T_output = [tokens_T]
249
+
250
+ list_output = [
251
+ tokens_A1_output,
252
+ tokens_A2_output,
253
+ tokens_A3_output,
254
+ tokens_A4_output,
255
+ tokens_A5_output,
256
+ tokens_A6_output,
257
+ tokens_A7_output,
258
+ tokens_T_output,
259
+ ]
260
+
261
+ input_pos = torch.tensor([T], device=device)
262
+ model_input_ids = [
263
+ tokens_A1.view(1, -1),
264
+ tokens_A2.view(1, -1),
265
+ tokens_A3.view(1, -1),
266
+ tokens_A4.view(1, -1),
267
+ tokens_A5.view(1, -1),
268
+ tokens_A6.view(1, -1),
269
+ tokens_A7.view(1, -1),
270
+ tokens_T.view(1, -1),
271
+ ]
272
+
273
+ tokens_A, token_T = next_token(
274
+ model,
275
+ torch.arange(0, T, device=device),
276
+ model_input_ids,
277
+ temperature=temperature,
278
+ top_k=top_k,
279
+ top_p=top_p,
280
+ )
281
+ for i in range(7):
282
+ list_output[i].append(tokens_A[i].clone())
283
+ list_output[7].append(token_T.clone())
284
+
285
+ # prepare the input for the next iteration
286
+ for i in range(7):
287
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
288
+ token_T = token_T.clone()
289
+
290
+ text_end = False
291
+ max_returned_tokens = 1000
292
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
293
+ model_input_ids = [
294
+ token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
295
+ ] + [token_T.view(1, -1).to(torch.int32)]
296
+ tokens_A, token_T = next_token(
297
+ model,
298
+ input_pos,
299
+ model_input_ids,
300
+ temperature=temperature,
301
+ top_k=top_k,
302
+ top_p=top_p,
303
+ )
304
+ if text_end:
305
+ token_T = torch.tensor([pad_id], device=device)
306
+
307
+ for i in range(7):
308
+ list_output[i].append(tokens_A[i].clone())
309
+ list_output[7].append(token_T.clone())
310
+
311
+ if tokens_A[-1] == eos_id_a:
312
+ break
313
+ if token_T == eos_id_t:
314
+ if generate_text:
315
+ break
316
+ text_end = True
317
+
318
+ for i in range(7):
319
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
320
+ token_T = token_T.clone()
321
+ input_pos = input_pos.add_(1)
322
+
323
+ for i in range(len(list_output)):
324
+ list_output[i] = torch.cat(list_output[i])
325
+ return list_output
326
+
327
+
328
+ @torch.inference_mode()
329
+ def generate_TA_BATCH(
330
+ model: GPT,
331
+ audio_features: torch.Tensor,
332
+ input_ids: list,
333
+ leng,
334
+ task,
335
+ max_returned_tokens: int = 1000,
336
+ *,
337
+ temperature: float = 1.0,
338
+ top_k: Optional[int] = None,
339
+ top_p: float = 1.0,
340
+ eos_id_a: Optional[int] = None,
341
+ eos_id_t: Optional[int] = None,
342
+ pad_id_t: Optional[int] = None,
343
+ shift: Optional[int] = None,
344
+ include_prompt: bool = True,
345
+ generate_text=False,
346
+ ) -> torch.Tensor:
347
+
348
+ T = input_ids[0].size(1)
349
+ device = input_ids[0].device
350
+ assert max_returned_tokens > T
351
+ if model.max_seq_length < max_returned_tokens - 1:
352
+ raise NotImplementedError(
353
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
354
+ )
355
+
356
+ input_pos = torch.tensor([T], device=device)
357
+ model_input_ids = input_ids
358
+
359
+ list_output = [[] for i in range(8)]
360
+
361
+ tokens_A, token_T = next_token_image_batch(
362
+ model,
363
+ audio_features.to(torch.float32).to(model.device),
364
+ None,
365
+ input_ids,
366
+ [T - 3, T - 3],
367
+ ["A1T2", "A1T2"],
368
+ input_pos=torch.arange(0, T, device=device),
369
+ temperature=temperature,
370
+ top_k=top_k,
371
+ top_p=top_p,
372
+ )
373
+
374
+ for i in range(7):
375
+ list_output[i].append(tokens_A[i].tolist()[0])
376
+ list_output[7].append(token_T.tolist()[0])
377
+
378
+ model_input_ids = [[] for i in range(8)]
379
+ for i in range(7):
380
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
381
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
382
+ model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
383
+ model_input_ids[i] = torch.stack(model_input_ids[i])
384
+
385
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
386
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
387
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
388
+
389
+ text_end = False
390
+
391
+ for _ in range(2, max_returned_tokens - T + 1):
392
+ tokens_A, token_T = next_token_image_batch(
393
+ model,
394
+ None,
395
+ None,
396
+ model_input_ids,
397
+ None,
398
+ None,
399
+ input_pos=input_pos,
400
+ temperature=temperature,
401
+ top_k=top_k,
402
+ top_p=top_p,
403
+ )
404
+
405
+ if text_end:
406
+ token_T = torch.tensor([pad_id_t], device=device)
407
+
408
+ if tokens_A[-1] == eos_id_a:
409
+ break
410
+ if token_T == eos_id_t:
411
+ text_end = True
412
+
413
+ for i in range(7):
414
+ list_output[i].append(tokens_A[i].tolist()[0])
415
+ list_output[7].append(token_T.tolist()[0])
416
+
417
+ model_input_ids = [[] for i in range(8)]
418
+ for i in range(7):
419
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
420
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
421
+ model_input_ids[i].append(
422
+ torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
423
+ )
424
+ model_input_ids[i] = torch.stack(model_input_ids[i])
425
+
426
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
427
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
428
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
429
+
430
+ input_pos = input_pos.add_(1)
431
+
432
+ return list_output
433
+
434
+
435
+ @torch.inference_mode()
436
+ def generate_TT(
437
+ model: GPT,
438
+ audio_features: torch.Tensor,
439
+ input_ids: list,
440
+ leng,
441
+ task,
442
+ max_returned_tokens: int = 2048,
443
+ *,
444
+ temperature: float = 1.0,
445
+ top_k: Optional[int] = None,
446
+ top_p: float = 1.0,
447
+ eos_id_a: Optional[int] = None,
448
+ eos_id_t: Optional[int] = None,
449
+ pad_id_t: Optional[int] = None,
450
+ shift: Optional[int] = None,
451
+ include_prompt: bool = True,
452
+ generate_text=False,
453
+ ) -> torch.Tensor:
454
+
455
+ T = input_ids[0].size(1)
456
+ device = input_ids[0].device
457
+
458
+ output = []
459
+ token_T = next_token_A1T1(
460
+ model,
461
+ None,
462
+ input_ids,
463
+ None,
464
+ None,
465
+ input_pos=torch.arange(0, T, device=device),
466
+ temperature=temperature,
467
+ top_k=top_k,
468
+ top_p=top_p,
469
+ )
470
+
471
+ output.append(token_T.clone().tolist()[0])
472
+ input_pos = torch.tensor([T], device=device)
473
+
474
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
475
+ model_input_ids = []
476
+ for i in range(7):
477
+ model_input_ids.append(
478
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
479
+ .view(1, -1)
480
+ .to(torch.int32)
481
+ .to(device)
482
+ )
483
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
484
+ token_T = next_token_A1T1(
485
+ model,
486
+ None,
487
+ model_input_ids,
488
+ None,
489
+ None,
490
+ input_pos=input_pos,
491
+ temperature=temperature,
492
+ top_k=top_k,
493
+ top_p=top_p,
494
+ )
495
+ if token_T == eos_id_t:
496
+ break
497
+ output.append(token_T.clone().tolist()[0])
498
+ input_pos = input_pos.add_(1)
499
+ return output
500
+
501
+
502
+ @torch.inference_mode()
503
+ def generate_AT(
504
+ model: GPT,
505
+ audio_features: torch.Tensor,
506
+ input_ids: list,
507
+ leng,
508
+ task,
509
+ max_returned_tokens: int = 2048,
510
+ *,
511
+ temperature: float = 1.0,
512
+ top_k: Optional[int] = None,
513
+ top_p: float = 1.0,
514
+ eos_id_a: Optional[int] = None,
515
+ eos_id_t: Optional[int] = None,
516
+ pad_id_t: Optional[int] = None,
517
+ shift: Optional[int] = None,
518
+ include_prompt: bool = True,
519
+ generate_text=False,
520
+ ) -> torch.Tensor:
521
+
522
+ T = input_ids[0].size(1)
523
+ device = input_ids[0].device
524
+
525
+ output = []
526
+ token_T = next_token_A1T1(
527
+ model,
528
+ audio_features.to(torch.float32).to(model.device),
529
+ input_ids,
530
+ [T - 3],
531
+ ["AT"],
532
+ input_pos=torch.arange(0, T, device=device),
533
+ temperature=temperature,
534
+ top_k=top_k,
535
+ top_p=top_p,
536
+ )
537
+ output.append(token_T.clone().tolist()[0])
538
+ input_pos = torch.tensor([T], device=device)
539
+ text_end = False
540
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
541
+ model_input_ids = []
542
+ for i in range(7):
543
+ model_input_ids.append(
544
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
545
+ .view(1, -1)
546
+ .to(torch.int32)
547
+ .to(device)
548
+ )
549
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
550
+ token_T = next_token_A1T1(
551
+ model,
552
+ None,
553
+ model_input_ids,
554
+ None,
555
+ None,
556
+ input_pos=input_pos,
557
+ temperature=temperature,
558
+ top_k=top_k,
559
+ top_p=top_p,
560
+ )
561
+ if token_T == eos_id_t:
562
+ break
563
+ output.append(token_T.clone().tolist()[0])
564
+ input_pos = input_pos.add_(1)
565
+ return output
566
+
567
+
568
+ @torch.inference_mode()
569
+ def generate_TA(
570
+ model: GPT,
571
+ audio_features: torch.Tensor,
572
+ input_ids: list,
573
+ leng,
574
+ task,
575
+ max_returned_tokens: int = 2048,
576
+ *,
577
+ temperature: float = 1.0,
578
+ top_k: Optional[int] = None,
579
+ top_p: float = 1.0,
580
+ eos_id_a: Optional[int] = None,
581
+ eos_id_t: Optional[int] = None,
582
+ pad_id_t: Optional[int] = None,
583
+ shift: Optional[int] = None,
584
+ include_prompt: bool = True,
585
+ generate_text=False,
586
+ ) -> torch.Tensor:
587
+
588
+ T = input_ids[0].size(1)
589
+ device = input_ids[0].device
590
+
591
+ output = [[] for _ in range(8)]
592
+ tokens_A, token_T = next_token_A1T2(
593
+ model,
594
+ None,
595
+ input_ids,
596
+ None,
597
+ None,
598
+ input_pos=torch.arange(0, T, device=device),
599
+ temperature=temperature,
600
+ top_k=top_k,
601
+ top_p=top_p,
602
+ )
603
+ for i in range(7):
604
+ output[i].append(tokens_A[i].clone().tolist()[0])
605
+ output[7].append(token_T.clone().tolist()[0])
606
+
607
+ input_pos = torch.tensor([T], device=device)
608
+ text_end = False
609
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
610
+
611
+ model_input_ids = []
612
+ for i in range(7):
613
+ model_input_ids.append(
614
+ layershift(tokens_A[i].clone(), i)
615
+ .view(1, -1)
616
+ .to(torch.int32)
617
+ .to(device)
618
+ )
619
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
620
+
621
+ tokens_A, token_T = next_token_A1T2(
622
+ model,
623
+ None,
624
+ model_input_ids,
625
+ None,
626
+ None,
627
+ input_pos=input_pos,
628
+ temperature=temperature,
629
+ top_k=top_k,
630
+ top_p=top_p,
631
+ )
632
+
633
+ if text_end:
634
+ token_T = torch.tensor([pad_id_t], device=device)
635
+
636
+ if tokens_A[-1] == eos_id_a:
637
+ break
638
+
639
+ if token_T == eos_id_t:
640
+ text_end = True
641
+
642
+ for i in range(7):
643
+ output[i].append(tokens_A[i].clone().tolist()[0])
644
+ output[7].append(token_T.clone().tolist()[0])
645
+ input_pos = input_pos.add_(1)
646
+
647
+ return output
648
+
649
+
650
+ @torch.inference_mode()
651
+ def generate_AA(
652
+ model: GPT,
653
+ audio_features: torch.Tensor,
654
+ input_ids: list,
655
+ leng,
656
+ task,
657
+ max_returned_tokens: int = 2048,
658
+ *,
659
+ temperature: float = 1.0,
660
+ top_k: Optional[int] = None,
661
+ top_p: float = 1.0,
662
+ eos_id_a: Optional[int] = None,
663
+ eos_id_t: Optional[int] = None,
664
+ pad_id_t: Optional[int] = None,
665
+ shift: Optional[int] = None,
666
+ include_prompt: bool = True,
667
+ generate_text=False,
668
+ ) -> torch.Tensor:
669
+
670
+ T = input_ids[0].size(1)
671
+ device = input_ids[0].device
672
+
673
+ output = [[] for _ in range(8)]
674
+ tokens_A, token_T = next_token_A1T2(
675
+ model,
676
+ audio_features.to(torch.float32).to(model.device),
677
+ input_ids,
678
+ [T - 3],
679
+ ["A1T2"],
680
+ input_pos=torch.arange(0, T, device=device),
681
+ temperature=temperature,
682
+ top_k=top_k,
683
+ top_p=top_p,
684
+ )
685
+ for i in range(7):
686
+ output[i].append(tokens_A[i].clone().tolist()[0])
687
+ output[7].append(token_T.clone().tolist()[0])
688
+
689
+ input_pos = torch.tensor([T], device=device)
690
+
691
+ text_end = False
692
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
693
+
694
+ model_input_ids = []
695
+ for i in range(7):
696
+ model_input_ids.append(
697
+ layershift(tokens_A[i].clone(), i)
698
+ .view(1, -1)
699
+ .to(torch.int32)
700
+ .to(device)
701
+ )
702
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
703
+
704
+ tokens_A, token_T = next_token_A1T2(
705
+ model,
706
+ None,
707
+ model_input_ids,
708
+ None,
709
+ None,
710
+ input_pos=input_pos,
711
+ temperature=temperature,
712
+ top_k=top_k,
713
+ top_p=top_p,
714
+ )
715
+
716
+ if text_end:
717
+ token_T = torch.tensor([pad_id_t], device=device)
718
+
719
+ if tokens_A[-1] == eos_id_a:
720
+ break
721
+ if token_T == eos_id_t:
722
+ # print("text_end")
723
+ text_end = True
724
+
725
+ for i in range(7):
726
+ output[i].append(tokens_A[i].clone().tolist()[0])
727
+ output[7].append(token_T.clone().tolist()[0])
728
+ input_pos = input_pos.add_(1)
729
+
730
+ return output
731
+
732
+
733
+ @torch.inference_mode()
734
+ def generate_ASR(
735
+ model: GPT,
736
+ audio_features: torch.Tensor,
737
+ input_ids: list,
738
+ leng,
739
+ task,
740
+ max_returned_tokens: int = 1200,
741
+ *,
742
+ temperature: float = 1.0,
743
+ top_k: Optional[int] = None,
744
+ top_p: float = 1.0,
745
+ eos_id_a: Optional[int] = None,
746
+ eos_id_t: Optional[int] = None,
747
+ pad_id_t: Optional[int] = None,
748
+ shift: Optional[int] = None,
749
+ include_prompt: bool = True,
750
+ generate_text=False,
751
+ ) -> torch.Tensor:
752
+
753
+ T = input_ids[0].size(1)
754
+ device = input_ids[0].device
755
+ output = []
756
+ token_T = next_token_A1T1(
757
+ model,
758
+ audio_features.to(torch.float32).to(model.device),
759
+ input_ids,
760
+ [T - 3],
761
+ ["asr"],
762
+ input_pos=torch.arange(0, T, device=device),
763
+ temperature=temperature,
764
+ top_k=top_k,
765
+ top_p=top_p,
766
+ )
767
+ output.append(token_T.clone().tolist()[0])
768
+ input_pos = torch.tensor([T], device=device)
769
+ text_end = False
770
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
771
+ model_input_ids = []
772
+ for i in range(7):
773
+ model_input_ids.append(
774
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
775
+ .view(1, -1)
776
+ .to(torch.int32)
777
+ .to(device)
778
+ )
779
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
780
+ token_T = next_token_A1T1(
781
+ model,
782
+ None,
783
+ model_input_ids,
784
+ None,
785
+ None,
786
+ input_pos=input_pos,
787
+ temperature=temperature,
788
+ top_k=top_k,
789
+ top_p=top_p,
790
+ )
791
+ if token_T == eos_id_t:
792
+ break
793
+ output.append(token_T.clone().tolist()[0])
794
+ input_pos = input_pos.add_(1)
795
+ return output
litgpt/model.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Full definition of a decoder-only transformer-based language model, all of it in this single file.
4
+
5
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
6
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
7
+ """
8
+
9
+ import math
10
+ from typing import Any, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing_extensions import Self
15
+ from litgpt.config import Config
16
+
17
+
18
+ class GPT(nn.Module):
19
+ def __init__(self, config: Config) -> None:
20
+ super().__init__()
21
+ assert config.padded_vocab_size is not None
22
+ self.config = config
23
+ if self.config.asr_adapter == "mlp":
24
+ print("Using MLP adapter for ASR feature")
25
+ self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
26
+ elif self.config.asr_adapter == "llamamlp":
27
+ print("using LLAMA MLP adapter for ASR feature")
28
+ self.whisper_adapter = whisperMLP(config=config)
29
+ else:
30
+ raise ValueError("asr_adapter should be mlp or llamamlp")
31
+ self.lm_head = nn.Linear(
32
+ config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
33
+ )
34
+
35
+ self.vision_adapter = visionMLP(config = config)
36
+ if config.post_adapter:
37
+ self.transformer = nn.ModuleDict(
38
+ dict(
39
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
40
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
41
+ post_adapter=nn.ModuleList(
42
+ Block(config) for _ in range(config.post_adapter_layers)
43
+ ),
44
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
45
+ post_adapter_audio_ln=config.norm_class(
46
+ config.n_embd, eps=config.norm_eps
47
+ ),
48
+ post_adapter_audio_lm_head=nn.Linear(
49
+ config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
50
+ ),
51
+ )
52
+ )
53
+ else:
54
+ self.transformer = nn.ModuleDict(
55
+ dict(
56
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
57
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
58
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
59
+ )
60
+ )
61
+ self.max_seq_length = self.config.block_size
62
+ self.mask_cache: Optional[torch.Tensor] = None
63
+ if config.tie_word_embeddings:
64
+ self.lm_head.weight = self.transformer.wte.weight
65
+
66
+ @property
67
+ def max_seq_length(self) -> int:
68
+ return self._max_seq_length
69
+
70
+ @max_seq_length.setter
71
+ def max_seq_length(self, value: int) -> None:
72
+ """
73
+ When doing inference, the sequences used might be shorter than the model's context length.
74
+ This allows setting a smaller number to avoid allocating unused memory
75
+ """
76
+ if value > self.config.block_size:
77
+ raise ValueError(
78
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
79
+ )
80
+ self._max_seq_length = value
81
+ if not hasattr(self, "cos"):
82
+ # first call
83
+ cos, sin = self.rope_cache()
84
+ self.register_buffer("cos", cos, persistent=False)
85
+ self.register_buffer("sin", sin, persistent=False)
86
+ # override
87
+ elif value != self.cos.size(0):
88
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
89
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
90
+ # if the kv cache is expected
91
+
92
+ def reset_parameters(self) -> None:
93
+ # Trigger resetting the rope-cache
94
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
95
+
96
+ def _init_weights(self, module: nn.Module) -> None:
97
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
98
+ if isinstance(module, nn.Linear):
99
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
100
+ if module.bias is not None:
101
+ torch.nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
104
+
105
+ def concat_feat(self, audio_feature, clip_feature, input_ids, T, task):
106
+
107
+ for j in range(len(T)):
108
+ if task[j] != 'T1T2' and task[j] != 'T1A2' and task[j]!='ImageQA_T' and not task[j] == 'ImageCAP' and not task[j] == 'ImageQA_A' and not task[j] == 'ImageQA_AT':
109
+ for i in range(7):
110
+ input_ids[i][j,1:T[j]+1,:] = audio_feature[j][:T[j]].clone()
111
+ assert task[j] != 'ImageQ', "ImageQ should be concat with audio feature"
112
+
113
+ elif task[j] == 'ImageQA_A' or task[j] == 'ImageQA_AT':
114
+ print("concat ImageQA_A feature")
115
+ for i in range(7):
116
+ input_ids[i][j,1:51,:] = clip_feature[j].clone()
117
+
118
+ input_ids[i][j,52 : 52 + T[j],:] = audio_feature[j][:T[j]].clone()
119
+
120
+ elif task[j] == 'ImageQA_T' or task[j] =='ImageCAP':
121
+ for i in range(7):
122
+ input_ids[i][j,1:51,:] = clip_feature[j].clone()
123
+
124
+ return input_ids
125
+
126
+ def forward(
127
+ self,
128
+ audio_features: torch.Tensor,
129
+ input_ids: torch.Tensor,
130
+ clip_features: torch.Tensor,
131
+ input_pos: Optional[torch.Tensor] = None,
132
+ whisper_lens: Optional[list] = None,
133
+ task: Optional[str] = None,
134
+ ) -> torch.Tensor:
135
+
136
+ show = False
137
+ T = input_ids[0].size(1)
138
+ if self.max_seq_length < T:
139
+ raise ValueError(
140
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
141
+ )
142
+
143
+ if input_pos is not None: # use the kv cache
144
+ cos = self.cos.index_select(0, input_pos)
145
+ sin = self.sin.index_select(0, input_pos)
146
+ if self.mask_cache is None:
147
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
148
+ mask = self.mask_cache.index_select(2, input_pos)
149
+ else:
150
+ cos = self.cos[:T]
151
+ sin = self.sin[:T]
152
+ mask = None
153
+
154
+ if audio_features is not None:
155
+ # get whisper feature
156
+ x_a = self.whisper_adapter(audio_features)
157
+ if clip_features is not None:
158
+ x_v = self.vision_adapter(clip_features)
159
+ else:
160
+ x_v = None
161
+ # get input_ids embedding
162
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
163
+
164
+ x0 = self.transformer.wte(x0)
165
+ x1 = self.transformer.wte(x1)
166
+ x2 = self.transformer.wte(x2)
167
+ x3 = self.transformer.wte(x3)
168
+ x4 = self.transformer.wte(x4)
169
+ x5 = self.transformer.wte(x5)
170
+ x6 = self.transformer.wte(x6)
171
+ x7 = self.transformer.wte(x7)
172
+
173
+ # concat whisper feature
174
+ input_emb = self.concat_feat(
175
+ x_a, x_v, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
176
+ )
177
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
178
+
179
+ else:
180
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
181
+
182
+ x0 = self.transformer.wte(x0)
183
+ x1 = self.transformer.wte(x1)
184
+ x2 = self.transformer.wte(x2)
185
+ x3 = self.transformer.wte(x3)
186
+ x4 = self.transformer.wte(x4)
187
+ x5 = self.transformer.wte(x5)
188
+ x6 = self.transformer.wte(x6)
189
+ x7 = self.transformer.wte(x7)
190
+
191
+ x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
192
+
193
+ if self.config.scale_embeddings:
194
+ x = x * (self.config.n_embd**0.5)
195
+
196
+ for block in self.transformer.h:
197
+ x = block(x, cos, sin, mask, input_pos)
198
+
199
+
200
+ text_vocab_size = self.config.text_vocab_size
201
+ audio_vocab_size = self.config.audio_vocab_size
202
+
203
+ x_ori = x
204
+ x_ori = self.transformer.ln_f(x_ori)
205
+ x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
206
+ xt = x_ori[..., :text_vocab_size]
207
+
208
+ if self.config.post_adapter:
209
+ for block in self.transformer.post_adapter:
210
+ x = block(x, cos, sin, mask, input_pos)
211
+ x = self.transformer.post_adapter_audio_ln(x)
212
+ x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
213
+ xa = []
214
+ for i in range(7):
215
+ xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
216
+ else:
217
+ xa = []
218
+ for i in range(7):
219
+ xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
220
+
221
+ return xa, xt
222
+
223
+ @classmethod
224
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
225
+ return cls(Config.from_name(name, **kwargs))
226
+
227
+ def rope_cache(
228
+ self, device: Optional[torch.device] = None
229
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
230
+ return build_rope_cache(
231
+ seq_len=self.max_seq_length,
232
+ n_elem=self.config.rope_n_elem,
233
+ device=device,
234
+ condense_ratio=self.config.rope_condense_ratio,
235
+ base=self.config.rope_base,
236
+ )
237
+
238
+ def set_kv_cache(
239
+ self,
240
+ batch_size: int,
241
+ rope_cache_length: Optional[int] = None,
242
+ device: Optional[torch.device] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ) -> None:
245
+ if rope_cache_length is None:
246
+ rope_cache_length = self.cos.size(-1)
247
+ max_seq_length = self.max_seq_length
248
+
249
+ # initialize the kv cache for all blocks
250
+ for block in self.transformer.h:
251
+ block.attn.kv_cache = block.attn.build_kv_cache(
252
+ batch_size, max_seq_length, rope_cache_length, device, dtype
253
+ )
254
+ if self.config.post_adapter:
255
+ for block in self.transformer.post_adapter:
256
+ block.attn.kv_cache = block.attn.build_kv_cache(
257
+ batch_size, max_seq_length, rope_cache_length, device, dtype
258
+ )
259
+
260
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
261
+ # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
262
+ # for the kv-cache support (only during inference), we only create it in that situation
263
+ self.mask_cache = build_mask_cache(max_seq_length, device)
264
+
265
+ def clear_kv_cache(self) -> None:
266
+ self.mask_cache = None
267
+ for block in self.transformer.h:
268
+ block.attn.kv_cache = None
269
+
270
+
271
+ class visionMLP(nn.Module):
272
+ def __init__(self, config: Config) -> None:
273
+ super().__init__()
274
+ vision_adapter_dim = config.vision_adapter_dim
275
+ self.fc_1 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
276
+ self.fc_2 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
277
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
278
+
279
+ self.config = config
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ x_fc_1 = self.fc_1(x)
283
+ x_fc_2 = self.fc_2(x)
284
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
285
+ return self.proj(x)
286
+
287
+
288
+ class Block(nn.Module):
289
+
290
+ def __init__(self, config: Config) -> None:
291
+ super().__init__()
292
+ if not config.parallel_residual and config.shared_attention_norm:
293
+ raise NotImplementedError(
294
+ "No checkpoint amongst the ones we support uses this configuration"
295
+ " (non-parallel residual and shared attention norm)."
296
+ )
297
+
298
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
299
+ self.attn = CausalSelfAttention(config)
300
+ self.norm_2 = (
301
+ None
302
+ if config.shared_attention_norm
303
+ else config.norm_class(config.n_embd, eps=config.norm_eps)
304
+ )
305
+ self.mlp = config.mlp_class(config)
306
+
307
+ self.config = config
308
+
309
+ def forward(
310
+ self,
311
+ x: torch.Tensor,
312
+ cos: torch.Tensor,
313
+ sin: torch.Tensor,
314
+ mask: Optional[torch.Tensor] = None,
315
+ input_pos: Optional[torch.Tensor] = None,
316
+ ) -> torch.Tensor:
317
+ """
318
+ Non-parallel residual Parallel residual
319
+ ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
320
+ │ ↓ │ ↓ ↓ the output from `norm_1` is reused
321
+ │ norm_1 │ norm_1 ───► norm_2
322
+ │ ↓ │ ↓ ↓
323
+ │ attn │ attn mlp
324
+ │ ↓ │ ↓ │
325
+ ┌─ └► + └► + ◄───────────┘
326
+ │ norm_2
327
+ │ ↓
328
+ │ mlp
329
+ │ ↓
330
+ └───► +
331
+ """
332
+
333
+ x_normed = self.norm_1(x)
334
+ attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
335
+
336
+ if self.config.parallel_residual:
337
+ x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
338
+ x = self.mlp(x_normed) + attention_output + x
339
+ else:
340
+ x = attention_output + x
341
+ x = self.mlp(self.norm_2(x)) + x
342
+ return x
343
+
344
+
345
+ class CausalSelfAttention(nn.Module):
346
+ def __init__(self, config: Config) -> None:
347
+ super().__init__()
348
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
349
+ # key, query, value projections for all heads, but in a batch
350
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
351
+ # output projection
352
+ # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
353
+ self.proj = nn.Linear(
354
+ config.head_size * config.n_head, config.n_embd, bias=config.bias
355
+ )
356
+ # disabled by default
357
+ self.kv_cache: Optional[KVCache] = None
358
+
359
+ self.config = config
360
+
361
+ def forward(
362
+ self,
363
+ x: torch.Tensor,
364
+ cos: torch.Tensor,
365
+ sin: torch.Tensor,
366
+ mask: Optional[torch.Tensor] = None,
367
+ input_pos: Optional[torch.Tensor] = None,
368
+ ) -> torch.Tensor:
369
+ B, T, C = (
370
+ x.size()
371
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
372
+
373
+ qkv = self.attn(x)
374
+
375
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
376
+ q_per_kv = self.config.n_head // self.config.n_query_groups
377
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
378
+ qkv = qkv.view(
379
+ B, T, self.config.n_query_groups, total_qkv, self.config.head_size
380
+ )
381
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
382
+
383
+ # split batched computation into three
384
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
385
+
386
+ # maybe repeat k and v if for the non multi-head attention cases
387
+ # training: flash attention requires it
388
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
389
+ if self.config.n_query_groups != self.config.n_head and (
390
+ input_pos is None or self.config.n_query_groups != 1
391
+ ):
392
+ k = k.expand(
393
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
394
+ )
395
+ v = v.expand(
396
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
397
+ )
398
+
399
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
400
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
401
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
402
+
403
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
404
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
405
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
406
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
407
+
408
+ if input_pos is not None:
409
+ if not isinstance(self.kv_cache, KVCache):
410
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
411
+ k, v = self.kv_cache(input_pos, k, v)
412
+
413
+ y = self.scaled_dot_product_attention(q, k, v, mask)
414
+
415
+ y = y.reshape(
416
+ B, T, self.config.head_size * self.config.n_head
417
+ ) # re-assemble all head outputs side by side
418
+
419
+ # output projection
420
+ return self.proj(y)
421
+
422
+ def scaled_dot_product_attention(
423
+ self,
424
+ q: torch.Tensor,
425
+ k: torch.Tensor,
426
+ v: torch.Tensor,
427
+ mask: Optional[torch.Tensor] = None,
428
+ ) -> torch.Tensor:
429
+ scale = 1.0 / math.sqrt(self.config.head_size)
430
+ y = torch.nn.functional.scaled_dot_product_attention(
431
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
432
+ )
433
+ return y.transpose(1, 2)
434
+
435
+ def build_kv_cache(
436
+ self,
437
+ batch_size: int,
438
+ max_seq_length: int,
439
+ rope_cache_length: Optional[int] = None,
440
+ device: Optional[torch.device] = None,
441
+ dtype: Optional[torch.dtype] = None,
442
+ ) -> "KVCache":
443
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
444
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
445
+ if rope_cache_length is None:
446
+ if self.config.rotary_percentage != 1.0:
447
+ raise TypeError(
448
+ "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
449
+ )
450
+ k_shape = v_shape
451
+ else:
452
+ k_shape = (
453
+ batch_size,
454
+ heads,
455
+ max_seq_length,
456
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
457
+ )
458
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
459
+
460
+
461
+ class GptNeoxMLP(nn.Module):
462
+ def __init__(self, config: Config) -> None:
463
+ super().__init__()
464
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
465
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
466
+
467
+ self.config = config
468
+
469
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
470
+ x = self.fc(x)
471
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
472
+ return self.proj(x)
473
+
474
+
475
+ class LLaMAMLP(nn.Module):
476
+ def __init__(self, config: Config) -> None:
477
+ super().__init__()
478
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
479
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
480
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
481
+
482
+ self.config = config
483
+
484
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
485
+ x_fc_1 = self.fc_1(x)
486
+ x_fc_2 = self.fc_2(x)
487
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
488
+ return self.proj(x)
489
+
490
+
491
+ class whisperMLP(nn.Module):
492
+ def __init__(self, config: Config) -> None:
493
+ super().__init__()
494
+ self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
495
+ self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
496
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
497
+
498
+ self.config = config
499
+
500
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
501
+ x_fc_1 = self.fc_1(x)
502
+ x_fc_2 = self.fc_2(x)
503
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
504
+ return self.proj(x)
505
+
506
+
507
+ class GemmaMLP(LLaMAMLP):
508
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
509
+ x_fc_1 = self.fc_1(x)
510
+ x_fc_2 = self.fc_2(x)
511
+ x = (
512
+ torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
513
+ * x_fc_2
514
+ )
515
+ return self.proj(x)
516
+
517
+
518
+ class LLaMAMoE(nn.Module):
519
+ def __init__(self, config: Config) -> None:
520
+ super().__init__()
521
+ self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
522
+ self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
523
+
524
+ self.config = config
525
+
526
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
527
+ """
528
+ Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
529
+ See also figure 1 in https://arxiv.org/abs/2211.15841
530
+ """
531
+ B, T, C = (
532
+ x.size()
533
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
534
+ x = x.view(-1, C) # (B*T, C)
535
+ router = self.gate(x) # (B*T, n_expert)
536
+ probs, indices = torch.topk(
537
+ router, self.config.n_expert_per_token
538
+ ) # (B*T, n_expert_per_token)
539
+ probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
540
+ masks = indices.unsqueeze(-1) == torch.arange(
541
+ self.config.n_expert, device=x.device
542
+ )
543
+ masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
544
+ y = torch.zeros_like(x) # (B*T, C)
545
+ for mask, expert in zip(masks, self.experts):
546
+ token_idx, expert_idx = torch.where(mask)
547
+ y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
548
+ return y.view(B, T, C)
549
+
550
+
551
+ def build_rope_cache(
552
+ seq_len: int,
553
+ n_elem: int,
554
+ device: Optional[torch.device] = None,
555
+ base: int = 10000,
556
+ condense_ratio: int = 1,
557
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
558
+ """Enhanced Transformer with Rotary Position Embedding.
559
+
560
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
561
+ transformers/rope/__init__.py. MIT License:
562
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
563
+ """
564
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
565
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
566
+
567
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
568
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
569
+
570
+ # Calculate the product of position index and $\theta_i$
571
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
572
+
573
+ return torch.cos(idx_theta), torch.sin(idx_theta)
574
+
575
+
576
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
577
+ head_size = x.size(-1)
578
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
579
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
580
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
581
+ roped = (x * cos) + (rotated * sin)
582
+ return roped.to(dtype=x.dtype)
583
+
584
+
585
+ class KVCache(nn.Module):
586
+ def __init__(
587
+ self,
588
+ k_shape: Tuple[int, int, int, int],
589
+ v_shape: Tuple[int, int, int, int],
590
+ device: Optional[torch.device] = None,
591
+ dtype: Optional[torch.dtype] = None,
592
+ ) -> None:
593
+ super().__init__()
594
+ self.register_buffer(
595
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
596
+ )
597
+ self.register_buffer(
598
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
599
+ )
600
+
601
+ def forward(
602
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
603
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
604
+ # move the buffer to the activation dtype for when AMP is used
605
+ self.k = self.k.to(k.dtype)
606
+ self.v = self.v.to(v.dtype)
607
+ # update the cache
608
+ k = self.k.index_copy_(2, input_pos, k)
609
+ v = self.v.index_copy_(2, input_pos, v)
610
+ return k, v
611
+
612
+ def reset_parameters(self) -> None:
613
+ torch.nn.init.zeros_(self.k)
614
+ torch.nn.init.zeros_(self.v)
615
+
616
+
617
+ def build_mask_cache(
618
+ max_seq_length: int, device: Optional[torch.device] = None
619
+ ) -> torch.Tensor:
620
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
621
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
622
+
623
+
624
+ class RMSNorm(torch.nn.Module):
625
+ """Root Mean Square Layer Normalization.
626
+
627
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
628
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
629
+ """
630
+
631
+ def __init__(
632
+ self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
633
+ ) -> None:
634
+ super().__init__()
635
+ self.weight = torch.nn.Parameter(torch.ones(size))
636
+ self.eps = eps
637
+ self.dim = dim
638
+ self.add_unit_offset = add_unit_offset
639
+
640
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
641
+ dtype = x.dtype
642
+ x = x.float()
643
+ # NOTE: the original RMSNorm paper implementation is not equivalent
644
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
645
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
646
+ x_normed = x_normed.to(dtype=dtype)
647
+ if self.add_unit_offset:
648
+ # Gemma model requires a unit offset
649
+ # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
650
+ return x_normed * (1 + self.weight)
651
+ return x_normed * self.weight
652
+
653
+ def reset_parameters(self) -> None:
654
+ torch.nn.init.ones_(self.weight)
litgpt/tokenizer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+
9
+
10
+ class Tokenizer:
11
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
12
+ checkpoint_dir = Path(checkpoint_dir)
13
+ if not checkpoint_dir.exists():
14
+ raise NotADirectoryError(
15
+ f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
16
+ )
17
+
18
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
19
+ self.bos_id = None
20
+ self.eos_id = None
21
+
22
+ # some checkpoints have both files, `.json` takes precedence
23
+ if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (
30
+ special_tokens_path := checkpoint_dir / "tokenizer_config.json"
31
+ ).is_file():
32
+ with open(special_tokens_path, encoding="utf-8") as fp:
33
+ config = json.load(fp)
34
+ bos_token = config.get("bos_token")
35
+ eos_token = config.get("eos_token")
36
+ if bos_token is not None and isinstance(bos_token, dict):
37
+ bos_token = bos_token.get("content")
38
+ if eos_token is not None and isinstance(eos_token, dict):
39
+ eos_token = eos_token.get("content")
40
+ self.bos_id = (
41
+ self.token_to_id(bos_token) if bos_token is not None else None
42
+ )
43
+ self.eos_id = (
44
+ self.token_to_id(eos_token) if eos_token is not None else None
45
+ )
46
+ if (
47
+ special_tokens_path := checkpoint_dir / "generation_config.json"
48
+ ).is_file():
49
+ with open(special_tokens_path, encoding="utf-8") as fp:
50
+ config = json.load(fp)
51
+ if self.bos_id is None:
52
+ self.bos_id = config.get("bos_token_id")
53
+ if self.eos_id is None:
54
+ self.eos_id = config.get("eos_token_id")
55
+
56
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
57
+ from sentencepiece import SentencePieceProcessor
58
+
59
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
60
+ self.backend = "sentencepiece"
61
+ self.bos_id = self.processor.bos_id()
62
+ self.eos_id = self.processor.eos_id()
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ @property
67
+ def vocab_size(self) -> int:
68
+ if self.backend == "huggingface":
69
+ return self.processor.get_vocab_size(with_added_tokens=False)
70
+ if self.backend == "sentencepiece":
71
+ return self.processor.vocab_size()
72
+ raise RuntimeError
73
+
74
+ def token_to_id(self, token: str) -> int:
75
+ if self.backend == "huggingface":
76
+ id_ = self.processor.token_to_id(token)
77
+ elif self.backend == "sentencepiece":
78
+ id_ = self.processor.piece_to_id(token)
79
+ else:
80
+ raise RuntimeError
81
+ if id_ is None:
82
+ raise ValueError(f"token {token!r} not found in the collection.")
83
+ return id_
84
+
85
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
86
+ if not (
87
+ tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
88
+ ).is_file():
89
+ return False
90
+ with open(tokenizer_config_path, encoding="utf-8") as fp:
91
+ config = json.load(fp)
92
+ if "add_bos_token" in config:
93
+ return config["add_bos_token"]
94
+ # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
95
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
96
+ return config.get("tokenizer_class") == "LlamaTokenizer"
97
+
98
+ def encode(
99
+ self,
100
+ string: str,
101
+ device: Optional[torch.device] = None,
102
+ bos: Optional[bool] = None,
103
+ eos: bool = False,
104
+ max_length: int = -1,
105
+ ) -> torch.Tensor:
106
+ if self.backend == "huggingface":
107
+ tokens = self.processor.encode(string).ids
108
+ elif self.backend == "sentencepiece":
109
+ tokens = self.processor.encode(string)
110
+ else:
111
+ raise RuntimeError
112
+ if bos or (bos is None and self.use_bos):
113
+ bos_id = self.bos_id
114
+ if bos_id is None:
115
+ raise NotImplementedError(
116
+ "This tokenizer does not have a defined a bos token"
117
+ )
118
+ if tokens[0] != bos_id:
119
+ tokens = [bos_id] + tokens
120
+ if tokens is None:
121
+ raise ValueError("`tokens` is None")
122
+
123
+ if eos and (not tokens or tokens[-1] != self.eos_id):
124
+ tokens = tokens + [self.eos_id]
125
+ if max_length > 0:
126
+ tokens = tokens[:max_length]
127
+ return torch.tensor(tokens, dtype=torch.int, device=device)
128
+
129
+ def decode(self, tensor: torch.Tensor) -> str:
130
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
131
+ return self.processor.decode(tokens)
litgpt/utils.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Utility functions for training and inference."""
4
+ import inspect
5
+ import math
6
+ import os
7
+ import pickle
8
+ import shutil
9
+ import sys
10
+ from dataclasses import asdict, is_dataclass
11
+ from io import BytesIO
12
+ from pathlib import Path
13
+ from typing import (
14
+ TYPE_CHECKING,
15
+ Any,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ Literal,
20
+ Mapping,
21
+ Optional,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import lightning as L
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.utils._device
30
+ import yaml
31
+ from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
32
+ from lightning.fabric.strategies import FSDPStrategy
33
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
34
+ from lightning.pytorch.loggers import WandbLogger
35
+ from lightning.pytorch.cli import instantiate_class
36
+ from torch.serialization import normalize_storage_type
37
+ from typing_extensions import Self
38
+
39
+ if TYPE_CHECKING:
40
+ from litgpt import GPT, Config
41
+
42
+
43
+ def init_out_dir(out_dir: Path) -> Path:
44
+ if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
45
+ return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
46
+ return out_dir
47
+
48
+
49
+ def find_resume_path(
50
+ resume: Union[bool, Literal["auto"], Path], out_dir: Path
51
+ ) -> Optional[Path]:
52
+ if not resume or isinstance(resume, Path):
53
+ return resume
54
+
55
+ resume_path = max(
56
+ out_dir.rglob("step-*/*.pth"),
57
+ key=(lambda p: int(p.parent.name.split("-")[1])),
58
+ default=None,
59
+ )
60
+ if resume == "auto":
61
+ return resume_path
62
+ if resume is True and resume_path is None:
63
+ raise FileNotFoundError(
64
+ f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
65
+ )
66
+ return resume_path
67
+
68
+
69
+ def find_multiple(n: int, k: int) -> int:
70
+ assert k > 0
71
+ if n % k == 0:
72
+ return n
73
+ return n + k - (n % k)
74
+
75
+
76
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
77
+ total = 0
78
+ for p in module.parameters():
79
+ if requires_grad is None or p.requires_grad == requires_grad:
80
+ if hasattr(p, "quant_state"):
81
+ # bitsandbytes 4bit layer support
82
+ total += math.prod(p.quant_state.shape)
83
+ else:
84
+ total += p.numel()
85
+ return total
86
+
87
+
88
+ def reset_parameters(module: nn.Module) -> None:
89
+ """Calls `reset_parameters` on the module and all its submodules."""
90
+ for mod in module.modules():
91
+ if callable(getattr(mod, "reset_parameters", None)):
92
+ mod.reset_parameters()
93
+
94
+
95
+ def check_valid_checkpoint_dir(
96
+ checkpoint_dir: Path,
97
+ model_filename: str = "lit_model.pth",
98
+ verbose: bool = True,
99
+ raise_error: bool = False,
100
+ ) -> None:
101
+ files = {
102
+ model_filename: (checkpoint_dir / model_filename).is_file(),
103
+ "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
104
+ "tokenizer.json OR tokenizer.model": (
105
+ checkpoint_dir / "tokenizer.json"
106
+ ).is_file()
107
+ or (checkpoint_dir / "tokenizer.model").is_file(),
108
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
109
+ }
110
+ if checkpoint_dir.is_dir():
111
+ if all(files.values()):
112
+ # we're good
113
+ return
114
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
115
+ else:
116
+ problem = " is not a checkpoint directory"
117
+
118
+ # list locally available checkpoints
119
+ available = list(Path("checkpoints").glob("*/*"))
120
+ if available:
121
+ options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
122
+ extra = f"\nYou have downloaded locally:{options}\n"
123
+ else:
124
+ extra = ""
125
+
126
+ if verbose:
127
+ error_message = (
128
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
129
+ "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
130
+ f"{extra}\nSee all download options by running:\n litgpt download"
131
+ )
132
+ print(error_message, file=sys.stderr)
133
+
134
+ if raise_error:
135
+ raise FileNotFoundError(
136
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
137
+ )
138
+ else:
139
+ raise SystemExit(1)
140
+
141
+
142
+ class SavingProxyForStorage:
143
+ def __init__(self, obj, saver, protocol_version=5):
144
+ self.protocol_version = protocol_version
145
+ self.saver = saver
146
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
147
+ raise TypeError(f"expected storage, not {type(obj)}")
148
+
149
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
150
+ if isinstance(obj, torch.storage.TypedStorage):
151
+ # PT upstream wants to deprecate this eventually...
152
+ storage = obj._untyped_storage
153
+ storage_type_str = obj._pickle_storage_type()
154
+ storage_type = getattr(torch, storage_type_str)
155
+ storage_numel = obj._size()
156
+ else:
157
+ storage = obj
158
+ storage_type = normalize_storage_type(type(obj))
159
+ storage_numel = storage.nbytes()
160
+
161
+ storage_key = saver._write_storage_and_return_key(storage)
162
+ location = torch.serialization.location_tag(storage)
163
+
164
+ self.storage_info = (
165
+ "storage",
166
+ storage_type,
167
+ storage_key,
168
+ location,
169
+ storage_numel,
170
+ )
171
+
172
+ def __reduce_ex__(self, protocol_version):
173
+ assert False, "this should be handled with out of band"
174
+
175
+
176
+ class SavingProxyForTensor:
177
+ def __init__(self, tensor, saver, protocol_version=5):
178
+ self.protocol_version = protocol_version
179
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
180
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
181
+ # for Tensors with Python attributes
182
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
183
+ assert isinstance(
184
+ storage, torch.storage.TypedStorage
185
+ ), "Please check for updates"
186
+ storage_proxy = SavingProxyForStorage(
187
+ storage, saver, protocol_version=protocol_version
188
+ )
189
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
190
+ else:
191
+ (storage, *other_reduce_args) = reduce_args
192
+ assert isinstance(
193
+ storage, torch.storage.TypedStorage
194
+ ), "Please check for updates"
195
+ storage_proxy = SavingProxyForStorage(
196
+ storage, saver, protocol_version=protocol_version
197
+ )
198
+ self.reduce_args = (storage_proxy, *other_reduce_args)
199
+
200
+ def __reduce_ex__(self, protocol_version):
201
+ if protocol_version != self.protocol_version:
202
+ raise RuntimeError(
203
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
204
+ )
205
+ return self.reduce_ret_fn, self.reduce_args
206
+
207
+
208
+ class IncrementalPyTorchPickler(pickle.Pickler):
209
+ def __init__(self, saver, *args, **kwargs):
210
+ super().__init__(*args, **kwargs)
211
+ self.storage_dtypes = {}
212
+ self.saver = saver
213
+ self.id_map = {}
214
+
215
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
216
+ def persistent_id(self, obj):
217
+ # FIXME: the docs say that persistent_id should only return a string
218
+ # but torch store returns tuples. This works only in the binary protocol
219
+ # see
220
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
221
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
222
+ if isinstance(obj, SavingProxyForStorage):
223
+ return obj.storage_info
224
+
225
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
226
+ if isinstance(obj, torch.storage.TypedStorage):
227
+ # TODO: Once we decide to break serialization FC, this case
228
+ # can be deleted
229
+ storage = obj._untyped_storage
230
+ storage_dtype = obj.dtype
231
+ storage_type_str = obj._pickle_storage_type()
232
+ storage_type = getattr(torch, storage_type_str)
233
+ storage_numel = obj._size()
234
+
235
+ else:
236
+ storage = obj
237
+ storage_dtype = torch.uint8
238
+ storage_type = normalize_storage_type(type(obj))
239
+ storage_numel = storage.nbytes()
240
+
241
+ # If storage is allocated, ensure that any other saved storages
242
+ # pointing to the same data all have the same dtype. If storage is
243
+ # not allocated, don't perform this check
244
+ if storage.data_ptr() != 0:
245
+ if storage.data_ptr() in self.storage_dtypes:
246
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
247
+ raise RuntimeError(
248
+ "Cannot save multiple tensors or storages that view the same data as different types"
249
+ )
250
+ else:
251
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
252
+
253
+ storage_key = self.id_map.get(storage._cdata)
254
+ if storage_key is None:
255
+ storage_key = self.saver._write_storage_and_return_key(storage)
256
+ self.id_map[storage._cdata] = storage_key
257
+ location = torch.serialization.location_tag(storage)
258
+
259
+ return ("storage", storage_type, storage_key, location, storage_numel)
260
+
261
+ return None
262
+
263
+
264
+ class incremental_save:
265
+ def __init__(self, name):
266
+ self.name = name
267
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
268
+ self.has_saved = False
269
+ self.next_key = 0
270
+
271
+ def __enter__(self):
272
+ return self
273
+
274
+ def store_early(self, tensor):
275
+ if isinstance(tensor, torch.Tensor):
276
+ return SavingProxyForTensor(tensor, self)
277
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
278
+
279
+ def save(self, obj):
280
+ if self.has_saved:
281
+ raise RuntimeError("have already saved")
282
+ # Write the pickle data for `obj`
283
+ data_buf = BytesIO()
284
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
285
+ pickler.dump(obj)
286
+ data_value = data_buf.getvalue()
287
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
288
+ self.has_saved = True
289
+
290
+ def _write_storage_and_return_key(self, storage):
291
+ if self.has_saved:
292
+ raise RuntimeError("have already saved")
293
+ key = self.next_key
294
+ self.next_key += 1
295
+ name = f"data/{key}"
296
+ if storage.device.type != "cpu":
297
+ storage = storage.cpu()
298
+ num_bytes = storage.nbytes()
299
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
300
+ return key
301
+
302
+ def __exit__(self, type, value, traceback):
303
+ self.zipfile.write_end_of_file()
304
+
305
+
306
+ T = TypeVar("T")
307
+
308
+
309
+ def chunked_cross_entropy(
310
+ logits: Union[torch.Tensor, List[torch.Tensor]],
311
+ targets: torch.Tensor,
312
+ chunk_size: int = 128,
313
+ ignore_index: int = -100,
314
+ ) -> torch.Tensor:
315
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
316
+ # the memory usage in fine-tuning settings with low number of parameters.
317
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
318
+ # the memory spike's magnitude
319
+
320
+ # lm_head was chunked (we are fine-tuning)
321
+ if isinstance(logits, list):
322
+ # don't want to chunk cross entropy
323
+ if chunk_size == 0:
324
+ logits = torch.cat(logits, dim=1)
325
+ logits = logits.reshape(-1, logits.size(-1))
326
+ targets = targets.reshape(-1)
327
+ return torch.nn.functional.cross_entropy(
328
+ logits, targets, ignore_index=ignore_index
329
+ )
330
+
331
+ # chunk cross entropy
332
+ logit_chunks = [
333
+ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
334
+ ]
335
+ target_chunks = [
336
+ target_chunk.reshape(-1)
337
+ for target_chunk in targets.split(logits[0].size(1), dim=1)
338
+ ]
339
+ loss_chunks = [
340
+ torch.nn.functional.cross_entropy(
341
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
342
+ )
343
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
344
+ ]
345
+ non_masked_elems = (targets != ignore_index).sum()
346
+ # See [non_masked_elems div note]
347
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
348
+ torch.ones_like(non_masked_elems)
349
+ )
350
+
351
+ # no chunking at all
352
+ logits = logits.reshape(-1, logits.size(-1))
353
+ targets = targets.reshape(-1)
354
+ if chunk_size == 0:
355
+ return torch.nn.functional.cross_entropy(
356
+ logits, targets, ignore_index=ignore_index
357
+ )
358
+
359
+ # lm_head wasn't chunked, chunk cross entropy
360
+ logit_chunks = logits.split(chunk_size)
361
+ target_chunks = targets.split(chunk_size)
362
+ loss_chunks = [
363
+ torch.nn.functional.cross_entropy(
364
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
365
+ )
366
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
367
+ ]
368
+ non_masked_elems = (targets != ignore_index).sum()
369
+ # [non_masked_elems div note]:
370
+ # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
371
+ # results in a python int which is then passed back to torch division. By using the
372
+ # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
373
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
374
+ torch.ones_like(non_masked_elems)
375
+ )
376
+
377
+
378
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
379
+ for checkpoint_name, attribute_name in mapping.items():
380
+ full_checkpoint_name = prefix + checkpoint_name
381
+ if full_checkpoint_name in state_dict:
382
+ full_attribute_name = prefix + attribute_name
383
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
384
+ return state_dict
385
+
386
+
387
+ def get_default_supported_precision(training: bool) -> str:
388
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
389
+
390
+ Args:
391
+ training: `-mixed` or `-true` version of the precision to use
392
+
393
+ Returns:
394
+ default precision that is suitable for the task and is supported by the hardware
395
+ """
396
+ from lightning.fabric.accelerators import MPSAccelerator
397
+
398
+ if MPSAccelerator.is_available() or (
399
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
400
+ ):
401
+ return "16-mixed" if training else "16-true"
402
+ return "bf16-mixed" if training else "bf16-true"
403
+
404
+
405
+ def load_checkpoint(
406
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
407
+ ) -> None:
408
+ if isinstance(fabric.strategy, FSDPStrategy):
409
+ fabric.load_raw(checkpoint_path, model, strict=strict)
410
+ else:
411
+ state_dict = lazy_load(checkpoint_path)
412
+ state_dict = state_dict.get("model", state_dict)
413
+ model.load_state_dict(state_dict, strict=strict)
414
+
415
+
416
+ def flops_per_param(
417
+ max_seq_length: int, n_layer: int, n_embd: int, n_params: int
418
+ ) -> int:
419
+ flops_per_token = (
420
+ 2 * n_params
421
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
422
+ # this assumes that all samples have a fixed length equal to the block size
423
+ # which is most likely false during finetuning
424
+ flops_per_seq = flops_per_token * max_seq_length
425
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
426
+ return flops_per_seq + attn_flops_per_seq
427
+
428
+
429
+ def estimate_flops(model: "GPT", training: bool) -> int:
430
+ """Measures estimated FLOPs for MFU.
431
+
432
+ Refs:
433
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
434
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
435
+ """
436
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
437
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
438
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
439
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
440
+ n_trainable_params = num_parameters(model, requires_grad=True)
441
+ trainable_flops = flops_per_param(
442
+ model.max_seq_length,
443
+ model.config.n_layer,
444
+ model.config.n_embd,
445
+ n_trainable_params,
446
+ )
447
+ # forward + backward + gradients (assumes no gradient accumulation)
448
+ ops_per_step = 3 if training else 1
449
+ n_frozen_params = num_parameters(model, requires_grad=False)
450
+ frozen_flops = flops_per_param(
451
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
452
+ )
453
+ # forward + backward
454
+ frozen_ops_per_step = 2 if training else 1
455
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
456
+
457
+
458
+ class CycleIterator:
459
+ """An iterator that cycles through an iterable indefinitely.
460
+
461
+ Example:
462
+ >>> iterator = CycleIterator([1, 2, 3])
463
+ >>> [next(iterator) for _ in range(5)]
464
+ [1, 2, 3, 1, 2]
465
+
466
+ Note:
467
+ Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
468
+ """
469
+
470
+ def __init__(self, iterable: Iterable) -> None:
471
+ self.iterable = iterable
472
+ self.epoch = 0
473
+ self._iterator = None
474
+
475
+ def __next__(self) -> Any:
476
+ if self._iterator is None:
477
+ self._iterator = iter(self.iterable)
478
+ try:
479
+ return next(self._iterator)
480
+ except StopIteration:
481
+ self._iterator = iter(self.iterable)
482
+ self.epoch += 1
483
+ return next(self._iterator)
484
+
485
+ def __iter__(self) -> Self:
486
+ return self
487
+
488
+
489
+ def copy_config_files(source_dir: Path, out_dir: Path) -> None:
490
+ """Copies the specified configuration and tokenizer files into the output directory."""
491
+
492
+ config_files = ["config.json", "generation_config.json", "model_config.yaml"]
493
+ tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
494
+
495
+ for file_name in config_files + tokenizer_files:
496
+ src_path = source_dir / file_name
497
+ if src_path.exists():
498
+ shutil.copy(src_path, out_dir)
499
+
500
+
501
+ def CLI(*args: Any, **kwargs: Any) -> Any:
502
+ from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
503
+
504
+ set_docstring_parse_options(attribute_docstrings=True)
505
+ set_config_read_mode(urls_enabled=True)
506
+
507
+ return CLI(*args, **kwargs)
508
+
509
+
510
+ def capture_hparams() -> Dict[str, Any]:
511
+ """Captures the local variables ('hyperparameters') from where this function gets called."""
512
+ caller_frame = inspect.currentframe().f_back
513
+ locals_of_caller = caller_frame.f_locals
514
+ hparams = {}
515
+ for name, value in locals_of_caller.items():
516
+ if value is None or isinstance(value, (int, float, str, bool, Path)):
517
+ hparams[name] = value
518
+ elif is_dataclass(value):
519
+ hparams[name] = asdict(value)
520
+ else:
521
+ hparams[name] = str(value)
522
+ return hparams
523
+
524
+
525
+ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
526
+ """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
527
+ from jsonargparse import capture_parser
528
+
529
+ # TODO: Make this more robust
530
+ # This hack strips away the subcommands from the top-level CLI
531
+ # to parse the file as if it was called as a script
532
+ known_commands = [
533
+ ("finetune_full",), # For subcommands, use `("finetune", "full")` etc
534
+ ("finetune_lora",),
535
+ ("finetune_adapter",),
536
+ ("finetune_adapter_v2",),
537
+ ("finetune",),
538
+ ("pretrain",),
539
+ ]
540
+ for known_command in known_commands:
541
+ unwanted = slice(1, 1 + len(known_command))
542
+ if tuple(sys.argv[unwanted]) == known_command:
543
+ sys.argv[unwanted] = []
544
+
545
+ parser = capture_parser(lambda: CLI(function))
546
+ config = parser.parse_args()
547
+ parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
548
+
549
+
550
+ def save_config(config: "Config", checkpoint_dir: Path) -> None:
551
+ config_dict = asdict(config)
552
+ with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
553
+ yaml.dump(config_dict, fp)
554
+
555
+
556
+ def parse_devices(devices: Union[str, int]) -> int:
557
+ if devices in (-1, "auto"):
558
+ return torch.cuda.device_count() or 1
559
+ if isinstance(devices, int) and devices > 0:
560
+ return devices
561
+ raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
562
+
563
+
564
+ def choose_logger(
565
+ logger_name: Literal["csv", "tensorboard", "wandb"],
566
+ out_dir: Path,
567
+ name: str,
568
+ log_interval: int = 1,
569
+ resume: Optional[bool] = None,
570
+ **kwargs: Any,
571
+ ):
572
+ if logger_name == "csv":
573
+ return CSVLogger(
574
+ root_dir=(out_dir / "logs"),
575
+ name="csv",
576
+ flush_logs_every_n_steps=log_interval,
577
+ **kwargs,
578
+ )
579
+ if logger_name == "tensorboard":
580
+ return TensorBoardLogger(
581
+ root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
582
+ )
583
+ if logger_name == "wandb":
584
+ return WandbLogger(project=name, resume=resume, **kwargs)
585
+ raise ValueError(
586
+ f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
587
+ )
588
+
589
+
590
+ def get_argument_names(cls):
591
+ sig = inspect.signature(cls.__init__)
592
+ return {
593
+ name
594
+ for name, param in sig.parameters.items()
595
+ if param.kind
596
+ in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
597
+ }
598
+
599
+
600
+ def instantiate_bnb_optimizer(optimizer, model_parameters):
601
+ if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
602
+ isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
603
+ ):
604
+ raise ValueError(
605
+ "The chosen quantization format only supports the AdamW optimizer."
606
+ )
607
+
608
+ import bitsandbytes as bnb
609
+
610
+ if isinstance(optimizer, str):
611
+ optimizer = bnb.optim.PagedAdamW(model_parameters)
612
+ else:
613
+ optim_args = get_argument_names(bnb.optim.PagedAdamW)
614
+ allowed_kwargs = {
615
+ key: optimizer["init_args"][key]
616
+ for key in optim_args & optimizer["init_args"].keys()
617
+ }
618
+ optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
619
+ return optimizer
620
+
621
+
622
+ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
623
+ if isinstance(optimizer, str):
624
+ optimizer_cls = getattr(torch.optim, optimizer)
625
+ optimizer = optimizer_cls(model_parameters, **kwargs)
626
+ else:
627
+ optimizer = dict(optimizer) # copy
628
+ optimizer["init_args"].update(kwargs)
629
+ optimizer = instantiate_class(model_parameters, optimizer)
630
+ return optimizer
631
+
632
+
633
+ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
634
+ new_checkpoint_dir = "checkpoints" / checkpoint_dir
635
+ should_return_new_dir = (
636
+ not checkpoint_dir.is_dir()
637
+ and checkpoint_dir.parts[0] != "checkpoints"
638
+ and not checkpoint_dir.is_absolute()
639
+ and new_checkpoint_dir.exists()
640
+ )
641
+ return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
models/README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: any-to-any
4
+ library_name: mini-omni2
5
+ ---
6
+
7
+ # Mini-Omni2
8
+
9
+ <!-- <p align="center">
10
+ <img src="./data/figures/title.png" width="100%"/>
11
+ </p> -->
12
+
13
+
14
+ <p align="center">
15
+ 🤗 <a href="https://huggingface.co/gpt-omni/mini-omni2">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni2">Github</a>
16
+ | 📑 <a href="https://arxiv.org/abs/2410.11190">Technical report</a>
17
+ </p>
18
+
19
+ Mini-Omni2 is an **omni-interactive** model. It can **understand image, audio and text inputs and has end-to-end voice conversations with users**. Featuring **real-time voice output**, **omni-capable multimodal understanding** and flexible interaction **ability with interruption mechanism while speaking**.
20
+
21
+ <p align="center">
22
+ <img src="./data/figures/framework.jpeg" width="100%"/>
23
+ </p>
24
+
25
+
26
+ ## Updates
27
+
28
+ - **2024.10:** Release the model, technical report, inference and chat demo code.
29
+
30
+ ## Features
31
+ ✅ **Multimodal interaction**: with the ability to understand images, speech and text, just like GPT-4o.
32
+
33
+ ✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required, just like [Mini-Omni](https://github.com/gpt-omni/mini-omni).
34
+
35
+ <!-- ✅ **Streaming audio output**: with first-chunk latency of audio stream less than 0.3s. -->
36
+
37
+ <!-- ✅ **Duplex interaction**: hearing while speaking, it can be interrupted by key words like "stop omni". -->
38
+
39
+
40
+ ## Demo
41
+
42
+ NOTE: need to unmute first.
43
+
44
+ https://github.com/user-attachments/assets/ad97ca7f-f8b4-40c3-a7e8-fa54b4edf155
45
+
46
+
47
+ ## ToDo
48
+ - [ ] update interruption mechanism
49
+
50
+
51
+ ## Install
52
+
53
+ Create a new conda environment and install the required packages:
54
+
55
+ ```sh
56
+ conda create -n omni python=3.10
57
+ conda activate omni
58
+
59
+ git clone https://github.com/gpt-omni/mini-omni2.git
60
+ cd mini-omni2
61
+ pip install -r requirements.txt
62
+ ```
63
+
64
+ ## Quick start
65
+
66
+ **Interactive demo**
67
+
68
+ - start server
69
+
70
+ NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
71
+
72
+ ```sh
73
+ sudo apt-get install ffmpeg
74
+ conda activate omni
75
+ cd mini-omni2
76
+ python3 server.py --ip '0.0.0.0' --port 60808
77
+ ```
78
+
79
+
80
+ - run streamlit demo
81
+
82
+ NOTE: you need to run streamlit **locally** with PyAudio installed.
83
+
84
+ ```sh
85
+ pip install PyAudio==0.2.14
86
+ API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
87
+ ```
88
+
89
+
90
+ **Local test**
91
+
92
+ ```sh
93
+ conda activate omni
94
+ cd mini-omni2
95
+ # test run the preset audio samples and questions
96
+ python inference_vision.py
97
+ ```
98
+
99
+ ## Mini-Omni2 Overview
100
+
101
+ **1. Multimodal Modeling**:
102
+ We use multiple sequences as the input and output of the model. In the input part, we will concatenate image, audio and text features to perform a series of comprehensive tasks, as shown in the following figures. In the output part, we use text-guided delayed parallel output to generate real-time speech responses.
103
+ <p align="center">
104
+ <img src="./data/figures/inputids.png" width="100%"/>
105
+ </p>
106
+
107
+ **2. Multi-stage Training**:
108
+ We propose an efficient alignment training method and conduct encoder adaptation, modal alignment, and multimodal fine-tuning respectively in the three-stage training.
109
+ <p align="center">
110
+ <img src="./data/figures/training.jpeg" width="100%"/>
111
+ </p>
112
+
113
+ <!-- **3. Cases**:
114
+ Here are more cases of Mini-Omni2:
115
+ <p align="center">
116
+ <img src="./data/figures/samples.png" width="100%"/>
117
+ </p> -->
118
+
119
+ ## FAQ
120
+
121
+ **1. Does the model support other languages?**
122
+
123
+ No, the model is only trained on English. However, as we use whisper as the audio encoder, the model can understand other languages which is supported by whisper (like chinese), but the output is only in English.
124
+
125
+ **2. Error: can not run streamlit in local browser, with remote streamlit server**
126
+
127
+ You need start streamlit **locally** with PyAudio installed.
128
+
129
+
130
+ ## Acknowledgements
131
+
132
+ - [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
133
+ - [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
134
+ - [whisper](https://github.com/openai/whisper/) for audio encoding.
135
+ - [clip](https://github.com/openai/CLIP) for image encoding.
136
+ - [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
137
+ - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
138
+ - [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
139
+
140
+ <!-- ## Star History
141
+
142
+ [![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni2&type=Date)](https://star-history.com/#gpt-omni/mini-omni2&Date)
models/ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
models/data/figures/framework.jpeg ADDED

Git LFS Details

  • SHA256: bc668450030500a62ddbb7cf6ea170f0b53da7e3e5506d01a0dc6f2ec690fd1a
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
models/data/figures/inputids.png ADDED

Git LFS Details

  • SHA256: ad4cf663684c53f72952b13f52ea93fcbe19e287301b3decfcd917de9e23f312
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
models/data/figures/samples.png ADDED

Git LFS Details

  • SHA256: e63a8cbc2859304cb9c50b831366ac8804ad0326b6ae4897d08f8ab0e1eb63c6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
models/data/figures/title.png ADDED

Git LFS Details

  • SHA256: 56194a7fd5cfd29d6e2ce574fc7628315d87220adbc7aa2949e579c1a63ed2a3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
models/data/figures/training.jpeg ADDED

Git LFS Details

  • SHA256: fd49f75dbe5838a3e28f02c8f853dec34d0aad8573911d52bd827ab6dae8f9a1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
models/data/omni2-demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2098124af391dca9c48854f5686143c137cc069f08b5e457675b9ba744bd2f
3
+ size 11784395
models/hub/.locks/models--hubertsiuzdak--snac_24khz/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40.lock ADDED
File without changes
models/hub/.locks/models--hubertsiuzdak--snac_24khz/a9e7ef62bf7e1eb94d2713721029837aacab3b55.lock ADDED
File without changes
models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
3
+ size 79488254
models/hub/models--hubertsiuzdak--snac_24khz/blobs/a9e7ef62bf7e1eb94d2713721029837aacab3b55 ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sampling_rate": 24000,
3
+ "encoder_dim": 48,
4
+ "encoder_rates": [2, 4, 8, 8],
5
+ "decoder_dim": 1024,
6
+ "decoder_rates": [8, 8, 4, 2],
7
+ "attn_window_size": null,
8
+ "codebook_size": 4096,
9
+ "codebook_dim": 8,
10
+ "vq_strides": [4, 2, 1],
11
+ "noise": true,
12
+ "depthwise": true
13
+ }
models/hub/models--hubertsiuzdak--snac_24khz/refs/main ADDED
@@ -0,0 +1 @@
 
 
1
+ d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec
models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sampling_rate": 24000,
3
+ "encoder_dim": 48,
4
+ "encoder_rates": [2, 4, 8, 8],
5
+ "decoder_dim": 1024,
6
+ "decoder_rates": [8, 8, 4, 2],
7
+ "attn_window_size": null,
8
+ "codebook_size": 4096,
9
+ "codebook_dim": 8,
10
+ "vq_strides": [4, 2, 1],
11
+ "noise": true,
12
+ "depthwise": true
13
+ }
models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
3
+ size 79488254
models/hub/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
models/lit_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98a0e4ad1912b0ee0081a5898b6c260f4b06a37aaef34e49b0719b46808c231c
3
+ size 2814623738
models/model_config.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_qkv_bias: true
2
+ asr_adapter: llamamlp
3
+ attn_dropout: 0.0
4
+ bias: false
5
+ block_size: 2048
6
+ force_align: false
7
+ gelu_approximate: none
8
+ head_size: 64
9
+ hf_config:
10
+ name: Qwen2-0.5B
11
+ org: Qwen
12
+ intermediate_size: 4864
13
+ lm_head_bias: false
14
+ mlp_class_name: LLaMAMLP
15
+ n_embd: 896
16
+ n_expert: 0
17
+ n_expert_per_token: 0
18
+ n_head: 14
19
+ n_layer: 24
20
+ n_query_groups: 2
21
+ name: Qwen2-0.5B
22
+ norm_class_name: RMSNorm
23
+ norm_eps: 1.0e-06
24
+ padded_vocab_size: 181120
25
+ padding_multiple: 512
26
+ parallel_residual: false
27
+ pos_type: rope
28
+ post_adapter: false
29
+ post_adapter_layers: 6
30
+ prompt_vocab_size: null
31
+ rope_base: 1000000
32
+ rope_condense_ratio: 1
33
+ rotary_percentage: 1
34
+ scale_embeddings: false
35
+ shared_attention_norm: false
36
+ tie_word_embeddings: true
37
+ use_pretrain_phoneme_emb: false
38
+ vocab_size: 50254
39
+ text_vocab_size: 152000
40
+ cat_audio_vocab_size: 29120
41
+ audio_vocab_size: 4160
42
+ whisper_adapter_dim: 768
43
+ vision_adapter_dim: 512
models/small.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fb26a40bfcfbb3d7e41586205d21c90ffc1de552c15367efb4a723ce11f700f
3
+ size 483586606
models/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/tokenizer_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
30
+ "bos_token": null,
31
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "<|endoftext|>",
34
+ "errors": "replace",
35
+ "model_max_length": 32768,
36
+ "pad_token": "<|endoftext|>",
37
+ "split_special_tokens": false,
38
+ "tokenizer_class": "Qwen2Tokenizer",
39
+ "unk_token": null
40
+ }
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ torchaudio==2.3.1
4
+ litgpt==0.4.3
5
+ snac==1.2.0
6
+ soundfile==0.12.1
7
+ openai-whisper
8
+ tokenizers==0.19.1
9
+ streamlit==1.37.1
10
+ streamlit-webrtc
11
+ # PyAudio==0.2.14
12
+ pydub==0.25.1
13
+ onnxruntime==1.19.0
14
+ # numpy==1.26.3
15
+ librosa==0.10.2.post1
16
+ flask==3.0.3
17
+ fire
18
+ git+https://github.com/mini-omni/CLIP.git
19
+ gradio_webrtc[vad]==0.0.11
20
+ twilio