File size: 2,970 Bytes
46455cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMP(nn.Module):
    """LSTM with projection.

    PyTorch does not support exporting LSTM with projection to ONNX.
    This class reimplements LSTM with projection using basic matrix-matrix
    and matrix-vector operations. It is not intended for training.
    """

    def __init__(self, lstm: nn.LSTM):
        """
        Args:
          lstm:
            LSTM with proj_size. We support only uni-directional,
            1-layer LSTM with projection at present.
        """
        super().__init__()
        assert lstm.bidirectional is False, lstm.bidirectional
        assert lstm.num_layers == 1, lstm.num_layers
        assert 0 < lstm.proj_size < lstm.hidden_size, (
            lstm.proj_size,
            lstm.hidden_size,
        )

        assert lstm.batch_first is False, lstm.batch_first

        state_dict = lstm.state_dict()

        w_ih = state_dict["weight_ih_l0"]
        w_hh = state_dict["weight_hh_l0"]

        b_ih = state_dict["bias_ih_l0"]
        b_hh = state_dict["bias_hh_l0"]

        w_hr = state_dict["weight_hr_l0"]
        self.input_size = lstm.input_size
        self.proj_size = lstm.proj_size
        self.hidden_size = lstm.hidden_size

        self.w_ih = w_ih
        self.w_hh = w_hh
        self.b = b_ih + b_hh
        self.w_hr = w_hr

    def forward(
        self,
        input: torch.Tensor,
        hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Args:
          input:
            A tensor of shape [T, N, hidden_size]
          hx:
            A tuple containing:
              - h0: a tensor of shape (1, N, proj_size)
              - c0: a tensor of shape (1, N, hidden_size)
        Returns:
          Return a tuple containing:
            - output: a tensor of shape (T, N, proj_size).
            - A tuple containing:
               - h: a tensor of shape (1, N, proj_size)
               - c: a tensor of shape (1, N, hidden_size)

        """
        x_list = input.unbind(dim=0)  # We use batch_first=False

        if hx is not None:
            h0, c0 = hx
        else:
            h0 = torch.zeros(1, input.size(1), self.proj_size)
            c0 = torch.zeros(1, input.size(1), self.hidden_size)
        h0 = h0.squeeze(0)
        c0 = c0.squeeze(0)
        y_list = []
        for x in x_list:
            gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh)
            i, f, g, o = gates.chunk(4, dim=1)

            i = i.sigmoid()
            f = f.sigmoid()
            g = g.tanh()
            o = o.sigmoid()

            c = f * c0 + i * g
            h = o * c.tanh()

            h = F.linear(h, self.w_hr)
            y_list.append(h)

            c0 = c
            h0 = h

        y = torch.stack(y_list, dim=0)

        return y, (h0.unsqueeze(0), c0.unsqueeze(0))