Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Triton Inference Server class. | |
The class provide functionality to run Triton Inference Server, load the Python models and serve the requests/response | |
for models inference. | |
Examples of use: | |
with Triton() as triton: | |
triton.bind( | |
model_name="BERT", | |
infer_func=_infer_fn, | |
inputs=[ | |
Tensor(dtype=np.bytes_, shape=(1,)), | |
], | |
outputs=[ | |
Tensor(dtype=np.float32, shape=(-1,)), | |
], | |
config=PythonModelConfig(max_batch_size=16), | |
) | |
triton.serve() | |
""" | |
import atexit | |
import codecs | |
import contextlib | |
import dataclasses | |
import logging | |
import os | |
import pathlib | |
import re | |
import shutil | |
import sys | |
import threading | |
import threading as th | |
import typing | |
from typing import Any, Callable, Dict, List, Optional, Sequence, Union | |
import typing_inspect | |
from pytriton.client import ModelClient | |
from pytriton.client.utils import TritonUrl, create_client_from_url, wait_for_server_ready | |
from pytriton.constants import DEFAULT_TRITON_STARTUP_TIMEOUT_S | |
from pytriton.decorators import TritonContext | |
from pytriton.exceptions import PyTritonValidationError | |
from pytriton.model_config.tensor import Tensor | |
from pytriton.models.manager import ModelManager | |
from pytriton.models.model import Model, ModelConfig, ModelEvent | |
from pytriton.proxy.telemetry import build_proxy_tracer_from_triton_config, get_telemetry_tracer, set_telemetry_tracer | |
from pytriton.server.python_backend_config import PythonBackendConfig | |
from pytriton.server.triton_server import TritonServer | |
from pytriton.server.triton_server_config import TritonServerConfig | |
from pytriton.utils import endpoint_utils | |
from pytriton.utils.dataclasses import kwonly_dataclass | |
from pytriton.utils.distribution import get_libs_path, get_root_module_path, get_stub_path | |
from pytriton.utils.workspace import Workspace | |
LOGGER = logging.getLogger(__name__) | |
TRITONSERVER_DIST_DIR = get_root_module_path() / "tritonserver" | |
MONITORING_PERIOD_S = 10.0 | |
WAIT_FORM_MODEL_TIMEOUT_S = 60.0 | |
INITIAL_BACKEND_SHM_SIZE = 4194304 # 4MB, Python Backend default is 64MB, but is automatically increased | |
GROWTH_BACKEND_SHM_SIZE = 1048576 # 1MB, Python Backend default is 64MB | |
MODEL_URL = "/v2/models/{model_name}" | |
MODEL_READY_URL = f"{MODEL_URL}/ready/" | |
MODEL_CONFIG_URL = f"{MODEL_URL}/config/" | |
MODEL_INFER_URL = f"{MODEL_URL}/infer/" | |
# see https://github.com/triton-inference-server/server/blob/main/src/command_line_parser.cc for more details | |
class TritonConfig: | |
"""Triton Inference Server configuration class for customization of server execution. | |
The arguments are optional. If value is not provided the defaults for Triton Inference Server are used. | |
Please, refer to https://github.com/triton-inference-server/server/ for more details. | |
Args: | |
id: Identifier for this server. | |
log_verbose: Set verbose logging level. Zero (0) disables verbose logging and | |
values >= 1 enable verbose logging. | |
log_file: Set the name of the log output file. | |
exit_timeout_secs: Timeout (in seconds) when exiting to wait for in-flight inferences to finish. | |
exit_on_error: Exit the inference server if an error occurs during initialization. | |
strict_readiness: If true /v2/health/ready endpoint indicates ready if the server is | |
responsive and all models are available. | |
allow_http: Allow the server to listen for HTTP requests. | |
http_address: The address for the http server to bind to. Default is 0.0.0.0. | |
http_port: The port for the server to listen on for HTTP requests. Default is 8000. | |
http_header_forward_pattern: The regular expression pattern | |
that will be used for forwarding HTTP headers as inference request parameters. | |
http_thread_count: Number of threads handling HTTP requests. | |
allow_grpc: Allow the server to listen for GRPC requests. | |
grpc_address: The address for the grpc server to binds to. Default is 0.0.0.0. | |
grpc_port: The port for the server to listen on for GRPC requests. Default is 8001. | |
grpc_header_forward_pattern: The regular expression pattern that will be used | |
for forwarding GRPC headers as inference request parameters. | |
grpc_infer_allocation_pool_size: The maximum number of inference request/response objects | |
that remain allocated for reuse. As long as the number of in-flight requests doesn't exceed | |
this value there will be no allocation/deallocation of request/response objects. | |
grpc_use_ssl: Use SSL authentication for GRPC requests. Default is false. | |
grpc_use_ssl_mutual: Use mututal SSL authentication for GRPC requests. | |
This option will preempt grpc_use_ssl if it is also specified. Default is false. | |
grpc_server_cert: File holding PEM-encoded server certificate. Ignored unless grpc_use_ssl is true. | |
grpc_server_key: Path to file holding PEM-encoded server key. Ignored unless grpc_use_ssl is true. | |
grpc_root_cert: Path to file holding PEM-encoded root certificate. Ignored unless grpc_use_ssl is true. | |
grpc_infer_response_compression_level: The compression level to be used while returning the inference | |
response to the peer. Allowed values are none, low, medium and high. Default is none. | |
grpc_keepalive_time: The period (in milliseconds) after which a keepalive ping is sent on the transport. | |
grpc_keepalive_timeout: The period (in milliseconds) the sender of the keepalive ping waits | |
for an acknowledgement. | |
grpc_keepalive_permit_without_calls: Allows keepalive pings to be sent even if there are no calls in flight | |
grpc_http2_max_pings_without_data: The maximum number of pings that can be sent when there is no | |
data/header frame to be sent. | |
grpc_http2_min_recv_ping_interval_without_data: If there are no data/header frames being sent on the | |
transport, this channel argument on the server side controls the minimum time (in milliseconds) that | |
gRPC Core would expect between receiving successive pings. | |
grpc_http2_max_ping_strikes: Maximum number of bad pings that the server will tolerate before sending | |
an HTTP2 GOAWAY frame and closing the transport. | |
grpc_restricted_protocol: Specify restricted GRPC protocol setting. | |
The format of this flag is `<protocols>,<key>=<value>`. | |
Where `<protocol>` is a comma-separated list of protocols to be restricted. | |
`<key>` will be additional header key to be checked when a GRPC request | |
is received, and `<value>` is the value expected to be matched. | |
allow_metrics: Allow the server to provide prometheus metrics. | |
allow_gpu_metrics: Allow the server to provide GPU metrics. | |
allow_cpu_metrics: Allow the server to provide CPU metrics. | |
metrics_interval_ms: Metrics will be collected once every `<metrics-interval-ms>` milliseconds. | |
metrics_port: The port reporting prometheus metrics. | |
metrics_address: The address for the metrics server to bind to. Default is the same as http_address. | |
allow_sagemaker: Allow the server to listen for Sagemaker requests. | |
sagemaker_port: The port for the server to listen on for Sagemaker requests. | |
sagemaker_safe_port_range: Set the allowed port range for endpoints other than the SageMaker endpoints. | |
sagemaker_thread_count: Number of threads handling Sagemaker requests. | |
allow_vertex_ai: Allow the server to listen for Vertex AI requests. | |
vertex_ai_port: The port for the server to listen on for Vertex AI requests. | |
vertex_ai_thread_count: Number of threads handling Vertex AI requests. | |
vertex_ai_default_model: The name of the model to use for single-model inference requests. | |
metrics_config: Specify a metrics-specific configuration setting. | |
The format of this flag is `<setting>=<value>`. It can be specified multiple times | |
trace_config: Specify global or trace mode specific configuration setting. | |
The format of this flag is `<mode>,<setting>=<value>`. | |
Where `<mode>` is either 'triton' or 'opentelemetry'. The default is 'triton'. | |
To specify global trace settings (level, rate, count, or mode), the format would be `<setting>=<value>`. | |
For 'triton' mode, the server will use Triton's Trace APIs. | |
For 'opentelemetry' mode, the server will use OpenTelemetry's APIs to generate, | |
collect and export traces for individual inference requests. | |
More details, including supported settings can be found at [Triton trace guide](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/trace.md). | |
cache_config: Specify a cache-specific configuration setting. | |
The format of this flag is `<cache_name>,<setting>=<value>`. | |
Where `<cache_name>` is the name of the cache, such as 'local' or 'redis'. | |
Example: `local,size=1048576` will configure a 'local' cache implementation | |
with a fixed buffer pool of size 1048576 bytes. | |
cache_directory: The global directory searched for cache shared libraries. Default is '/opt/tritonserver/caches'. | |
This directory is expected to contain a cache implementation as a shared library with the name 'libtritoncache.so'. | |
buffer_manager_thread_count: The number of threads used to accelerate copies and other operations | |
required to manage input and output tensor contents. | |
""" | |
model_repository: Optional[pathlib.Path] = None | |
id: Optional[str] = None | |
log_verbose: Optional[int] = None | |
log_file: Optional[pathlib.Path] = None | |
exit_timeout_secs: Optional[int] = None | |
exit_on_error: Optional[bool] = None | |
strict_readiness: Optional[bool] = None | |
allow_http: Optional[bool] = None | |
http_address: Optional[str] = None | |
http_port: Optional[int] = None | |
http_header_forward_pattern: Optional[str] = None | |
http_thread_count: Optional[int] = None | |
allow_grpc: Optional[bool] = None | |
grpc_address: Optional[str] = None | |
grpc_port: Optional[int] = None | |
grpc_header_forward_pattern: Optional[str] = None | |
grpc_infer_allocation_pool_size: Optional[int] = None | |
grpc_use_ssl: Optional[bool] = None | |
grpc_use_ssl_mutual: Optional[bool] = None | |
grpc_server_cert: Optional[pathlib.Path] = None | |
grpc_server_key: Optional[pathlib.Path] = None | |
grpc_root_cert: Optional[pathlib.Path] = None | |
grpc_infer_response_compression_level: Optional[str] = None | |
grpc_keepalive_time: Optional[int] = None | |
grpc_keepalive_timeout: Optional[int] = None | |
grpc_keepalive_permit_without_calls: Optional[bool] = None | |
grpc_http2_max_pings_without_data: Optional[int] = None | |
grpc_http2_min_recv_ping_interval_without_data: Optional[int] = None | |
grpc_http2_max_ping_strikes: Optional[int] = None | |
allow_metrics: Optional[bool] = None | |
allow_gpu_metrics: Optional[bool] = None | |
allow_cpu_metrics: Optional[bool] = None | |
metrics_interval_ms: Optional[int] = None | |
metrics_port: Optional[int] = None | |
metrics_address: Optional[str] = None | |
allow_sagemaker: Optional[bool] = None | |
sagemaker_port: Optional[int] = None | |
sagemaker_safe_port_range: Optional[str] = None | |
sagemaker_thread_count: Optional[int] = None | |
allow_vertex_ai: Optional[bool] = None | |
vertex_ai_port: Optional[int] = None | |
vertex_ai_thread_count: Optional[int] = None | |
vertex_ai_default_model: Optional[str] = None | |
metrics_config: Optional[List[str]] = None | |
trace_config: Optional[List[str]] = None | |
cache_config: Optional[List[str]] = None | |
cache_directory: Optional[str] = None | |
buffer_manager_thread_count: Optional[int] = None | |
def __post_init__(self): | |
"""Validate configuration for early error handling.""" | |
if self.allow_http not in [True, None] and self.allow_grpc not in [True, None]: | |
raise PyTritonValidationError("The `http` or `grpc` endpoint has to be allowed.") | |
def to_dict(self): | |
"""Map config object to dictionary.""" | |
return dataclasses.asdict(self) | |
def from_dict(cls, config: Dict[str, Any]) -> "TritonConfig": | |
"""Creates a ``TritonConfig`` instance from an input dictionary. Values are converted into correct types. | |
Args: | |
config: a dictionary with all required fields | |
Returns: | |
a ``TritonConfig`` instance | |
""" | |
fields: Dict[str, dataclasses.Field] = {field.name: field for field in dataclasses.fields(cls)} | |
unknown_config_parameters = {name: value for name, value in config.items() if name not in fields} | |
for name, value in unknown_config_parameters.items(): | |
LOGGER.warning( | |
f"Ignoring {name}={value} as could not find matching config field. " | |
f"Available fields: {', '.join(map(str, fields))}" | |
) | |
def _cast_value(_field, _value): | |
field_type = _field.type | |
is_optional = typing_inspect.is_optional_type(field_type) | |
if is_optional: | |
field_type = field_type.__args__[0] | |
if hasattr(field_type, "__origin__") and field_type.__origin__ is list: | |
return list(_value) if _value is not None else None | |
elif isinstance(_value, str) and isinstance(field_type, type) and issubclass(field_type, list): | |
return _value.split(",") | |
return field_type(_value) | |
config_with_casted_values = { | |
name: _cast_value(fields[name], value) for name, value in config.items() if name in fields | |
} | |
return cls(**config_with_casted_values) | |
def from_env(cls) -> "TritonConfig": | |
"""Creates TritonConfig from environment variables. | |
Environment variables should start with `PYTRITON_TRITON_CONFIG_` prefix. For example: | |
PYTRITON_TRITON_CONFIG_GRPC_PORT=45436 | |
PYTRITON_TRITON_CONFIG_LOG_VERBOSE=4 | |
Typical use: | |
triton_config = TritonConfig.from_env() | |
Returns: | |
TritonConfig class instantiated from environment variables. | |
""" | |
prefix = "PYTRITON_TRITON_CONFIG_" | |
config = {} | |
list_pattern = re.compile(r"^(.+?)_(\d+)$") | |
for name, value in os.environ.items(): | |
if name.startswith(prefix): | |
key = name[len(prefix) :].lower() | |
match = list_pattern.match(key) | |
if match: | |
list_key, index = match.groups() | |
index = int(index) | |
if list_key not in config: | |
config[list_key] = [] | |
if len(config[list_key]) <= index: | |
config[list_key].extend([None] * (index + 1 - len(config[list_key]))) | |
config[list_key][index] = value | |
else: | |
config[key] = value | |
# Remove None values from lists (in case of non-sequential indexes) | |
for key in config: | |
if isinstance(config[key], list): | |
config[key] = [item for item in config[key] if item is not None] | |
return cls.from_dict(config) | |
class TritonLifecyclePolicy: | |
"""Triton Inference Server lifecycle policy. | |
Indicates when Triton server is launched and where the model store is located (locally or remotely managed by | |
Triton server). | |
""" | |
launch_triton_on_startup: bool = True | |
local_model_store: bool = False | |
DefaultTritonLifecyclePolicy = TritonLifecyclePolicy() | |
VertextAILifecyclePolicy = TritonLifecyclePolicy(launch_triton_on_startup=False, local_model_store=True) | |
class _LogLevelChecker: | |
"""Check if log level is too verbose.""" | |
def __init__(self, url: str) -> None: | |
"""Initialize LogLevelChecker. | |
Args: | |
url: Triton Inference Server URL in form of <scheme>://<host>:<port> | |
Raises: | |
PyTritonClientInvalidUrlError: if url is invalid | |
""" | |
self._log_settings = None | |
self._url = url | |
def check(self, skip_update: bool = False): | |
"""Check if log level is too verbose. | |
Also obtains wait for server is ready + log settings from server if not already obtained. | |
Raises: | |
PyTritonClientTimeoutError: if timeout is reached | |
""" | |
if self._log_settings is None and not skip_update: | |
with contextlib.closing(create_client_from_url(self._url)) as client: | |
wait_for_server_ready(client, timeout_s=DEFAULT_TRITON_STARTUP_TIMEOUT_S) | |
self._log_settings = client.get_log_settings() | |
if self._log_settings is not None: | |
log_settings = self._log_settings | |
log_verbose_level = 0 | |
if hasattr(log_settings, "settings"): # grpc client | |
for key, value in log_settings.settings.items(): | |
if key == "log_verbose_level": | |
log_verbose_level = value.uint32_param | |
break | |
else: # http client | |
log_settings = {key: str(value) for key, value in log_settings.items()} | |
log_verbose_level = int(log_settings.get("log_verbose_level", 0)) | |
if log_verbose_level > 0: | |
LOGGER.warning( | |
f"Triton Inference Server is running with enabled verbose logs (log_verbose_level={log_verbose_level}). " | |
"It may affect inference performance." | |
) | |
class TritonBase: | |
"""Base class for Triton Inference Server.""" | |
def __init__( | |
self, | |
url: str, | |
workspace: Union[Workspace, str, pathlib.Path, None] = None, | |
triton_lifecycle_policy: TritonLifecyclePolicy = DefaultTritonLifecyclePolicy, | |
): | |
"""Initialize TritonBase. | |
Args: | |
url: Triton Inference Server URL in form of <scheme>://<host>:<port> | |
workspace: Workspace for storing communication sockets and the other temporary files. | |
triton_lifecycle_policy: policy indicating when Triton server is launched and where the model store is located | |
(locally or remotely managed by Triton server). | |
""" | |
self._triton_lifecycle_policy = triton_lifecycle_policy | |
self._workspace = workspace if isinstance(workspace, Workspace) else Workspace(workspace) | |
self._url = url | |
_local_model_config_path = ( | |
self._workspace.model_store_path if triton_lifecycle_policy.local_model_store else None | |
) | |
self._model_manager = ModelManager(self._url, _local_model_config_path) | |
self._cv = th.Condition() | |
self._triton_context = TritonContext() | |
self._log_level_checker = _LogLevelChecker(self._url) | |
with self._cv: | |
self._stopped = True | |
self._connected = False | |
atexit.register(self.stop) | |
def bind( | |
self, | |
model_name: str, | |
infer_func: Union[Callable, Sequence[Callable]], | |
inputs: Sequence[Tensor], | |
outputs: Sequence[Tensor], | |
model_version: int = 1, | |
config: Optional[ModelConfig] = None, | |
strict: bool = False, | |
trace_config: Optional[List[str]] = None, | |
) -> None: | |
"""Create a model with given name and inference callable binding into Triton Inference Server. | |
More information about model configuration: | |
https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md | |
Args: | |
infer_func: Inference callable to handle request/response from Triton Inference Server | |
(or list of inference callable for multi instance model) | |
inputs: Definition of model inputs | |
outputs: Definition of model outputs | |
model_name: Name under which model is available in Triton Inference Server. It can only contain | |
alphanumeric characters, dots, underscores and dashes. | |
model_version: Version of model | |
config: Model configuration for Triton Inference Server deployment | |
strict: Enable strict validation between model config outputs and inference function result | |
trace_config: List of trace config parameters | |
""" | |
self._validate_model_name(model_name) | |
model_kwargs = {} | |
if trace_config is None: | |
triton_config = getattr(self, "_config", None) | |
if triton_config is not None: | |
trace_config = getattr(triton_config, "trace_config", None) | |
if trace_config is not None: | |
LOGGER.info(f"Using trace config from TritonConfig: {trace_config}") | |
model_kwargs["trace_config"] = trace_config | |
else: | |
model_kwargs["trace_config"] = trace_config | |
telemetry_tracer = get_telemetry_tracer() | |
# Automatically set telemetry tracer if not set at the proxy side | |
if telemetry_tracer is None and trace_config is not None: | |
LOGGER.info("Setting telemetry tracer from TritonConfig") | |
telemetry_tracer = build_proxy_tracer_from_triton_config(trace_config) | |
set_telemetry_tracer(telemetry_tracer) | |
model = Model( | |
model_name=model_name, | |
model_version=model_version, | |
inference_fn=infer_func, | |
inputs=inputs, | |
outputs=outputs, | |
config=config if config else ModelConfig(), | |
workspace=self._workspace, | |
triton_context=self._triton_context, | |
strict=strict, | |
**model_kwargs, | |
) | |
model.on_model_event(self._on_model_event) | |
self._model_manager.add_model(model, self.is_connected()) | |
def connect(self) -> None: | |
"""Connect to Triton Inference Server. | |
Raises: | |
TimeoutError: if Triton Inference Server is not ready after timeout | |
""" | |
with self._cv: | |
if self._connected: | |
LOGGER.debug("Triton Inference already connected.") | |
return | |
self._wait_for_server() | |
if self._triton_lifecycle_policy.local_model_store: | |
self._model_manager.setup_models() | |
else: | |
self._model_manager.load_models() | |
self._wait_for_models() | |
self._connected = True | |
def serve(self, monitoring_period_s: float = MONITORING_PERIOD_S) -> None: | |
"""Run Triton Inference Server and lock thread for serving requests/response. | |
Args: | |
monitoring_period_s: the timeout of monitoring if Triton and models are available. | |
Every monitoring_period_s seconds main thread wakes up and check if triton server and proxy backend | |
are still alive and sleep again. If triton or proxy is not alive - method returns. | |
""" | |
self.connect() | |
with self._cv: | |
try: | |
while self.is_alive(): | |
self._cv.wait(timeout=monitoring_period_s) | |
except KeyboardInterrupt: | |
LOGGER.info("SIGINT received, exiting.") | |
self.stop() | |
def stop(self) -> bool: | |
"""Stop Triton Inference Server and clean workspace.""" | |
with self._cv: | |
if self._stopped: | |
LOGGER.debug("Triton Inference already stopped.") | |
return False | |
self._stopped = True | |
self._connected = False | |
atexit.unregister(self.stop) | |
self._pre_stop_impl() | |
self._model_manager.clean() | |
self._workspace.clean() | |
with self._cv: | |
self._cv.notify_all() | |
LOGGER.debug("Stopped Triton Inference server and proxy backends") | |
self._log_level_checker.check(skip_update=True) | |
return True | |
def is_alive(self) -> bool: | |
"""Check if Triton Inference Server is alive.""" | |
if not self._is_alive_impl(): | |
return False | |
for model in self._model_manager.models: | |
if not model.is_alive(): | |
return False | |
return True | |
def is_connected(self) -> bool: | |
"""Check if Triton Inference Server is connected.""" | |
with self._cv: | |
return self._connected | |
def __enter__(self): | |
"""Connects to Triton server on __enter__. | |
Returns: | |
A Triton object | |
""" | |
if self._triton_lifecycle_policy.launch_triton_on_startup: | |
self.connect() | |
return self | |
def __exit__(self, *_) -> None: | |
"""Exit the context stopping the process and cleaning the workspace. | |
Args: | |
*_: unused arguments | |
""" | |
self.stop() | |
def _is_alive_impl(self) -> bool: | |
return True | |
def _pre_stop_impl(self): | |
pass | |
def _post_stop_impl(self): | |
pass | |
def _wait_for_server(self) -> None: | |
"""Wait for Triton Inference Server to be ready.""" | |
self._log_level_checker.check() | |
try: | |
with contextlib.closing(create_client_from_url(self._url)) as client: | |
wait_for_server_ready(client, timeout_s=DEFAULT_TRITON_STARTUP_TIMEOUT_S) | |
except TimeoutError as e: | |
LOGGER.warning( | |
f"Could not verify locally if Triton Inference Server is ready using {self._url}. " | |
"Please, check the server logs for details." | |
) | |
raise TimeoutError("Triton Inference Server is not ready after timeout.") from e | |
def _wait_for_models(self) -> None: | |
"""Log loaded models in console to show the available endpoints.""" | |
self._log_level_checker.check() | |
try: | |
for model in self._model_manager.models: | |
with ModelClient( | |
url=self._url, model_name=model.model_name, model_version=str(model.model_version) | |
) as client: | |
# This waits for only tritonserver and lightweight proxy backend to be ready | |
# timeout should be short as model is loaded before execution of Triton.start() method | |
client.wait_for_model(timeout_s=WAIT_FORM_MODEL_TIMEOUT_S) | |
except TimeoutError: | |
LOGGER.warning( | |
f"Could not verify locally if models are ready using {self._url}. " | |
"Please, check the server logs for details." | |
) | |
for model in self._model_manager.models: | |
LOGGER.info(f"Infer function available as model: `{MODEL_URL.format(model_name=model.model_name)}`") | |
LOGGER.info(f" Status: `GET {MODEL_READY_URL.format(model_name=model.model_name)}`") | |
LOGGER.info(f" Model config: `GET {MODEL_CONFIG_URL.format(model_name=model.model_name)}`") | |
LOGGER.info(f" Inference: `POST {MODEL_INFER_URL.format(model_name=model.model_name)}`") | |
LOGGER.info( | |
"""Read more about configuring and serving models in """ | |
"""documentation: https://triton-inference-server.github.io/pytriton.""" | |
) | |
LOGGER.info(f"(Press CTRL+C or use the command `kill -SIGINT {os.getpid()}` to send a SIGINT signal and quit)") | |
def _on_model_event(self, model: Model, event: ModelEvent, context: typing.Optional[typing.Any] = None): | |
LOGGER.info(f"Received {event} from {model}; context={context}") | |
if event in [ModelEvent.RUNTIME_TERMINATING, ModelEvent.RUNTIME_TERMINATED]: | |
threading.Thread(target=self.stop).start() | |
def _validate_model_name(cls, model_name: str) -> None: | |
"""Validate model name. | |
Args: | |
model_name: Model name | |
""" | |
if not model_name: | |
raise PyTritonValidationError("Model name cannot be empty") | |
if not re.match(r"^[a-zA-Z0-9._-]+$", model_name): | |
raise PyTritonValidationError( | |
"Model name can only contain alphanumeric characters, dots, underscores and dashes" | |
) | |
class Triton(TritonBase): | |
"""Triton Inference Server for Python models.""" | |
def __init__( | |
self, | |
*, | |
config: Optional[TritonConfig] = None, | |
workspace: Union[Workspace, str, pathlib.Path, None] = None, | |
triton_lifecycle_policy: Optional[TritonLifecyclePolicy] = None, | |
): | |
"""Initialize Triton Inference Server context for starting server and loading models. | |
Args: | |
config: TritonConfig object with optional customizations for Triton Inference Server. | |
Configuration can be passed also through environment variables. | |
See [TritonConfig.from_env()][pytriton.triton.TritonConfig.from_env] class method for details. | |
Order of precedence: | |
- config defined through `config` parameter of init method. | |
- config defined in environment variables | |
- default TritonConfig values | |
workspace: workspace or path where the Triton Model Store and files used by pytriton will be created. | |
If workspace is `None` random workspace will be created. | |
Workspace will be deleted in [Triton.stop()][pytriton.triton.Triton.stop]. | |
triton_lifecycle_policy: policy indicating when Triton server is launched and where the model store is located | |
(locally or remotely managed by Triton server). If triton_lifecycle_policy is None, | |
DefaultTritonLifecyclePolicy is used by default (Triton server is launched on startup and model store is not local). | |
Only if triton_lifecycle_policy is None and config.allow_vertex_ai is True, VertextAILifecyclePolicy is used instead. | |
""" | |
_triton_lifecycle_policy = ( | |
VertextAILifecyclePolicy | |
if triton_lifecycle_policy is None and config is not None and config.allow_vertex_ai | |
else triton_lifecycle_policy | |
) or DefaultTritonLifecyclePolicy | |
def _without_none_values(_d): | |
return {name: value for name, value in _d.items() if value is not None} | |
default_config_dict = _without_none_values(TritonConfig().to_dict()) | |
env_config_dict = _without_none_values(TritonConfig.from_env().to_dict()) | |
explicit_config_dict = _without_none_values(config.to_dict() if config else {}) | |
config_dict = {**default_config_dict, **env_config_dict, **explicit_config_dict} | |
self._config = TritonConfig(**config_dict) | |
workspace_instance = workspace if isinstance(workspace, Workspace) else Workspace(workspace) | |
self._prepare_triton_config(workspace_instance) | |
endpoint_protocol = "http" if self._config.allow_http in [True, None] else "grpc" | |
super().__init__( | |
url=endpoint_utils.get_endpoint(self._triton_server_config, endpoint_protocol), | |
workspace=workspace_instance, | |
triton_lifecycle_policy=_triton_lifecycle_policy, | |
) | |
self._triton_server = None | |
def __enter__(self) -> "Triton": | |
"""Entering the context launches the triton server. | |
Returns: | |
A Triton object | |
""" | |
if self._triton_lifecycle_policy.launch_triton_on_startup: | |
self._run_server() | |
super().__enter__() | |
return self | |
def run(self) -> None: | |
"""Run Triton Inference Server.""" | |
self._run_server() | |
self.connect() | |
def serve(self, monitoring_period_s: float = MONITORING_PERIOD_S) -> None: | |
"""Run Triton Inference Server and lock thread for serving requests/response. | |
Args: | |
monitoring_period_s: the timeout of monitoring if Triton and models are available. | |
Every monitoring_period_s seconds main thread wakes up and check if triton server and proxy backend | |
are still alive and sleep again. If triton or proxy is not alive - method returns. | |
""" | |
self._run_server() | |
super().serve(monitoring_period_s=monitoring_period_s) | |
def _initialize_server(self) -> None: | |
"""Initialize Triton Inference Server before binary execution.""" | |
self._triton_inference_server_path = self._prepare_triton_inference_server() | |
self._triton_server = TritonServer( | |
path=(self._triton_inference_server_path / "bin" / "tritonserver").as_posix(), | |
libs_path=get_libs_path(), | |
config=self._triton_server_config, | |
) | |
url = ( | |
self._triton_server.get_endpoint("http") | |
if (self._config.allow_http is None or self._config.allow_http) | |
else self._triton_server.get_endpoint("grpc") | |
) | |
self._log_level_checker = _LogLevelChecker(url) | |
def _prepare_triton_config(self, workspace: Workspace) -> None: | |
self._triton_server_config = TritonServerConfig() | |
config_data = self._config.to_dict() | |
self._python_backend_config = PythonBackendConfig() | |
python_backend_config_data = { | |
"shm-region-prefix-name": self._shm_prefix(), | |
"shm-default-byte-size": INITIAL_BACKEND_SHM_SIZE, | |
"shm-growth-byte-size": GROWTH_BACKEND_SHM_SIZE, | |
} | |
for name, value in python_backend_config_data.items(): | |
if name not in PythonBackendConfig.allowed_keys() or value is None: | |
continue | |
if isinstance(value, pathlib.Path): | |
value = value.as_posix() | |
self._python_backend_config[name] = value | |
for name, value in config_data.items(): | |
if name not in TritonServerConfig.allowed_keys() or value is None: | |
continue | |
if isinstance(value, pathlib.Path): | |
value = value.as_posix() | |
self._triton_server_config[name] = value | |
self._triton_server_config["model_control_mode"] = "explicit" | |
self._triton_server_config["load-model"] = "*" | |
self._triton_server_config["backend_config"] = self._python_backend_config.to_list_args() | |
if "model_repository" not in self._triton_server_config: | |
self._triton_server_config["model_repository"] = workspace.model_store_path.as_posix() | |
def _prepare_triton_inference_server(self) -> pathlib.Path: | |
"""Prepare binaries and libraries of Triton Inference Server for execution. | |
Return: | |
Path where Triton binaries are ready for execution | |
""" | |
triton_inference_server_path = self._workspace.path / "tritonserver" | |
LOGGER.debug("Preparing Triton Inference Server binaries and libs for execution.") | |
shutil.copytree( | |
TRITONSERVER_DIST_DIR, | |
triton_inference_server_path, | |
ignore=shutil.ignore_patterns("python_backend_stubs", "triton_python_backend_stub"), | |
) | |
LOGGER.debug(f"Triton Inference Server binaries copied to {triton_inference_server_path} without stubs.") | |
major = sys.version_info[0] | |
minor = sys.version_info[1] | |
version = f"{major}.{minor}" | |
src_stub_path = get_stub_path(version) | |
dst_stub_path = triton_inference_server_path / "backends" / "python" / "triton_python_backend_stub" | |
LOGGER.debug(f"Copying stub for version {version} from {src_stub_path} to {dst_stub_path}") | |
shutil.copy(src_stub_path, dst_stub_path) | |
LOGGER.debug(f"Triton Inference Server binaries ready in {triton_inference_server_path}") | |
self._triton_server_config["backend_directory"] = (triton_inference_server_path / "backends").as_posix() | |
if "cache_directory" not in self._triton_server_config: | |
self._triton_server_config["cache_directory"] = (triton_inference_server_path / "caches").as_posix() | |
return triton_inference_server_path | |
def _shm_prefix(self) -> str: | |
"""Generate unique prefix for shm memory. | |
Returns: | |
String with prefix | |
""" | |
hash = codecs.encode(os.urandom(4), "hex").decode() | |
pid = os.getpid() | |
return f"pytrtion{pid}-{hash}" | |
def _run_server(self): | |
"""Run Triton Inference Server.""" | |
if self._triton_server is None: | |
self._initialize_server() | |
if not self._triton_server.is_alive(): | |
with self._cv: | |
self._stopped = False | |
LOGGER.debug("Starting Triton Inference") | |
self._triton_server.register_on_exit(self._on_tritonserver_exit) | |
self._triton_server.start() | |
def _is_alive_impl(self) -> bool: | |
"""Verify is deployed models and server are alive. | |
Returns: | |
True if server and loaded models are alive, False otherwise. | |
""" | |
if not self._triton_server: | |
return False | |
return self._triton_server.is_alive() | |
def _pre_stop_impl(self): | |
self._triton_server.unregister_on_exit(self._on_tritonserver_exit) | |
if self._triton_server is not None: | |
self._triton_server.stop() | |
def _on_tritonserver_exit(self, *_) -> None: | |
"""Handle the Triton Inference Server process exit. | |
Args: | |
_: unused arguments | |
""" | |
LOGGER.debug("Got callback that tritonserver process finished") | |
self.stop() | |
class RemoteTriton(TritonBase): | |
"""RemoteTriton connects to Triton Inference Server running on remote host.""" | |
def __init__(self, url: str, workspace: Union[Workspace, str, pathlib.Path, None] = None): | |
"""Initialize RemoteTriton. | |
Args: | |
url: Triton Inference Server URL in form of <scheme>://<host>:<port> | |
If scheme is not provided, http is used as default. | |
If port is not provided, 8000 is used as default for http and 8001 for grpc. | |
workspace: path to be created where the files used by pytriton will be stored | |
(e.g. socket files for communication). | |
If workspace is `None` temporary workspace will be created. | |
Workspace should be created in shared filesystem space between RemoteTriton | |
and Triton Inference Server to allow access to socket files | |
(if you use containers, folder must be shared between containers). | |
""" | |
super().__init__( | |
url=TritonUrl.from_url(url).with_scheme, | |
workspace=workspace, | |
triton_lifecycle_policy=TritonLifecyclePolicy(launch_triton_on_startup=True, local_model_store=False), | |
) | |
with self._cv: | |
self._stopped = False | |
def __enter__(self) -> "RemoteTriton": | |
"""Entering the context connects to remote Triton server. | |
Returns: | |
A RemoteTriton object | |
""" | |
super().__enter__() | |
return self | |