import json import torch from safetensors.torch import save_file # Set paths model_index_path = "./Janus-Pro-7B/pytorch_model.bin.index.json" # Path to the model index file output_path = "model.safetensors" # Path for the output safetensors file def load_sharded_checkpoint(index_file): """ Load sharded PyTorch model weights from .bin files. """ # Read the model index file with open(index_file, "r") as f: index_data = json.load(f) # Retrieve all shard file paths weight_map = index_data["weight_map"] shards = set(weight_map.values()) # Load weights from all shards state_dict = {} for shard_file in shards: print(f"Loading shard: {shard_file}") shard_data = torch.load(shard_file, map_location="cpu") state_dict.update(shard_data) return state_dict def convert_to_safetensors(index_file, output_file): """ Convert sharded .bin model weights to a single .safetensors file. """ # Load weights from sharded checkpoint state_dict = load_sharded_checkpoint(index_file) # Save the weights to safetensors format print(f"Saving to {output_file}...") save_file(state_dict, output_file) print("Conversion complete!") # Execute the conversion process convert_to_safetensors(model_index_path, output_path)