from langchain.llms.base import LLM from langchain.callbacks.manager import CallbackManagerForLLMRun from typing import Optional, List, Mapping, Any import warnings import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from transformers.models.mistral.modeling_mistral import MistralForCausalLM from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast from pydantic import Field class CustomLLMMistral(LLM): model: MistralForCausalLM = Field(...) tokenizer: LlamaTokenizerFast = Field(...) def __init__(self): model_name = "mistralai/Mistral-7B-Instruct-v0.3" quantization_config = BitsAndBytesConfig(load_in_4bit=True) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(model_name) super().__init__(model=model, tokenizer=tokenizer) self.model = model self.tokenizer = tokenizer @property def _llm_type(self) -> str: return "custom" def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None) -> str: messages = [ {"role": "user", "content": prompt}, ] encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") model_inputs = encodeds.to(self.model.device) generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, top_k=4, temperature=0.7) decoded = self.tokenizer.batch_decode(generated_ids) output = decoded[0].split("[/INST]")[1].replace("", "").strip() if stop is not None: for word in stop: output = output.split(word)[0].strip() while not output.endswith("```"): output += "`" return output @property def _identifying_params(self) -> Mapping[str, Any]: return {"model": self.model}