Edit model card

Raven Fine-Tuned Gemma-2B

Raven is a Fine-tuned version of google/gemma-2 whith same prompting style of gemma-2b-it which trained Using TPU VM v4-64 and EasyDeL

both fine-tuning and serving code are available and it's recommended to use JAX-EasyDeL Gemma since HF-Gemma implementaion is Wrong.

Serving and Using Raven

from EasyDel import JAXServer, JAXServerConfig, EasyServe
from fjformer import get_dtype
from EasyDel.serve.prompters import GemmaPrompter, Llama2Prompter, OpenChatPrompter, ChatMLPrompter
from EasyDel.serve.prompters.base_prompter import BasePrompter
from jax import numpy as jnp, lax
import jax
from typing import List, Union, Optional

max_sequence_length = 8192
max_compile_tokens = 256
max_new_tokens_ratio = 25
dtype = "fp16"
prompter_type = "gemma"
sharding_axis_dims = (1, 1, 1, -1)
pretrained_model_name_or_path = "erfanzar/Raven-v0.1"
attn_mechanism = "normal"
scan_mlp_chunk_size = max_compile_tokens
use_scan_mlp = True
scan_ring_attention = True
block_k = 128
block_q = 128
use_sharded_kv_caching = False

server_config = JAXServerConfig(
    max_sequence_length=max_sequence_length,
    max_compile_tokens=max_compile_tokens,
    max_new_tokens=max_compile_tokens * max_new_tokens_ratio,
    dtype=dtype,
    pre_compile=False,
    eos_token_id=107
)

prompters = {
    "gemma": GemmaPrompter(),
    "llama": Llama2Prompter(),
    "openchat": OpenChatPrompter(),
    "chatml": ChatMLPrompter()
}

prompter: BasePrompter = prompters[prompter_type]

class JAXServerC(JAXServer):
    @staticmethod
    def format_chat(history: List[List[str]], prompt: str, system: Union[str, None]) -> str:
        return prompter.format_message(
            history=history,
            prompt=prompt,
            system_message=system,
            prefix=None
        )

    @staticmethod
    def format_instruct(system: str, instruction: str) -> str:
        return prompter.format_message(
            prefix=None,
            system_message=system,
            prompt=instruction,
            history=[]
        )

server = JAXServerC.from_torch_pretrained(
    server_config=server_config,
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    dtype=get_dtype(dtype=dtype),
    param_dtype=get_dtype(dtype=dtype),
    precision=jax.lax.Precision("fastest"),
    sharding_axis_dims=sharding_axis_dims,
    sharding_axis_names=("dp", "fsdp", "tp", "sp"),
    input_shape=(1, server_config.max_sequence_length),
    model_config_kwargs=dict(
        fully_sharded_data_parallel=True,
        attn_mechanism=attn_mechanism,
        scan_mlp_chunk_size=max_compile_tokens,
        use_scan_mlp=use_scan_mlp,
        scan_ring_attention=scan_ring_attention,
        block_k=block_k,
        block_q=block_q,
        use_sharded_kv_caching=use_sharded_kv_caching
    )
)

history = []
while True:
    user_prompt = input("> ")
    model_prompt = server.format_chat(
        history,
        user_prompt,
        "You are an AI assistant be respect-full and explain detailed questions step by step."
    )

    past_response_length = 0
    
    for response, used_tokens in server.sample(
        model_prompt,
        greedy=False
    ):
        print(response[past_response_length:], end="")
        past_response_length = len(response)
    
    history.append([user_prompt, response])

Gradio UI is also available via server.gradio_inference().launch().

Downloads last month
7
Safetensors
Model size
2.51B params
Tensor type
FP16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train erfanzar/Raven-v0.1