Mrw33554432
commited on
Commit
•
89f541f
1
Parent(s):
d013df6
Upload 2 files
Browse files- bitlinear.py +31 -0
- replace_hf.py +49 -0
bitlinear.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
from torch import Tensor, nn
|
3 |
+
|
4 |
+
|
5 |
+
def weight_quant(w):
|
6 |
+
"""
|
7 |
+
from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf,
|
8 |
+
This is a little bit different from paper by adding '/ scale' in the end,
|
9 |
+
which is super crucial for training (7.5 loss vs 2.5)
|
10 |
+
"""
|
11 |
+
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
|
12 |
+
u = (w * scale).round().clamp_(-1, 1) / scale
|
13 |
+
return u
|
14 |
+
|
15 |
+
|
16 |
+
class BitLinear(nn.Linear):
|
17 |
+
"""
|
18 |
+
A modified version of bit linear, only apply bit quant to weight.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def forward(self, x: Tensor) -> Tensor:
|
22 |
+
"""
|
23 |
+
Forward pass of the BitLinear layer, applying quantization to weights.
|
24 |
+
Args:
|
25 |
+
x (Tensor): The input tensor.
|
26 |
+
Returns:
|
27 |
+
Tensor: The output tensor.
|
28 |
+
"""
|
29 |
+
w = self.weight
|
30 |
+
w_quant = w + (weight_quant(w) - w).detach() # Apply quantization adjustments
|
31 |
+
return F.linear(x, w_quant, self.bias)
|
replace_hf.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from bitlinear import BitLinear
|
7 |
+
|
8 |
+
|
9 |
+
# Adapt from https://github.com/kyegomez/BitNet/blob/main/bitnet/replace_hf.py
|
10 |
+
def replace_linear_in_hf(model, keep_param: bool):
|
11 |
+
"""
|
12 |
+
Replaces all instances of nn.Linear in the given model with BitLinear, except lm_head.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model (nn.Module): The model to modify.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
None
|
19 |
+
:param model: The model to modify.
|
20 |
+
:param keep_param: if ture, the model will keep param from the initial model.
|
21 |
+
if false, the model will be using random init weight (For training)
|
22 |
+
"""
|
23 |
+
for name, module in model.named_children():
|
24 |
+
if isinstance(module, nn.Linear):
|
25 |
+
if 'head' in name:
|
26 |
+
continue
|
27 |
+
# Create a new BitLinear layer with random parameters
|
28 |
+
bit_linear = BitLinear(
|
29 |
+
in_features=module.in_features,
|
30 |
+
out_features=module.out_features,
|
31 |
+
bias=module.bias is not None,
|
32 |
+
)
|
33 |
+
|
34 |
+
if keep_param:
|
35 |
+
# Transfer the weights and bias from the original nn.Linear to the new BitLinear
|
36 |
+
bit_linear.weight.data.copy_(module.weight.data)
|
37 |
+
if module.bias is not None:
|
38 |
+
bit_linear.bias.data.copy_(module.bias.data)
|
39 |
+
|
40 |
+
del module
|
41 |
+
|
42 |
+
# Replace the nn.Linear with the new BitLinear
|
43 |
+
setattr(model, name, bit_linear)
|
44 |
+
else:
|
45 |
+
# Recursively apply to child modules
|
46 |
+
replace_linear_in_hf(module, keep_param)
|
47 |
+
gc.collect()
|
48 |
+
if torch.cuda.is_available():
|
49 |
+
torch.cuda.empty_cache()
|