trt-llm-rag-windows-main / trt_llama_api.py
larsthepenguin's picture
Upload folder using huggingface_hub
272de00 verified
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
from typing import Any, Callable, Dict, Optional, Sequence
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.llms.base import (
ChatMessage,
ChatResponse,
CompletionResponse,
ChatResponseGen,
CompletionResponseGen,
LLMMetadata,
llm_chat_callback,
llm_completion_callback,
)
from llama_index.llms.custom import CustomLLM
from llama_index.llms.generic_utils import stream_completion_response_to_chat_response
from llama_index.llms.generic_utils import completion_response_to_chat_response
from llama_index.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
from utils import (DEFAULT_HF_MODEL_DIRS, DEFAULT_PROMPT_TEMPLATES,
load_tokenizer, read_model_name, throttle_generator)
import gc
import torch
import tensorrt_llm
import uuid
import time
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
from tensorrt_llm.logger import logger
EOS_TOKEN = 2
PAD_TOKEN = 2
class TrtLlmAPI(CustomLLM):
model_path: Optional[str] = Field(
description="The path to the trt engine."
)
temperature: float = Field(description="The temperature to use for sampling.")
max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
context_window: int = Field(
description="The maximum number of context tokens for the model."
)
messages_to_prompt: Callable = Field(
description="The function to convert messages to a prompt.", exclude=True
)
completion_to_prompt: Callable = Field(
description="The function to convert a completion to a prompt.", exclude=True
)
generate_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Kwargs used for generation."
)
model_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Kwargs used for model initialization."
)
verbose: bool = Field(description="Whether to print verbose output.")
_model: Any = PrivateAttr()
_model_config: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_pad_id:Any = PrivateAttr()
_end_id: Any = PrivateAttr()
_new_max_token: Any = PrivateAttr()
_max_new_tokens = PrivateAttr()
_sampling_config = PrivateAttr()
_verbose = PrivateAttr()
def __init__(
self,
model_path: Optional[str] = None,
engine_name: Optional[str] = None,
tokenizer_dir: Optional[str] = None,
temperature: float = 0.1,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
context_window: int = DEFAULT_CONTEXT_WINDOW,
messages_to_prompt: Optional[Callable] = None,
completion_to_prompt: Optional[Callable] = None,
callback_manager: Optional[CallbackManager] = None,
generate_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
verbose: bool = False
) -> None:
model_kwargs = model_kwargs or {}
model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
#logger.set_level('verbose')
runtime_rank = tensorrt_llm.mpi_rank()
model_name = read_model_name(model_path)
self._tokenizer, self._pad_id, self._end_id = load_tokenizer(
tokenizer_dir=tokenizer_dir,
#vocab_file=args.vocab_file,
model_name=model_name,
#tokenizer_type=args.tokenizer_type,
)
stop_words_list = None
bad_words_list = None
runner_cls = ModelRunner
runner_kwargs = dict(engine_dir=model_path,
#lora_dir=args.lora_dir,
rank=runtime_rank,
debug_mode=True,
lora_ckpt_source='hf')
self._model = runner_cls.from_dir(**runner_kwargs)
messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
completion_to_prompt = completion_to_prompt or (lambda x: x)
generate_kwargs = generate_kwargs or {}
generate_kwargs.update(
{"temperature": temperature, "max_tokens": max_new_tokens}
)
#self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_dir, legacy=False)
self._new_max_token = max_new_tokens
super().__init__(
model_path=model_path,
temperature=temperature,
context_window=context_window,
max_new_tokens=max_new_tokens,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
callback_manager=callback_manager,
generate_kwargs=generate_kwargs,
model_kwargs=model_kwargs,
verbose=verbose,
)
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "TrtLlmAPI"
@property
def metadata(self) -> LLMMetadata:
"""LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name=self.model_path,
)
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
self.generate_kwargs.update({"stream": False})
is_formatted = kwargs.pop("formatted", False)
if not is_formatted:
prompt = self.completion_to_prompt(prompt)
input_text = [prompt]
batch_input_ids = self.parse_input(self._tokenizer,
input_text,
pad_id=self._pad_id,
)
input_lengths = [x.size(1) for x in batch_input_ids]
with torch.no_grad():
outputs = self._model.generate(
batch_input_ids,
max_new_tokens=self._new_max_token,
end_id=self._end_id,
pad_id=self._pad_id,
temperature=1.0,
top_k=1,
top_p=0,
num_beams=1,
length_penalty=1.0,
repetition_penalty=1.0,
stop_words_list=None,
bad_words_list=None,
lora_uids=None,
prompt_table_path=None,
prompt_tasks=None,
streaming=False,
output_sequence_lengths=True,
return_dict=True)
torch.cuda.synchronize()
output_ids = outputs['output_ids']
sequence_lengths = outputs['sequence_lengths']
output_txt, output_token_ids = self.print_output(self._tokenizer,
output_ids,
input_lengths,
sequence_lengths)
# call garbage collected after inference
torch.cuda.empty_cache()
gc.collect()
return CompletionResponse(text=output_txt, raw=self.generate_completion_dict(output_txt))
def parse_input(self,
tokenizer,
input_text=None,
prompt_template=None,
input_file=None,
add_special_tokens=True,
max_input_length=4096,
pad_id=None,
num_prepend_vtokens=[]):
if pad_id is None:
pad_id = tokenizer.pad_token_id
batch_input_ids = []
for curr_text in input_text:
if prompt_template is not None:
curr_text = prompt_template.format(input_text=curr_text)
input_ids = tokenizer.encode(curr_text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=max_input_length)
batch_input_ids.append(input_ids)
if num_prepend_vtokens:
assert len(num_prepend_vtokens) == len(batch_input_ids)
base_vocab_size = tokenizer.vocab_size - len(
tokenizer.special_tokens_map.get('additional_special_tokens', []))
for i, length in enumerate(num_prepend_vtokens):
batch_input_ids[i] = list(
range(base_vocab_size,
base_vocab_size + length)) + batch_input_ids[i]
batch_input_ids = [
torch.tensor(x, dtype=torch.int32).unsqueeze(0) for x in batch_input_ids
]
return batch_input_ids
def remove_extra_eos_ids(self, outputs):
outputs.reverse()
while outputs and outputs[0] == 2:
outputs.pop(0)
outputs.reverse()
outputs.append(2)
return outputs
def print_output(self,
tokenizer,
output_ids,
input_lengths,
sequence_lengths,
output_csv=None,
output_npy=None,
context_logits=None,
generation_logits=None,
output_logits_npy=None):
output_text = ""
batch_size, num_beams, _ = output_ids.size()
if output_csv is None and output_npy is None:
for batch_idx in range(batch_size):
inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist(
)
for beam in range(num_beams):
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[batch_idx][beam]
outputs = output_ids[batch_idx][beam][
output_begin:output_end].tolist()
output_text = tokenizer.decode(outputs)
output_ids = output_ids.reshape((-1, output_ids.size(2)))
return output_text, output_ids
def get_output(self, output_ids, input_lengths, max_output_len, tokenizer):
num_beams = 1
output_text = ""
outputs = None
for b in range(input_lengths.size(0)):
for beam in range(num_beams):
output_begin = input_lengths[b]
output_end = input_lengths[b] + max_output_len
outputs = output_ids[b][beam][output_begin:output_end].tolist()
outputs = self.remove_extra_eos_ids(outputs)
output_text = tokenizer.decode(outputs)
return output_text, outputs
def generate_completion_dict(self, text_str):
"""
Generate a dictionary for text completion details.
Returns:
dict: A dictionary containing completion details.
"""
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
model_name: str = self._model if self._model is not None else self.model_path
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [
{
"text": text_str,
"index": 0,
"logprobs": None,
"finish_reason": 'stop'
}
],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None
}
}
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
is_formatted = kwargs.pop("formatted", False)
if not is_formatted:
prompt = self.completion_to_prompt(prompt)
input_text = [prompt]
batch_input_ids = self.parse_input(self._tokenizer,
input_text,
pad_id=self._end_id,
)
input_lengths = [x.size(1) for x in batch_input_ids]
with torch.no_grad():
outputs = self._model.generate(
batch_input_ids,
max_new_tokens=self._new_max_token,
end_id=self._end_id,
pad_id=self._pad_id,
temperature=1.0,
top_k=1,
top_p=0,
num_beams=1,
length_penalty=1.0,
repetition_penalty=1.0,
stop_words_list=None,
bad_words_list=None,
lora_uids=None,
prompt_table_path=None,
prompt_tasks=None,
streaming=True,
output_sequence_lengths=True,
return_dict=True)
torch.cuda.synchronize()
previous_text = "" # To keep track of the previously yielded text
def gen() -> CompletionResponseGen:
nonlocal previous_text # Declare previous_text as nonlocal
for curr_outputs in throttle_generator(outputs,
5):
output_ids = curr_outputs['output_ids']
sequence_lengths = curr_outputs['sequence_lengths']
output_txt, output_token_ids = self.print_output(self._tokenizer,
output_ids,
input_lengths,
sequence_lengths)
if output_txt.endswith("</s>"):
output_txt = output_txt[:-4]
pre_token_len = len(previous_text)
new_text = output_txt[pre_token_len:] # Get only the new text
yield CompletionResponse(delta=new_text, text=output_txt,
raw=self.generate_completion_dict(output_txt))
previous_text = output_txt # Update the previously yielded text after yielding
return gen()
def unload_model(self):
if self._model is not None:
del self._model
# Step 3: Additional cleanup if needed
torch.cuda.empty_cache()
gc.collect()