JAX Implementation!

#24
by jrosseruk - opened

I ported DeepSeek-R1 to JAX as fun a weekend project/learning experience!

(Flax Linen, fully sharded, transferred the weights from HuggingFace)

Just the 1.5B distill for now, check it out :)
https://github.com/J-Rosser-UK/Torch2Jax-DeepSeek-R1-Distill-Qwen-1.5B

Looks interesting!

Sign up or log in to comment