import torch | |
import numpy as np | |
import jax | |
import jax.numpy as jnp | |
from transformers import AutoTokenizer | |
from transformers import FlaxGPT2LMHeadModel | |
from transformers import GPT2LMHeadModel | |
tokenizer = AutoTokenizer.from_pretrained(".") | |
tokenizer.pad_token = tokenizer.eos_token | |
model_fx = FlaxGPT2LMHeadModel.from_pretrained(".") | |
# def to_f32(t): | |
# return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) | |
# model_fx.params = to_f32(model_fx.params) | |
# model_fx.save_pretrained("./fx") | |
model_pt = GPT2LMHeadModel.from_pretrained(".", from_flax=True) | |
model_pt.save_pretrained("./pt") | |
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32) | |
input_ids_pt = torch.tensor(input_ids) | |
logits_pt = model_pt(input_ids_pt).logits | |
print(logits_pt) | |
logits_fx = model_fx(input_ids).logits | |
print(logits_fx) |