File size: 986 Bytes
e58ac13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch
model = torch.nn.Linear(2, 3).to("cuda").eval()
import numpy as np
from pytriton.decorators import batch
@batch
def infer_fn(**inputs: np.ndarray):
(input1_batch,) = inputs.values()
input1_batch_tensor = torch.from_numpy(input1_batch).to("cuda")
output1_batch_tensor = model(input1_batch_tensor) # Calling the Python model inference
output1_batch = output1_batch_tensor.cpu().detach().numpy()
return [output1_batch]
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton
# Connecting inference callback with Triton Inference Server
with Triton() as triton:
# Load model into Triton Inference Server
triton.bind(
model_name="Linear",
infer_func=infer_fn,
inputs=[
Tensor(dtype=np.float32, shape=(-1,)),
],
outputs=[
Tensor(dtype=np.float32, shape=(-1,)),
],
config=ModelConfig(max_batch_size=128)
)
triton.serve()
|