# Example of using Triton Server Wrapper with RAPIDS/CuPy library in Jupyter Notebook

### Pure Python/CuPy and Triton Wrapper equivalent of The RAPIDS-Triton Linear Example:
 https://github.com/rapidsai/rapids-triton-linear-example#the-rapids-triton-linear-example
 (Remark: Above example is focused on latency minimization - our equivalent is focused on easy of use)

## Triton server setup with custom linear model

Install dependencies

In [None]:
import sys
!{sys.executable} -m pip install numpy

Required imports:

In [None]:
import numpy as np
import cupy as cp

from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton
from pytriton.decorators import batch

Define linear model (for simplicity, sample model parameters are defined in class initializer):

In [None]:
VECTOR_SIZE = 10

class LinearModel:
 def __init__(self):
 self.alpha = 2
 self.beta = cp.arange(VECTOR_SIZE)

 @batch
 def linear(self, **inputs):
 u_batch, v_batch = inputs.values()
 u_batch_cp, v_batch_cp = cp.asarray(u_batch), cp.asarray(v_batch)
 lin = u_batch_cp * self.alpha + v_batch_cp + self.beta
 return {"lin": cp.asnumpy(lin)}

Instantiate titon wrapper class and load model with defined callable:

In [None]:
triton = Triton()
lin_model = LinearModel()
triton.bind(
 model_name="Linear",
 infer_func=lin_model.linear,
 inputs=[
 Tensor(dtype=np.float64, shape=(VECTOR_SIZE,)),
 Tensor(dtype=np.float64, shape=(VECTOR_SIZE,)),
 ],
 outputs=[
 Tensor(name="lin", dtype=np.float64, shape=(-1,)),
 ],
 config=ModelConfig(max_batch_size=128),
 strict=True,
)

Run triton server with defined model inference callable

In [None]:
triton.run()

## Example inference performed with ModelClient calling triton server

In [None]:
from pytriton.client import ModelClient

VECTOR_SIZE = 10
BATCH_SIZE = 2

u_batch = np.ones((BATCH_SIZE, VECTOR_SIZE), dtype=np.float64)
v_batch = np.ones((BATCH_SIZE, VECTOR_SIZE), dtype=np.float64)

In [None]:
with ModelClient("localhost", "Linear") as client:
 result_batch = client.infer_batch(u_batch, v_batch)

for output_name, data_batch in result_batch.items():
 print(f"{output_name}: {data_batch.tolist()}")

Stop triton server at the end

In [None]:
triton.stop()