AlexHung29629
commited on
Commit
•
159da21
1
Parent(s):
af56b11
Update mllama_audio_model.py
Browse files- mllama_audio_model.py +40 -3
mllama_audio_model.py
CHANGED
@@ -6,16 +6,52 @@ from transformers import Wav2Vec2Model, Wav2Vec2Config, MllamaPreTrainedModel
|
|
6 |
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
|
7 |
from .configuration_llama3 import Llama3Config
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class Llama3Embedding(MllamaPreTrainedModel):
|
11 |
config_class = Llama3Config
|
12 |
base_model_prefix = "audio_model"
|
13 |
def __init__(self, config: Llama3Config):
|
14 |
super().__init__(config)
|
15 |
-
#assert config.audio_config.output_hidden_size * 2 == config.text_config.hidden_size
|
16 |
self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
|
17 |
-
assert config.audio_config.add_adapter ==
|
18 |
-
self.
|
|
|
19 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.text_config.hidden_size)), requires_grad=True)
|
20 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.text_config.hidden_size)), requires_grad=True)
|
21 |
self.text_config = config.text_config
|
@@ -30,6 +66,7 @@ class Llama3Embedding(MllamaPreTrainedModel):
|
|
30 |
return input_embeddings
|
31 |
bs, max_num_clip, l = audio_features.shape
|
32 |
audio_embeddings = self.audio_model(input_values=audio_features.view((bs*max_num_clip, l)))['last_hidden_state']
|
|
|
33 |
audio_embeddings = audio_embeddings.view((bs, max_num_clip, -1, self.start_of_audio.shape[-1]))
|
34 |
|
35 |
for i in range(bs):
|
|
|
6 |
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
|
7 |
from .configuration_llama3 import Llama3Config
|
8 |
|
9 |
+
class AudioAdapter(nn.Module):
|
10 |
+
def __init__(self, config: Wav2Vec2Config):
|
11 |
+
super().__init__()
|
12 |
+
# feature dim might need to be down-projected
|
13 |
+
if config.output_hidden_size != config.hidden_size:
|
14 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
15 |
+
else:
|
16 |
+
self.proj = None
|
17 |
+
self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
18 |
+
|
19 |
+
self.kernel_size = config.adapter_kernel_size
|
20 |
+
self.stride = config.adapter_stride
|
21 |
+
|
22 |
+
def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
|
23 |
+
if seq_lens is None:
|
24 |
+
return seq_lens
|
25 |
+
pad = self.stride // 2
|
26 |
+
seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
|
27 |
+
return seq_lens.floor()
|
28 |
+
|
29 |
+
def forward(self, hidden_states, attention_mask=None):
|
30 |
+
# down project hidden_states if necessary
|
31 |
+
if self.proj is not None:
|
32 |
+
hidden_states = self.proj(hidden_states)
|
33 |
+
|
34 |
+
sub_sampled_lengths = None
|
35 |
+
if attention_mask is not None:
|
36 |
+
sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
|
37 |
+
|
38 |
+
for layer in self.layers:
|
39 |
+
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
|
40 |
+
hidden_states = layer(
|
41 |
+
hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
|
42 |
+
)
|
43 |
+
|
44 |
+
return hidden_states
|
45 |
|
46 |
class Llama3Embedding(MllamaPreTrainedModel):
|
47 |
config_class = Llama3Config
|
48 |
base_model_prefix = "audio_model"
|
49 |
def __init__(self, config: Llama3Config):
|
50 |
super().__init__(config)
|
|
|
51 |
self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
|
52 |
+
assert config.audio_config.add_adapter == False
|
53 |
+
self.audio_encoder = Wav2Vec2Model(config.audio_config)
|
54 |
+
self.audio_adapter = AudioAdapter(config.audio_config)
|
55 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.text_config.hidden_size)), requires_grad=True)
|
56 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.text_config.hidden_size)), requires_grad=True)
|
57 |
self.text_config = config.text_config
|
|
|
66 |
return input_embeddings
|
67 |
bs, max_num_clip, l = audio_features.shape
|
68 |
audio_embeddings = self.audio_model(input_values=audio_features.view((bs*max_num_clip, l)))['last_hidden_state']
|
69 |
+
audio_embeddings = self.audio_adapter(audio_embeddings)
|
70 |
audio_embeddings = audio_embeddings.view((bs, max_num_clip, -1, self.start_of_audio.shape[-1]))
|
71 |
|
72 |
for i in range(bs):
|