|
import gc |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from bitlinear import BitLinear |
|
|
|
|
|
|
|
def replace_linear_in_hf(model, keep_param: bool): |
|
""" |
|
Replaces all instances of nn.Linear in the given model with BitLinear, except lm_head. |
|
|
|
Args: |
|
model (nn.Module): The model to modify. |
|
|
|
Returns: |
|
None |
|
:param model: The model to modify. |
|
:param keep_param: if ture, the model will keep param from the initial model. |
|
if false, the model will be using random init weight (For training) |
|
""" |
|
for name, module in model.named_children(): |
|
if isinstance(module, nn.Linear): |
|
if 'head' in name: |
|
continue |
|
|
|
bit_linear = BitLinear( |
|
in_features=module.in_features, |
|
out_features=module.out_features, |
|
bias=module.bias is not None, |
|
) |
|
|
|
if keep_param: |
|
|
|
bit_linear.weight.data.copy_(module.weight.data) |
|
if module.bias is not None: |
|
bit_linear.bias.data.copy_(module.bias.data) |
|
|
|
del module |
|
|
|
|
|
setattr(model, name, bit_linear) |
|
else: |
|
|
|
replace_linear_in_hf(module, keep_param) |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|