Janus-Pro-7B-Safetensors / convert_to_safetensors.py
Ryan
init
91bcd07
import json
import torch
from safetensors.torch import save_file
# Set paths
model_index_path = "./Janus-Pro-7B/pytorch_model.bin.index.json" # Path to the model index file
output_path = "model.safetensors" # Path for the output safetensors file
def load_sharded_checkpoint(index_file):
"""
Load sharded PyTorch model weights from .bin files.
"""
# Read the model index file
with open(index_file, "r") as f:
index_data = json.load(f)
# Retrieve all shard file paths
weight_map = index_data["weight_map"]
shards = set(weight_map.values())
# Load weights from all shards
state_dict = {}
for shard_file in shards:
print(f"Loading shard: {shard_file}")
shard_data = torch.load(shard_file, map_location="cpu")
state_dict.update(shard_data)
return state_dict
def convert_to_safetensors(index_file, output_file):
"""
Convert sharded .bin model weights to a single .safetensors file.
"""
# Load weights from sharded checkpoint
state_dict = load_sharded_checkpoint(index_file)
# Save the weights to safetensors format
print(f"Saving to {output_file}...")
save_file(state_dict, output_file)
print("Conversion complete!")
# Execute the conversion process
convert_to_safetensors(model_index_path, output_path)