yerang's picture
Upload 1110 files
e3af00f verified
raw
history blame
6.64 kB
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import enum
import subprocess
from typing import Callable, Optional, Sequence
import numpy as np
from pytriton.decorators import batch, sample
from pytriton.model_config import ModelConfig, Tensor
class Framework(enum.Enum):
TENSORFLOW = "tensorflow"
PYTORCH = "pytorch"
@dataclasses.dataclass(frozen=True)
class TestModelSpec:
name: str
framework: Optional[Framework]
create_infer_fn: Callable[..., Callable]
inputs: Sequence[Tensor]
outputs: Sequence[Tensor]
model_config: ModelConfig
def _create_add_sub_fn() -> Callable:
@batch
def _add_sub(**inputs):
a_batch, b_batch = inputs.values()
add_batch = a_batch + b_batch
sub_batch = a_batch - b_batch
return {"add": add_batch, "sub": sub_batch}
return _add_sub
ADD_SUB_PYTHON_MODEL = TestModelSpec(
name="AddSub",
framework=None,
create_infer_fn=_create_add_sub_fn,
inputs=(
Tensor(dtype=np.float32, shape=(-1,)),
Tensor(dtype=np.float32, shape=(-1,)),
),
outputs=(
Tensor(name="add", dtype=np.float32, shape=(-1,)),
Tensor(name="sub", dtype=np.float32, shape=(-1,)),
),
model_config=ModelConfig(max_batch_size=128),
)
def _create_identity_fn() -> Callable:
@batch
def _identity(**inputs):
(a_batch,) = inputs.values()
return {"identity": a_batch}
return _identity
IDENTITY_PYTHON_MODEL = TestModelSpec(
name="Identity",
framework=None,
create_infer_fn=_create_identity_fn,
inputs=(Tensor(dtype=np.float32, shape=(-1,)),),
outputs=(Tensor(name="identity", dtype=np.float32, shape=(-1,)),),
model_config=ModelConfig(max_batch_size=128),
)
def _create_tfhub_tensorflow_efficientdet_lite0_detection_fn() -> Callable:
subprocess.run(["pip", "install", "tensorflow_hub"], check=True)
import tensorflow_hub as hub # pytype: disable=import-error
detector = hub.load("https://tfhub.dev/tensorflow/efficientdet/lite0/detection/1")
@batch
def _tfhub_tensorflow_efficientdet_lite0_detection(image: np.ndarray):
boxes, scores, classes, num_detections = detector(image)
return {
"boxes": boxes.numpy(),
"scores": scores.numpy(),
"classes": classes.numpy(),
"num_detections": num_detections.numpy(),
}
return _tfhub_tensorflow_efficientdet_lite0_detection
EFFICIENTDET_DETECTION_TF_MODEL = TestModelSpec(
name="EfficientDetDetector",
framework=Framework.TENSORFLOW,
create_infer_fn=_create_tfhub_tensorflow_efficientdet_lite0_detection_fn,
inputs=(Tensor(name="image", dtype=np.uint8, shape=(-1, -1, 3)),),
outputs=(
Tensor(
name="boxes",
dtype=np.float32,
shape=(-1, 4),
),
Tensor(
name="scores",
dtype=np.float32,
shape=(-1,),
),
Tensor(
name="classes",
dtype=np.int32,
shape=(-1,),
),
Tensor(name="num_detections", dtype=np.int32, shape=(1,)),
),
model_config=ModelConfig(max_batch_size=128),
)
def _create_hf_pytorch_bart_fn() -> Callable:
subprocess.run(["pip", "install", "transformers"], check=True)
import transformers # pytype: disable=import-error
classifier = transformers.pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0)
@sample
def _hf_pytorch_bart_fn(sequence: np.ndarray, labels: np.ndarray):
sequence = sequence[0].decode("utf-8")
labels = [label.decode("utf-8") for label in labels]
classification_result = classifier(sequence, labels)
scores_batch = np.array(classification_result["scores"], dtype=np.float32)
return {"scores": scores_batch}
return _hf_pytorch_bart_fn
BART_CLASSIFIER_PYTORCH_MODEL = TestModelSpec(
name="BARTClassifier",
framework=Framework.PYTORCH,
create_infer_fn=_create_hf_pytorch_bart_fn,
inputs=(Tensor(name="sequence", dtype=object, shape=(-1,)), Tensor(name="labels", dtype=object, shape=(-1,))),
outputs=(Tensor(name="scores", dtype=np.float32, shape=(-1,)),),
model_config=ModelConfig(batching=False),
)
def _create_hf_jax_bert_fn() -> Callable:
subprocess.run(
[
"pip",
"install",
"--upgrade",
"jax[cuda12_pip]",
"-f",
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html",
],
check=True,
)
subprocess.run(["pip", "install", "transformers", "flax"], check=True)
import transformers # pytype: disable=import-error
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
model = transformers.FlaxBertModel.from_pretrained("bert-base-uncased")
@batch
def _infer_fn(**inputs: np.ndarray):
(sequence_batch,) = inputs.values()
sequence_batch = sequence_batch.tolist()
last_hidden_states = []
for sequence_item in sequence_batch:
# 0 needed to extract str from numpy array and deocode utf-8
sequence_as_str = sequence_item[0].decode("utf-8")
tokenized_sequence = tokenizer(sequence_as_str, return_tensors="jax")
results = model(**tokenized_sequence)
last_hidden_states.append(results.last_hidden_state)
last_hidden_states = np.array(last_hidden_states, dtype=np.float32)
return [last_hidden_states]
return _infer_fn
BERT_JAX_MODEL = TestModelSpec(
name="BERTJax",
framework=Framework.TENSORFLOW,
create_infer_fn=_create_hf_jax_bert_fn,
inputs=(Tensor(name="sequence", dtype=object, shape=(1,)),),
outputs=(Tensor(name="last_hidden_state", dtype=np.float32, shape=(-1,)),),
model_config=ModelConfig(max_batch_size=128),
)
MODELS_CATALOGUE = [
ADD_SUB_PYTHON_MODEL,
IDENTITY_PYTHON_MODEL,
EFFICIENTDET_DETECTION_TF_MODEL,
BART_CLASSIFIER_PYTORCH_MODEL,
BERT_JAX_MODEL,
]