tanthinhdt commited on
Commit
0bce006
·
verified ·
1 Parent(s): 0afb517

Upload feature extractor

Browse files
Files changed (7) hide show
  1. README.md +199 -0
  2. configuration.py +188 -0
  3. encoder.py +110 -0
  4. modelling.py +797 -0
  5. preprocessor_config.json +10 -0
  6. resnet.py +216 -0
  7. utils.py +187 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
configuration.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class AVHubertConfig(PretrainedConfig):
6
+ model_type = "av_hubert"
7
+
8
+ def __init__(
9
+ self,
10
+ label_rate: int = 25,
11
+ sample_rate: int = 25,
12
+ input_modality: str = "video",
13
+ extractor_mode: str = "default",
14
+ encoder_layers: int = 24,
15
+ encoder_embed_dim: int = 1024,
16
+ encoder_ffn_embed_dim: int = 4096,
17
+ encoder_attention_heads: int = 16,
18
+ activation_fn: str = "gelu",
19
+ dropout: float = 0.1,
20
+ attention_dropout: float = 0.1,
21
+ activation_dropout: float = 0.1,
22
+ encoder_layerdrop: float = 0.0,
23
+ dropout_input: float = 0.0,
24
+ dropout_features: float = 0.0,
25
+ final_dim: int = 256,
26
+ untie_final_proj: bool = False,
27
+ layer_norm_first: bool = False,
28
+ conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
29
+ conv_bias: bool = False,
30
+ logit_temp: float = 0.1,
31
+ target_glu: bool = False,
32
+ feature_grad_mult: float = 1.0,
33
+ mask_length_audio: int = 10,
34
+ mask_prob_audio: float = 0.65,
35
+ mask_length_image: int = 10,
36
+ mask_prob_image: float = 0.65,
37
+ mask_selection: str = "static",
38
+ mask_other: float = 0.0,
39
+ no_mask_overlap: bool = False,
40
+ mask_min_space: int = 1,
41
+ mask_channel_length: int = 64,
42
+ mask_channel_prob: float = 0.5,
43
+ mask_channel_selection: str = "static",
44
+ mask_channel_other: float = 0.0,
45
+ no_mask_channel_overlap: bool = False,
46
+ mask_channel_min_space: int = 1,
47
+ conv_pos: int = 128,
48
+ conv_pos_groups: int = 16,
49
+ latent_temp: Tuple[float, float, float] = (2.0, 0.5, 0.999995),
50
+ skip_masked: bool = False,
51
+ skip_nomask: bool = False,
52
+ resnet_relu_type: str = "prelu",
53
+ resnet_weights: str = None,
54
+ sim_type: str = "cosine",
55
+ sub_encoder_layers: int = 0,
56
+ audio_feat_dim: int = 104,
57
+ modality_dropout: float = 0.0,
58
+ audio_dropout: float = 0.0,
59
+ modality_fuse: str = "concat",
60
+ selection_type: str = "same_other_seq",
61
+ masking_type: str = "input",
62
+ decoder_embed_dim: int = 2560,
63
+ decoder_ffn_embed_dim: int = 3072,
64
+ decoder_layers: int = 6,
65
+ decoder_layerdrop: float = 0.0,
66
+ decoder_attention_heads: int = 4,
67
+ decoder_learned_pos: bool = False,
68
+ decoder_normalize_before: bool = False,
69
+ no_token_positional_embeddings: bool = False,
70
+ decoder_dropout: float = 0.1,
71
+ decoder_attention_dropout: float = 0.1,
72
+ decoder_activation_dropout: float = 0.0,
73
+ max_target_positions: int = 2048,
74
+ share_decoder_input_output_embed: bool = False,
75
+ no_scale_embedding: bool = True,
76
+ num_classes: int = 2004,
77
+ feature_ds_rate: int = 1,
78
+ **kwargs,
79
+ ) -> None:
80
+ super().__init__(**kwargs)
81
+
82
+ self.label_rate = label_rate
83
+ self.sample_rate = sample_rate
84
+ self.input_modality = input_modality
85
+ self.extractor_mode = extractor_mode
86
+ self.encoder_layers = encoder_layers
87
+ self.encoder_embed_dim = encoder_embed_dim
88
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
89
+ self.encoder_attention_heads = encoder_attention_heads
90
+ self.activation_fn = activation_fn
91
+ self.dropout = dropout
92
+ self.attention_dropout = attention_dropout
93
+ self.activation_dropout = activation_dropout
94
+ self.encoder_layerdrop = encoder_layerdrop
95
+ self.dropout_input = dropout_input
96
+ self.dropout_features = dropout_features
97
+ self.final_dim = final_dim
98
+ self.untie_final_proj = untie_final_proj
99
+ self.layer_norm_first = layer_norm_first
100
+ self.conv_feature_layers = conv_feature_layers
101
+ self.conv_bias = conv_bias
102
+ self.logit_temp = logit_temp
103
+ self.target_glu = target_glu
104
+ self.feature_grad_mult = feature_grad_mult
105
+ self.mask_length_audio = mask_length_audio
106
+ self.mask_prob_audio = mask_prob_audio
107
+ self.mask_length_image = mask_length_image
108
+ self.mask_prob_image = mask_prob_image
109
+ self.mask_selection = mask_selection
110
+ self.mask_other = mask_other
111
+ self.no_mask_overlap = no_mask_overlap
112
+ self.mask_min_space = mask_min_space
113
+ self.mask_channel_length = mask_channel_length
114
+ self.mask_channel_prob = mask_channel_prob
115
+ self.mask_channel_selection = mask_channel_selection
116
+ self.mask_channel_other = mask_channel_other
117
+ self.no_mask_channel_overlap = no_mask_channel_overlap
118
+ self.mask_channel_min_space = mask_channel_min_space
119
+ self.conv_pos = conv_pos
120
+ self.conv_pos_groups = conv_pos_groups
121
+ self.latent_temp = latent_temp
122
+ self.skip_masked = skip_masked
123
+ self.skip_nomask = skip_nomask
124
+ self.resnet_relu_type = resnet_relu_type
125
+ self.resnet_weights = resnet_weights
126
+ self.sim_type = sim_type
127
+ self.sub_encoder_layers = sub_encoder_layers
128
+ self.audio_feat_dim = audio_feat_dim
129
+ self.modality_dropout = modality_dropout
130
+ self.audio_dropout = audio_dropout
131
+ self.modality_fuse = modality_fuse
132
+ self.selection_type = selection_type
133
+ self.masking_type = masking_type
134
+ self.decoder_embed_dim = decoder_embed_dim
135
+ self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
136
+ self.decoder_layers = decoder_layers
137
+ self.decoder_layerdrop = decoder_layerdrop
138
+ self.decoder_attention_heads = decoder_attention_heads
139
+ self.decoder_learned_pos = decoder_learned_pos
140
+ self.decoder_normalize_before = decoder_normalize_before
141
+ self.no_token_positional_embeddings = no_token_positional_embeddings
142
+ self.decoder_dropout = decoder_dropout
143
+ self.decoder_attention_dropout = decoder_attention_dropout
144
+ self.decoder_activation_dropout = decoder_activation_dropout
145
+ self.max_target_positions = max_target_positions
146
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
147
+ self.no_scale_embedding = no_scale_embedding
148
+ self.num_classes = num_classes
149
+ self.feature_ds_rate = feature_ds_rate
150
+
151
+
152
+ class AVSPLLMConfig(AVHubertConfig):
153
+ model_type = "avsp_llm"
154
+
155
+ def __init__(
156
+ self,
157
+ llm_ckpt_path: str = "vilm/vinallama-2.7b",
158
+ cache_dir: str = "models/huggingface",
159
+ no_pretrained_weights: bool = False,
160
+ final_dropout: float = 0.1,
161
+ apply_mask: bool = False,
162
+ mask_length: int = 10,
163
+ mask_prob: float = 0.5,
164
+ masking_updates: int = 0,
165
+ layerdrop: float = 0.0,
166
+ normalize: bool = False,
167
+ data: str = None,
168
+ w2v_args: dict = None,
169
+ freeze_finetune_updates: int = 0,
170
+ km_path: str = "model.km",
171
+ **kwargs,
172
+ ) -> None:
173
+ super().__init__(**kwargs)
174
+
175
+ self.llm_ckpt_path = llm_ckpt_path
176
+ self.cache_dir = cache_dir
177
+ self.no_pretrained_weights = no_pretrained_weights
178
+ self.final_dropout = final_dropout
179
+ self.apply_mask = apply_mask
180
+ self.mask_length = mask_length
181
+ self.mask_prob = mask_prob
182
+ self.masking_updates = masking_updates
183
+ self.layerdrop = layerdrop
184
+ self.normalize = normalize
185
+ self.data = data
186
+ self.w2v_args = w2v_args
187
+ self.freeze_finetune_updates = freeze_finetune_updates
188
+ self.km_path = km_path
encoder.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import List, Optional, Tuple
7
+ from .configuration import AVHubertConfig
8
+ from fairseq.utils import index_put
9
+ from fairseq.modules import LayerNorm, SamePad
10
+ from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
11
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
12
+
13
+
14
+ class TransformerEncoder(nn.Module):
15
+ def __init__(self, config: AVHubertConfig) -> None:
16
+ super().__init__()
17
+
18
+ self.dropout = config.dropout
19
+ self.embedding_dim = config.encoder_embed_dim
20
+
21
+ self.pos_conv = nn.Conv1d(
22
+ self.embedding_dim,
23
+ self.embedding_dim,
24
+ kernel_size=config.conv_pos,
25
+ padding=config.conv_pos // 2,
26
+ groups=config.conv_pos_groups,
27
+ )
28
+ dropout = 0
29
+ std = math.sqrt((4 * (1.0 - dropout)) / (config.conv_pos * self.embedding_dim))
30
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
31
+ nn.init.constant_(self.pos_conv.bias, 0)
32
+
33
+ self.pos_conv = nn.utils.weight_norm(
34
+ self.pos_conv, name="weight", dim=2
35
+ )
36
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(config.conv_pos), nn.GELU())
37
+
38
+ self.layers = nn.ModuleList(
39
+ [
40
+ TransformerSentenceEncoderLayer(
41
+ embedding_dim=self.embedding_dim,
42
+ ffn_embedding_dim=config.encoder_ffn_embed_dim,
43
+ num_attention_heads=config.encoder_attention_heads,
44
+ dropout=self.dropout,
45
+ attention_dropout=config.attention_dropout,
46
+ activation_dropout=config.activation_dropout,
47
+ activation_fn=config.activation_fn,
48
+ layer_norm_first=config.layer_norm_first,
49
+ )
50
+ for _ in range(config.encoder_layers)
51
+ ]
52
+ )
53
+
54
+ self.layer_norm_first = config.layer_norm_first
55
+ self.layer_norm = LayerNorm(self.embedding_dim)
56
+ self.layerdrop = config.encoder_layerdrop
57
+
58
+ self.apply(init_bert_params)
59
+
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ padding_mask: Optional[torch.Tensor] = None,
64
+ layer: Optional[int] = None,
65
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
66
+ x, layer_results = self.extract_features(x, padding_mask, layer)
67
+ if self.layer_norm_first and layer is None:
68
+ x = self.layer_norm(x)
69
+ return x, layer_results
70
+
71
+ def extract_features(
72
+ self,
73
+ x: torch.Tensor,
74
+ padding_mask: Optional[torch.Tensor] = None,
75
+ tgt_layer: Optional[int] = None,
76
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
77
+ if padding_mask is not None:
78
+ x = index_put(x, padding_mask, 0)
79
+
80
+ x_conv = self.pos_conv(x.transpose(1, 2))
81
+ x_conv = x_conv.transpose(1, 2)
82
+ x = x + x_conv
83
+
84
+ if not self.layer_norm_first:
85
+ x = self.layer_norm(x)
86
+
87
+ x = F.dropout(x, p=self.dropout, training=self.training)
88
+
89
+ # B x T x C -> T x B x C
90
+ x = x.transpose(0, 1)
91
+
92
+ layer_results = []
93
+ r = None
94
+ for i, layer in enumerate(self.layers):
95
+ dropout_probability = np.random.random()
96
+ if not self.training or (dropout_probability > self.layerdrop):
97
+ x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
98
+ if tgt_layer is not None:
99
+ layer_results.append((x, z))
100
+ if i == tgt_layer:
101
+ r = x
102
+ break
103
+
104
+ if r is not None:
105
+ x = r
106
+
107
+ # T x B x C -> B x T x C
108
+ x = x.transpose(0, 1)
109
+
110
+ return x, layer_results
modelling.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import contextlib
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from pathlib import Path
7
+ from .resnet import ResNetEncoder
8
+ from .encoder import TransformerEncoder
9
+ from .configuration import AVHubertConfig, AVSPLLMConfig
10
+ from .utils import compute_mask_indices, load_kmeans_model
11
+ from typing import Optional, Tuple, List, Dict, Any
12
+ from peft import get_peft_model, LoraConfig
13
+ from fairseq.modules import GradMultiply, LayerNorm
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ from transformers import (
16
+ FeatureExtractionMixin,
17
+ PreTrainedModel,
18
+ BitsAndBytesConfig,
19
+ AutoModelForCausalLM,
20
+ GenerationConfig,
21
+ )
22
+
23
+
24
+ logging.root.setLevel(logging.WARNING)
25
+
26
+
27
+ class AVHubertFeatureExtractor(FeatureExtractionMixin):
28
+ def __init__(self, config: AVHubertConfig = AVHubertConfig(), **kwargs) -> None:
29
+ super().__init__(**kwargs)
30
+ self.audio_feat_dim = config.audio_feat_dim
31
+
32
+ self.size = 88
33
+ self.num_frames = 76
34
+ self.num_channels = 1
35
+
36
+
37
+ class AVSPLLMFeatureExtractor(AVHubertFeatureExtractor):
38
+ def __init__(self, config: AVSPLLMConfig = AVSPLLMConfig(), **kwargs) -> None:
39
+ super().__init__(config, **kwargs)
40
+
41
+
42
+ class AVHubertVideoFeatureEncoder(nn.Module):
43
+ def __init__(self, config: AVHubertConfig) -> None:
44
+ super().__init__()
45
+ self.resnet = ResNetEncoder(relu_type=config.resnet_relu_type)
46
+ self.proj = nn.Linear(self.resnet.backend_out, config.encoder_embed_dim)
47
+ self.encoder = (
48
+ TransformerEncoder(config)
49
+ if config.sub_encoder_layers > 0
50
+ else None
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.resnet(x)
55
+ x = self.proj(x.transpose(1, 2))
56
+ if self.encoder is not None:
57
+ x = self.encoder(x)[0].transpose(1, 2)
58
+ else:
59
+ x = x.transpose(1, 2)
60
+ return x
61
+
62
+
63
+ class AVHubertAudioFeatureEncoder(nn.Module):
64
+ def __init__(self, config: AVHubertConfig) -> None:
65
+ super().__init__()
66
+ self.proj = nn.Linear(config.audio_feat_dim, config.encoder_embed_dim)
67
+ self.encoder = (
68
+ TransformerEncoder(config)
69
+ if config.sub_encoder_layers > 0
70
+ else None
71
+ )
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ x = self.proj(x.transpose(1, 2))
75
+ if self.encoder is not None:
76
+ x = self.encoder(x)[0].transpose(1, 2)
77
+ else:
78
+ x = x.transpose(1, 2)
79
+ return x
80
+
81
+
82
+ class AVHubertModel(PreTrainedModel):
83
+ config_class = AVHubertConfig
84
+
85
+ def __init__(
86
+ self,
87
+ config: AVHubertConfig = AVHubertConfig(),
88
+ dictionaries: List = [None],
89
+ ) -> None:
90
+ super().__init__(config=config)
91
+ label_rate = config.label_rate
92
+ feature_ds_rate = config.feature_ds_rate
93
+ sample_rate = config.sample_rate
94
+ self.feat2tar_ration = label_rate * feature_ds_rate / sample_rate
95
+
96
+ self.feature_extractor_video = AVHubertVideoFeatureEncoder(config)
97
+ self.feature_extractor_audio = AVHubertAudioFeatureEncoder(config)
98
+
99
+ if config.modality_fuse == "concat":
100
+ self.encoder_embed_dim = config.encoder_embed_dim * 2
101
+ elif config.modality_fuse == "add":
102
+ self.encoder_embed_dim = config.encoder_embed_dim
103
+
104
+ self.post_extract_proj = (
105
+ nn.Linear(self.encoder_embed_dim, config.encoder_embed_dim)
106
+ if self.encoder_embed_dim != config.encoder_embed_dim
107
+ else None
108
+ )
109
+
110
+ self.dropout_input = nn.Dropout(config.dropout_input)
111
+ self.dropout_features = nn.Dropout(config.dropout_features)
112
+
113
+ if self.config.final_dim > 0:
114
+ final_dim = config.final_dim
115
+ else:
116
+ final_dim = config.encoder_embed_dim
117
+
118
+ self.mask_emb = nn.Parameter(
119
+ torch.FloatTensor(config.audio_feat_dim).uniform_()
120
+ if config.masking_type == "input"
121
+ else torch.FloatTensor(config.encoder_embed_dim).uniform_()
122
+ )
123
+
124
+ self.encoder = TransformerEncoder(self.config)
125
+ self.layer_norm = LayerNorm(self.encoder_embed_dim)
126
+
127
+ self.target_glu = None
128
+ if config.target_glu:
129
+ self.target_glu = nn.Sequential(
130
+ nn.Linear(config.final_dim, config.final_dim * 2),
131
+ nn.GLU(),
132
+ )
133
+
134
+ if config.untie_final_proj:
135
+ self.final_proj = nn.Linear(
136
+ config.encoder_embed_dim,
137
+ final_dim * len(dictionaries),
138
+ )
139
+ else:
140
+ self.final_proj = nn.Linear(config.encoder_embed_dim, final_dim)
141
+
142
+ # modules below are not needed during fine-tuning
143
+ if any([d is None for d in dictionaries]):
144
+ self.num_classes = config.num_classes
145
+ else:
146
+ self.num_classes = sum([len(d) for d in dictionaries])
147
+ self.label_embs_concat = nn.Parameter(
148
+ torch.FloatTensor(self.num_classes, final_dim)
149
+ )
150
+ nn.init.uniform_(self.label_embs_concat)
151
+
152
+ def apply_input_mask(
153
+ self,
154
+ x: torch.Tensor,
155
+ padding_mask: torch.Tensor,
156
+ target_list: List[torch.Tensor],
157
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ B, C, T = x.shape[:3]
159
+ is_audio = True if len(x.shape) == 3 else False
160
+
161
+ if is_audio:
162
+ mask_prob = self.config.mask_prob_audio
163
+ mask_length = self.config.mask_length_audio
164
+ else:
165
+ mask_prob = self.config.mask_prob_image
166
+ mask_length = self.config.mask_length_image
167
+
168
+ if mask_prob > 0:
169
+ mask_indices, starts, ends, batch_indexes = compute_mask_indices(
170
+ (B, T),
171
+ padding_mask,
172
+ mask_prob,
173
+ mask_length,
174
+ self.config.mask_selection,
175
+ self.config.mask_other,
176
+ min_masks=2,
177
+ no_overlap=self.config.no_mask_overlap,
178
+ min_space=self.config.mask_min_space,
179
+ )
180
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
181
+ x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
182
+ if B == 1:
183
+ x[mask_indices] = 0
184
+ elif is_audio:
185
+ x[mask_indices] = self.mask_emb
186
+ elif self.config.selection_type == "same_other_seq":
187
+ perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
188
+ x_perm = x[perm]
189
+ x[mask_indices] = x_perm[mask_indices]
190
+ elif self.config.selection_type == "same_seq":
191
+ batch_indexes_, other_indexes = [], []
192
+ for batch_index, start, end in zip(batch_indexes, starts, ends):
193
+ length = end - start
194
+ other_start = np.setdiff1d(
195
+ np.arange(T), np.arange(max(0, start - length), end)
196
+ )
197
+ if len(other_start) > 0:
198
+ other_start = np.random.choice(other_start, size=1)
199
+ else:
200
+ other_start = 0
201
+ other_end = other_start + length
202
+ other_indexes.append(
203
+ np.arange(other_start, other_end).clip(max=T - 1)
204
+ )
205
+ batch_indexes_.append(
206
+ np.zeros([length], dtype=np.int64) + batch_index
207
+ )
208
+ batch_indexes = np.concatenate(batch_indexes_)
209
+ other_indexes = np.concatenate(other_indexes)
210
+ x[mask_indices] = x[batch_indexes, other_indexes]
211
+ x = x.transpose(1, 2).contiguous()
212
+ else:
213
+ mask_indices = None
214
+
215
+ if self.config.mask_channel_prob > 0:
216
+ logging.warn("No mask channel prob for input masking")
217
+ return x, mask_indices
218
+
219
+ def apply_feature_mask(
220
+ self,
221
+ x: torch.Tensor,
222
+ padding_mask: torch.Tensor,
223
+ target_list: List[torch.Tensor],
224
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
225
+ B, T, C = x.shape
226
+ assert all((
227
+ self.config.mask_prob_audio == self.config.mask_prob_image,
228
+ self.config.mask_length_audio == self.config.mask_length_image,
229
+ )), "masking prob/length for image/audio be same for feature masking"
230
+
231
+ mask_prob = self.config.mask_prob_audio
232
+ mask_length = self.config.mask_length_image
233
+ if mask_prob > 0:
234
+ mask_indices, _, _, _ = compute_mask_indices(
235
+ (B, T),
236
+ padding_mask,
237
+ mask_prob,
238
+ mask_length,
239
+ self.config.mask_selection,
240
+ self.config.mask_other,
241
+ min_masks=2,
242
+ no_overlap=self.config.no_mask_overlap,
243
+ min_space=self.config.mask_min_space,
244
+ )
245
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
246
+ x[mask_indices] = self.mask_emb
247
+ else:
248
+ mask_indices = None
249
+
250
+ if self.config.mask_channel_prob > 0:
251
+ mask_channel_indices, _, _, _ = compute_mask_indices(
252
+ (B, C),
253
+ None,
254
+ self.config.mask_channel_prob,
255
+ self.config.mask_channel_length,
256
+ self.config.mask_channel_selection,
257
+ self.config.mask_channel_other,
258
+ no_overlap=self.config.no_mask_channel_overlap,
259
+ min_space=self.config.mask_channel_min_space,
260
+ )
261
+ mask_channel_indices = (
262
+ torch.from_numpy(mask_channel_indices)
263
+ .to(x.device)
264
+ .unsqueeze(1)
265
+ .expand(-1, T, -1)
266
+ )
267
+ x[mask_channel_indices] = 0
268
+
269
+ return x, mask_indices
270
+
271
+ def forward_features(
272
+ self,
273
+ source: Dict[str, torch.Tensor],
274
+ modality: str,
275
+ ) -> torch.Tensor:
276
+ extractor = eval(f"self.feature_extractor_{modality}")
277
+ if self.config.feature_grad_mult > 0:
278
+ features = extractor(source)
279
+ if self.config.feature_grad_mult != 1.0:
280
+ features = GradMultiply.apply(features, self.config.feature_grad_mult)
281
+ else:
282
+ with torch.no_grad():
283
+ features = extractor(source)
284
+ return features
285
+
286
+ def forward_targets(
287
+ self,
288
+ features: torch.Tensor,
289
+ mask_indices: torch.Tensor,
290
+ target_list: List[torch.Tensor],
291
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
292
+ # Trim features to ensure labels exist and then get aligned labels
293
+ feat_tsz = features.size(2)
294
+ targ_tsz = min([t.size(1) for t in target_list])
295
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
296
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
297
+ features = features[..., :feat_tsz]
298
+ if mask_indices is not None:
299
+ mask_indices = mask_indices[..., :feat_tsz]
300
+ target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
301
+ target_list = [t[:, target_inds.long()] for t in target_list]
302
+ return features, mask_indices, target_list
303
+
304
+ def forward_padding_mask(
305
+ self,
306
+ features: torch.Tensor,
307
+ padding_mask: torch.Tensor,
308
+ ) -> torch.Tensor:
309
+ extra = padding_mask.size(1) % features.size(1)
310
+ if extra > 0:
311
+ padding_mask = padding_mask[:, :-extra]
312
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
313
+ padding_mask = padding_mask.all(-1)
314
+ return padding_mask
315
+
316
+ def compute_logits(self, feats: torch.Tensor, emb_mat: torch.Tensor) -> torch.Tensor:
317
+ # feats: [B, T, F], emb_mat: [V, F]
318
+ if self.config.sim_type == "dot":
319
+ logits = torch.matmul(feats, emb_mat.transpose(0, 1))
320
+ elif self.config.sim_type == "cosine":
321
+ batch_size, timesteps, emb_dim = feats.size()
322
+ feats_ = feats.view(-1, emb_dim)
323
+ # [B*T, V]
324
+ nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1)
325
+ # [B*T, V]
326
+ denom = (
327
+ (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1)
328
+ * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0)
329
+ )
330
+ logits = (nom / denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
331
+ else:
332
+ raise NotImplementedError
333
+ logits = logits / self.config.logit_temp
334
+ return logits
335
+
336
+ def forward(
337
+ self,
338
+ source: Dict[str, torch.Tensor],
339
+ target_list: Optional[List[torch.Tensor]] = None,
340
+ padding_mask: Optional[torch.Tensor] = None,
341
+ mask: bool = True,
342
+ features_only: bool = False,
343
+ output_layer: Optional[int] = None,
344
+ ) -> Dict[str, torch.Tensor]:
345
+ """output layer is 1-based"""
346
+ src_audio, src_video = source["audio"], source["video"]
347
+ if mask and self.masking_type == "input":
348
+ src_video, mask_indices_video = self.apply_input_mask(
349
+ src_video, padding_mask, target_list
350
+ )
351
+ src_audio, mask_indices_audio = self.apply_input_mask(
352
+ src_audio, padding_mask, target_list
353
+ )
354
+ mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
355
+ else:
356
+ src_audio, src_video, mask_indices = src_audio, src_video, None
357
+
358
+ # [B, F, T]
359
+ features_audio = self.forward_features(src_audio, modality="audio")
360
+ features_video = self.forward_features(src_video, modality="video")
361
+
362
+ if self.training:
363
+ modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
364
+ if modality_drop_prob < self.config.modality_dropout:
365
+ if audio_drop_prob < self.config.audio_dropout:
366
+ features_audio = 0 * features_audio
367
+ else:
368
+ features_video = 0 * features_video
369
+
370
+ if self.config.modality_fuse == "concat":
371
+ features = torch.cat([features_audio, features_video], dim=1)
372
+ elif self.config.modality_fuse == "add":
373
+ features = features_audio + features_video
374
+
375
+ if target_list is not None:
376
+ features, mask_indices, target_list = self.forward_targets(
377
+ features, mask_indices, target_list
378
+ )
379
+
380
+ features_pen = features.float().pow(2).mean()
381
+
382
+ features = features.transpose(1, 2)
383
+ features = self.layer_norm(features)
384
+
385
+ if padding_mask is not None:
386
+ padding_mask = self.forward_padding_mask(features, padding_mask)
387
+
388
+ if self.post_extract_proj is not None:
389
+ features = self.post_extract_proj(features)
390
+
391
+ features = self.dropout_input(features)
392
+ if self.config.masking_type == "feature" and mask:
393
+ x, mask_indices = self.apply_feature_mask(
394
+ features, padding_mask, target_list
395
+ )
396
+ else:
397
+ x = features
398
+
399
+ # feature: (B, T, D), float
400
+ # target: (B, T), long
401
+ # x: (B, T, D), float
402
+ # padding_mask: (B, T), bool
403
+ # mask_indices: (B, T), bool
404
+ x, _ = self.encoder(
405
+ x,
406
+ padding_mask=padding_mask,
407
+ layer=None if output_layer is None else output_layer - 1,
408
+ )
409
+
410
+ if features_only:
411
+ return {"x": x, "padding_mask": padding_mask, "features": features}
412
+
413
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
414
+ proj_x = self.final_proj(x)
415
+ if self.config.untie_final_proj:
416
+ proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
417
+ else:
418
+ proj_x_list = [proj_x for _ in self.num_classes]
419
+
420
+ # [[B*T, V]]
421
+ logit_list = [
422
+ self.compute_logits(proj, emb).view(-1, num_class)
423
+ for proj, emb, num_class in zip(
424
+ proj_x_list, label_embs_list, self.num_classes
425
+ )
426
+ ]
427
+
428
+ mask = torch.logical_and(mask_indices, ~padding_mask).view(-1)
429
+ unmask = torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
430
+ logit_m_list = [logit[mask] for logit in logit_list]
431
+ logit_u_list = [logit[unmask] for logit in logit_list]
432
+ target_m_list = [target.view(-1)[mask].long() for target in target_list]
433
+ target_u_list = [target.view(-1)[unmask].long() for target in target_list]
434
+
435
+ return {
436
+ "logit_m_list": logit_m_list,
437
+ "logit_u_list": logit_u_list,
438
+ "target_m_list": target_m_list,
439
+ "target_u_list": target_u_list,
440
+ "padding_mask": padding_mask,
441
+ "features_pen": features_pen,
442
+ }
443
+
444
+ def extract_features(
445
+ self,
446
+ source: Dict[str, torch.Tensor],
447
+ padding_mask: Optional[torch.Tensor] = None,
448
+ mask: bool = False,
449
+ ret_conv: bool = False,
450
+ output_layer: Optional[int] = None,
451
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
452
+ res = self.forward(
453
+ source,
454
+ padding_mask=padding_mask,
455
+ mask=mask,
456
+ features_only=True,
457
+ output_layer=output_layer,
458
+ )
459
+ feature = res["features"] if ret_conv else res["x"]
460
+ return feature, res["padding_mask"]
461
+
462
+ def extract_units(
463
+ self,
464
+ source: Dict[str, torch.Tensor],
465
+ padding_mask: torch.Tensor = None,
466
+ mask: bool = False,
467
+ ret_conv: bool = False,
468
+ output_layer: Optional[int] = None,
469
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
470
+ res = self.forward(
471
+ source,
472
+ padding_mask=padding_mask,
473
+ mask=mask,
474
+ features_only=True,
475
+ output_layer=None,
476
+ )
477
+
478
+ feature = res["features"] if ret_conv else res["x"]
479
+ proj_x = self.final_proj(feature)
480
+ # B T
481
+ units = (
482
+ torch
483
+ .matmul(proj_x, self.label_embs_concat.transpose(0, 1))
484
+ .argmax(dim=-1)
485
+ )
486
+ return units
487
+
488
+ def extract_finetune(
489
+ self,
490
+ source: Dict[str, torch.Tensor],
491
+ padding_mask: torch.Tensor = None,
492
+ mask: bool = False,
493
+ ret_conv: bool = False,
494
+ output_layer: Optional[int] = None,
495
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
496
+ src_audio, src_video = source["audio"], source["video"]
497
+ if mask and self.config.masking_type == "input":
498
+ src_video, _ = self.apply_input_mask(
499
+ src_video, padding_mask, target_list=None
500
+ )
501
+ src_audio, _ = self.apply_input_mask(
502
+ src_audio, padding_mask, target_list=None
503
+ )
504
+ else:
505
+ src_audio, src_video, _ = src_audio, src_video, None
506
+
507
+ # features: [B, F, T]
508
+ if src_audio is not None and src_video is None:
509
+ features_audio = self.forward_features(
510
+ src_audio, modality="audio"
511
+ )
512
+ features_video = features_audio.new_zeros(
513
+ features_audio.size(0),
514
+ self.encoder_embed_dim,
515
+ features_audio.size(-1)
516
+ )
517
+ elif src_audio is None and src_video is not None:
518
+ features_video = self.forward_features(src_video, modality="video")
519
+ features_audio = features_video.new_zeros(
520
+ features_video.size(0),
521
+ self.encoder_embed_dim,
522
+ features_video.size(-1)
523
+ )
524
+ elif src_audio is not None and src_video is not None:
525
+ features_video = self.forward_features(src_video, modality="video")
526
+ features_audio = self.forward_features(
527
+ src_audio, modality="audio"
528
+ )
529
+
530
+ if self.config.modality_fuse == "concat":
531
+ features = torch.cat([features_audio, features_video], dim=1)
532
+ elif self.config.modality_fuse == "add":
533
+ features = features_audio + features_video
534
+
535
+ features = features.transpose(1, 2)
536
+ features = self.layer_norm(features)
537
+ unmasked_features = features.clone()
538
+
539
+ if padding_mask is not None:
540
+ padding_mask = self.forward_padding_mask(features, padding_mask)
541
+
542
+ if self.post_extract_proj is not None:
543
+ features = self.post_extract_proj(features)
544
+
545
+ features = self.dropout_input(features)
546
+ unmasked_features = self.dropout_features(unmasked_features)
547
+
548
+ # feature: (B, T, D), float
549
+ # target: (B, T), long
550
+ # x: (B, T, D), float
551
+ # padding_mask: (B, T), bool
552
+ # mask_indices: (B, T), bool
553
+ x, _ = self.encoder(
554
+ features,
555
+ padding_mask=padding_mask,
556
+ layer=None if output_layer is None else output_layer - 1,
557
+ )
558
+
559
+ return x, padding_mask
560
+
561
+ def get_extra_losses(
562
+ self,
563
+ net_output: Dict[str, torch.Tensor],
564
+ ) -> Tuple[List[torch.Tensor], List[str]]:
565
+ extra_losses = []
566
+ names = []
567
+ if "features_pen" in net_output:
568
+ extra_losses.append(net_output["features_pen"])
569
+ names.append("features_pen")
570
+
571
+ return extra_losses, names
572
+
573
+ def remove_pretraining_modules(self) -> None:
574
+ self.target_glu = None
575
+ self.final_proj = None
576
+
577
+ def compute_nce(
578
+ self,
579
+ x: torch.Tensor,
580
+ pos: torch.Tensor,
581
+ negs: torch.Tensor,
582
+ ) -> torch.Tensor:
583
+ neg_is_pos = (pos == negs).all(-1)
584
+ pos = pos.unsqueeze(0)
585
+ targets = torch.cat([pos, negs], dim=0)
586
+
587
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
588
+ logits /= self.config.logit_temp
589
+ if neg_is_pos.any():
590
+ logits[1:][neg_is_pos] = float("-inf")
591
+ logits = logits.transpose(0, 1) # (num_x, num_cls+1)
592
+ return logits
593
+
594
+
595
+ class HubertEncoderWrapper(nn.Module):
596
+ def __init__(
597
+ self,
598
+ config: AVHubertConfig,
599
+ dictionaries: List = [None],
600
+ ) -> None:
601
+ super().__init__()
602
+ self.w2v_model = AVHubertModel(config, dictionaries)
603
+
604
+ def forward(
605
+ self,
606
+ source: Dict[str, torch.Tensor],
607
+ padding_mask: torch.Tensor,
608
+ **kwargs,
609
+ ) -> Dict[str, torch.Tensor]:
610
+ w2v_args = {
611
+ "source": source,
612
+ "padding_mask": padding_mask,
613
+ }
614
+ x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
615
+ return {
616
+ "encoder_out": x, # T x B x C
617
+ "encoder_padding_mask": padding_mask, # B x T
618
+ "padding_mask": padding_mask,
619
+ }
620
+
621
+ def reorder_encoder_out(
622
+ self,
623
+ encoder_out: Dict[str, torch.Tensor],
624
+ new_order: torch.Tensor,
625
+ ) -> Dict[str, torch.Tensor]:
626
+ if encoder_out["encoder_out"] is not None:
627
+ encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
628
+ 1, new_order
629
+ )
630
+ if encoder_out["encoder_padding_mask"] is not None:
631
+ encoder_out["encoder_padding_mask"] = encoder_out[
632
+ "encoder_padding_mask"
633
+ ].index_select(0, new_order)
634
+ if encoder_out["padding_mask"] is not None:
635
+ encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
636
+ 0, new_order
637
+ )
638
+ return encoder_out
639
+
640
+
641
+ class AVSPLLMModel(PreTrainedModel):
642
+ config_class = AVSPLLMConfig
643
+
644
+ def __init__(
645
+ self,
646
+ config: AVSPLLMConfig = AVSPLLMConfig(),
647
+ dictionaries: List = [None],
648
+ ) -> None:
649
+ super().__init__(config=config)
650
+ current_dir = Path(__file__).resolve().parent
651
+ self.km_path = current_dir / config.km_path
652
+ if not self.km_path.is_file():
653
+ repo_id = self.config._name_or_path
654
+ self.km_path = f"{repo_id}/model.km"
655
+ self.km_path = str(self.km_path)
656
+ self.C, self.Cnorm = load_kmeans_model(self.km_path)
657
+
658
+ self.encoder = HubertEncoderWrapper(config, dictionaries)
659
+ self.encoder.w2v_model.remove_pretraining_modules()
660
+
661
+ self.avfeat_to_llm = nn.Linear(
662
+ config.encoder_embed_dim, config.decoder_embed_dim
663
+ )
664
+
665
+ bnb_config = BitsAndBytesConfig(
666
+ load_in_4bit=True,
667
+ bnb_4bit_use_double_quant=True,
668
+ bnb_4bit_quant_type="nf4",
669
+ bnb_4bit_compute_dtype=torch.bfloat16,
670
+ )
671
+ decoder_4bit = AutoModelForCausalLM.from_pretrained(
672
+ config.llm_ckpt_path,
673
+ quantization_config=bnb_config,
674
+ cache_dir=config.cache_dir,
675
+ trust_remote_code=True,
676
+ )
677
+ lora_config = LoraConfig(
678
+ r=16,
679
+ lora_alpha=32,
680
+ target_modules=["q_proj", "v_proj", "k_proj"],
681
+ lora_dropout=0.05,
682
+ bias="none",
683
+ task_type="CAUSAL_LM",
684
+ )
685
+ self.decoder = get_peft_model(decoder_4bit, lora_config)
686
+ self.decoder.print_trainable_parameters()
687
+
688
+ def apply_kmeans(self, feat: torch.Tensor) -> torch.Tensor:
689
+ dist = (
690
+ feat.squeeze(0).pow(2).sum(1, keepdim=True)
691
+ - 2 * torch.matmul(feat.squeeze(0), self.C)
692
+ + self.Cnorm
693
+ )
694
+ cluster_counts = dist.argmin(dim=1)
695
+
696
+ current_counts = 1
697
+ counts = []
698
+ for i in range(1, len(cluster_counts)):
699
+ if cluster_counts[i] == cluster_counts[i - 1]:
700
+ current_counts += 1
701
+ else:
702
+ counts.append(current_counts)
703
+ current_counts = 1
704
+ counts.append(current_counts)
705
+
706
+ return torch.tensor(counts)
707
+
708
+ def deduplicate(
709
+ self,
710
+ feat: torch.Tensor,
711
+ cluster_counts: torch.Tensor,
712
+ ) -> torch.Tensor:
713
+ results_tensor = []
714
+ start_idx = 0
715
+ for clutser_num in cluster_counts:
716
+ end_idx = start_idx + clutser_num
717
+ slice = feat[:, start_idx:end_idx, :]
718
+ mean_tensor = torch.mean(slice, dim=1, keepdim=True)
719
+ results_tensor.append(mean_tensor)
720
+ start_idx = end_idx
721
+
722
+ assert cluster_counts.sum().item() == feat.size()[1], \
723
+ f"{cluster_counts.sum().item()} != {feat.size()[1]}"
724
+
725
+ return torch.cat(results_tensor, dim=1)
726
+
727
+ def embed(
728
+ self,
729
+ source: Dict[str, torch.Tensor],
730
+ padding_mask: torch.Tensor,
731
+ target_list: torch.Tensor = None,
732
+ **kwargs,
733
+ ) -> torch.Tensor:
734
+ ft = self.config.freeze_finetune_updates <= kwargs.get("num_updates", -1)
735
+ with torch.no_grad() if not ft else contextlib.ExitStack():
736
+ output = self.encoder(source, padding_mask, **kwargs)
737
+
738
+ cluster_counts = self.apply_kmeans(output["encoder_out"])
739
+
740
+ output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"])
741
+
742
+ reduced_enc_out = self.deduplicate(output["encoder_out"], cluster_counts)
743
+ reduced_enc_out = reduced_enc_out.to(self.decoder.device)
744
+ B, T, D = reduced_enc_out.size()
745
+
746
+ instruction = source["text"]
747
+ instruction_embedding = self.decoder.model.model.embed_tokens(instruction)
748
+
749
+ llm_input = torch.cat((instruction_embedding, reduced_enc_out), dim=1)
750
+
751
+ if target_list is None:
752
+ return llm_input, None
753
+
754
+ labels = target_list.clone()
755
+ labels_embedding = self.decoder.model.model.embed_tokens(labels)
756
+
757
+ llm_input = torch.cat((llm_input, labels_embedding), dim=1)
758
+
759
+ llm_labels = labels.clone()
760
+ llm_labels[llm_labels == 0] = -100
761
+
762
+ _, instruction_embedding_t, _ = instruction_embedding.size()
763
+ target_ids = (
764
+ torch.full((B, T + instruction_embedding_t), -100).long().to(labels.device)
765
+ )
766
+ llm_labels = torch.cat((target_ids, llm_labels), dim=1)
767
+
768
+ return llm_input, llm_labels
769
+
770
+ def forward(
771
+ self,
772
+ source: Dict[str, torch.Tensor],
773
+ padding_mask: torch.Tensor,
774
+ target_list: torch.Tensor = None,
775
+ **kwargs,
776
+ ) -> CausalLMOutputWithPast:
777
+ llm_input, llm_labels = self.embed(
778
+ source, padding_mask, target_list, **kwargs
779
+ )
780
+ return self.decoder(
781
+ inputs_embeds=llm_input.to(torch.float16), labels=llm_labels, return_dict=True
782
+ )
783
+
784
+ @torch.no_grad()
785
+ def generate(
786
+ self,
787
+ inputs: Optional[Dict[str, torch.Tensor]] = None,
788
+ generation_config: Optional[GenerationConfig] = None,
789
+ **kwargs,
790
+ ) -> Any:
791
+ llm_input, _ = self.embed(**inputs, **kwargs)
792
+ self.decoder.config.use_cache = True
793
+ return self.decoder.generate(
794
+ inputs_embeds=llm_input,
795
+ **generation_config,
796
+ **kwargs,
797
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_feat_dim": 104,
3
+ "auto_map": {
4
+ "AutoFeatureExtractor": "modelling.AVSPLLMFeatureExtractor"
5
+ },
6
+ "feature_extractor_type": "AVSPLLMFeatureExtractor",
7
+ "num_channels": 1,
8
+ "num_frames": 76,
9
+ "size": 88
10
+ }
resnet.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from collections import OrderedDict
5
+
6
+
7
+ def conv3x3(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
8
+ return nn.Conv2d(
9
+ in_channels=in_channels,
10
+ out_channels=out_channels,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=1,
14
+ bias=False
15
+ )
16
+
17
+
18
+ def downsample_basic_block(
19
+ in_channels: int,
20
+ out_channels: int,
21
+ stride: int,
22
+ ) -> nn.Sequential:
23
+ return nn.Sequential(
24
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
25
+ nn.BatchNorm2d(out_channels),
26
+ )
27
+
28
+
29
+ def downsample_basic_block_v2(
30
+ in_channels: int,
31
+ out_channels: int,
32
+ stride: int,
33
+ ) -> nn.Sequential:
34
+ return nn.Sequential(
35
+ nn.AvgPool2d(
36
+ kernel_size=stride,
37
+ stride=stride,
38
+ ceil_mode=True,
39
+ count_include_pad=False,
40
+ ),
41
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
42
+ nn.BatchNorm2d(out_channels),
43
+ )
44
+
45
+
46
+ class BasicBlock(nn.Module):
47
+ expansion = 1
48
+
49
+ def __init__(
50
+ self,
51
+ in_channels: int,
52
+ channels: int,
53
+ stride: int = 1,
54
+ downsample: nn.Sequential = None,
55
+ relu_type: str = "relu",
56
+ ) -> None:
57
+ super(BasicBlock, self).__init__()
58
+ assert relu_type in ["relu", "prelu"]
59
+
60
+ self.conv1 = conv3x3(in_channels, channels, stride)
61
+ self.bn1 = nn.BatchNorm2d(channels)
62
+
63
+ if relu_type == "relu":
64
+ self.relu1 = nn.ReLU(inplace=True)
65
+ self.relu2 = nn.ReLU(inplace=True)
66
+ elif relu_type == "prelu":
67
+ self.relu1 = nn.PReLU(num_parameters=channels)
68
+ self.relu2 = nn.PReLU(num_parameters=channels)
69
+ else:
70
+ raise Exception("relu type not implemented")
71
+
72
+ self.conv2 = conv3x3(channels, channels)
73
+ self.bn2 = nn.BatchNorm2d(channels)
74
+
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ residual = x
80
+ out = self.conv1(x)
81
+ out = self.bn1(out)
82
+ out = self.relu1(out)
83
+ out = self.conv2(out)
84
+ out = self.bn2(out)
85
+ if self.downsample is not None:
86
+ residual = self.downsample(x)
87
+ out += residual
88
+ out = self.relu2(out)
89
+ return out
90
+
91
+
92
+ class ResNet(nn.Module):
93
+ def __init__(
94
+ self,
95
+ block: nn.Module,
96
+ layers: list,
97
+ relu_type: str = "relu",
98
+ gamma_zero: bool = False,
99
+ avg_pool_downsample: bool = False,
100
+ ) -> None:
101
+ self.in_channels = 64
102
+ self.relu_type = relu_type
103
+ self.gamma_zero = gamma_zero
104
+ self.downsample_block = (
105
+ downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
106
+ )
107
+
108
+ super(ResNet, self).__init__()
109
+ self.layer1 = self._make_layer(block, 64, layers[0])
110
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
114
+
115
+ for m in self.modules():
116
+ if isinstance(m, nn.Conv2d):
117
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
119
+ elif isinstance(m, nn.BatchNorm2d):
120
+ m.weight.data.fill_(1)
121
+ m.bias.data.zero_()
122
+
123
+ if self.gamma_zero:
124
+ for m in self.modules():
125
+ if isinstance(m, BasicBlock):
126
+ m.bn2.weight.data.zero_()
127
+
128
+ def _make_layer(
129
+ self,
130
+ block: nn.Module,
131
+ channels: int,
132
+ n_blocks: int,
133
+ stride: int = 1,
134
+ ) -> nn.Sequential:
135
+ downsample = None
136
+ if stride != 1 or self.in_channels != channels * block.expansion:
137
+ downsample = self.downsample_block(
138
+ in_channels=self.in_channels,
139
+ out_channels=channels * block.expansion,
140
+ stride=stride,
141
+ )
142
+
143
+ layers = [
144
+ block(
145
+ self.in_channels, channels, stride, downsample, relu_type=self.relu_type
146
+ )
147
+ ]
148
+ self.in_channels = channels * block.expansion
149
+ for _ in range(1, n_blocks):
150
+ layers.append(block(self.in_channels, channels, relu_type=self.relu_type))
151
+
152
+ return nn.Sequential(*layers)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ x = self.layer1(x)
156
+ x = self.layer2(x)
157
+ x = self.layer3(x)
158
+ x = self.layer4(x)
159
+ x = self.avgpool(x)
160
+ x = x.view(x.size(0), -1)
161
+ return x
162
+
163
+
164
+ class ResNetEncoder(nn.Module):
165
+ def __init__(self, relu_type: str, weight_file: str = None) -> None:
166
+ super(ResNetEncoder, self).__init__()
167
+ self.frontend_out = 64
168
+ self.backend_out = 512
169
+ frontend_relu = (
170
+ nn.PReLU(num_parameters=self.frontend_out)
171
+ if relu_type == "prelu"
172
+ else nn.ReLU()
173
+ )
174
+
175
+ self.frontend3D = nn.Sequential(
176
+ nn.Conv3d(
177
+ 1,
178
+ self.frontend_out,
179
+ kernel_size=(5, 7, 7),
180
+ stride=(1, 2, 2),
181
+ padding=(2, 3, 3),
182
+ bias=False,
183
+ ),
184
+ nn.BatchNorm3d(self.frontend_out),
185
+ frontend_relu,
186
+ nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
187
+ )
188
+ self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
189
+
190
+ if weight_file is not None:
191
+ model_state_dict = torch.load(weight_file, map_location=torch.device("cpu"))
192
+ model_state_dict = model_state_dict["model_state_dict"]
193
+ frontend_state_dict, trunk_state_dict = OrderedDict(), OrderedDict()
194
+ for key, val in model_state_dict.items():
195
+ new_key = ".".join(key.split(".")[1:])
196
+ if "frontend3D" in key:
197
+ frontend_state_dict[new_key] = val
198
+ if "trunk" in key:
199
+ trunk_state_dict[new_key] = val
200
+ self.frontend3D.load_state_dict(frontend_state_dict)
201
+ self.trunk.load_state_dict(trunk_state_dict)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ B, C, T, H, W = x.size()
205
+ x = self.frontend3D(x)
206
+ Tnew = x.shape[2]
207
+ x = self.convert_3D_to_2D(x)
208
+ x = self.trunk(x)
209
+ x = x.view(B, Tnew, x.size(1))
210
+ x = x.transpose(1, 2).contiguous()
211
+ return x
212
+
213
+ def convert_3D_to_2D(self, x: torch.Tensor) -> torch.Tensor:
214
+ n_batches, n_channels, s_time, sx, sy = x.shape
215
+ x = x.transpose(1, 2).contiguous()
216
+ return x.reshape(n_batches * s_time, n_channels, sx, sy)
utils.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import joblib
3
+ import numpy as np
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+ from typing import Tuple, Optional
7
+ from huggingface_hub import HfFileSystem
8
+
9
+
10
+ def load_kmeans_model(km_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
11
+ """Load the k-means model."""
12
+ fs = HfFileSystem()
13
+
14
+ if Path(km_path).exists():
15
+ km_file = Path(km_path)
16
+ elif fs.exists(km_path):
17
+ km_file = BytesIO(fs.read_bytes(km_path))
18
+ else:
19
+ raise FileNotFoundError(f"K-means model not found at {km_path}")
20
+
21
+ kmeans_model = joblib.load(km_file)
22
+ C = torch.from_numpy(kmeans_model.cluster_centers_.transpose())
23
+ Cnorm = C.pow(2).sum(0, keepdim=True)
24
+ return C, Cnorm
25
+
26
+
27
+ def find_runs(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
28
+ """Find runs of consecutive items in an array."""
29
+
30
+ # ensure array
31
+ x = np.asanyarray(x)
32
+ if x.ndim != 1:
33
+ raise ValueError("only 1D array supported")
34
+ n = x.shape[0]
35
+
36
+ # handle empty array
37
+ if n == 0:
38
+ return np.array([]), np.array([]), np.array([])
39
+ else:
40
+ # find run starts
41
+ loc_run_start = np.empty(n, dtype=bool)
42
+ loc_run_start[0] = True
43
+ np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
44
+ run_starts = np.nonzero(loc_run_start)[0]
45
+
46
+ # find run values
47
+ run_values = x[loc_run_start]
48
+
49
+ # find run lengths
50
+ run_lengths = np.diff(np.append(run_starts, n))
51
+
52
+ return run_values, run_starts, run_lengths
53
+
54
+
55
+ def compute_mask_indices(
56
+ shape: Tuple[int, int],
57
+ padding_mask: Optional[torch.Tensor],
58
+ mask_prob: float,
59
+ mask_length: int,
60
+ mask_type: str = "static",
61
+ mask_other: float = 0.0,
62
+ min_masks: int = 0,
63
+ no_overlap: bool = False,
64
+ min_space: int = 0,
65
+ ) -> np.ndarray:
66
+ """
67
+ Computes random mask spans for a given shape
68
+ Args:
69
+ shape: the the shape for which to compute masks.
70
+ should be of size 2 where first element is batch size and 2nd is timesteps
71
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
72
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
73
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
74
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
75
+ mask_type: how to compute mask lengths
76
+ static = fixed size
77
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
78
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
79
+ poisson = sample from possion distribution with lambda = mask length
80
+ min_masks: minimum number of masked spans
81
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
82
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
83
+ """
84
+ bsz, all_sz = shape
85
+ mask = np.full((bsz, all_sz), False)
86
+
87
+ all_num_mask = int(
88
+ # add a random number for probabilistic rounding
89
+ mask_prob * all_sz / float(mask_length)
90
+ + np.random.rand()
91
+ )
92
+
93
+ all_num_mask = max(min_masks, all_num_mask)
94
+
95
+ mask_idcs = []
96
+ for i in range(bsz):
97
+ if padding_mask is not None:
98
+ sz = all_sz - padding_mask[i].long().sum().item()
99
+ num_mask = int(
100
+ # add a random number for probabilistic rounding
101
+ mask_prob * sz / float(mask_length)
102
+ + np.random.rand()
103
+ )
104
+ num_mask = max(min_masks, num_mask)
105
+ else:
106
+ sz = all_sz
107
+ num_mask = all_num_mask
108
+
109
+ if mask_type == "static":
110
+ lengths = np.full(num_mask, mask_length)
111
+ elif mask_type == "uniform":
112
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
113
+ elif mask_type == "normal":
114
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
115
+ lengths = [max(1, int(round(x))) for x in lengths]
116
+ elif mask_type == "poisson":
117
+ lengths = np.random.poisson(mask_length, size=num_mask)
118
+ lengths = [int(round(x)) for x in lengths]
119
+ else:
120
+ raise Exception("unknown mask selection " + mask_type)
121
+
122
+ if sum(lengths) == 0:
123
+ lengths[0] = min(mask_length, sz - 1)
124
+
125
+ if no_overlap:
126
+ mask_idc = []
127
+
128
+ def arrange(s, e, length, keep_length):
129
+ span_start = np.random.randint(s, e - length)
130
+ mask_idc.extend(span_start + i for i in range(length))
131
+
132
+ new_parts = []
133
+ if span_start - s - min_space >= keep_length:
134
+ new_parts.append((s, span_start - min_space + 1))
135
+ if e - span_start - keep_length - min_space > keep_length:
136
+ new_parts.append((span_start + length + min_space, e))
137
+ return new_parts
138
+
139
+ parts = [(0, sz)]
140
+ min_length = min(lengths)
141
+ for length in sorted(lengths, reverse=True):
142
+ lens = np.fromiter(
143
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
144
+ np.int,
145
+ )
146
+ l_sum = np.sum(lens)
147
+ if l_sum == 0:
148
+ break
149
+ probs = lens / np.sum(lens)
150
+ c = np.random.choice(len(parts), p=probs)
151
+ s, e = parts.pop(c)
152
+ parts.extend(arrange(s, e, length, min_length))
153
+ mask_idc = np.asarray(mask_idc)
154
+ else:
155
+ min_len = min(lengths)
156
+ if sz - min_len <= num_mask:
157
+ min_len = sz - num_mask - 1
158
+
159
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
160
+
161
+ mask_idc = np.asarray(
162
+ [
163
+ mask_idc[j] + offset
164
+ for j in range(len(mask_idc))
165
+ for offset in range(lengths[j])
166
+ ]
167
+ )
168
+
169
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
170
+
171
+ min_len = min([len(m) for m in mask_idcs])
172
+ batch_indexes, starts, ends = [], [], []
173
+ for i, mask_idc in enumerate(mask_idcs):
174
+ if len(mask_idc) > min_len:
175
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
176
+ mask[i, mask_idc] = True
177
+ vals, run_starts, run_lengths = find_runs(mask[i])
178
+ start_indices, lengths = run_starts[vals == True], run_lengths[vals == True]
179
+ starts.append(start_indices)
180
+ ends.append(start_indices + lengths)
181
+ batch_indexes.append(np.zeros([len(start_indices)]) + i)
182
+ return (
183
+ mask,
184
+ np.concatenate(starts).astype(np.int64),
185
+ np.concatenate(ends).astype(np.int64),
186
+ np.concatenate(batch_indexes).astype(np.int64),
187
+ )