music2emo-youtube-link-ja / utils /transformer_modules.py
kjysmu's picture
Upload 22 files
6ad6801 verified
raw
history blame
10.5 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
def _gen_bias_mask(max_length):
"""
Generates bias values (-Inf) to mask future timesteps during attention
"""
np_mask = np.triu(np.full([max_length, max_length], -np.inf), 1)
torch_mask = torch.from_numpy(np_mask).type(torch.FloatTensor)
return torch_mask.unsqueeze(0).unsqueeze(1)
def _gen_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):
"""
Generates a [1, length, channels] timing signal consisting of sinusoids
Adapted from:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
"""
position = np.arange(length)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(float(num_timescales) - 1))
inv_timescales = min_timescale * np.exp(
np.arange(num_timescales).astype(np.float64) * -log_timescale_increment)
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, channels % 2]],
'constant', constant_values=[0.0, 0.0])
signal = signal.reshape([1, length, channels])
return torch.from_numpy(signal).type(torch.FloatTensor)
class LayerNorm(nn.Module):
# Borrowed from jekbradbury
# https://github.com/pytorch/pytorch/issues/1959
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class OutputLayer(nn.Module):
"""
Abstract base class for output layer.
Handles projection to output labels
"""
def __init__(self, hidden_size, output_size, probs_out=False):
super(OutputLayer, self).__init__()
self.output_size = output_size
self.output_projection = nn.Linear(hidden_size, output_size)
self.probs_out = probs_out
self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=int(hidden_size/2), batch_first=True, bidirectional=True)
self.hidden_size = hidden_size
def loss(self, hidden, labels):
raise NotImplementedError('Must implement {}.loss'.format(self.__class__.__name__))
class SoftmaxOutputLayer(OutputLayer):
"""
Implements a softmax based output layer
"""
def forward(self, hidden):
logits = self.output_projection(hidden)
probs = F.softmax(logits, -1)
# _, predictions = torch.max(probs, dim=-1)
topk, indices = torch.topk(probs, 2)
predictions = indices[:,:,0]
second = indices[:,:,1]
if self.probs_out is True:
return logits
# return probs
return predictions, second
def loss(self, hidden, labels):
logits = self.output_projection(hidden)
log_probs = F.log_softmax(logits, -1)
return F.nll_loss(log_probs.view(-1, self.output_size), labels.view(-1))
class MultiHeadAttention(nn.Module):
"""
Multi-head attention as per https://arxiv.org/pdf/1706.03762.pdf
Refer Figure 2
"""
def __init__(self, input_depth, total_key_depth, total_value_depth, output_depth,
num_heads, bias_mask=None, dropout=0.0, attention_map=False):
"""
Parameters:
input_depth: Size of last dimension of input
total_key_depth: Size of last dimension of keys. Must be divisible by num_head
total_value_depth: Size of last dimension of values. Must be divisible by num_head
output_depth: Size last dimension of the final output
num_heads: Number of attention heads
bias_mask: Masking tensor to prevent connections to future elements
dropout: Dropout probability (Should be non-zero only during training)
"""
super(MultiHeadAttention, self).__init__()
# Checks borrowed from
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py
if total_key_depth % num_heads != 0:
raise ValueError("Key depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_key_depth, num_heads))
if total_value_depth % num_heads != 0:
raise ValueError("Value depth (%d) must be divisible by the number of "
"attention heads (%d)." % (total_value_depth, num_heads))
self.attention_map = attention_map
self.num_heads = num_heads
self.query_scale = (total_key_depth // num_heads) ** -0.5
self.bias_mask = bias_mask
# Key and query depth will be same
self.query_linear = nn.Linear(input_depth, total_key_depth, bias=False)
self.key_linear = nn.Linear(input_depth, total_key_depth, bias=False)
self.value_linear = nn.Linear(input_depth, total_value_depth, bias=False)
self.output_linear = nn.Linear(total_value_depth, output_depth, bias=False)
self.dropout = nn.Dropout(dropout)
def _split_heads(self, x):
"""
Split x such to add an extra num_heads dimension
Input:
x: a Tensor with shape [batch_size, seq_length, depth]
Returns:
A Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
"""
if len(x.shape) != 3:
raise ValueError("x must have rank 3")
shape = x.shape
return x.view(shape[0], shape[1], self.num_heads, shape[2] // self.num_heads).permute(0, 2, 1, 3)
def _merge_heads(self, x):
"""
Merge the extra num_heads into the last dimension
Input:
x: a Tensor with shape [batch_size, num_heads, seq_length, depth/num_heads]
Returns:
A Tensor with shape [batch_size, seq_length, depth]
"""
if len(x.shape) != 4:
raise ValueError("x must have rank 4")
shape = x.shape
return x.permute(0, 2, 1, 3).contiguous().view(shape[0], shape[2], shape[3] * self.num_heads)
def forward(self, queries, keys, values):
# Do a linear for each component
queries = self.query_linear(queries)
keys = self.key_linear(keys)
values = self.value_linear(values)
# Split into multiple heads
queries = self._split_heads(queries)
keys = self._split_heads(keys)
values = self._split_heads(values)
# Scale queries
queries *= self.query_scale
# Combine queries and keys
logits = torch.matmul(queries, keys.permute(0, 1, 3, 2))
# Add bias to mask future values
if self.bias_mask is not None:
logits += self.bias_mask[:, :, :logits.shape[-2], :logits.shape[-1]].type_as(logits.data)
# Convert to probabilites
weights = nn.functional.softmax(logits, dim=-1)
# Dropout
weights = self.dropout(weights)
# Combine with values to get context
contexts = torch.matmul(weights, values)
# Merge heads
contexts = self._merge_heads(contexts)
# contexts = torch.tanh(contexts)
# Linear to get output
outputs = self.output_linear(contexts)
if self.attention_map is True:
return outputs, weights
return outputs
class Conv(nn.Module):
"""
Convenience class that does padding and convolution for inputs in the format
[batch_size, sequence length, hidden size]
"""
def __init__(self, input_size, output_size, kernel_size, pad_type):
"""
Parameters:
input_size: Input feature size
output_size: Output feature size
kernel_size: Kernel width
pad_type: left -> pad on the left side (to mask future data_loader),
both -> pad on both sides
"""
super(Conv, self).__init__()
padding = (kernel_size - 1, 0) if pad_type == 'left' else (kernel_size // 2, (kernel_size - 1) // 2)
self.pad = nn.ConstantPad1d(padding, 0)
self.conv = nn.Conv1d(input_size, output_size, kernel_size=kernel_size, padding=0)
def forward(self, inputs):
inputs = self.pad(inputs.permute(0, 2, 1))
outputs = self.conv(inputs).permute(0, 2, 1)
return outputs
class PositionwiseFeedForward(nn.Module):
"""
Does a Linear + RELU + Linear on each of the timesteps
"""
def __init__(self, input_depth, filter_size, output_depth, layer_config='ll', padding='left', dropout=0.0):
"""
Parameters:
input_depth: Size of last dimension of input
filter_size: Hidden size of the middle layer
output_depth: Size last dimension of the final output
layer_config: ll -> linear + ReLU + linear
cc -> conv + ReLU + conv etc.
padding: left -> pad on the left side (to mask future data_loader),
both -> pad on both sides
dropout: Dropout probability (Should be non-zero only during training)
"""
super(PositionwiseFeedForward, self).__init__()
layers = []
sizes = ([(input_depth, filter_size)] +
[(filter_size, filter_size)] * (len(layer_config) - 2) +
[(filter_size, output_depth)])
for lc, s in zip(list(layer_config), sizes):
if lc == 'l':
layers.append(nn.Linear(*s))
elif lc == 'c':
layers.append(Conv(*s, kernel_size=3, pad_type=padding))
else:
raise ValueError("Unknown layer type {}".format(lc))
self.layers = nn.ModuleList(layers)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers):
x = self.relu(x)
x = self.dropout(x)
return x