Spaces:
Sleeping
Quick Start
The prerequisite for this page is to install PyTriton, which can be found in the installation page.
The Quick Start presents how to run a Python model in the Triton Inference Server without needing to change the current working
environment. In this example, we are using a simple Linear
PyTorch model.
The integration of the model requires providing the following elements:
- The model - a framework or Python model or function that handles inference requests
- Inference Callable - function or class with
__call__
method, that handles the input data coming from Triton and returns the result - Python function connection with Triton Inference Server - a binding for communication between Triton and the Inference Callable
The requirement for the example is to have PyTorch installed in your environment. You can do this by running:
pip install torch
In the next step, define the Linear
model:
import torch
model = torch.nn.Linear(2, 3).to("cuda").eval()
In the second step, create an inference callable as a function. The function obtains the HTTP/gRPC request data as an argument, which should be in the form of a NumPy array. The expected return object should also be a NumPy array. You can define an inference callable as a function that uses the @batch
decorator from PyTriton. This decorator converts the input request into a more suitable format that can be directly passed to the model. You can read more about decorators here.
Example implementation:
import numpy as np
import torch
from pytriton.decorators import batch
@batch
def infer_fn(**inputs: np.ndarray):
(input1_batch,) = inputs.values()
input1_batch_tensor = torch.from_numpy(input1_batch).to("cuda")
output1_batch_tensor = model(input1_batch_tensor) # Calling the Python model inference
output1_batch = output1_batch_tensor.cpu().detach().numpy()
return [output1_batch]
In the next step, you can create the binding between the inference callable and Triton Inference Server using the bind
method from PyTriton. This method takes the model name, the inference callable, the inputs and outputs tensors, and an optional model configuration object.
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton
# Connecting inference callable with Triton Inference Server
with Triton() as triton:
triton.bind(
model_name="Linear",
infer_func=infer_fn,
inputs=[
Tensor(dtype=np.float32, shape=(-1,)),
],
outputs=[
Tensor(dtype=np.float32, shape=(-1,)),
],
config=ModelConfig(max_batch_size=128)
)
...
Finally, serve the model with the Triton Inference Server:
from pytriton.triton import Triton
with Triton() as triton:
... # Load models here
triton.serve()
The bind
method creates a connection between the Triton Inference Server and the infer_fn
, which handles
the inference queries. The inputs
and outputs
describe the model inputs and outputs that are exposed in
Triton. The config field allows more parameters for model deployment.
The serve
method is blocking, and at this point, the application waits for incoming HTTP/gRPC requests. From that
moment, the model is available under the name Linear
in the Triton server. The inference queries can be sent to
localhost:8000/v2/models/Linear/infer
, which are passed to the infer_fn
function.
If you would like to use Triton in the background mode, use run
. More about that can be found
in the Deploying Models page.
Once the serve
or run
method is called on the Triton
object, the server status can be obtained using:
curl -v localhost:8000/v2/health/live
The model is loaded right after the server starts, and its status can be queried using:
curl -v localhost:8000/v2/models/Linear/ready
Finally, you can send an inference query to the model:
curl -X POST \
-H "Content-Type: application/json" \
-d @input.json \
localhost:8000/v2/models/Linear/infer
The input.json
with sample query:
{
"id": "0",
"inputs": [
{
"name": "INPUT_1",
"shape": [1, 2],
"datatype": "FP32",
"parameters": {},
"data": [[-0.04281254857778549, 0.6738349795341492]]
}
]
}
Read more about the HTTP/gRPC interface in the Triton Inference Server documentation.
You can also validate the deployed model using a simple client that can perform inference requests:
import torch
from pytriton.client import ModelClient
input1_data = torch.randn(128, 2).cpu().detach().numpy()
with ModelClient("localhost:8000", "Linear") as client:
result_dict = client.infer_batch(input1_data)
print(result_dict)
The full example code can be found in examples/linear_random_pytorch.
More information about running the server and models can be found in Deploying Models page.