File size: 4,149 Bytes
e3af00f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import re
from unittest.mock import PropertyMock

import pytest

from pytriton.exceptions import PyTritonValidationError
from pytriton.triton import (
    GROWTH_BACKEND_SHM_SIZE,
    INITIAL_BACKEND_SHM_SIZE,
    TRITONSERVER_DIST_DIR,
    Triton,
    TritonConfig,
)

EXPECTED_BACKEND_ARGS = [
    "python,shm-region-prefix-name=pytrtion[0-9]+",
    f"python,shm-default-byte-size={INITIAL_BACKEND_SHM_SIZE}",
    f"python,shm-growth-byte-size={GROWTH_BACKEND_SHM_SIZE}",
]


def test_triton_is_alive_return_false_when_not_initialized():
    triton = Triton()

    assert triton._triton_server is None
    assert triton.is_alive() is False


def test_triton_server_initialize_server_with_default_arguments(mocker):
    triton = Triton()
    triton._prepare_triton_inference_server()

    assert triton._triton_server_config["model_repository"] is not None
    assert triton._triton_server_config["backend_directory"] is not None
    assert len(triton._triton_server_config["backend_config"]) == 3
    for idx in range(len(EXPECTED_BACKEND_ARGS)):
        assert re.match(EXPECTED_BACKEND_ARGS[idx], triton._triton_server_config["backend_config"][idx])


def test_triton_server_initialize_server_with_custom_arguments(mocker):
    config = TritonConfig(id="CustomId", model_repository=pathlib.Path("/tmp"), allow_metrics=False)
    triton = Triton(config=config)
    triton._prepare_triton_inference_server()

    assert triton._triton_server_config["id"] == "CustomId"
    assert triton._triton_server_config["model_repository"] == "/tmp"
    assert triton._triton_server_config["allow_metrics"] is False
    assert triton._triton_server_config["backend_directory"] is not None
    for idx in range(len(EXPECTED_BACKEND_ARGS)):
        assert re.match(EXPECTED_BACKEND_ARGS[idx], triton._triton_server_config["backend_config"][idx])


def test_triton_server_initialize_server_with_custom_arguments_and_env_variables(mocker):
    import os

    updated_environ = {
        **os.environ,
        "PYTRITON_TRITON_CONFIG_GRPC_PORT": "8080",
        "PYTRITON_TRITON_CONFIG_MODEL_REPOSITORY": "/opt",
    }
    mocker.patch("os.environ", new_callable=PropertyMock(return_value=updated_environ))

    config = TritonConfig(id="CustomId", model_repository=pathlib.Path("/tmp"), allow_metrics=False)
    triton = Triton(config=config)
    triton._prepare_triton_inference_server()

    assert triton._triton_server_config["id"] == "CustomId"
    assert triton._triton_server_config["model_repository"] == "/tmp"
    assert triton._triton_server_config["grpc_port"] == 8080
    assert triton._triton_server_config["allow_metrics"] is False
    assert triton._triton_server_config["backend_directory"] is not None
    assert triton._triton_server_config["backend_config"] is not None
    for idx in range(len(EXPECTED_BACKEND_ARGS)):
        assert re.match(EXPECTED_BACKEND_ARGS[idx], triton._triton_server_config["backend_config"][idx])


def test_triton_bind_model_name_verification(mocker):
    mocker.patch.object(Triton, "_prepare_triton_inference_server").return_value = TRITONSERVER_DIST_DIR

    triton = Triton()
    triton.bind("AB-cd_90.1", lambda: None, [], [])

    with pytest.raises(
        PyTritonValidationError,
        match="Model name can only contain alphanumeric characters, dots, underscores and dashes",
    ):
        triton.bind("AB#cd/90/1", lambda: None, [], [])

    with pytest.raises(PyTritonValidationError, match="Model name cannot be empty"):
        triton.bind("", lambda: None, [], [])