from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class LSTMP(nn.Module): """LSTM with projection. PyTorch does not support exporting LSTM with projection to ONNX. This class reimplements LSTM with projection using basic matrix-matrix and matrix-vector operations. It is not intended for training. """ def __init__(self, lstm: nn.LSTM): """ Args: lstm: LSTM with proj_size. We support only uni-directional, 1-layer LSTM with projection at present. """ super().__init__() assert lstm.bidirectional is False, lstm.bidirectional assert lstm.num_layers == 1, lstm.num_layers assert 0 < lstm.proj_size < lstm.hidden_size, ( lstm.proj_size, lstm.hidden_size, ) assert lstm.batch_first is False, lstm.batch_first state_dict = lstm.state_dict() w_ih = state_dict["weight_ih_l0"] w_hh = state_dict["weight_hh_l0"] b_ih = state_dict["bias_ih_l0"] b_hh = state_dict["bias_hh_l0"] w_hr = state_dict["weight_hr_l0"] self.input_size = lstm.input_size self.proj_size = lstm.proj_size self.hidden_size = lstm.hidden_size self.w_ih = w_ih self.w_hh = w_hh self.b = b_ih + b_hh self.w_hr = w_hr def forward( self, input: torch.Tensor, hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: input: A tensor of shape [T, N, hidden_size] hx: A tuple containing: - h0: a tensor of shape (1, N, proj_size) - c0: a tensor of shape (1, N, hidden_size) Returns: Return a tuple containing: - output: a tensor of shape (T, N, proj_size). - A tuple containing: - h: a tensor of shape (1, N, proj_size) - c: a tensor of shape (1, N, hidden_size) """ x_list = input.unbind(dim=0) # We use batch_first=False if hx is not None: h0, c0 = hx else: h0 = torch.zeros(1, input.size(1), self.proj_size) c0 = torch.zeros(1, input.size(1), self.hidden_size) h0 = h0.squeeze(0) c0 = c0.squeeze(0) y_list = [] for x in x_list: gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh) i, f, g, o = gates.chunk(4, dim=1) i = i.sigmoid() f = f.sigmoid() g = g.tanh() o = o.sigmoid() c = f * c0 + i * g h = o * c.tanh() h = F.linear(h, self.w_hr) y_list.append(h) c0 = c h0 = h y = torch.stack(y_list, dim=0) return y, (h0.unsqueeze(0), c0.unsqueeze(0))