KaleiNeely commited on
Commit
efe9ebb
·
1 Parent(s): 29e5afb

Update modeling_rwkv5.py

Browse files
Files changed (1) hide show
  1. modeling_rwkv5.py +0 -20
modeling_rwkv5.py CHANGED
@@ -92,7 +92,6 @@ def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_firs
92
  time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
93
  lxw = lxw.float()
94
  lxb = lxb.float()
95
- # if seq_mode:
96
  out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
97
  for t in range(T):
98
  rt = receptance[:,:,t:t+1,:]
@@ -106,25 +105,6 @@ def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_firs
106
  out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
107
  out = out.to(dtype=hidden.dtype) * gate
108
  out = out @ ow
109
- # else:
110
- # a = key @ value
111
- # # print('key.shape: ', key.shape)
112
- # # print('value.shape: ', value.shape)
113
- # # print('receptance.shape: ', receptance.shape)
114
- # # print('a.shape: ', a.shape)
115
- # # print('time_first.shape: ', time_first.shape)
116
- # # print('(time_first * a).shape: ', (time_first * a).shape)
117
- # # print('time_decay.shape: ', time_decay.shape)
118
- # # print('state.shape: ', state.shape)
119
- # out = receptance @ (time_first * a + state)
120
- # # print('out.shape: ', out.shape)
121
- # state = a + time_decay * state
122
- # # print('state.shape: ', state.shape)
123
- # out = out.reshape(B, H*S)
124
- # out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, 1, H*S)
125
- # out = out.to(dtype=hidden.dtype) * gate
126
- # out = out @ ow
127
-
128
 
129
  return out, state
130
 
 
92
  time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
93
  lxw = lxw.float()
94
  lxb = lxb.float()
 
95
  out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
96
  for t in range(T):
97
  rt = receptance[:,:,t:t+1,:]
 
105
  out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
106
  out = out.to(dtype=hidden.dtype) * gate
107
  out = out @ ow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  return out, state
110