File size: 3,398 Bytes
3b36ece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.

import torch
import torch.nn as nn
from nncore.nn import MODELS, build_model


@MODELS.register()
class R2Block(nn.Module):

    def __init__(self,
                 dims,
                 in_dims,
                 k=4,
                 dropout=0.5,
                 use_tef=True,
                 pos_cfg=None,
                 tem_cfg=None):
        super(R2Block, self).__init__()

        # yapf:disable
        self.video_map = nn.Sequential(
            nn.LayerNorm((in_dims[0] + 2) if use_tef else in_dims[0]),
            nn.Dropout(dropout),
            nn.Linear((in_dims[0] + 2) if use_tef else in_dims[0], dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(dims),
            nn.Dropout(dropout),
            nn.Linear(dims, dims))

        self.query_map = nn.Sequential(
            nn.LayerNorm(in_dims[1]),
            nn.Dropout(dropout),
            nn.Linear(in_dims[1], dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(dims),
            nn.Dropout(dropout),
            nn.Linear(dims, dims))
        # yapf:enable

        if k > 1:
            self.gate = nn.Parameter(torch.zeros([k - 1]))

        self.v_map = nn.Linear(dims, dims)
        self.q_map = nn.Linear(dims, dims)
        self.scale = nn.Parameter(torch.zeros([k]))

        self.pos = build_model(pos_cfg, dims=dims)
        self.tem = build_model(tem_cfg, dims=dims)

        self.dims = dims
        self.in_dims = in_dims
        self.k = k
        self.dropout = dropout
        self.use_tef = use_tef

    def forward(self, video_emb, query_emb, video_msk, query_msk):
        video_emb = video_emb[-self.k:]
        query_emb = query_emb[-self.k:]

        _, b, t, p, _ = video_emb.size()

        if self.use_tef:
            tef_s = torch.arange(0, 1, 1 / t, device=video_emb.device)
            tef_e = tef_s + 1.0 / t
            tef = torch.stack((tef_s, tef_e), dim=1)
            tef = tef.unsqueeze(1).unsqueeze(0).unsqueeze(0).repeat(self.k, b, 1, p, 1)
            video_emb = torch.cat((video_emb, tef[:, :, :video_emb.size(2)]), dim=-1)

        coll_v, coll_q, last = [], [], None
        for i in range(self.k - 1, -1, -1):
            v_emb = self.video_map(video_emb[i])  # B * T * P * C
            q_emb = self.query_map(query_emb[i])  # B * L * C

            coll_v.append(v_emb[:, :, 0])
            coll_q.append(q_emb)

            v_pool = v_emb.view(b * t, -1, self.dims)  # BT * P * C
            q_pool = q_emb.repeat_interleave(t, dim=0)  # BT * L * C

            v_pool_map = self.v_map(v_pool)  # BT * P * C
            q_pool_map = self.q_map(q_pool)  # BT * L * C

            att = torch.bmm(q_pool_map, v_pool_map.transpose(1, 2)) / self.dims**0.5
            att = att.softmax(-1)  # BT * L * P

            o_pool = torch.bmm(att, v_pool) + q_pool  # BT * L * C
            o_pool = o_pool.amax(dim=1, keepdim=True)  # BT * 1 * C
            v_emb = v_pool[:, 0, None] + o_pool * self.scale[i].tanh()
            v_emb = v_emb.view(b, t, self.dims)  # B * T * C

            if i < self.k - 1:
                gate = self.gate[i].sigmoid()
                v_emb = gate * v_emb + (1 - gate) * last

            v_pe = self.pos(v_emb)
            last = self.tem(v_emb, q_emb, q_pe=v_pe, q_mask=video_msk, k_mask=query_msk)

        return last, q_emb, coll_v, coll_q