import numpy as np | |
import mlx.core as mx | |
from glob import glob | |
from safetensors.numpy import save_file | |
patch_weights = mx.load("49-20240112-184735.npz") | |
for file in glob("model*.safetensors"): | |
print(f"{file=}") | |
weights = mx.load(file) | |
for k, v in weights.items(): | |
if k in patch_weights: | |
print(f"patching {k}") | |
weights[k] = np.array(patch_weights[k], copy=False) | |
else: | |
weights[k] = np.array(v, copy=False) | |
save_file(weights, "patched_" + file) | |