Spaces:
Sleeping
Sleeping
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import dataclasses | |
import enum | |
import subprocess | |
from typing import Callable, Optional, Sequence | |
import numpy as np | |
from pytriton.decorators import batch, sample | |
from pytriton.model_config import ModelConfig, Tensor | |
class Framework(enum.Enum): | |
TENSORFLOW = "tensorflow" | |
PYTORCH = "pytorch" | |
class TestModelSpec: | |
name: str | |
framework: Optional[Framework] | |
create_infer_fn: Callable[..., Callable] | |
inputs: Sequence[Tensor] | |
outputs: Sequence[Tensor] | |
model_config: ModelConfig | |
def _create_add_sub_fn() -> Callable: | |
def _add_sub(**inputs): | |
a_batch, b_batch = inputs.values() | |
add_batch = a_batch + b_batch | |
sub_batch = a_batch - b_batch | |
return {"add": add_batch, "sub": sub_batch} | |
return _add_sub | |
ADD_SUB_PYTHON_MODEL = TestModelSpec( | |
name="AddSub", | |
framework=None, | |
create_infer_fn=_create_add_sub_fn, | |
inputs=( | |
Tensor(dtype=np.float32, shape=(-1,)), | |
Tensor(dtype=np.float32, shape=(-1,)), | |
), | |
outputs=( | |
Tensor(name="add", dtype=np.float32, shape=(-1,)), | |
Tensor(name="sub", dtype=np.float32, shape=(-1,)), | |
), | |
model_config=ModelConfig(max_batch_size=128), | |
) | |
def _create_identity_fn() -> Callable: | |
def _identity(**inputs): | |
(a_batch,) = inputs.values() | |
return {"identity": a_batch} | |
return _identity | |
IDENTITY_PYTHON_MODEL = TestModelSpec( | |
name="Identity", | |
framework=None, | |
create_infer_fn=_create_identity_fn, | |
inputs=(Tensor(dtype=np.float32, shape=(-1,)),), | |
outputs=(Tensor(name="identity", dtype=np.float32, shape=(-1,)),), | |
model_config=ModelConfig(max_batch_size=128), | |
) | |
def _create_tfhub_tensorflow_efficientdet_lite0_detection_fn() -> Callable: | |
subprocess.run(["pip", "install", "tensorflow_hub"], check=True) | |
import tensorflow_hub as hub # pytype: disable=import-error | |
detector = hub.load("https://tfhub.dev/tensorflow/efficientdet/lite0/detection/1") | |
def _tfhub_tensorflow_efficientdet_lite0_detection(image: np.ndarray): | |
boxes, scores, classes, num_detections = detector(image) | |
return { | |
"boxes": boxes.numpy(), | |
"scores": scores.numpy(), | |
"classes": classes.numpy(), | |
"num_detections": num_detections.numpy(), | |
} | |
return _tfhub_tensorflow_efficientdet_lite0_detection | |
EFFICIENTDET_DETECTION_TF_MODEL = TestModelSpec( | |
name="EfficientDetDetector", | |
framework=Framework.TENSORFLOW, | |
create_infer_fn=_create_tfhub_tensorflow_efficientdet_lite0_detection_fn, | |
inputs=(Tensor(name="image", dtype=np.uint8, shape=(-1, -1, 3)),), | |
outputs=( | |
Tensor( | |
name="boxes", | |
dtype=np.float32, | |
shape=(-1, 4), | |
), | |
Tensor( | |
name="scores", | |
dtype=np.float32, | |
shape=(-1,), | |
), | |
Tensor( | |
name="classes", | |
dtype=np.int32, | |
shape=(-1,), | |
), | |
Tensor(name="num_detections", dtype=np.int32, shape=(1,)), | |
), | |
model_config=ModelConfig(max_batch_size=128), | |
) | |
def _create_hf_pytorch_bart_fn() -> Callable: | |
subprocess.run(["pip", "install", "transformers"], check=True) | |
import transformers # pytype: disable=import-error | |
classifier = transformers.pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=0) | |
def _hf_pytorch_bart_fn(sequence: np.ndarray, labels: np.ndarray): | |
sequence = sequence[0].decode("utf-8") | |
labels = [label.decode("utf-8") for label in labels] | |
classification_result = classifier(sequence, labels) | |
scores_batch = np.array(classification_result["scores"], dtype=np.float32) | |
return {"scores": scores_batch} | |
return _hf_pytorch_bart_fn | |
BART_CLASSIFIER_PYTORCH_MODEL = TestModelSpec( | |
name="BARTClassifier", | |
framework=Framework.PYTORCH, | |
create_infer_fn=_create_hf_pytorch_bart_fn, | |
inputs=(Tensor(name="sequence", dtype=object, shape=(-1,)), Tensor(name="labels", dtype=object, shape=(-1,))), | |
outputs=(Tensor(name="scores", dtype=np.float32, shape=(-1,)),), | |
model_config=ModelConfig(batching=False), | |
) | |
def _create_hf_jax_bert_fn() -> Callable: | |
subprocess.run( | |
[ | |
"pip", | |
"install", | |
"--upgrade", | |
"jax[cuda12_pip]", | |
"-f", | |
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html", | |
], | |
check=True, | |
) | |
subprocess.run(["pip", "install", "transformers", "flax"], check=True) | |
import transformers # pytype: disable=import-error | |
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") | |
model = transformers.FlaxBertModel.from_pretrained("bert-base-uncased") | |
def _infer_fn(**inputs: np.ndarray): | |
(sequence_batch,) = inputs.values() | |
sequence_batch = sequence_batch.tolist() | |
last_hidden_states = [] | |
for sequence_item in sequence_batch: | |
# 0 needed to extract str from numpy array and deocode utf-8 | |
sequence_as_str = sequence_item[0].decode("utf-8") | |
tokenized_sequence = tokenizer(sequence_as_str, return_tensors="jax") | |
results = model(**tokenized_sequence) | |
last_hidden_states.append(results.last_hidden_state) | |
last_hidden_states = np.array(last_hidden_states, dtype=np.float32) | |
return [last_hidden_states] | |
return _infer_fn | |
BERT_JAX_MODEL = TestModelSpec( | |
name="BERTJax", | |
framework=Framework.TENSORFLOW, | |
create_infer_fn=_create_hf_jax_bert_fn, | |
inputs=(Tensor(name="sequence", dtype=object, shape=(1,)),), | |
outputs=(Tensor(name="last_hidden_state", dtype=np.float32, shape=(-1,)),), | |
model_config=ModelConfig(max_batch_size=128), | |
) | |
MODELS_CATALOGUE = [ | |
ADD_SUB_PYTHON_MODEL, | |
IDENTITY_PYTHON_MODEL, | |
EFFICIENTDET_DETECTION_TF_MODEL, | |
BART_CLASSIFIER_PYTORCH_MODEL, | |
BERT_JAX_MODEL, | |
] | |