|
from typing import List, Union |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, BertTokenizer |
|
from transformers.utils import is_remote_url, download_url |
|
from pathlib import Path |
|
from configuration_vgcn import VGCNConfig |
|
import pickle as pkl |
|
import numpy as np |
|
import scipy.sparse as sp |
|
|
|
|
|
|
|
|
|
def get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj,gcn_config:VGCNConfig): |
|
|
|
def sparse_scipy2torch(coo_sparse): |
|
|
|
i = torch.LongTensor(np.vstack((coo_sparse.row, coo_sparse.col))) |
|
v = torch.from_numpy(coo_sparse.data) |
|
return torch.sparse.FloatTensor(i, v, torch.Size(coo_sparse.shape)) |
|
|
|
def normalize_adj(adj): |
|
""" |
|
Symmetrically normalize adjacency matrix. |
|
""" |
|
|
|
D_matrix = np.array(adj.sum(axis=1)) |
|
D_inv_sqrt = np.power(D_matrix, -0.5).flatten() |
|
D_inv_sqrt[np.isinf(D_inv_sqrt)] = 0. |
|
d_mat_inv_sqrt = sp.diags(D_inv_sqrt) |
|
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) |
|
|
|
gcn_vocab_adj_tf.data *= (gcn_vocab_adj_tf.data > gcn_config.tf_threshold) |
|
gcn_vocab_adj_tf.eliminate_zeros() |
|
|
|
gcn_vocab_adj.data *= (gcn_vocab_adj.data > gcn_config.npmi_threshold) |
|
gcn_vocab_adj.eliminate_zeros() |
|
|
|
if gcn_config.vocab_type == 'pmi': |
|
gcn_vocab_adj_list = [gcn_vocab_adj] |
|
elif gcn_config.vocab_type == 'tf': |
|
gcn_vocab_adj_list = [gcn_vocab_adj_tf] |
|
elif gcn_config.vocab_type == 'all': |
|
gcn_vocab_adj_list = [gcn_vocab_adj_tf, gcn_vocab_adj] |
|
else: |
|
raise ValueError(f"vocab_type must be 'pmi', 'tf' or 'all', got {gcn_config.vocab_type}") |
|
|
|
norm_gcn_vocab_adj_list = [] |
|
for i in range(len(gcn_vocab_adj_list)): |
|
adj = gcn_vocab_adj_list[i] |
|
adj = normalize_adj(adj) |
|
norm_gcn_vocab_adj_list.append(sparse_scipy2torch(adj.tocoo())) |
|
|
|
for t in norm_gcn_vocab_adj_list: |
|
t.requires_grad = False |
|
|
|
del gcn_vocab_adj_list |
|
|
|
return norm_gcn_vocab_adj_list |
|
|
|
|
|
|
|
class VCGNModelForTextClassification(PreTrainedModel): |
|
config_class = VGCNConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.pre_trained_model_name = '' |
|
self.remove_stop_words = False |
|
self.tokenizer = None |
|
self.norm_gcn_vocab_adj_list = None |
|
self.gcn_vocab_size = config.vocab_size |
|
|
|
|
|
self.load_adj_matrix(config.gcn_adj_matrix) |
|
|
|
self.model = VGCN_Bert( |
|
config, |
|
gcn_adj_matrix=self.norm_gcn_vocab_adj_list, |
|
gcn_adj_dim=config.vocab_size, |
|
gcn_adj_num=len(self.norm_gcn_vocab_adj_list), |
|
gcn_embedding_dim=config.gcn_embedding_dim, |
|
|
|
) |
|
|
|
def load_adj_matrix(self, adj_matrix): |
|
filename = None |
|
if Path(adj_matrix).is_file(): |
|
filename = Path(adj_matrix) |
|
|
|
elif (Path(__file__).parent / Path(adj_matrix)).is_file(): |
|
filename = Path(__file__).parent / Path(adj_matrix) |
|
elif is_remote_url(adj_matrix): |
|
filename = download_url(adj_matrix) |
|
|
|
|
|
gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb')) |
|
|
|
|
|
self.pre_trained_model_name = adj_config['bert_model'] |
|
self.remove_stop_words = adj_config['remove_stop_words'] |
|
self.tokenizer = BertTokenizer.from_pretrained(self.pre_trained_model_name) |
|
self.norm_gcn_vocab_adj_list = get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config) |
|
|
|
def _prep_batch(self, batch: torch.Tensor): |
|
|
|
vocab_size = self.tokenizer.vocab_size |
|
|
|
batch_gcn_swop_eye = F.one_hot(batch, vocab_size).float().to(self.device) |
|
batch_gcn_swop_eye = batch_gcn_swop_eye.transpose(1,2) |
|
|
|
batch_gcn_swop_eye[:, self.tokenizer.pad_token_id, :] = 0 |
|
batch_gcn_swop_eye[:, self.tokenizer.cls_token_id, :] = 0 |
|
batch_gcn_swop_eye[:, self.tokenizer.sep_token_id, :] = 0 |
|
|
|
batch_gcn_swop_eye = F.pad(batch_gcn_swop_eye,(0,self.config.gcn_embedding_dim,0,0,0,0),value=0) |
|
|
|
batch = F.pad(batch, (0, self.config.gcn_embedding_dim), 'constant', 0) |
|
|
|
|
|
mask = torch.zeros(batch.shape[0], batch.shape[1] + 1, dtype=batch.dtype, device=self.device) |
|
mask2 = torch.zeros(batch.shape[0], batch.shape[1] + 1, dtype=batch.dtype, device=self.device) |
|
|
|
pos_start = (batch==self.tokenizer.pad_token_id).int().argmax(1) |
|
|
|
mask[(torch.arange(batch.shape[0]), pos_start)] = 1 |
|
mask2[(torch.arange(batch.shape[0]), pos_start+self.config.gcn_embedding_dim)] = 1 |
|
|
|
mask = mask.cumsum(1)[:, :-1].bool() |
|
mask2 = mask2.cumsum(1)[:, :-1].bool() |
|
|
|
mask = mask & ~mask2 |
|
|
|
batch.masked_fill_(mask, self.tokenizer.sep_token_id) |
|
|
|
return batch, batch_gcn_swop_eye |
|
|
|
def text_to_batch(self, text: Union[List[str], str]): |
|
if isinstance(text, str): |
|
text = [text] |
|
encoded = self.tokenizer.batch_encode_plus(text, padding=True, truncation=True, return_tensors='pt', max_length=self.config.max_seq_len-self.config.gcn_embedding_dim) |
|
return encoded['input_ids'].to(self.device) |
|
|
|
def forward(self, input:Union[torch.Tensor, List[str], str], labels=None): |
|
|
|
if not isinstance(input, torch.Tensor): |
|
input = self.text_to_batch(input) |
|
|
|
input, batch_gcn_swop_eye = self._prep_batch(input) |
|
|
|
segment_ids = torch.zeros_like(input).int().to(self.device) |
|
input_mask = (input>0).int().to(self.device) |
|
|
|
|
|
logits = self.model(batch_gcn_swop_eye, input, segment_ids, input_mask ) |
|
if labels is not None: |
|
loss = torch.nn.cross_entropy(logits, labels) |
|
return {"loss": loss, "logits": logits} |
|
return {"logits": logits} |
|
|
|
def predict(self, text: Union[List[str], str], as_dict=True): |
|
with torch.no_grad(): |
|
logits = self.forward(text)['logits'] |
|
if as_dict: |
|
label_id = torch.argmax(logits, dim=1).cpu().numpy() |
|
label = [self.config.id2label[l] for l in label_id] |
|
return { |
|
"logits": logits, |
|
"label_id": label_id, |
|
"label": label, |
|
} |
|
else: |
|
return torch.argmax(logits, dim=1).cpu().numpy() |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
import math |
|
|
|
from transformers import BertModel |
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertPooler,BertEncoder |
|
|
|
class VocabGraphConvolution(nn.Module): |
|
"""Vocabulary GCN module. |
|
|
|
Params: |
|
`voc_dim`: The size of vocabulary graph |
|
`num_adj`: The number of the adjacency matrix of Vocabulary graph |
|
`hid_dim`: The hidden dimension after XAW |
|
`out_dim`: The output dimension after Relu(XAW)W |
|
`dropout_rate`: The dropout probabilitiy for all fully connected |
|
layers in the embeddings, encoder, and pooler. |
|
|
|
Inputs: |
|
`vocab_adj_list`: The list of the adjacency matrix |
|
`X_dv`: the feature of mini batch document, can be TF-IDF (batch, vocab), or word embedding (batch, word_embedding_dim, vocab) |
|
|
|
Outputs: |
|
The graph embedding representation, dimension (batch, `out_dim`) or (batch, word_embedding_dim, `out_dim`) |
|
|
|
""" |
|
def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2): |
|
super(VocabGraphConvolution, self).__init__() |
|
if type(adj_matrix) is not list: |
|
self.adj_matrix=adj_matrix |
|
else: |
|
self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix]) |
|
for p in self.adj_matrix: |
|
p.requires_grad=False |
|
|
|
self.voc_dim=voc_dim |
|
self.num_adj=num_adj |
|
self.hid_dim=hid_dim |
|
self.out_dim=out_dim |
|
|
|
for i in range(self.num_adj): |
|
setattr(self, 'W%d_vh'%i, nn.Parameter(torch.randn(voc_dim, hid_dim))) |
|
|
|
self.fc_hc=nn.Linear(hid_dim,out_dim) |
|
self.act_func = nn.ReLU() |
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for n,p in self.named_parameters(): |
|
if n.startswith('W') : |
|
init.kaiming_uniform_(p, a=math.sqrt(5)) |
|
|
|
def forward(self, X_dv, add_linear_mapping_term=False): |
|
for i in range(self.num_adj): |
|
H_vh=self.adj_matrix[i].mm(getattr(self, 'W%d_vh'%i)) |
|
|
|
H_vh=self.dropout(H_vh) |
|
H_dh=X_dv.matmul(H_vh) |
|
|
|
if add_linear_mapping_term: |
|
H_linear=X_dv.matmul(getattr(self, 'W%d_vh'%i)) |
|
H_linear=self.dropout(H_linear) |
|
H_dh+=H_linear |
|
|
|
if i == 0: |
|
fused_H = H_dh |
|
else: |
|
fused_H += H_dh |
|
|
|
out=self.fc_hc(fused_H) |
|
return out |
|
|
|
|
|
class VGCNBertEmbeddings(BertEmbeddings): |
|
"""Construct the embeddings from word, VGCN graph, position and token_type embeddings. |
|
|
|
Params: |
|
`config`: a BertConfig class instance with the configuration to build a new model |
|
`gcn_adj_dim`: The size of vocabulary graph |
|
`gcn_adj_num`: The number of the adjacency matrix of Vocabulary graph |
|
`gcn_embedding_dim`: The output dimension after VGCN |
|
|
|
Inputs: |
|
`vocab_adj_list`: The list of the adjacency matrix |
|
`gcn_swop_eye`: The transform matrix for transform the token sequence (sentence) to the Vocabulary order (BoW order) |
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
|
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts |
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`) |
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
|
a `sentence B` token (see BERT paper for more details). |
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
|
input sequence length in the current batch. It's the mask that we typically use for attention when |
|
a batch has varying length sentences. |
|
|
|
Outputs: |
|
the word embeddings fused by VGCN embedding, position embedding and token_type embeddings. |
|
|
|
""" |
|
def __init__(self, config, gcn_adj_matrix, gcn_adj_dim, gcn_adj_num, gcn_embedding_dim): |
|
super(VGCNBertEmbeddings, self).__init__(config) |
|
assert gcn_embedding_dim>=0 |
|
self.gcn_adj_matrix=gcn_adj_matrix |
|
self.gcn_embedding_dim=gcn_embedding_dim |
|
self.vocab_gcn=VocabGraphConvolution(gcn_adj_matrix,gcn_adj_dim, gcn_adj_num, 128, gcn_embedding_dim) |
|
|
|
def forward(self, gcn_swop_eye, input_ids, token_type_ids=None, attention_mask=None): |
|
words_embeddings = self.word_embeddings(input_ids) |
|
vocab_input=gcn_swop_eye.matmul(words_embeddings).transpose(1,2) |
|
|
|
if self.gcn_embedding_dim>0: |
|
gcn_vocab_out = self.vocab_gcn(vocab_input) |
|
|
|
gcn_words_embeddings=words_embeddings.clone() |
|
for i in range(self.gcn_embedding_dim): |
|
tmp_pos=(attention_mask.sum(-1)-2-self.gcn_embedding_dim+1+i)+torch.arange(0,input_ids.shape[0]).to(input_ids.device)*input_ids.shape[1] |
|
gcn_words_embeddings.flatten(start_dim=0, end_dim=1)[tmp_pos,:]=gcn_vocab_out[:,:,i] |
|
|
|
seq_length = input_ids.size(1) |
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
if self.gcn_embedding_dim>0: |
|
embeddings = gcn_words_embeddings + position_embeddings + token_type_embeddings |
|
else: |
|
embeddings = words_embeddings + position_embeddings + token_type_embeddings |
|
|
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class VGCN_Bert(BertModel): |
|
"""VGCN-BERT model for text classification. It inherits from Huggingface's BertModel. |
|
|
|
Params: |
|
`config`: a BertConfig class instance with the configuration to build a new model |
|
`gcn_adj_dim`: The size of vocabulary graph |
|
`gcn_adj_num`: The number of the adjacency matrix of Vocabulary graph |
|
`gcn_embedding_dim`: The output dimension after VGCN |
|
`num_labels`: the number of classes for the classifier. Default = 2. |
|
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False |
|
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. |
|
This can be used to compute head importance metrics. Default: False |
|
|
|
Inputs: |
|
`vocab_adj_list`: The list of the adjacency matrix |
|
`gcn_swop_eye`: The transform matrix for transform the token sequence (sentence) to the Vocabulary order (BoW order) |
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] |
|
with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts |
|
`extract_features.py`, `run_classifier.py` and `run_squad.py`) |
|
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token |
|
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to |
|
a `sentence B` token (see BERT paper for more details). |
|
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices |
|
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max |
|
input sequence length in the current batch. It's the mask that we typically use for attention when |
|
a batch has varying length sentences. |
|
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size] |
|
with indices selected in [0, ..., num_labels]. |
|
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. |
|
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. |
|
|
|
Outputs: |
|
Outputs the classification logits of shape [batch_size, num_labels]. |
|
|
|
""" |
|
def __init__(self, config, gcn_adj_matrix, gcn_adj_dim, gcn_adj_num, gcn_embedding_dim): |
|
super(VGCN_Bert, self).__init__(config) |
|
self.embeddings = VGCNBertEmbeddings(config,gcn_adj_matrix,gcn_adj_dim,gcn_adj_num, gcn_embedding_dim) |
|
self.encoder = BertEncoder(config) |
|
self.pooler = BertPooler(config) |
|
self.gcn_adj_matrix=gcn_adj_matrix |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
self.will_collect_cls_states=False |
|
self.all_cls_states=[] |
|
self.output_attentions=config.output_attentions |
|
|
|
|
|
|
|
def forward(self, gcn_swop_eye, input_ids, token_type_ids=None, attention_mask=None, output_hidden_states=False, head_mask=None): |
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(input_ids) |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
embedding_output = self.embeddings(gcn_swop_eye, input_ids, token_type_ids,attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) |
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
if head_mask is not None: |
|
if head_mask.dim() == 1: |
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1) |
|
elif head_mask.dim() == 2: |
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) |
|
else: |
|
head_mask = [None] * self.config.num_hidden_layers |
|
|
|
if self.output_attentions: |
|
output_all_encoded_layers=True |
|
encoded_layers = self.encoder(embedding_output, |
|
extended_attention_mask, |
|
output_hidden_states=output_hidden_states, |
|
head_mask=head_mask) |
|
if self.output_attentions: |
|
all_attentions, encoded_layers = encoded_layers |
|
|
|
pooled_output = self.pooler(encoded_layers[-1]) |
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
if self.output_attentions: |
|
return all_attentions, logits |
|
|
|
return logits |
|
|