Update modeling_bitllama.py
Browse files- modeling_bitllama.py +5 -1
modeling_bitllama.py
CHANGED
@@ -253,9 +253,13 @@ def weight_quant(w):
|
|
253 |
|
254 |
|
255 |
class BitLinear(nn.Linear):
|
|
|
|
|
|
|
|
|
256 |
def forward(self, x):
|
257 |
w = self.weight
|
258 |
-
x_norm =
|
259 |
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
|
260 |
w_quant = w + (weight_quant(w) - w).detach()
|
261 |
return F.linear(x_quant, w_quant)
|
|
|
253 |
|
254 |
|
255 |
class BitLinear(nn.Linear):
|
256 |
+
def __init__(self, in_features, out_features, bias=True):
|
257 |
+
super().__init__(in_features, out_features, bias=bias)
|
258 |
+
self.norm = LlamaRMSNorm(in_features)
|
259 |
+
|
260 |
def forward(self, x):
|
261 |
w = self.weight
|
262 |
+
x_norm = self.norm(x)
|
263 |
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
|
264 |
w_quant = w + (weight_quant(w) - w).detach()
|
265 |
return F.linear(x_quant, w_quant)
|