AlexHung29629 commited on
Commit
99423c9
1 Parent(s): 159da21

Update mllama_audio_model.py

Browse files
Files changed (1) hide show
  1. mllama_audio_model.py +2 -2
mllama_audio_model.py CHANGED
@@ -2,7 +2,7 @@ from typing import Optional, Tuple, Union
2
  import torch
3
  from torch import nn
4
  from transformers.modeling_outputs import BaseModelOutput
5
- from transformers import Wav2Vec2Model, Wav2Vec2Config, MllamaPreTrainedModel
6
  from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
7
  from .configuration_llama3 import Llama3Config
8
 
@@ -14,7 +14,7 @@ class AudioAdapter(nn.Module):
14
  self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
15
  else:
16
  self.proj = None
17
- self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
18
 
19
  self.kernel_size = config.adapter_kernel_size
20
  self.stride = config.adapter_stride
 
2
  import torch
3
  from torch import nn
4
  from transformers.modeling_outputs import BaseModelOutput
5
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, MllamaPreTrainedModel, Wav2Vec2BertConfig
6
  from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
7
  from .configuration_llama3 import Llama3Config
8
 
 
14
  self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
15
  else:
16
  self.proj = None
17
+ self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(Wav2Vec2BertConfig(adapter_kernel_size=config.adapter_kernel_size, adapter_stride=config.adapter_stride)) for _ in range(config.num_adapter_layers))
18
 
19
  self.kernel_size = config.adapter_kernel_size
20
  self.stride = config.adapter_stride