|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import math |
|
|
|
def _gen_bias_mask(max_length): |
|
""" |
|
Generates bias values (-Inf) to mask future timesteps during attention |
|
""" |
|
np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1) |
|
torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor) |
|
return torch_mask.unsqueeze(0).unsqueeze(1) |
|
|
|
def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4): |
|
""" |
|
Generates a [1, length, channels] timing signal consisting of sinusoids |
|
Adapted from: |
|
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py |
|
""" |
|
position = np.arange(length) |
|
num_timescales = channels // 2 |
|
log_timescale_increment = ( |
|
math.log(float(max_timescale) / float(min_timescale)) / |
|
(float(num_timescales) - 1)) |
|
inv_timescales = min_timescale * np.exp( |
|
np.arange(num_timescales).astype(np.float64) * -log_timescale_increment) |
|
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) |
|
|
|
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) |
|
signal = np.pad(signal, [[0, 0], [0, channels % 2]], |
|
'constant', constant_values=[0.0, 0.0]) |
|
signal = signal.reshape([1, length, channels]) |
|
|
|
return torch.from_numpy(signal).type(torch.FloatTensor) |
|
|
|
class LayerNorm(nn.Module): |
|
|
|
|
|
def __init__(self, features, eps=1e-6): |
|
super(LayerNorm, self).__init__() |
|
self.gamma = nn.Parameter(torch.ones(features)) |
|
self.beta = nn.Parameter(torch.zeros(features)) |
|
self.eps = eps |
|
|
|
def forward(self, x): |
|
mean = x.mean(-1, keepdim=True) |
|
std = x.std(-1, keepdim=True) |
|
return self.gamma * (x - mean) / (std + self.eps) + self.beta |
|
|
|
class OutputLayer(nn.Module): |
|
""" |
|
Abstract base class for output layer. |
|
Handles projection to output labels |
|
""" |
|
def __init__(self, hidden_size, output_size, probs_out=False): |
|
super(OutputLayer, self).__init__() |
|
self.output_size = output_size |
|
self.output_projection = nn.Linear(hidden_size, output_size) |
|
self.probs_out = probs_out |
|
self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True) |
|
self.hidden_size = hidden_size |
|
|
|
def loss(self, hidden, labels): |
|
raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__)) |
|
|
|
class SoftmaxOutputLayer(OutputLayer): |
|
""" |
|
Implements a softmax based output layer |
|
""" |
|
def forward(self, hidden): |
|
logits = self.output_projection(hidden) |
|
probs = F.softmax(logits, -1) |
|
|
|
topk, indices = torch.topk(probs, 2) |
|
predictions = indices[:,:,0] |
|
second = indices[:,:,1] |
|
if self.probs_out is True: |
|
return logits |
|
|
|
return predictions, second |
|
|
|
def loss(self, hidden, labels): |
|
logits = self.output_projection(hidden) |
|
log_probs = F.log_softmax(logits, -1) |
|
return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1)) |
|
|
|
class MultiHeadAttention(nn.Module): |
|
""" |
|
Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf |
|
Refer Figure 2 |
|
""" |
|
|
|
def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth, |
|
num_heads, bias_mask=None, dropout=0.0, attention_map=False): |
|
""" |
|
Parameters: |
|
input_depth: Size of last dimension of input |
|
total_key_depth: Size of last dimension of keys. Must be divisible by num_head |
|
total_value_depth: Size of last dimension of values. Must be divisible by num_head |
|
output_depth: Size last dimension of the final output |
|
num_heads: Number of attention heads |
|
bias_mask: Masking tensor to prevent connections to future elements |
|
dropout: Dropout probability (Should be non-zero only during training) |
|
""" |
|
super(MultiHeadAttention, self).__init__() |
|
|
|
|
|
if total_key_depth % num_heads != 0: |
|
raise ValueError("Key depth (%d) must be divisible by the number of " |
|
"attention heads (%d)." % (total_key_depth, num_heads)) |
|
if total_value_depth % num_heads != 0: |
|
raise ValueError("Value depth (%d) must be divisible by the number of " |
|
"attention heads (%d)." % (total_value_depth, num_heads)) |
|
|
|
self.attention_map = attention_map |
|
|
|
self.num_heads = num_heads |
|
self.query_scale = (total_key_depth // num_heads) ** -0.5 |
|
self.bias_mask = bias_mask |
|
|
|
|
|
self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False) |
|
self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False) |
|
self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False) |
|
self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def _split_heads(self, x): |
|
""" |
|
Split x such to add an extra num_heads dimension |
|
Input: |
|
x: a Tensor with shape [batch_size, seq_length, depth] |
|
Returns: |
|
A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] |
|
""" |
|
if len(x.shape) != 3: |
|
raise ValueError("x must have rank 3") |
|
shape = x.shape |
|
return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3) |
|
|
|
def _merge_heads(self, x): |
|
""" |
|
Merge the extra num_heads into the last dimension |
|
Input: |
|
x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads] |
|
Returns: |
|
A Tensor with shape [batch_size, seq_length, depth] |
|
""" |
|
if len(x.shape) != 4: |
|
raise ValueError("x must have rank 4") |
|
shape = x.shape |
|
return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads) |
|
|
|
def forward(self, queries, keys, values): |
|
|
|
|
|
queries = self.query_linear(queries) |
|
keys = self.key_linear(keys) |
|
values = self.value_linear(values) |
|
|
|
|
|
queries = self._split_heads(queries) |
|
keys = self._split_heads(keys) |
|
values = self._split_heads(values) |
|
|
|
|
|
queries *= self.query_scale |
|
|
|
|
|
logits = torch.matmul(queries, keys.permute(0, 1, 3, 2)) |
|
|
|
|
|
if self.bias_mask is not None: |
|
logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data) |
|
|
|
|
|
weights = nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
weights = self.dropout(weights) |
|
|
|
|
|
contexts = torch.matmul(weights, values) |
|
|
|
|
|
contexts = self._merge_heads(contexts) |
|
|
|
|
|
|
|
outputs = self.output_linear(contexts) |
|
|
|
if self.attention_map is True: |
|
return outputs, weights |
|
|
|
return outputs |
|
|
|
|
|
class Conv(nn.Module): |
|
""" |
|
Convenience class that does padding and convolution for inputs in the format |
|
[batch_size, sequence length, hidden size] |
|
""" |
|
|
|
def __init__(self, input_size, output_size, kernel_size, pad_type): |
|
""" |
|
Parameters: |
|
input_size: Input feature size |
|
output_size: Output feature size |
|
kernel_size: Kernel width |
|
pad_type: left -> pad on the left side (to mask future data_loader), |
|
both -> pad on both sides |
|
""" |
|
super(Conv, self).__init__() |
|
padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2) |
|
self.pad = nn.ConstantPad1d(padding, 0) |
|
self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0) |
|
|
|
def forward(self, inputs): |
|
inputs = self.pad(inputs.permute(0, 2, 1)) |
|
outputs = self.conv(inputs).permute(0, 2, 1) |
|
|
|
return outputs |
|
|
|
|
|
class PositionwiseFeedForward(nn.Module): |
|
""" |
|
Does a Linear + RELU + Linear on each of the timesteps |
|
""" |
|
|
|
def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0): |
|
""" |
|
Parameters: |
|
input_depth: Size of last dimension of input |
|
filter_size: Hidden size of the middle layer |
|
output_depth: Size last dimension of the final output |
|
layer_config: ll -> linear + ReLU + linear |
|
cc -> conv + ReLU + conv etc. |
|
padding: left -> pad on the left side (to mask future data_loader), |
|
both -> pad on both sides |
|
dropout: Dropout probability (Should be non-zero only during training) |
|
""" |
|
super(PositionwiseFeedForward, self).__init__() |
|
|
|
layers = [] |
|
sizes = ([(input_depth, filter_size)] + |
|
[(filter_size, filter_size)] * (len(layer_config) - 2) + |
|
[(filter_size, output_depth)]) |
|
|
|
for lc, s in zip(list(layer_config), sizes): |
|
if lc == 'l': |
|
layers.append(nn.Linear(*s)) |
|
elif lc == 'c': |
|
layers.append(Conv(*s, kernel_size=3, pad_type=padding)) |
|
else: |
|
raise ValueError("Unknown layer type {}".format(lc)) |
|
|
|
self.layers = nn.ModuleList(layers) |
|
self.relu = nn.ReLU() |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, inputs): |
|
x = inputs |
|
for i, layer in enumerate(self.layers): |
|
x = layer(x) |
|
if i < len(self.layers): |
|
x = self.relu(x) |
|
x = self.dropout(x) |
|
|
|
return x |