Update visual.py
Browse files
visual.py
CHANGED
@@ -189,7 +189,7 @@ class VisualAttention(nn.Module):
|
|
189 |
# query/key/value: [sq, b, h]
|
190 |
sq, b, _ = query.size()
|
191 |
|
192 |
-
assert query
|
193 |
sk = sq
|
194 |
mixed_x_layer = self.in_proj(query)
|
195 |
|
|
|
189 |
# query/key/value: [sq, b, h]
|
190 |
sq, b, _ = query.size()
|
191 |
|
192 |
+
assert torch.allclose(query, key), 'Only Support Self-Attention Currently'
|
193 |
sk = sq
|
194 |
mixed_x_layer = self.in_proj(query)
|
195 |
|