bitLinear-phi-1.5 / replace_hf.py
Mrw33554432's picture
Upload 2 files
89f541f verified
raw
history blame
1.64 kB
import gc
import torch
from torch import nn
from bitlinear import BitLinear
# Adapt from https://github.com/kyegomez/BitNet/blob/main/bitnet/replace_hf.py
def replace_linear_in_hf(model, keep_param: bool):
"""
Replaces all instances of nn.Linear in the given model with BitLinear, except lm_head.
Args:
model (nn.Module): The model to modify.
Returns:
None
:param model: The model to modify.
:param keep_param: if ture, the model will keep param from the initial model.
if false, the model will be using random init weight (For training)
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
if 'head' in name:
continue
# Create a new BitLinear layer with random parameters
bit_linear = BitLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
)
if keep_param:
# Transfer the weights and bias from the original nn.Linear to the new BitLinear
bit_linear.weight.data.copy_(module.weight.data)
if module.bias is not None:
bit_linear.bias.data.copy_(module.bias.data)
del module
# Replace the nn.Linear with the new BitLinear
setattr(model, name, bit_linear)
else:
# Recursively apply to child modules
replace_linear_in_hf(module, keep_param)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()