Spaces:
Paused
Paused
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from typing import Optional, List, Mapping, Any | |
import warnings | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from transformers.models.mistral.modeling_mistral import MistralForCausalLM | |
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast | |
from pydantic import Field | |
class CustomLLMMistral(LLM): | |
model: MistralForCausalLM = Field(...) | |
tokenizer: LlamaTokenizerFast = Field(...) | |
def __init__(self): | |
model_name = "mistralai/Mistral-7B-Instruct-v0.3" | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
super().__init__(model=model, tokenizer=tokenizer) | |
self.model = model | |
self.tokenizer = tokenizer | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None) -> str: | |
messages = [ | |
{"role": "user", "content": prompt}, | |
] | |
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
model_inputs = encodeds.to(self.model.device) | |
generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, top_k=4, temperature=0.7) | |
decoded = self.tokenizer.batch_decode(generated_ids) | |
output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip() | |
if stop is not None: | |
for word in stop: | |
output = output.split(word)[0].strip() | |
while not output.endswith("```"): | |
output += "`" | |
return output | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {"model": self.model} |