|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""StructFormer and transformer model.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import layers |
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput |
|
|
|
|
|
def cumprod(x, reverse=False, exclusive=False): |
|
"""cumulative product.""" |
|
if reverse: |
|
x = x.flip([-1]) |
|
|
|
if exclusive: |
|
x = F.pad(x[:, :, :-1], (1, 0), value=1) |
|
|
|
cx = x.cumprod(-1) |
|
|
|
if reverse: |
|
cx = cx.flip([-1]) |
|
return cx |
|
|
|
def cumsum(x, reverse=False, exclusive=False): |
|
"""cumulative sum.""" |
|
bsz, _, length = x.size() |
|
device = x.device |
|
if reverse: |
|
if exclusive: |
|
w = torch.ones([bsz, length, length], device=device).tril(-1) |
|
else: |
|
w = torch.ones([bsz, length, length], device=device).tril(0) |
|
cx = torch.bmm(x, w) |
|
else: |
|
if exclusive: |
|
w = torch.ones([bsz, length, length], device=device).triu(1) |
|
else: |
|
w = torch.ones([bsz, length, length], device=device).triu(0) |
|
cx = torch.bmm(x, w) |
|
return cx |
|
|
|
def cummin(x, reverse=False, exclusive=False, max_value=1e9): |
|
"""cumulative min.""" |
|
if reverse: |
|
if exclusive: |
|
x = F.pad(x[:, :, 1:], (0, 1), value=max_value) |
|
x = x.flip([-1]).cummin(-1)[0].flip([-1]) |
|
else: |
|
if exclusive: |
|
x = F.pad(x[:, :, :-1], (1, 0), value=max_value) |
|
x = x.cummin(-1)[0] |
|
return x |
|
|
|
class Transformer(nn.Module): |
|
"""Transformer model.""" |
|
|
|
def __init__(self, |
|
hidden_size, |
|
nlayers, |
|
ntokens, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=True, |
|
pos_emb=False, |
|
pad=0): |
|
"""Initialization. |
|
|
|
Args: |
|
hidden_size: dimension of inputs and hidden states |
|
nlayers: number of layers |
|
ntokens: number of output categories |
|
nhead: number of self-attention heads |
|
dropout: dropout rate |
|
dropatt: drop attention rate |
|
relative_bias: bool, indicate whether use a relative position based |
|
attention bias |
|
pos_emb: bool, indicate whether use a learnable positional embedding |
|
pad: pad token index |
|
""" |
|
|
|
super(Transformer, self).__init__() |
|
|
|
self.drop = nn.Dropout(dropout) |
|
|
|
self.emb = nn.Embedding(ntokens, hidden_size) |
|
if pos_emb: |
|
self.pos_emb = nn.Embedding(500, hidden_size) |
|
|
|
self.layers = nn.ModuleList([ |
|
layers.TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout, |
|
dropatt=dropatt, relative_bias=relative_bias) |
|
for _ in range(nlayers)]) |
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
self.output_layer = nn.Linear(hidden_size, ntokens) |
|
self.output_layer.weight = self.emb.weight |
|
|
|
self.init_weights() |
|
|
|
self.nlayers = nlayers |
|
self.nhead = nhead |
|
self.ntokens = ntokens |
|
self.hidden_size = hidden_size |
|
self.pad = pad |
|
|
|
def init_weights(self): |
|
"""Initialize token embedding and output bias.""" |
|
initrange = 0.1 |
|
self.emb.weight.data.uniform_(-initrange, initrange) |
|
if hasattr(self, 'pos_emb'): |
|
self.pos_emb.weight.data.uniform_(-initrange, initrange) |
|
self.output_layer.bias.data.fill_(0) |
|
|
|
def visibility(self, x, device): |
|
"""Mask pad tokens.""" |
|
visibility = (x != self.pad).float() |
|
visibility = visibility[:, None, :].expand(-1, x.size(1), -1) |
|
visibility = torch.repeat_interleave(visibility, self.nhead, dim=0) |
|
return visibility.log() |
|
|
|
def encode(self, x, pos): |
|
"""Standard transformer encode process.""" |
|
h = self.emb(x) |
|
if hasattr(self, 'pos_emb'): |
|
h = h + self.pos_emb(pos) |
|
h_list = [] |
|
visibility = self.visibility(x, x.device) |
|
|
|
for i in range(self.nlayers): |
|
h_list.append(h) |
|
h = self.layers[i]( |
|
h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1) |
|
|
|
output = h |
|
h_array = torch.stack(h_list, dim=2) |
|
|
|
return output, h_array |
|
|
|
def forward(self, x, pos): |
|
"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
x: input tokens (required). |
|
pos: position for each token (optional). |
|
Returns: |
|
output: probability distributions for missing tokens. |
|
state_dict: parsing results and raw output |
|
""" |
|
|
|
batch_size, length = x.size() |
|
|
|
raw_output, _ = self.encode(x, pos) |
|
raw_output = self.norm(raw_output) |
|
raw_output = self.drop(raw_output) |
|
|
|
output = self.output_layer(raw_output) |
|
return output.view(batch_size * length, -1), {'raw_output': raw_output,} |
|
|
|
class StructFormer(Transformer): |
|
"""StructFormer model.""" |
|
|
|
def __init__(self, |
|
hidden_size, |
|
nlayers, |
|
ntokens, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=False, |
|
pos_emb=False, |
|
pad=0, |
|
n_parser_layers=4, |
|
conv_size=9, |
|
relations=('head', 'child'), |
|
weight_act='softmax'): |
|
"""Initialization. |
|
|
|
Args: |
|
hidden_size: dimension of inputs and hidden states |
|
nlayers: number of layers |
|
ntokens: number of output categories |
|
nhead: number of self-attention heads |
|
dropout: dropout rate |
|
dropatt: drop attention rate |
|
relative_bias: bool, indicate whether use a relative position based |
|
attention bias |
|
pos_emb: bool, indicate whether use a learnable positional embedding |
|
pad: pad token index |
|
n_parser_layers: number of parsing layers |
|
conv_size: convolution kernel size for parser |
|
relations: relations that are used to compute self attention |
|
weight_act: relations distribution activation function |
|
""" |
|
|
|
super(StructFormer, self).__init__( |
|
hidden_size, |
|
nlayers, |
|
ntokens, |
|
nhead=nhead, |
|
dropout=dropout, |
|
dropatt=dropatt, |
|
relative_bias=relative_bias, |
|
pos_emb=pos_emb, |
|
pad=pad) |
|
|
|
self.parser_layers = nn.ModuleList([ |
|
nn.Sequential(layers.Conv1d(hidden_size, conv_size), |
|
nn.LayerNorm(hidden_size, elementwise_affine=False), |
|
nn.Tanh()) for i in range(n_parser_layers)]) |
|
|
|
self.distance_ff = nn.Sequential( |
|
layers.Conv1d(hidden_size, 2), |
|
nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(), |
|
nn.Linear(hidden_size, 1)) |
|
|
|
self.height_ff = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(), |
|
nn.Linear(hidden_size, 1)) |
|
|
|
n_rel = len(relations) |
|
self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel))) |
|
self._rel_weight.data.normal_(0, 0.1) |
|
|
|
self._scaler = nn.Parameter(torch.zeros(2)) |
|
|
|
self.n_parse_layers = n_parser_layers |
|
self.weight_act = weight_act |
|
self.relations = relations |
|
|
|
@property |
|
def scaler(self): |
|
return self._scaler.exp() |
|
|
|
@property |
|
def rel_weight(self): |
|
if self.weight_act == 'sigmoid': |
|
return torch.sigmoid(self._rel_weight) |
|
elif self.weight_act == 'softmax': |
|
return torch.softmax(self._rel_weight, dim=-1) |
|
|
|
def parse(self, x, pos): |
|
"""Parse input sentence. |
|
|
|
Args: |
|
x: input tokens (required). |
|
pos: position for each token (optional). |
|
Returns: |
|
distance: syntactic distance |
|
height: syntactic height |
|
""" |
|
|
|
mask = (x != self.pad) |
|
mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0) |
|
|
|
h = self.emb(x) |
|
for i in range(self.n_parse_layers): |
|
h = h.masked_fill(~mask[:, :, None], 0) |
|
h = self.parser_layers[i](h) |
|
|
|
height = self.height_ff(h).squeeze(-1) |
|
height.masked_fill_(~mask, -1e9) |
|
|
|
distance = self.distance_ff(h).squeeze(-1) |
|
distance.masked_fill_(~mask_shifted, 1e9) |
|
|
|
|
|
length = distance.size(1) |
|
height_max = height[:, None, :].expand(-1, length, -1) |
|
height_max = torch.cummax( |
|
height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9, |
|
dim=-1)[0].triu(0) |
|
|
|
margin_left = torch.relu( |
|
F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max) |
|
margin_right = torch.relu(distance[:, None, :] - height_max) |
|
margin = torch.where(margin_left > margin_right, margin_right, |
|
margin_left).triu(0) |
|
|
|
margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1) |
|
margin.masked_fill_(~margin_mask, 0) |
|
margin = margin.max() |
|
|
|
distance = distance - margin |
|
|
|
return distance, height |
|
|
|
def compute_block(self, distance, height): |
|
"""Compute constituents from distance and height.""" |
|
|
|
beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0] |
|
|
|
gamma = torch.sigmoid(-beta_logits) |
|
ones = torch.ones_like(gamma) |
|
|
|
block_mask_left = cummin( |
|
gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1) |
|
block_mask_left = block_mask_left - F.pad( |
|
block_mask_left[:, :, :-1], (1, 0), value=0) |
|
block_mask_left.tril_(0) |
|
|
|
block_mask_right = cummin( |
|
gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1) |
|
block_mask_right = block_mask_right - F.pad( |
|
block_mask_right[:, :, 1:], (0, 1), value=0) |
|
block_mask_right.triu_(0) |
|
|
|
block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :] |
|
block = cumsum(block_mask_left).tril(0) + cumsum( |
|
block_mask_right, reverse=True).triu(1) |
|
|
|
return block_p, block |
|
|
|
def compute_head(self, height): |
|
"""Estimate head for each constituent.""" |
|
|
|
_, length = height.size() |
|
head_logits = height * self.scaler[1] |
|
index = torch.arange(length, device=height.device) |
|
|
|
mask = (index[:, None, None] <= index[None, None, :]) * ( |
|
index[None, None, :] <= index[None, :, None]) |
|
head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1) |
|
head_logits.masked_fill_(~mask[None, :, :, :], -1e9) |
|
|
|
head_p = torch.softmax(head_logits, dim=-1) |
|
|
|
return head_p |
|
|
|
def generate_mask(self, x, distance, height): |
|
"""Compute head and cibling distribution for each token.""" |
|
|
|
bsz, length = x.size() |
|
|
|
eye = torch.eye(length, device=x.device, dtype=torch.bool) |
|
eye = eye[None, :, :].expand((bsz, -1, -1)) |
|
|
|
block_p, block = self.compute_block(distance, height) |
|
head_p = self.compute_head(height) |
|
head = torch.einsum('blij,bijh->blh', block_p, head_p) |
|
head = head.masked_fill(eye, 0) |
|
child = head.transpose(1, 2) |
|
cibling = torch.bmm(head, child).masked_fill(eye, 0) |
|
|
|
rel_list = [] |
|
if 'head' in self.relations: |
|
rel_list.append(head) |
|
if 'child' in self.relations: |
|
rel_list.append(child) |
|
if 'cibling' in self.relations: |
|
rel_list.append(cibling) |
|
|
|
rel = torch.stack(rel_list, dim=1) |
|
|
|
rel_weight = self.rel_weight |
|
|
|
dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel) |
|
att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length) |
|
|
|
return att_mask, cibling, head, block |
|
|
|
def encode(self, x, pos, att_mask): |
|
"""Structformer encoding process.""" |
|
|
|
visibility = self.visibility(x, x.device) |
|
h = self.emb(x) |
|
if hasattr(self, 'pos_emb'): |
|
assert pos.max() < 500 |
|
h = h + self.pos_emb(pos) |
|
for i in range(self.nlayers): |
|
h = self.layers[i]( |
|
h.transpose(0, 1), attn_mask=att_mask[i], |
|
key_padding_mask=visibility).transpose(0, 1) |
|
return h |
|
|
|
def forward(self, x, pos): |
|
"""Pass the input through the encoder layer. |
|
|
|
Args: |
|
x: input tokens (required). |
|
pos: position for each token (optional). |
|
Returns: |
|
output: probability distributions for missing tokens. |
|
state_dict: parsing results and raw output |
|
""" |
|
|
|
batch_size, length = x.size() |
|
|
|
distance, height = self.parse(x, pos) |
|
att_mask, cibling, head, block = self.generate_mask(x, distance, height) |
|
|
|
raw_output = self.encode(x, pos, att_mask) |
|
raw_output = self.norm(raw_output) |
|
raw_output = self.drop(raw_output) |
|
|
|
output = self.output_layer(raw_output) |
|
|
|
return output.view(batch_size * length, -1), \ |
|
{'raw_output': raw_output, 'distance': distance, 'height': height, |
|
'cibling': cibling, 'head': head, 'block': block} |
|
|
|
|
|
|
|
|
|
|
|
class ClassificationHead(nn.Module): |
|
"""Head for sentence-level classification tasks.""" |
|
def __init__(self, config): |
|
super(ClassificationHead, self).__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, features, **kwargs): |
|
x = features[:, 0, :] |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
|
|
|
|
class StructFormerConfig(PretrainedConfig): |
|
model_type = "structformer" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=512, |
|
nlayers=8, |
|
ntokens=10_000, |
|
nhead=8, |
|
dropout=0.1, |
|
dropatt=0.1, |
|
relative_bias=False, |
|
pos_emb=False, |
|
pad=0, |
|
n_parser_layers=4, |
|
conv_size=9, |
|
relations=('head', 'child'), |
|
weight_act='softmax', |
|
num_labels=1, |
|
hidden_dropout_prob=0.1, |
|
initializer_range=0.02, |
|
**kwargs, |
|
): |
|
self.hidden_size = hidden_size |
|
self.nlayers = nlayers |
|
self.ntokens = ntokens |
|
self.nhead = nhead |
|
self.dropout = dropout |
|
self.dropatt = dropatt |
|
self.relative_bias = relative_bias |
|
self.pos_emb = pos_emb |
|
self.pad = pad |
|
self.n_parser_layers = n_parser_layers |
|
self.conv_size = conv_size |
|
self.relations = relations |
|
self.weight_act = weight_act |
|
self.num_labels = num_labels |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.initializer_range=initializer_range |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
class StructFormerModel(PreTrainedModel): |
|
config_class = StructFormerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = StructFormer( |
|
hidden_size=config.hidden_size, |
|
nlayers=config.nlayers, |
|
ntokens=config.ntokens, |
|
nhead=config.nhead, |
|
dropout=config.dropout, |
|
dropatt=config.dropatt, |
|
relative_bias=config.relative_bias, |
|
pos_emb=config.pos_emb, |
|
pad=config.pad, |
|
n_parser_layers=config.n_parser_layers, |
|
conv_size=config.conv_size, |
|
relations=config.relations, |
|
weight_act=config.weight_act |
|
) |
|
self.config = config |
|
|
|
def parse(self, input_ids, **kwargs): |
|
x = input_ids |
|
batch_size, length = x.size() |
|
pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length) |
|
|
|
sf_output = self.model(x, pos) |
|
|
|
return sf_output[1] |
|
|
|
def forward(self, input_ids, labels=None, **kwargs): |
|
x = input_ids |
|
batch_size, length = x.size() |
|
pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length) |
|
|
|
sf_output = self.model(x, pos) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(sf_output[0], labels.reshape(-1)) |
|
|
|
return MaskedLMOutput( |
|
loss=loss, |
|
logits=sf_output[0].view(batch_size, length, -1), |
|
hidden_states=None, |
|
attentions=None |
|
) |
|
|
|
class StructFormerModelForSequenceClassification(PreTrainedModel): |
|
config_class = StructFormerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = StructFormer( |
|
hidden_size=config.hidden_size, |
|
nlayers=config.nlayers, |
|
ntokens=config.ntokens, |
|
nhead=config.nhead, |
|
dropout=config.dropout, |
|
dropatt=config.dropatt, |
|
relative_bias=config.relative_bias, |
|
pos_emb=config.pos_emb, |
|
pad=config.pad, |
|
n_parser_layers=config.n_parser_layers, |
|
conv_size=config.conv_size, |
|
relations=config.relations, |
|
weight_act=config.weight_act |
|
) |
|
self.config = config |
|
self.model.classifier = ClassificationHead(config) |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def forward(self, input_ids, labels=None, **kwargs): |
|
x = input_ids |
|
batch_size, length = x.size() |
|
pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length) |
|
|
|
sf_output = self.model(x, pos) |
|
|
|
logits = self.model.classifier(sf_output[1]['raw_output']) |
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = nn.MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |