bjoernp commited on
Commit
f3887a8
·
verified ·
1 Parent(s): bf5f905

Update modeling_bitllama.py

Browse files
Files changed (1) hide show
  1. modeling_bitllama.py +5 -1
modeling_bitllama.py CHANGED
@@ -253,9 +253,13 @@ def weight_quant(w):
253
 
254
 
255
  class BitLinear(nn.Linear):
 
 
 
 
256
  def forward(self, x):
257
  w = self.weight
258
- x_norm = LlamaRMSNorm(x)
259
  x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
260
  w_quant = w + (weight_quant(w) - w).detach()
261
  return F.linear(x_quant, w_quant)
 
253
 
254
 
255
  class BitLinear(nn.Linear):
256
+ def __init__(self, in_features, out_features, bias=True):
257
+ super().__init__(in_features, out_features, bias=bias)
258
+ self.norm = LlamaRMSNorm(in_features)
259
+
260
  def forward(self, x):
261
  w = self.weight
262
+ x_norm = self.norm(x)
263
  x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
264
  w_quant = w + (weight_quant(w) - w).detach()
265
  return F.linear(x_quant, w_quant)