Edit model card

Uploaded model

  • Developed by: jingwang
  • License: apache-2.0
  • Finetuned from model : unsloth/mistral-7b-v0.3-bnb-4bit

This mistral model was trained 2x faster with Unsloth and Huggingface's TRL library.

install dependencies in google colab

!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes

inference


from unsloth import FastLanguageModel
from typing import Dict, List, Tuple, Union, Any
import pandas
from tqdm import trange, tqdm
import torch

class FormatPrompt_context_QA():
    '''format prompt class'''
    def __init__(self, eos_token:str='</s>') -> None:
        self.inputs = ['context','question'] # required input fields
        self.outputs = ['answer'] #  for training, and model inference output fields
        self.eos_token = eos_token

    def __call__(self, instance: Dict[str, Any]) -> str:
        '''
        function call operator 
        Args:
            instance: dictionary with keys: 'context', 'question', 'answer'
        Returns:
            prompt: formatted prompt
        '''
        return self.formatting_prompt_func(instance)
    
    def formatting_prompt_func(self, instance: dict) -> str:
        '''format prompt for domain specific QA
        note this is for fine-tuning pre-trained model,
        if starting with instuct tuned model, use `tokenizer.apply_chat_template(messages)` instead
        '''

        assert all([ item in instance.keys()  for item in self.inputs ]), logging.info(f"instance must have {self.inputs}!")
        
        prompt = f"""<s> [INST] Answer following question based on Context: {str(instance["context"])}\
        Question: {str(instance["question"])} \
        Answer: [/INST]"""

        if 'answer' in instance:
            prompt += str(instance['answer']) + self.eos_token
        return prompt
formatting_func = FormatPrompt_context_QA()

# pull model from huggingface
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "jingwang/mistral_context_qa",
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)


FastLanguageModel.for_inference(model)

example = {'question': 'What does the graph compare in terms of cumulative total return?',
 'context': 'the following graph shows a comparison, from january 1, 2019 through december 31, 2023, of the cumulative total return on our common stock, the nasdaq composite index and a group of all public companies sharing the same sic code as us, which is sic code 3711, “ motor vehicles and passenger car bodies ” ( motor vehicles and passenger car bodies public company group ). such returns are based on historical results and are not intended to suggest future performance. data for the nasdaq composite index and the motor vehicles and passenger car bodies public company group assumes an investment of $ 100 on january 1, 2019 and reinvestment of dividends. we have never declared or paid cash dividends on our common stock nor do we anticipate paying any such cash dividends in the foreseeable future. 31',
 'gold_answer': "The graph compares the cumulative total return from January 1, 2019, through December 31, 2023, of the company's common stock, the NASDAQ Composite Index, and a group of public companies with the same SIC code (3711 - Motor Vehicles and Passenger Car Bodies). The comparison assumes an initial investment of $100 on January 1, 2019, with reinvestment of dividends for the NASDAQ Composite Index and the Motor Vehicles and Passenger Car Bodies public company group.",
}

inputs = tokenizer([formatting_func(example)],  return_tensors="pt", padding=False).to(model.device)
input_length = inputs.input_ids.shape[-1]

with torch.no_grad():
  output = model.generate(**inputs,
                          do_sample=False,
                          temperature=0.1,
                          max_new_tokens=64,
                          pad_token_id=tokenizer.eos_token_id,
                          use_cache=False,
                          )
  response = tokenizer.decode(
                  output[0][input_length::], # response only, remove prompts
                  skip_special_tokens=True,
                  )
  print(response)

The graph compares the cumulative total return on our common stock, the NASDAQ Composite Index, and a group of all public companies sharing the same SIC code as us, which is SIC code 3711, "Motor Vehicles and Passenger Car Bodies."

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for jingwang/mistral_context_qa

Finetuned
(292)
this model