zhiyuan8 commited on
Commit
b3d89c3
·
verified ·
1 Parent(s): 70689a4

Upload wkv.py

Browse files
Files changed (1) hide show
  1. wkv.py +141 -98
wkv.py CHANGED
@@ -6,6 +6,8 @@ import math
6
  import torch.nn as nn
7
  from torch.nn import functional as F
8
  from .configuration_rwkv_hybrid import RwkvHybridConfig
 
 
9
 
10
  try:
11
  import triton
@@ -13,6 +15,7 @@ try:
13
  fused_recurrent_rwkv7,
14
  chunk_rwkv7,
15
  native_recurrent_rwkv7,
 
16
  ) # pylint: disable=C0411
17
  from rwkvfla.ops.rwkv6 import (
18
  fused_recurrent_rwkv6,
@@ -22,11 +25,13 @@ try:
22
  except ImportError:
23
  from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411
24
  from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6
 
25
 
26
  fused_recurrent_rwkv7 = native_recurrent_rwkv7
27
  chunk_rwkv7 = native_recurrent_rwkv7
28
  chunk_rwkv6 = native_recurrent_rwkv6
29
  fused_recurrent_rwkv6 = native_recurrent_rwkv6
 
30
 
31
 
32
  class Rwkv_Tmix_x070(nn.Module):
@@ -50,8 +55,7 @@ class Rwkv_Tmix_x070(nn.Module):
50
  self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
51
  self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
52
  self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
53
- self.x_g = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
54
-
55
  D_DECAY_LORA = 64
56
  D_AAA_LORA = 64
57
  D_MV_LORA = 32
@@ -70,6 +74,7 @@ class Rwkv_Tmix_x070(nn.Module):
70
  self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
71
 
72
  if self.args.wkv_has_gate:
 
73
  self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA))
74
  self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size))
75
 
@@ -78,7 +83,8 @@ class Rwkv_Tmix_x070(nn.Module):
78
  self.r_k = nn.Parameter(torch.Tensor(H, N))
79
 
80
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
81
- self.receptance = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
 
82
  self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
83
  self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
84
  self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
@@ -90,7 +96,8 @@ class Rwkv_Tmix_x070(nn.Module):
90
 
91
  def post_init(self):
92
  with torch.no_grad():
93
- ratio_0_to_1 = self.layer_id / (self.args.num_hidden_layers - 1) # 0 to 1
 
94
  ratio_1_to_almost0 = 1.0 - (
95
  self.layer_id / self.args.num_hidden_layers
96
  ) # 1 to ~0
@@ -99,39 +106,48 @@ class Rwkv_Tmix_x070(nn.Module):
99
  for i in range(self.args.hidden_size):
100
  ddd[0, 0, i] = i / self.args.hidden_size
101
 
102
- nn.init.constant_(self.x_r, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
103
- nn.init.constant_(self.x_w, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
 
 
104
  nn.init.constant_(
105
  self.x_k,
106
- 1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1),
 
107
  )
108
  nn.init.constant_(
109
  self.x_v,
110
- 1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1),
 
111
  )
112
- nn.init.constant_(self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
113
- nn.init.constant_(self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
 
114
 
115
  def ortho_init(x, scale):
116
  shape = x.shape
117
  original_dtype = x.dtype
118
  x_fp32 = x.float()
119
  if len(shape) == 2:
120
- gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
 
121
  nn.init.orthogonal_(x_fp32, gain=gain * scale)
122
  elif len(shape) == 3:
123
- gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
 
124
  for i in range(shape[0]):
125
  nn.init.orthogonal_(x_fp32[i], gain=gain * scale)
126
  else:
127
- raise ValueError("ortho_init only supports 2D or 3D tensors")
 
128
  x.data.copy_(x_fp32.to(original_dtype))
129
  return x
130
 
131
  D_DECAY_LORA = 64
132
  nn.init.zeros_(self.w1)
133
  self.w2 = nn.Parameter(
134
- ortho_init(torch.zeros(D_DECAY_LORA, self.args.hidden_size), 0.1)
 
135
  )
136
 
137
  decay_speed = torch.ones(self.args.hidden_size)
@@ -161,8 +177,11 @@ class Rwkv_Tmix_x070(nn.Module):
161
  if self.args.wkv_has_gate:
162
  nn.init.zeros_(self.g1)
163
  self.g2 = nn.Parameter(
164
- ortho_init(torch.zeros(D_GATE_LORA, self.args.hidden_size), 0.1)
 
165
  )
 
 
166
 
167
  nn.init.constant_(self.k_k, 0.85)
168
  nn.init.constant_(self.k_a, 1.0)
@@ -177,77 +196,68 @@ class Rwkv_Tmix_x070(nn.Module):
177
  nn.init.ones_(self.ln_x.weight)
178
  nn.init.zeros_(self.ln_x.bias)
179
 
180
- def apply_wkv7_state(self, r, k, v, w, a, b, s):
181
- r = rearrange(r, "b l (h d) -> b h l d", h=self.n_head)
182
- k = rearrange(k, "b l (h d) -> b h l d", h=self.n_head)
183
- v = rearrange(v, "b l (h d) -> b h l d", h=self.n_head)
184
- w = rearrange(w, "b l (h d) -> b h l d", h=self.n_head)
185
- a = rearrange(a, "b l (h d) -> b h l d", h=self.n_head)
186
- b = rearrange(b, "b l (h d) -> b h l d", h=self.n_head)
187
 
188
  if r.device.type == "cpu":
 
189
  o, state = native_recurrent_rwkv7(
190
- r,
191
- k,
192
- v,
193
- w,
194
- a,
195
- b,
196
  scale=1.0,
197
  initial_state=s.transpose(-1, -2),
198
  output_final_state=True,
199
- use_log_w=False,
200
  head_first=True,
201
  )
202
  state = state.transpose(-1, -2)
203
- elif self.training:
204
- o, state = chunk_rwkv7(
205
- r,
206
- k,
207
- v,
208
- w,
209
- a,
210
- b,
211
- scale=1.0,
212
- initial_state=s,
213
- output_final_state=True,
214
- use_log_w=False,
215
- head_first=True,
216
- )
217
  else:
218
- o, state = fused_recurrent_rwkv7(
219
- r,
220
- k,
221
- v,
222
- w,
223
- a,
224
- b,
225
  scale=1.0,
226
  initial_state=s,
227
- output_final_state=True,
228
- use_log_w=False,
229
- head_first=True,
230
  )
231
-
232
- x = rearrange(o, "b h l d -> b l (h d)")
233
  return x, state
234
 
235
- def forward(self, x, last_state: TimeMixState):
 
 
 
 
 
 
 
 
 
 
 
 
236
  shift_state = last_state.shift_state
237
- B, T, C = x.size()
238
- H = self.n_head
239
  if shift_state is not None:
240
- xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
 
241
  else:
242
- xx = self.time_shift(x) - x
243
- lx = x[:, -1]
244
 
245
- xr = x + xx * self.x_r
246
- xw = x + xx * self.x_w
247
- xk = x + xx * self.x_k
248
- xv = x + xx * self.x_v
249
- xa = x + xx * self.x_a
250
- xg = x + xx * self.x_g
 
251
 
252
  r = self.receptance(xr)
253
  w = (
@@ -269,11 +279,11 @@ class Rwkv_Tmix_x070(nn.Module):
269
  if self.args.wkv_has_gate:
270
  g = torch.sigmoid(xg @ self.g1) @ self.g2
271
  kk = k * self.k_k
272
- kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
273
  k = k * (1 + (a - 1) * self.k_a)
274
 
275
  wkv_state = last_state.wkv_state
276
- x, wkv_state = self.apply_wkv7_state(
277
  r,
278
  k,
279
  v,
@@ -281,17 +291,22 @@ class Rwkv_Tmix_x070(nn.Module):
281
  -kk,
282
  (kk * a),
283
  s=wkv_state,
 
 
 
284
  )
285
  if self.args.wkv_has_group_norm:
286
- x = self.ln_x(x.view(B * T, C)).view(B, T, C)
287
- x = x + (
288
- (r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(
 
289
  dim=-1, keepdim=True
290
  )
291
- * v.view(B, T, H, -1)
292
  ).view(B, T, C)
293
- x = self.output(x * g) if self.args.wkv_has_gate else self.output(x)
294
- return x, TimeMixState(lx, wkv_state)
 
295
 
296
 
297
  class Rwkv7Attention(nn.Module):
@@ -299,24 +314,43 @@ class Rwkv7Attention(nn.Module):
299
  super().__init__()
300
  self.args = args
301
  self.layer_idx = layer_id
302
- self.time_mixer = Rwkv_Tmix_x070(args, layer_id, update_v_first, get_v_first)
303
-
304
- def forward(self, hidden_states, past_key_value, **kwargs):
305
- attn_output = hidden_states
306
- batch_size, token_length, _ = attn_output.size()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
  if past_key_value is not None and len(past_key_value) > self.layer_idx:
309
  last_state = past_key_value[self.layer_idx][0]
310
  else:
311
  last_state = self.init_state(
312
- batch_size, attn_output.device, attn_output.dtype
313
  )
314
 
315
- attn_output, states = self.time_mixer(attn_output, last_state.time_mix_state)
 
 
 
 
316
  last_state.time_mix_state = states
317
 
318
  if past_key_value is not None:
319
  past_key_value.update(token_length, last_state, self.layer_idx)
 
320
  return attn_output, None
321
 
322
  def init_state(self, batch_size, device, dtype) -> BlockState:
@@ -357,9 +391,12 @@ class Rwkv_Tmix_x060(nn.Module):
357
  ddd[0, 0, i] = i / args.hidden_size
358
 
359
  # fancy time_mix
360
- self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
361
- self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
362
- self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
 
 
 
363
  self.time_maa_v = nn.Parameter(
364
  1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
365
  )
@@ -377,7 +414,8 @@ class Rwkv_Tmix_x060(nn.Module):
377
  torch.zeros(args.hidden_size, D_MIX_LORA * 5)
378
  )
379
  self.time_maa_w2 = nn.Parameter(
380
- torch.zeros(5, D_MIX_LORA, args.hidden_size).uniform_(-0.01, 0.01)
 
381
  )
382
 
383
  # fancy time_decay
@@ -386,7 +424,8 @@ class Rwkv_Tmix_x060(nn.Module):
386
  decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** (
387
  0.7 + 1.3 * ratio_0_to_1
388
  )
389
- self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.head_size))
 
390
 
391
  D_DECAY_LORA = 64
392
  if args.hidden_size == 4096:
@@ -401,13 +440,16 @@ class Rwkv_Tmix_x060(nn.Module):
401
  tmp = torch.zeros(args.head_size)
402
  for n in range(args.head_size):
403
  zigzag = ((n + 1) % 3 - 1) * 0.1
404
- tmp[n] = ratio_0_to_1 * (1 - (n / (args.head_size - 1))) + zigzag
 
405
 
406
- self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
 
407
  # self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
408
 
409
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
410
- self.receptance = nn.Linear(args.hidden_size, args.head_size, bias=False)
 
411
  self.key = nn.Linear(args.hidden_size, args.head_size, bias=False)
412
 
413
  self.value = nn.Linear(args.hidden_size, args.head_size, bias=False)
@@ -416,7 +458,8 @@ class Rwkv_Tmix_x060(nn.Module):
416
 
417
  if self.args.wkv_has_group_norm:
418
  self.ln_x = nn.GroupNorm(
419
- self.n_head, args.head_size, eps=(1e-5) * (args.head_size_divisor**2)
 
420
  )
421
 
422
  def post_init(self):
@@ -433,7 +476,8 @@ class Rwkv_Tmix_x060(nn.Module):
433
  lx = x[:, -1]
434
 
435
  xxx = x + xx * self.time_maa_x
436
- xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
 
437
  xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
438
  mw, mk, mv, mr, mg = xxx.unbind(dim=0)
439
 
@@ -461,10 +505,7 @@ class Rwkv_Tmix_x060(nn.Module):
461
  return x, TimeMixState(lx, wkv_state)
462
 
463
  def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
464
- r = rearrange(r, "b l (h d) -> b h l d", h=H)
465
- k = rearrange(k, "b l (h d) -> b h l d", h=H)
466
- v = rearrange(v, "b l (h d) -> b h l d", h=H)
467
- w = rearrange(w, "b l (h d) -> b h l d", h=H)
468
 
469
  if r.device.type == "cpu":
470
  wkv6_func = native_recurrent_rwkv6
@@ -504,7 +545,8 @@ class Rwkv6Attention(nn.Module):
504
  last_state = past_key_value[self.layer_idx][0]
505
  if last_state is None:
506
  wkv_states = torch.zeros(
507
- (B, self.args.num_wkv_heads, self.args.head_size, self.args.head_size),
 
508
  device=attn_output.device,
509
  dtype=torch.float32,
510
  )
@@ -514,7 +556,8 @@ class Rwkv6Attention(nn.Module):
514
  time_state = TimeMixState(token_shift, wkv_states)
515
  channel_state = None
516
  last_state = BlockState(time_state, channel_state)
517
- attn_output, states = self.time_mixer(attn_output, last_state.time_mix_state)
 
518
  last_state.time_mix_state = states
519
 
520
  if past_key_value is not None:
 
6
  import torch.nn as nn
7
  from torch.nn import functional as F
8
  from .configuration_rwkv_hybrid import RwkvHybridConfig
9
+ from typing import TYPE_CHECKING, Optional
10
+ from transformers.cache_utils import Cache
11
 
12
  try:
13
  import triton
 
15
  fused_recurrent_rwkv7,
16
  chunk_rwkv7,
17
  native_recurrent_rwkv7,
18
+ fused_addcmul_rwkv7,
19
  ) # pylint: disable=C0411
20
  from rwkvfla.ops.rwkv6 import (
21
  fused_recurrent_rwkv6,
 
25
  except ImportError:
26
  from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411
27
  from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6
28
+ from rwkvfla.ops.rwkv7 import torch_addcmul_rwkv7
29
 
30
  fused_recurrent_rwkv7 = native_recurrent_rwkv7
31
  chunk_rwkv7 = native_recurrent_rwkv7
32
  chunk_rwkv6 = native_recurrent_rwkv6
33
  fused_recurrent_rwkv6 = native_recurrent_rwkv6
34
+ fused_addcmul_rwkv7 = torch_addcmul_rwkv7
35
 
36
 
37
  class Rwkv_Tmix_x070(nn.Module):
 
55
  self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
56
  self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
57
  self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
58
+
 
59
  D_DECAY_LORA = 64
60
  D_AAA_LORA = 64
61
  D_MV_LORA = 32
 
74
  self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
75
 
76
  if self.args.wkv_has_gate:
77
+ self.x_g = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
78
  self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA))
79
  self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size))
80
 
 
83
  self.r_k = nn.Parameter(torch.Tensor(H, N))
84
 
85
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
86
+ self.receptance = nn.Linear(
87
+ args.hidden_size, args.hidden_size, bias=False)
88
  self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
89
  self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
90
  self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
 
96
 
97
  def post_init(self):
98
  with torch.no_grad():
99
+ ratio_0_to_1 = self.layer_id / \
100
+ (self.args.num_hidden_layers - 1) # 0 to 1
101
  ratio_1_to_almost0 = 1.0 - (
102
  self.layer_id / self.args.num_hidden_layers
103
  ) # 1 to ~0
 
106
  for i in range(self.args.hidden_size):
107
  ddd[0, 0, i] = i / self.args.hidden_size
108
 
109
+ nn.init.constant_(
110
+ self.x_r, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
111
+ nn.init.constant_(
112
+ self.x_w, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
113
  nn.init.constant_(
114
  self.x_k,
115
+ 1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) +
116
+ 0.4 * ratio_0_to_1),
117
  )
118
  nn.init.constant_(
119
  self.x_v,
120
+ 1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) +
121
+ 0.6 * ratio_0_to_1),
122
  )
123
+ nn.init.constant_(
124
+ self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
125
+
126
 
127
  def ortho_init(x, scale):
128
  shape = x.shape
129
  original_dtype = x.dtype
130
  x_fp32 = x.float()
131
  if len(shape) == 2:
132
+ gain = math.sqrt(shape[0] / shape[1]
133
+ ) if shape[0] > shape[1] else 1
134
  nn.init.orthogonal_(x_fp32, gain=gain * scale)
135
  elif len(shape) == 3:
136
+ gain = math.sqrt(shape[1] / shape[2]
137
+ ) if shape[1] > shape[2] else 1
138
  for i in range(shape[0]):
139
  nn.init.orthogonal_(x_fp32[i], gain=gain * scale)
140
  else:
141
+ raise ValueError(
142
+ "ortho_init only supports 2D or 3D tensors")
143
  x.data.copy_(x_fp32.to(original_dtype))
144
  return x
145
 
146
  D_DECAY_LORA = 64
147
  nn.init.zeros_(self.w1)
148
  self.w2 = nn.Parameter(
149
+ ortho_init(torch.zeros(
150
+ D_DECAY_LORA, self.args.hidden_size), 0.1)
151
  )
152
 
153
  decay_speed = torch.ones(self.args.hidden_size)
 
177
  if self.args.wkv_has_gate:
178
  nn.init.zeros_(self.g1)
179
  self.g2 = nn.Parameter(
180
+ ortho_init(torch.zeros(
181
+ D_GATE_LORA, self.args.hidden_size), 0.1)
182
  )
183
+ nn.init.constant_(
184
+ self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
185
 
186
  nn.init.constant_(self.k_k, 0.85)
187
  nn.init.constant_(self.k_a, 1.0)
 
196
  nn.init.ones_(self.ln_x.weight)
197
  nn.init.zeros_(self.ln_x.bias)
198
 
199
+ def apply_wkv7_state(self, r, k, v, w, a, b, s,
200
+ output_final_state,
201
+ cu_seqlens,
202
+ head_first
203
+ ):
 
 
204
 
205
  if r.device.type == "cpu":
206
+ r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
207
  o, state = native_recurrent_rwkv7(
208
+ r=r, k=k, v=v, w=w,
209
+ a=a, b=b,
 
 
 
 
210
  scale=1.0,
211
  initial_state=s.transpose(-1, -2),
212
  output_final_state=True,
 
213
  head_first=True,
214
  )
215
  state = state.transpose(-1, -2)
216
+ x = rearrange(o, "b h l d -> b l (h d)")
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  else:
218
+ r, w, k, v, a, b = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
219
+ wkv7_func = chunk_rwkv7 if self.training else fused_recurrent_rwkv7
220
+ o, state = wkv7_func(
221
+ r=r, k=k, v=v, w=w,
222
+ a=a, b=b,
 
 
223
  scale=1.0,
224
  initial_state=s,
225
+ output_final_state=output_final_state,
226
+ cu_seqlens=cu_seqlens,
227
+ head_first=head_first,
228
  )
229
+ x = rearrange(o, "b l h d -> b l (h d)")
 
230
  return x, state
231
 
232
+ def forward(
233
+ self,
234
+ hidden_states,
235
+ last_state: TimeMixState,
236
+ sequence_mask: Optional[torch.Tensor] = None,
237
+ use_cache: Optional[bool] = False,
238
+ cu_seqlens: Optional[torch.Tensor] = None,
239
+ **kwargs
240
+ ):
241
+ if sequence_mask is not None:
242
+ hidden_states = hidden_states.mul(
243
+ sequence_mask[:, -hidden_states.shape[-2]:, None])
244
+
245
  shift_state = last_state.shift_state
246
+ B, T, C = hidden_states.size()
247
+
248
  if shift_state is not None:
249
+ xx = torch.concat((shift_state.unsqueeze(
250
+ 1), hidden_states[:, :-1]), dim=1) - hidden_states
251
  else:
252
+ xx = self.time_shift(hidden_states) - hidden_states
 
253
 
254
+ lx = hidden_states[:, -1]
255
+
256
+ if self.args.wkv_has_gate:
257
+ xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(
258
+ hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a, self.x_g)
259
+ else:
260
+ xr, xw, xk, xv, xa, _ = fused_addcmul_rwkv7(hidden_states, xx, self.x_r, self.x_w, self.x_k, self.x_v, self.x_a)
261
 
262
  r = self.receptance(xr)
263
  w = (
 
279
  if self.args.wkv_has_gate:
280
  g = torch.sigmoid(xg @ self.g1) @ self.g2
281
  kk = k * self.k_k
282
+ kk = F.normalize(kk.view(B, T, self.n_head, -1), dim=-1, p=2.0).view(B, T, C)
283
  k = k * (1 + (a - 1) * self.k_a)
284
 
285
  wkv_state = last_state.wkv_state
286
+ hidden_states, wkv_state = self.apply_wkv7_state(
287
  r,
288
  k,
289
  v,
 
291
  -kk,
292
  (kk * a),
293
  s=wkv_state,
294
+ output_final_state=use_cache,
295
+ cu_seqlens=cu_seqlens,
296
+ head_first=False
297
  )
298
  if self.args.wkv_has_group_norm:
299
+ hidden_states = self.ln_x(
300
+ hidden_states.view(B * T, C)).view(B, T, C)
301
+ hidden_states = hidden_states + (
302
+ (r.view(B, T, self.n_head, -1) * k.view(B, T, self.n_head, -1) * self.r_k).sum(
303
  dim=-1, keepdim=True
304
  )
305
+ * v.view(B, T, self.n_head, -1)
306
  ).view(B, T, C)
307
+ hidden_states = self.output(
308
+ hidden_states * g) if self.args.wkv_has_gate else self.output(hidden_states)
309
+ return hidden_states, TimeMixState(lx, wkv_state)
310
 
311
 
312
  class Rwkv7Attention(nn.Module):
 
314
  super().__init__()
315
  self.args = args
316
  self.layer_idx = layer_id
317
+ self.time_mixer = Rwkv_Tmix_x070(
318
+ args, layer_id, update_v_first, get_v_first)
319
+
320
+ def forward(
321
+ self,
322
+ hidden_states: torch.Tensor,
323
+ sequence_mask: Optional[torch.Tensor] = None,
324
+ past_key_value: Optional[Cache] = None,
325
+ use_cache: Optional[bool] = False,
326
+ output_attentions: Optional[bool] = False,
327
+ **kwargs
328
+ ):
329
+ if sequence_mask is not None:
330
+ assert len(sequence_mask.shape) == 2, (
331
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
332
+ "for padding purposes (0 indicating padding). "
333
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
334
+ )
335
+ batch_size, token_length, _ = hidden_states.shape
336
 
337
  if past_key_value is not None and len(past_key_value) > self.layer_idx:
338
  last_state = past_key_value[self.layer_idx][0]
339
  else:
340
  last_state = self.init_state(
341
+ batch_size, hidden_states.device, hidden_states.dtype
342
  )
343
 
344
+ attn_output, states = self.time_mixer(hidden_states=hidden_states,
345
+ last_state=last_state.time_mix_state,
346
+ sequence_mask=sequence_mask,
347
+ use_cache=use_cache,
348
+ **kwargs)
349
  last_state.time_mix_state = states
350
 
351
  if past_key_value is not None:
352
  past_key_value.update(token_length, last_state, self.layer_idx)
353
+
354
  return attn_output, None
355
 
356
  def init_state(self, batch_size, device, dtype) -> BlockState:
 
391
  ddd[0, 0, i] = i / args.hidden_size
392
 
393
  # fancy time_mix
394
+ self.time_maa_x = nn.Parameter(
395
+ 1.0 - torch.pow(ddd, ratio_1_to_almost0))
396
+ self.time_maa_w = nn.Parameter(
397
+ 1.0 - torch.pow(ddd, ratio_1_to_almost0))
398
+ self.time_maa_k = nn.Parameter(
399
+ 1.0 - torch.pow(ddd, ratio_1_to_almost0))
400
  self.time_maa_v = nn.Parameter(
401
  1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
402
  )
 
414
  torch.zeros(args.hidden_size, D_MIX_LORA * 5)
415
  )
416
  self.time_maa_w2 = nn.Parameter(
417
+ torch.zeros(5, D_MIX_LORA,
418
+ args.hidden_size).uniform_(-0.01, 0.01)
419
  )
420
 
421
  # fancy time_decay
 
424
  decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** (
425
  0.7 + 1.3 * ratio_0_to_1
426
  )
427
+ self.time_decay = nn.Parameter(
428
+ decay_speed.reshape(1, 1, args.head_size))
429
 
430
  D_DECAY_LORA = 64
431
  if args.hidden_size == 4096:
 
440
  tmp = torch.zeros(args.head_size)
441
  for n in range(args.head_size):
442
  zigzag = ((n + 1) % 3 - 1) * 0.1
443
+ tmp[n] = ratio_0_to_1 * \
444
+ (1 - (n / (args.head_size - 1))) + zigzag
445
 
446
+ self.time_faaaa = nn.Parameter(
447
+ tmp.reshape(self.n_head, self.head_size))
448
  # self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
449
 
450
  self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
451
+ self.receptance = nn.Linear(
452
+ args.hidden_size, args.head_size, bias=False)
453
  self.key = nn.Linear(args.hidden_size, args.head_size, bias=False)
454
 
455
  self.value = nn.Linear(args.hidden_size, args.head_size, bias=False)
 
458
 
459
  if self.args.wkv_has_group_norm:
460
  self.ln_x = nn.GroupNorm(
461
+ self.n_head, args.head_size, eps=(
462
+ 1e-5) * (args.head_size_divisor**2)
463
  )
464
 
465
  def post_init(self):
 
476
  lx = x[:, -1]
477
 
478
  xxx = x + xx * self.time_maa_x
479
+ xxx = torch.tanh(xxx @ self.time_maa_w1).view(B *
480
+ T, 5, -1).transpose(0, 1)
481
  xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
482
  mw, mk, mv, mr, mg = xxx.unbind(dim=0)
483
 
 
505
  return x, TimeMixState(lx, wkv_state)
506
 
507
  def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
508
+ r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v))
 
 
 
509
 
510
  if r.device.type == "cpu":
511
  wkv6_func = native_recurrent_rwkv6
 
545
  last_state = past_key_value[self.layer_idx][0]
546
  if last_state is None:
547
  wkv_states = torch.zeros(
548
+ (B, self.args.num_wkv_heads,
549
+ self.args.head_size, self.args.head_size),
550
  device=attn_output.device,
551
  dtype=torch.float32,
552
  )
 
556
  time_state = TimeMixState(token_shift, wkv_states)
557
  channel_state = None
558
  last_state = BlockState(time_state, channel_state)
559
+ attn_output, states = self.time_mixer(
560
+ attn_output, last_state.time_mix_state)
561
  last_state.time_mix_state = states
562
 
563
  if past_key_value is not None: