|
import torch
|
|
import logging
|
|
import contextlib
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from pathlib import Path
|
|
from .resnet import ResNetEncoder
|
|
from .encoder import TransformerEncoder
|
|
from .configuration import AVHubertConfig, AVSPLLMConfig
|
|
from .utils import compute_mask_indices, load_kmeans_model
|
|
from typing import Optional, Tuple, List, Dict, Any
|
|
from peft import get_peft_model, LoraConfig
|
|
from fairseq.modules import GradMultiply, LayerNorm
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
from transformers import (
|
|
FeatureExtractionMixin,
|
|
PreTrainedModel,
|
|
BitsAndBytesConfig,
|
|
AutoModelForCausalLM,
|
|
GenerationConfig,
|
|
)
|
|
|
|
|
|
logging.root.setLevel(logging.WARNING)
|
|
|
|
|
|
class AVHubertFeatureExtractor(FeatureExtractionMixin):
|
|
def __init__(self, config: AVHubertConfig = AVHubertConfig(), **kwargs) -> None:
|
|
super().__init__(**kwargs)
|
|
self.audio_feat_dim = config.audio_feat_dim
|
|
|
|
self.size = 88
|
|
self.num_frames = 76
|
|
self.num_channels = 1
|
|
|
|
|
|
class AVSPLLMFeatureExtractor(AVHubertFeatureExtractor):
|
|
def __init__(self, config: AVSPLLMConfig = AVSPLLMConfig(), **kwargs) -> None:
|
|
super().__init__(config, **kwargs)
|
|
|
|
|
|
class AVHubertVideoFeatureEncoder(nn.Module):
|
|
def __init__(self, config: AVHubertConfig) -> None:
|
|
super().__init__()
|
|
self.resnet = ResNetEncoder(relu_type=config.resnet_relu_type)
|
|
self.proj = nn.Linear(self.resnet.backend_out, config.encoder_embed_dim)
|
|
self.encoder = (
|
|
TransformerEncoder(config)
|
|
if config.sub_encoder_layers > 0
|
|
else None
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.resnet(x)
|
|
x = self.proj(x.transpose(1, 2))
|
|
if self.encoder is not None:
|
|
x = self.encoder(x)[0].transpose(1, 2)
|
|
else:
|
|
x = x.transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class AVHubertAudioFeatureEncoder(nn.Module):
|
|
def __init__(self, config: AVHubertConfig) -> None:
|
|
super().__init__()
|
|
self.proj = nn.Linear(config.audio_feat_dim, config.encoder_embed_dim)
|
|
self.encoder = (
|
|
TransformerEncoder(config)
|
|
if config.sub_encoder_layers > 0
|
|
else None
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.proj(x.transpose(1, 2))
|
|
if self.encoder is not None:
|
|
x = self.encoder(x)[0].transpose(1, 2)
|
|
else:
|
|
x = x.transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class AVHubertModel(PreTrainedModel):
|
|
config_class = AVHubertConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: AVHubertConfig = AVHubertConfig(),
|
|
dictionaries: List = [None],
|
|
) -> None:
|
|
super().__init__(config=config)
|
|
label_rate = config.label_rate
|
|
feature_ds_rate = config.feature_ds_rate
|
|
sample_rate = config.sample_rate
|
|
self.feat2tar_ration = label_rate * feature_ds_rate / sample_rate
|
|
|
|
self.feature_extractor_video = AVHubertVideoFeatureEncoder(config)
|
|
self.feature_extractor_audio = AVHubertAudioFeatureEncoder(config)
|
|
|
|
if config.modality_fuse == "concat":
|
|
self.encoder_embed_dim = config.encoder_embed_dim * 2
|
|
elif config.modality_fuse == "add":
|
|
self.encoder_embed_dim = config.encoder_embed_dim
|
|
|
|
self.post_extract_proj = (
|
|
nn.Linear(self.encoder_embed_dim, config.encoder_embed_dim)
|
|
if self.encoder_embed_dim != config.encoder_embed_dim
|
|
else None
|
|
)
|
|
|
|
self.dropout_input = nn.Dropout(config.dropout_input)
|
|
self.dropout_features = nn.Dropout(config.dropout_features)
|
|
|
|
if self.config.final_dim > 0:
|
|
final_dim = config.final_dim
|
|
else:
|
|
final_dim = config.encoder_embed_dim
|
|
|
|
self.mask_emb = nn.Parameter(
|
|
torch.FloatTensor(config.audio_feat_dim).uniform_()
|
|
if config.masking_type == "input"
|
|
else torch.FloatTensor(config.encoder_embed_dim).uniform_()
|
|
)
|
|
|
|
self.encoder = TransformerEncoder(self.config)
|
|
self.layer_norm = LayerNorm(self.encoder_embed_dim)
|
|
|
|
self.target_glu = None
|
|
if config.target_glu:
|
|
self.target_glu = nn.Sequential(
|
|
nn.Linear(config.final_dim, config.final_dim * 2),
|
|
nn.GLU(),
|
|
)
|
|
|
|
if config.untie_final_proj:
|
|
self.final_proj = nn.Linear(
|
|
config.encoder_embed_dim,
|
|
final_dim * len(dictionaries),
|
|
)
|
|
else:
|
|
self.final_proj = nn.Linear(config.encoder_embed_dim, final_dim)
|
|
|
|
|
|
if any([d is None for d in dictionaries]):
|
|
self.num_classes = config.num_classes
|
|
else:
|
|
self.num_classes = sum([len(d) for d in dictionaries])
|
|
self.label_embs_concat = nn.Parameter(
|
|
torch.FloatTensor(self.num_classes, final_dim)
|
|
)
|
|
nn.init.uniform_(self.label_embs_concat)
|
|
|
|
def apply_input_mask(
|
|
self,
|
|
x: torch.Tensor,
|
|
padding_mask: torch.Tensor,
|
|
target_list: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
B, C, T = x.shape[:3]
|
|
is_audio = True if len(x.shape) == 3 else False
|
|
|
|
if is_audio:
|
|
mask_prob = self.config.mask_prob_audio
|
|
mask_length = self.config.mask_length_audio
|
|
else:
|
|
mask_prob = self.config.mask_prob_image
|
|
mask_length = self.config.mask_length_image
|
|
|
|
if mask_prob > 0:
|
|
mask_indices, starts, ends, batch_indexes = compute_mask_indices(
|
|
(B, T),
|
|
padding_mask,
|
|
mask_prob,
|
|
mask_length,
|
|
self.config.mask_selection,
|
|
self.config.mask_other,
|
|
min_masks=2,
|
|
no_overlap=self.config.no_mask_overlap,
|
|
min_space=self.config.mask_min_space,
|
|
)
|
|
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
x = x.transpose(1, 2).contiguous()
|
|
if B == 1:
|
|
x[mask_indices] = 0
|
|
elif is_audio:
|
|
x[mask_indices] = self.mask_emb
|
|
elif self.config.selection_type == "same_other_seq":
|
|
perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
|
|
x_perm = x[perm]
|
|
x[mask_indices] = x_perm[mask_indices]
|
|
elif self.config.selection_type == "same_seq":
|
|
batch_indexes_, other_indexes = [], []
|
|
for batch_index, start, end in zip(batch_indexes, starts, ends):
|
|
length = end - start
|
|
other_start = np.setdiff1d(
|
|
np.arange(T), np.arange(max(0, start - length), end)
|
|
)
|
|
if len(other_start) > 0:
|
|
other_start = np.random.choice(other_start, size=1)
|
|
else:
|
|
other_start = 0
|
|
other_end = other_start + length
|
|
other_indexes.append(
|
|
np.arange(other_start, other_end).clip(max=T - 1)
|
|
)
|
|
batch_indexes_.append(
|
|
np.zeros([length], dtype=np.int64) + batch_index
|
|
)
|
|
batch_indexes = np.concatenate(batch_indexes_)
|
|
other_indexes = np.concatenate(other_indexes)
|
|
x[mask_indices] = x[batch_indexes, other_indexes]
|
|
x = x.transpose(1, 2).contiguous()
|
|
else:
|
|
mask_indices = None
|
|
|
|
if self.config.mask_channel_prob > 0:
|
|
logging.warn("No mask channel prob for input masking")
|
|
return x, mask_indices
|
|
|
|
def apply_feature_mask(
|
|
self,
|
|
x: torch.Tensor,
|
|
padding_mask: torch.Tensor,
|
|
target_list: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
B, T, C = x.shape
|
|
assert all((
|
|
self.config.mask_prob_audio == self.config.mask_prob_image,
|
|
self.config.mask_length_audio == self.config.mask_length_image,
|
|
)), "masking prob/length for image/audio be same for feature masking"
|
|
|
|
mask_prob = self.config.mask_prob_audio
|
|
mask_length = self.config.mask_length_image
|
|
if mask_prob > 0:
|
|
mask_indices, _, _, _ = compute_mask_indices(
|
|
(B, T),
|
|
padding_mask,
|
|
mask_prob,
|
|
mask_length,
|
|
self.config.mask_selection,
|
|
self.config.mask_other,
|
|
min_masks=2,
|
|
no_overlap=self.config.no_mask_overlap,
|
|
min_space=self.config.mask_min_space,
|
|
)
|
|
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
|
x[mask_indices] = self.mask_emb
|
|
else:
|
|
mask_indices = None
|
|
|
|
if self.config.mask_channel_prob > 0:
|
|
mask_channel_indices, _, _, _ = compute_mask_indices(
|
|
(B, C),
|
|
None,
|
|
self.config.mask_channel_prob,
|
|
self.config.mask_channel_length,
|
|
self.config.mask_channel_selection,
|
|
self.config.mask_channel_other,
|
|
no_overlap=self.config.no_mask_channel_overlap,
|
|
min_space=self.config.mask_channel_min_space,
|
|
)
|
|
mask_channel_indices = (
|
|
torch.from_numpy(mask_channel_indices)
|
|
.to(x.device)
|
|
.unsqueeze(1)
|
|
.expand(-1, T, -1)
|
|
)
|
|
x[mask_channel_indices] = 0
|
|
|
|
return x, mask_indices
|
|
|
|
def forward_features(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
modality: str,
|
|
) -> torch.Tensor:
|
|
extractor = eval(f"self.feature_extractor_{modality}")
|
|
if self.config.feature_grad_mult > 0:
|
|
features = extractor(source)
|
|
if self.config.feature_grad_mult != 1.0:
|
|
features = GradMultiply.apply(features, self.config.feature_grad_mult)
|
|
else:
|
|
with torch.no_grad():
|
|
features = extractor(source)
|
|
return features
|
|
|
|
def forward_targets(
|
|
self,
|
|
features: torch.Tensor,
|
|
mask_indices: torch.Tensor,
|
|
target_list: List[torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
|
|
|
feat_tsz = features.size(2)
|
|
targ_tsz = min([t.size(1) for t in target_list])
|
|
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
|
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
|
features = features[..., :feat_tsz]
|
|
if mask_indices is not None:
|
|
mask_indices = mask_indices[..., :feat_tsz]
|
|
target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
|
|
target_list = [t[:, target_inds.long()] for t in target_list]
|
|
return features, mask_indices, target_list
|
|
|
|
def forward_padding_mask(
|
|
self,
|
|
features: torch.Tensor,
|
|
padding_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
extra = padding_mask.size(1) % features.size(1)
|
|
if extra > 0:
|
|
padding_mask = padding_mask[:, :-extra]
|
|
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
|
padding_mask = padding_mask.all(-1)
|
|
return padding_mask
|
|
|
|
def compute_logits(self, feats: torch.Tensor, emb_mat: torch.Tensor) -> torch.Tensor:
|
|
|
|
if self.config.sim_type == "dot":
|
|
logits = torch.matmul(feats, emb_mat.transpose(0, 1))
|
|
elif self.config.sim_type == "cosine":
|
|
batch_size, timesteps, emb_dim = feats.size()
|
|
feats_ = feats.view(-1, emb_dim)
|
|
|
|
nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1)
|
|
|
|
denom = (
|
|
(feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1)
|
|
* (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0)
|
|
)
|
|
logits = (nom / denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
|
|
else:
|
|
raise NotImplementedError
|
|
logits = logits / self.config.logit_temp
|
|
return logits
|
|
|
|
def forward(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
target_list: Optional[List[torch.Tensor]] = None,
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = True,
|
|
features_only: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""output layer is 1-based"""
|
|
src_audio, src_video = source["audio"], source["video"]
|
|
if mask and self.masking_type == "input":
|
|
src_video, mask_indices_video = self.apply_input_mask(
|
|
src_video, padding_mask, target_list
|
|
)
|
|
src_audio, mask_indices_audio = self.apply_input_mask(
|
|
src_audio, padding_mask, target_list
|
|
)
|
|
mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
|
|
else:
|
|
src_audio, src_video, mask_indices = src_audio, src_video, None
|
|
|
|
|
|
features_audio = self.forward_features(src_audio, modality="audio")
|
|
features_video = self.forward_features(src_video, modality="video")
|
|
|
|
if self.training:
|
|
modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
|
|
if modality_drop_prob < self.config.modality_dropout:
|
|
if audio_drop_prob < self.config.audio_dropout:
|
|
features_audio = 0 * features_audio
|
|
else:
|
|
features_video = 0 * features_video
|
|
|
|
if self.config.modality_fuse == "concat":
|
|
features = torch.cat([features_audio, features_video], dim=1)
|
|
elif self.config.modality_fuse == "add":
|
|
features = features_audio + features_video
|
|
|
|
if target_list is not None:
|
|
features, mask_indices, target_list = self.forward_targets(
|
|
features, mask_indices, target_list
|
|
)
|
|
|
|
features_pen = features.float().pow(2).mean()
|
|
|
|
features = features.transpose(1, 2)
|
|
features = self.layer_norm(features)
|
|
|
|
if padding_mask is not None:
|
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
|
|
|
if self.post_extract_proj is not None:
|
|
features = self.post_extract_proj(features)
|
|
|
|
features = self.dropout_input(features)
|
|
if self.config.masking_type == "feature" and mask:
|
|
x, mask_indices = self.apply_feature_mask(
|
|
features, padding_mask, target_list
|
|
)
|
|
else:
|
|
x = features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, _ = self.encoder(
|
|
x,
|
|
padding_mask=padding_mask,
|
|
layer=None if output_layer is None else output_layer - 1,
|
|
)
|
|
|
|
if features_only:
|
|
return {"x": x, "padding_mask": padding_mask, "features": features}
|
|
|
|
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
|
proj_x = self.final_proj(x)
|
|
if self.config.untie_final_proj:
|
|
proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
|
|
else:
|
|
proj_x_list = [proj_x for _ in self.num_classes]
|
|
|
|
|
|
logit_list = [
|
|
self.compute_logits(proj, emb).view(-1, num_class)
|
|
for proj, emb, num_class in zip(
|
|
proj_x_list, label_embs_list, self.num_classes
|
|
)
|
|
]
|
|
|
|
mask = torch.logical_and(mask_indices, ~padding_mask).view(-1)
|
|
unmask = torch.logical_and(~mask_indices, ~padding_mask).view(-1)
|
|
logit_m_list = [logit[mask] for logit in logit_list]
|
|
logit_u_list = [logit[unmask] for logit in logit_list]
|
|
target_m_list = [target.view(-1)[mask].long() for target in target_list]
|
|
target_u_list = [target.view(-1)[unmask].long() for target in target_list]
|
|
|
|
return {
|
|
"logit_m_list": logit_m_list,
|
|
"logit_u_list": logit_u_list,
|
|
"target_m_list": target_m_list,
|
|
"target_u_list": target_u_list,
|
|
"padding_mask": padding_mask,
|
|
"features_pen": features_pen,
|
|
}
|
|
|
|
def extract_features(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
padding_mask: Optional[torch.Tensor] = None,
|
|
mask: bool = False,
|
|
ret_conv: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
res = self.forward(
|
|
source,
|
|
padding_mask=padding_mask,
|
|
mask=mask,
|
|
features_only=True,
|
|
output_layer=output_layer,
|
|
)
|
|
feature = res["features"] if ret_conv else res["x"]
|
|
return feature, res["padding_mask"]
|
|
|
|
def extract_units(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
padding_mask: torch.Tensor = None,
|
|
mask: bool = False,
|
|
ret_conv: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
res = self.forward(
|
|
source,
|
|
padding_mask=padding_mask,
|
|
mask=mask,
|
|
features_only=True,
|
|
output_layer=None,
|
|
)
|
|
|
|
feature = res["features"] if ret_conv else res["x"]
|
|
proj_x = self.final_proj(feature)
|
|
|
|
units = (
|
|
torch
|
|
.matmul(proj_x, self.label_embs_concat.transpose(0, 1))
|
|
.argmax(dim=-1)
|
|
)
|
|
return units
|
|
|
|
def extract_finetune(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
padding_mask: torch.Tensor = None,
|
|
mask: bool = False,
|
|
ret_conv: bool = False,
|
|
output_layer: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
src_audio, src_video = source["audio"], source["video"]
|
|
if mask and self.config.masking_type == "input":
|
|
src_video, _ = self.apply_input_mask(
|
|
src_video, padding_mask, target_list=None
|
|
)
|
|
src_audio, _ = self.apply_input_mask(
|
|
src_audio, padding_mask, target_list=None
|
|
)
|
|
else:
|
|
src_audio, src_video, _ = src_audio, src_video, None
|
|
|
|
|
|
if src_audio is not None and src_video is None:
|
|
features_audio = self.forward_features(
|
|
src_audio, modality="audio"
|
|
)
|
|
features_video = features_audio.new_zeros(
|
|
features_audio.size(0),
|
|
self.encoder_embed_dim,
|
|
features_audio.size(-1)
|
|
)
|
|
elif src_audio is None and src_video is not None:
|
|
features_video = self.forward_features(src_video, modality="video")
|
|
features_audio = features_video.new_zeros(
|
|
features_video.size(0),
|
|
self.encoder_embed_dim,
|
|
features_video.size(-1)
|
|
)
|
|
elif src_audio is not None and src_video is not None:
|
|
features_video = self.forward_features(src_video, modality="video")
|
|
features_audio = self.forward_features(
|
|
src_audio, modality="audio"
|
|
)
|
|
|
|
if self.config.modality_fuse == "concat":
|
|
features = torch.cat([features_audio, features_video], dim=1)
|
|
elif self.config.modality_fuse == "add":
|
|
features = features_audio + features_video
|
|
|
|
features = features.transpose(1, 2)
|
|
features = self.layer_norm(features)
|
|
unmasked_features = features.clone()
|
|
|
|
if padding_mask is not None:
|
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
|
|
|
if self.post_extract_proj is not None:
|
|
features = self.post_extract_proj(features)
|
|
|
|
features = self.dropout_input(features)
|
|
unmasked_features = self.dropout_features(unmasked_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, _ = self.encoder(
|
|
features,
|
|
padding_mask=padding_mask,
|
|
layer=None if output_layer is None else output_layer - 1,
|
|
)
|
|
|
|
return x, padding_mask
|
|
|
|
def get_extra_losses(
|
|
self,
|
|
net_output: Dict[str, torch.Tensor],
|
|
) -> Tuple[List[torch.Tensor], List[str]]:
|
|
extra_losses = []
|
|
names = []
|
|
if "features_pen" in net_output:
|
|
extra_losses.append(net_output["features_pen"])
|
|
names.append("features_pen")
|
|
|
|
return extra_losses, names
|
|
|
|
def remove_pretraining_modules(self) -> None:
|
|
self.target_glu = None
|
|
self.final_proj = None
|
|
|
|
def compute_nce(
|
|
self,
|
|
x: torch.Tensor,
|
|
pos: torch.Tensor,
|
|
negs: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
neg_is_pos = (pos == negs).all(-1)
|
|
pos = pos.unsqueeze(0)
|
|
targets = torch.cat([pos, negs], dim=0)
|
|
|
|
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
|
logits /= self.config.logit_temp
|
|
if neg_is_pos.any():
|
|
logits[1:][neg_is_pos] = float("-inf")
|
|
logits = logits.transpose(0, 1)
|
|
return logits
|
|
|
|
|
|
class HubertEncoderWrapper(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: AVHubertConfig,
|
|
dictionaries: List = [None],
|
|
) -> None:
|
|
super().__init__()
|
|
self.w2v_model = AVHubertModel(config, dictionaries)
|
|
|
|
def forward(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
padding_mask: torch.Tensor,
|
|
**kwargs,
|
|
) -> Dict[str, torch.Tensor]:
|
|
w2v_args = {
|
|
"source": source,
|
|
"padding_mask": padding_mask,
|
|
}
|
|
x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
|
|
return {
|
|
"encoder_out": x,
|
|
"encoder_padding_mask": padding_mask,
|
|
"padding_mask": padding_mask,
|
|
}
|
|
|
|
def reorder_encoder_out(
|
|
self,
|
|
encoder_out: Dict[str, torch.Tensor],
|
|
new_order: torch.Tensor,
|
|
) -> Dict[str, torch.Tensor]:
|
|
if encoder_out["encoder_out"] is not None:
|
|
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
|
1, new_order
|
|
)
|
|
if encoder_out["encoder_padding_mask"] is not None:
|
|
encoder_out["encoder_padding_mask"] = encoder_out[
|
|
"encoder_padding_mask"
|
|
].index_select(0, new_order)
|
|
if encoder_out["padding_mask"] is not None:
|
|
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
|
|
0, new_order
|
|
)
|
|
return encoder_out
|
|
|
|
|
|
class AVSPLLMModel(PreTrainedModel):
|
|
config_class = AVSPLLMConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: AVSPLLMConfig = AVSPLLMConfig(),
|
|
dictionaries: List = [None],
|
|
) -> None:
|
|
super().__init__(config=config)
|
|
current_dir = Path(__file__).resolve().parent
|
|
self.km_path = current_dir / config.km_path
|
|
if not self.km_path.is_file():
|
|
repo_id = self.config._name_or_path
|
|
self.km_path = f"{repo_id}/model.km"
|
|
self.km_path = str(self.km_path)
|
|
self.C, self.Cnorm = load_kmeans_model(self.km_path)
|
|
|
|
self.encoder = HubertEncoderWrapper(config, dictionaries)
|
|
self.encoder.w2v_model.remove_pretraining_modules()
|
|
|
|
self.avfeat_to_llm = nn.Linear(
|
|
config.encoder_embed_dim, config.decoder_embed_dim
|
|
)
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
)
|
|
decoder_4bit = AutoModelForCausalLM.from_pretrained(
|
|
config.llm_ckpt_path,
|
|
quantization_config=bnb_config,
|
|
cache_dir=config.cache_dir,
|
|
trust_remote_code=True,
|
|
)
|
|
lora_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
target_modules=["q_proj", "v_proj", "k_proj"],
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
self.decoder = get_peft_model(decoder_4bit, lora_config)
|
|
self.decoder.print_trainable_parameters()
|
|
|
|
def apply_kmeans(self, feat: torch.Tensor) -> torch.Tensor:
|
|
dist = (
|
|
feat.squeeze(0).pow(2).sum(1, keepdim=True)
|
|
- 2 * torch.matmul(feat.squeeze(0), self.C)
|
|
+ self.Cnorm
|
|
)
|
|
cluster_counts = dist.argmin(dim=1)
|
|
|
|
current_counts = 1
|
|
counts = []
|
|
for i in range(1, len(cluster_counts)):
|
|
if cluster_counts[i] == cluster_counts[i - 1]:
|
|
current_counts += 1
|
|
else:
|
|
counts.append(current_counts)
|
|
current_counts = 1
|
|
counts.append(current_counts)
|
|
|
|
return torch.tensor(counts)
|
|
|
|
def deduplicate(
|
|
self,
|
|
feat: torch.Tensor,
|
|
cluster_counts: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
results_tensor = []
|
|
start_idx = 0
|
|
for clutser_num in cluster_counts:
|
|
end_idx = start_idx + clutser_num
|
|
slice = feat[:, start_idx:end_idx, :]
|
|
mean_tensor = torch.mean(slice, dim=1, keepdim=True)
|
|
results_tensor.append(mean_tensor)
|
|
start_idx = end_idx
|
|
|
|
assert cluster_counts.sum().item() == feat.size()[1], \
|
|
f"{cluster_counts.sum().item()} != {feat.size()[1]}"
|
|
|
|
return torch.cat(results_tensor, dim=1)
|
|
|
|
def embed(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
padding_mask: torch.Tensor,
|
|
target_list: torch.Tensor = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
ft = self.config.freeze_finetune_updates <= kwargs.get("num_updates", -1)
|
|
with torch.no_grad() if not ft else contextlib.ExitStack():
|
|
output = self.encoder(source, padding_mask, **kwargs)
|
|
|
|
cluster_counts = self.apply_kmeans(output["encoder_out"])
|
|
|
|
output["encoder_out"] = self.avfeat_to_llm(output["encoder_out"])
|
|
|
|
reduced_enc_out = self.deduplicate(output["encoder_out"], cluster_counts)
|
|
reduced_enc_out = reduced_enc_out.to(self.decoder.device)
|
|
B, T, D = reduced_enc_out.size()
|
|
|
|
instruction = source["text"]
|
|
instruction_embedding = self.decoder.model.model.embed_tokens(instruction)
|
|
|
|
llm_input = torch.cat((instruction_embedding, reduced_enc_out), dim=1)
|
|
|
|
if target_list is None:
|
|
return llm_input, None
|
|
|
|
labels = target_list.clone()
|
|
labels_embedding = self.decoder.model.model.embed_tokens(labels)
|
|
|
|
llm_input = torch.cat((llm_input, labels_embedding), dim=1)
|
|
|
|
llm_labels = labels.clone()
|
|
llm_labels[llm_labels == 0] = -100
|
|
|
|
_, instruction_embedding_t, _ = instruction_embedding.size()
|
|
target_ids = (
|
|
torch.full((B, T + instruction_embedding_t), -100).long().to(labels.device)
|
|
)
|
|
llm_labels = torch.cat((target_ids, llm_labels), dim=1)
|
|
|
|
return llm_input, llm_labels
|
|
|
|
def forward(
|
|
self,
|
|
source: Dict[str, torch.Tensor],
|
|
padding_mask: torch.Tensor,
|
|
target_list: torch.Tensor = None,
|
|
**kwargs,
|
|
) -> CausalLMOutputWithPast:
|
|
llm_input, llm_labels = self.embed(
|
|
source, padding_mask, target_list, **kwargs
|
|
)
|
|
return self.decoder(
|
|
inputs_embeds=llm_input.to(torch.float16), labels=llm_labels, return_dict=True
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
inputs: Optional[Dict[str, torch.Tensor]] = None,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
**kwargs,
|
|
) -> Any:
|
|
llm_input, _ = self.embed(**inputs, **kwargs)
|
|
self.decoder.config.use_cache = True
|
|
return self.decoder.generate(
|
|
inputs_embeds=llm_input,
|
|
**generation_config,
|
|
**kwargs,
|
|
)
|
|
|