Add 10k train step pytorch model
Browse files
config.json
CHANGED
@@ -32,6 +32,7 @@
|
|
32 |
"max_length": 50
|
33 |
}
|
34 |
},
|
|
|
35 |
"transformers_version": "4.16.0.dev0",
|
36 |
"use_cache": true,
|
37 |
"vocab_size": 50257
|
|
|
32 |
"max_length": 50
|
33 |
}
|
34 |
},
|
35 |
+
"torch_dtype": "float32",
|
36 |
"transformers_version": "4.16.0.dev0",
|
37 |
"use_cache": true,
|
38 |
"vocab_size": 50257
|
flax_model_to_pytorch.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, FlaxAutoModelForCausalLM, AutoTokenizer
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import jax
|
5 |
+
import jax.numpy as jnp
|
6 |
+
|
7 |
+
def to_f32(t):
|
8 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
9 |
+
|
10 |
+
jax.config.update('jax_platform_name', 'cpu')
|
11 |
+
MODEL_PATH = "./"
|
12 |
+
model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_PATH)
|
13 |
+
model.params = to_f32(model.params)
|
14 |
+
model.save_pretrained(MODEL_PATH)
|
15 |
+
|
16 |
+
pt_model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
MODEL_PATH, from_flax=True).to('cpu')
|
18 |
+
|
19 |
+
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
|
20 |
+
input_ids_pt = torch.tensor(input_ids)
|
21 |
+
|
22 |
+
logits_pt = pt_model(input_ids_pt).logits
|
23 |
+
print(logits_pt)
|
24 |
+
logits_fx = model(input_ids).logits
|
25 |
+
print(logits_fx)
|
26 |
+
|
27 |
+
pt_model.save_pretrained(MODEL_PATH)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95dd9704b464a65105c2dd2da7c317a0dc11707cade45ab6b8dc99d99eae0a26
|
3 |
+
size 510401385
|
runs/events.out.tfevents.1642099734.t1v-n-42145f73-w-0.2317757.0.v2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78cfa1f897e391e235a903d1ff19d56c36817328c8b3c8b76f575958a16fdf68
|
3 |
+
size 1912863
|