|
import json |
|
import torch |
|
from safetensors.torch import save_file |
|
|
|
|
|
model_index_path = "./Janus-Pro-7B/pytorch_model.bin.index.json" |
|
output_path = "model.safetensors" |
|
|
|
def load_sharded_checkpoint(index_file): |
|
""" |
|
Load sharded PyTorch model weights from .bin files. |
|
""" |
|
|
|
with open(index_file, "r") as f: |
|
index_data = json.load(f) |
|
|
|
|
|
weight_map = index_data["weight_map"] |
|
shards = set(weight_map.values()) |
|
|
|
|
|
state_dict = {} |
|
for shard_file in shards: |
|
print(f"Loading shard: {shard_file}") |
|
shard_data = torch.load(shard_file, map_location="cpu") |
|
state_dict.update(shard_data) |
|
|
|
return state_dict |
|
|
|
def convert_to_safetensors(index_file, output_file): |
|
""" |
|
Convert sharded .bin model weights to a single .safetensors file. |
|
""" |
|
|
|
state_dict = load_sharded_checkpoint(index_file) |
|
|
|
|
|
print(f"Saving to {output_file}...") |
|
save_file(state_dict, output_file) |
|
print("Conversion complete!") |
|
|
|
|
|
convert_to_safetensors(model_index_path, output_path) |
|
|