MolmoE-1B-0924 / convert_to_hf.py
Muennighoff's picture
Add
def21c3
raw
history blame
3.37 kB
import argparse
import logging
import os
import torch
from hf_molmo.config_molmo import MolmoConfig
from hf_molmo.image_preprocessing_molmo import MolmoImageProcessor
from hf_molmo.modelling_molmo import MOLMoForCausalLM
from hf_molmo.preprocessing_molmo import MolmoProcessor
from olmo import ModelConfig
from olmo.mm_data.data_utils import build_tokenizer
logger = logging.getLogger(__name__)
def write_config(checkpoint_dir: str, output_dir: str):
# save config as HF config
logger.info(f"Loading checkpoint from {checkpoint_dir}")
config_path = os.path.join(checkpoint_dir, "config.yaml")
model_config = ModelConfig.load(config_path, key="model")
config_kwargs = model_config.asdict()
config_kwargs["use_cache"] = True
config_kwargs["vit_load_path"] = None
config_kwargs["llm_load_path"] = None
config = MolmoConfig(
vocab_size=model_config.vocab_size,
embedding_size=model_config.embedding_size,
hidden_size=model_config.d_model,
intermediate_size=model_config.mlp_hidden_size,
num_hidden_layers=model_config.n_layers,
num_attention_heads=model_config.n_heads,
num_key_value_heads=model_config.n_kv_heads,
max_position_embeddings=model_config.max_position_embeddings or model_config.max_sequence_length,
initializer_range=model_config.initializer_range,
use_cache=True,
layer_norm_eps=model_config.layer_norm_eps,
rope_theta=model_config.rope_theta,
clip_qkv=model_config.clip_qkv,
qkv_bias=model_config.qkv_bias,
weight_tying=model_config.weight_tying,
use_position_ids=True,
tie_word_embeddings=False
)
logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
config.save_pretrained(output_dir)
preprocessor = MolmoProcessor(
MolmoImageProcessor(
max_crops=model_config.max_crops
), # FIXME now just assumes everything if fixed
build_tokenizer(model_config.tokenizer.identifier.split("m:")[1]).tokenizer
)
preprocessor.save_pretrained(output_dir)
def write_model(checkpoint_dir: str, output_dir: str, ignore_olmo_compatibility: bool = False):
# For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
# So, we explicitly store the model with the expected prefix.
old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(output_dir, "pytorch_model.bin")
state_dict = torch.load(old_model_path)
new_state_dict = {f"{MOLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
torch.save(new_state_dict, new_model_path)
def convert_checkpoint(checkpoint_dir: str, output_dir: str):
os.makedirs(output_dir, exist_ok=True)
write_config(checkpoint_dir, output_dir)
write_model(checkpoint_dir, output_dir)
def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
"making it easier to load weights as HF models."
)
parser.add_argument("checkpoint_dir")
parser.add_argument("output_dir")
args = parser.parse_args()
convert_checkpoint(args.checkpoint_dir, args.output_dir)
if __name__ == "__main__":
main()