yerang commited on
Commit
e3af00f
·
verified ·
1 Parent(s): cc89c8b

Upload 1110 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. stf/.DS_Store +0 -0
  2. stf/089.npz +3 -0
  3. stf/089.pth +3 -0
  4. stf/stf-api-alternative/.gitignore +160 -0
  5. stf/stf-api-alternative/.ipynb_checkpoints/README-checkpoint.md +1 -0
  6. stf/stf-api-alternative/.ipynb_checkpoints/poetry-checkpoint.lock +0 -0
  7. stf/stf-api-alternative/.ipynb_checkpoints/pyproject-checkpoint.toml +35 -0
  8. stf/stf-api-alternative/README.md +1 -0
  9. stf/stf-api-alternative/poetry.lock +0 -0
  10. stf/stf-api-alternative/pyproject.toml +35 -0
  11. stf/stf-api-alternative/pytriton/.flake8 +19 -0
  12. stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/bug_report.md +83 -0
  13. stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
  14. stf/stf-api-alternative/pytriton/.github/workflows/stale.yaml +35 -0
  15. stf/stf-api-alternative/pytriton/.gitignore +330 -0
  16. stf/stf-api-alternative/pytriton/.pre-commit-config.yaml +76 -0
  17. stf/stf-api-alternative/pytriton/CHANGELOG.md +239 -0
  18. stf/stf-api-alternative/pytriton/CONTRIBUTING.md +203 -0
  19. stf/stf-api-alternative/pytriton/COPYRIGHT +13 -0
  20. stf/stf-api-alternative/pytriton/LICENSE +174 -0
  21. stf/stf-api-alternative/pytriton/Makefile +124 -0
  22. stf/stf-api-alternative/pytriton/README.md +343 -0
  23. stf/stf-api-alternative/pytriton/build/lib/pytriton/__init__.py +27 -0
  24. stf/stf-api-alternative/pytriton/build/lib/pytriton/__main__.py +218 -0
  25. stf/stf-api-alternative/pytriton/build/lib/pytriton/check/__init__.py +14 -0
  26. stf/stf-api-alternative/pytriton/build/lib/pytriton/check/add_sub.py +139 -0
  27. stf/stf-api-alternative/pytriton/build/lib/pytriton/check/env_checks.py +201 -0
  28. stf/stf-api-alternative/pytriton/build/lib/pytriton/check/utils.py +555 -0
  29. stf/stf-api-alternative/pytriton/build/lib/pytriton/client/__init__.py +22 -0
  30. stf/stf-api-alternative/pytriton/build/lib/pytriton/client/asyncio_utils.py +308 -0
  31. stf/stf-api-alternative/pytriton/build/lib/pytriton/client/client.py +2033 -0
  32. stf/stf-api-alternative/pytriton/build/lib/pytriton/client/exceptions.py +92 -0
  33. stf/stf-api-alternative/pytriton/build/lib/pytriton/client/utils.py +384 -0
  34. stf/stf-api-alternative/pytriton/build/lib/pytriton/client/warnings.py +26 -0
  35. stf/stf-api-alternative/pytriton/build/lib/pytriton/constants.py +31 -0
  36. stf/stf-api-alternative/pytriton/build/lib/pytriton/decorators.py +678 -0
  37. stf/stf-api-alternative/pytriton/build/lib/pytriton/exceptions.py +80 -0
  38. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/__init__.py +17 -0
  39. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/common.py +93 -0
  40. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/generator.py +284 -0
  41. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/model_config.py +43 -0
  42. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/parser.py +258 -0
  43. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/tensor.py +57 -0
  44. stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/triton_model_config.py +68 -0
  45. stf/stf-api-alternative/pytriton/build/lib/pytriton/models/__init__.py +14 -0
  46. stf/stf-api-alternative/pytriton/build/lib/pytriton/models/manager.py +147 -0
  47. stf/stf-api-alternative/pytriton/build/lib/pytriton/models/model.py +335 -0
  48. stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/__init__.py +14 -0
  49. stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/communication.py +555 -0
  50. stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/data.py +1133 -0
stf/.DS_Store ADDED
Binary file (6.15 kB). View file
 
stf/089.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ce3fb07d8d15495eab879b47413c6b86bce114ca9ecd375b45b54777cf0e175
3
+ size 522605028
stf/089.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba4eb3437019d77abed141d60bcb5489b664f494cf965eec0bccf304c3d79b2a
3
+ size 1567401123
stf/stf-api-alternative/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
stf/stf-api-alternative/.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1 @@
 
 
1
+ stf_api와 동일한 기능을 수행하는 라이브러리
stf/stf-api-alternative/.ipynb_checkpoints/poetry-checkpoint.lock ADDED
The diff for this file is too large to render. See raw diff
 
stf/stf-api-alternative/.ipynb_checkpoints/pyproject-checkpoint.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "stf-alternative"
3
+ version = "0.1.0"
4
+ description = "alternative version of stf-api"
5
+ authors = ["Kim Minjong <[email protected]>"]
6
+ readme = "README.md"
7
+ packages = [
8
+ {include = "stf_alternative", from="src"}
9
+ ]
10
+
11
+ [tool.poetry.dependencies]
12
+ python = "^3.10"
13
+ librosa = "0.8.1"
14
+ imageio = "2.13.5"
15
+ imageio-ffmpeg = "0.4.5"
16
+ Pillow = "9.1.0"
17
+ tqdm = "4.64.0"
18
+ numpy = "1.22.4"
19
+ addict = "2.4.0"
20
+ scipy = "1.12.0"
21
+ pandas = "1.3.5"
22
+ face_alignment = "1.3.5"
23
+ moviepy = "1.0.3"
24
+ transformers = "4.29.2"
25
+ facenet_pytorch = "2.5.2"
26
+ ffmpeg-python = "^0.2"
27
+ pydub = "^0.25"
28
+ av = "^11.0.0"
29
+ nvidia-pytriton = {extras = ["client"], version = "^0.4.2"}
30
+ asyncstdlib = "^3.10.9"
31
+
32
+
33
+ [build-system]
34
+ requires = ["poetry-core"]
35
+ build-backend = "poetry.core.masonry.api"
stf/stf-api-alternative/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ stf_api와 동일한 기능을 수행하는 라이브러리
stf/stf-api-alternative/poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
stf/stf-api-alternative/pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "stf-alternative"
3
+ version = "0.1.0"
4
+ description = "alternative version of stf-api"
5
+ authors = ["Kim Minjong <[email protected]>"]
6
+ readme = "README.md"
7
+ packages = [
8
+ {include = "stf_alternative", from="src"}
9
+ ]
10
+
11
+ [tool.poetry.dependencies]
12
+ python = "^3.10"
13
+ librosa = "0.8.1"
14
+ imageio = "2.13.5"
15
+ imageio-ffmpeg = "0.4.5"
16
+ Pillow = "9.1.0"
17
+ tqdm = "4.64.0"
18
+ numpy = "1.24.4"
19
+ addict = "2.4.0"
20
+ scipy = "1.12.0"
21
+ pandas = "1.3.5"
22
+ face_alignment = "1.3.5"
23
+ moviepy = "1.0.3"
24
+ transformers = "4.29.2"
25
+ facenet_pytorch = "2.5.2"
26
+ ffmpeg-python = "^0.2"
27
+ pydub = "^0.25"
28
+ av = "^11.0.0"
29
+ nvidia-pytriton = {extras = ["client"], version = "^0.4.2"}
30
+ asyncstdlib = "^3.10.9"
31
+
32
+
33
+ [build-system]
34
+ requires = ["poetry-core"]
35
+ build-backend = "poetry.core.masonry.api"
stf/stf-api-alternative/pytriton/.flake8 ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ [flake8]
15
+ exclude = docs,experiments,blueprints,pytriton/tritonserver,sandbox
16
+ ignore = E203, E266, E501, W503
17
+ max-line-length = 120
18
+ max-complexity = 18
19
+ select = B,C,D,E,F,W,T,N
stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Description**
11
+
12
+ A clear and concise description of the bug.
13
+
14
+ **To reproduce**
15
+
16
+ If relevant, add a minimal example so that we can reproduce the error, if necessary, by running the code. For example:
17
+
18
+ ```python
19
+ # server
20
+ from pytriton.decorators import batch
21
+ from pytriton.model_config import ModelConfig, Tensor
22
+ from pytriton.triton import Triton
23
+
24
+ @batch
25
+ def _infer_fn(**inputs):
26
+ ...
27
+ results_dict = model(**inputs) # ex note: the bug is here, we expect to receive ...
28
+ ...
29
+ # note: observing results_dict as dictionary of numpy arrays
30
+ return results_dict
31
+
32
+
33
+ with Triton() as triton:
34
+ triton.bind(
35
+ model_name="MyModel",
36
+ infer_func=_infer_fn,
37
+ inputs=[
38
+ Tensor(name="in1", dtype=np.float32, shape=(-1,)),
39
+ Tensor(name="in2", dtype=np.float32, shape=(-1,)),
40
+ ],
41
+ outputs=[
42
+ Tensor(name="out1", dtype=np.float32, shape=(-1,)),
43
+ Tensor(name="out2", dtype=np.float32, shape=(-1,)),
44
+ ],
45
+ config=ModelConfig(max_batch_size=128),
46
+ )
47
+ triton.serve()
48
+ ```
49
+
50
+ ```python
51
+ # client
52
+ import numpy as np
53
+ from pytriton.client import ModelClient
54
+
55
+ batch_size = 2
56
+ in1_batch = np.ones((batch_size, 1), dtype=np.float32)
57
+ in2_batch = np.ones((batch_size, 1), dtype=np.float32)
58
+
59
+ with ModelClient("localhost", "MyModel") as client:
60
+ result_batch = client.infer_batch(in1_batch, in2_batch)
61
+ ```
62
+
63
+ **Observed results and expected behavior**
64
+
65
+ Please describe the observed results as well as the expected results.
66
+ If possible, attach relevant log output to help analyze your problem.
67
+ If an error is raised, please paste the full traceback of the exception.
68
+
69
+ ```
70
+
71
+ ```
72
+
73
+ **Environment**
74
+
75
+ - OS/container version: [e.g., container nvcr.io/nvidia/pytorch:23.02-py3 / virtual machine with Ubuntu 22.04]
76
+ - glibc version: [e.g., 2.31; can be checked with `ldd --version`]
77
+ - Python interpreter distribution and version: [e.g., CPython 3.8 / conda 4.7.12 with Python 3.8 environment]
78
+ - pip version: [e.g., 23.1.2]
79
+ - PyTriton version: [e.g., 0.1.4 / custom build from source at commit ______]
80
+ - Deployment details: [e.g., multi-node multi-GPU setup on GKE / multi-GPU single-node setup in Jupyter Notebook]
81
+
82
+ **Additional context**
83
+ Add any other context about the problem here.
stf/stf-api-alternative/pytriton/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12
+
13
+ **Describe the solution you'd like**
14
+ A clear and concise description of what you want to happen.
15
+
16
+ **Describe alternatives you've considered**
17
+ A clear and concise description of any alternative solutions or features you've considered.
18
+
19
+ **Additional context**
20
+ Add any other context or screenshots about the feature request here.
stf/stf-api-alternative/pytriton/.github/workflows/stale.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ name: 'Close stale issues and PRs'
15
+ on:
16
+ schedule:
17
+ - cron: "30 1 * * *"
18
+ jobs:
19
+ stale:
20
+ if: github.repository_owner == 'triton-inference-server'
21
+ runs-on: ubuntu-latest
22
+ permissions:
23
+ issues: write
24
+ pull-requests: write
25
+ steps:
26
+ - uses: actions/stale@v8
27
+ with:
28
+ days-before-stale: 21
29
+ days-before-close: 7
30
+ stale-issue-message: 'This issue is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
31
+ stale-pr-message: 'This PR is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
32
+ close-issue-message: 'This issue was closed because it has been stalled for 7 days with no activity.'
33
+ close-pr-message: 'This PR was closed because it has been stalled for 7 days with no activity.'
34
+ exempt-issue-labels: 'non-stale'
35
+ exempt-pr-labels: 'non-stale'
stf/stf-api-alternative/pytriton/.gitignore ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Created by https://www.toptal.com/developers/gitignore/api/pycharm+all,visualstudiocode,python,direnv,vim
15
+ # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm+all,visualstudiocode,python,direnv,vim
16
+
17
+ ### direnv ###
18
+ .direnv
19
+ .envrc
20
+
21
+ ### PyCharm+all ###
22
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
23
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
24
+
25
+ # User-specific stuff
26
+ .idea/**/workspace.xml
27
+ .idea/**/tasks.xml
28
+ .idea/**/usage.statistics.xml
29
+ .idea/**/dictionaries
30
+ .idea/**/shelf
31
+
32
+ # AWS User-specific
33
+ .idea/**/aws.xml
34
+
35
+ # Generated files
36
+ .idea/**/contentModel.xml
37
+
38
+ # Sensitive or high-churn files
39
+ .idea/**/dataSources/
40
+ .idea/**/dataSources.ids
41
+ .idea/**/dataSources.local.xml
42
+ .idea/**/sqlDataSources.xml
43
+ .idea/**/dynamic.xml
44
+ .idea/**/uiDesigner.xml
45
+ .idea/**/dbnavigator.xml
46
+
47
+ # Gradle
48
+ .idea/**/gradle.xml
49
+ .idea/**/libraries
50
+
51
+ # Gradle and Maven with auto-import
52
+ # When using Gradle or Maven with auto-import, you should exclude module files,
53
+ # since they will be recreated, and may cause churn. Uncomment if using
54
+ # auto-import.
55
+ # .idea/artifacts
56
+ # .idea/compiler.xml
57
+ # .idea/jarRepositories.xml
58
+ # .idea/modules.xml
59
+ # .idea/*.iml
60
+ # .idea/modules
61
+ # *.iml
62
+ # *.ipr
63
+
64
+ # CMake
65
+ cmake-build-*/
66
+
67
+ # Mongo Explorer plugin
68
+ .idea/**/mongoSettings.xml
69
+
70
+ # File-based project format
71
+ *.iws
72
+
73
+ # IntelliJ
74
+ out/
75
+
76
+ # mpeltonen/sbt-idea plugin
77
+ .idea_modules/
78
+
79
+ # JIRA plugin
80
+ atlassian-ide-plugin.xml
81
+
82
+ # Cursive Clojure plugin
83
+ .idea/replstate.xml
84
+
85
+ # SonarLint plugin
86
+ .idea/sonarlint/
87
+
88
+ # Crashlytics plugin (for Android Studio and IntelliJ)
89
+ com_crashlytics_export_strings.xml
90
+ crashlytics.properties
91
+ crashlytics-build.properties
92
+ fabric.properties
93
+
94
+ # Editor-based Rest Client
95
+ .idea/httpRequests
96
+
97
+ # Android studio 3.1+ serialized cache file
98
+ .idea/caches/build_file_checksums.ser
99
+
100
+ ### PyCharm+all Patch ###
101
+ # Ignore everything but code style settings and run configurations
102
+ # that are supposed to be shared within teams.
103
+
104
+ .idea/*
105
+
106
+ !.idea/codeStyles
107
+ !.idea/runConfigurations
108
+
109
+ ### Python ###
110
+ # Byte-compiled / optimized / DLL files
111
+ __pycache__/
112
+ *.py[cod]
113
+ *$py.class
114
+
115
+ # C extensions
116
+ *.so
117
+
118
+ # Distribution / packaging
119
+ .Python
120
+ build/
121
+ develop-eggs/
122
+ dist/
123
+ downloads/
124
+ eggs/
125
+ .eggs/
126
+ lib/
127
+ lib64/
128
+ parts/
129
+ sdist/
130
+ var/
131
+ wheels/
132
+ share/python-wheels/
133
+ *.egg-info/
134
+ .installed.cfg
135
+ *.egg
136
+ MANIFEST
137
+
138
+ # PyInstaller
139
+ # Usually these files are written by a python script from a template
140
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
141
+ *.manifest
142
+ *.spec
143
+
144
+ # Installer logs
145
+ pip-log.txt
146
+ pip-delete-this-directory.txt
147
+
148
+ # Unit test / coverage reports
149
+ htmlcov/
150
+ .tox/
151
+ .nox/
152
+ .coverage
153
+ .coverage.*
154
+ .cache
155
+ nosetests.xml
156
+ coverage.xml
157
+ *.cover
158
+ *.py,cover
159
+ .hypothesis/
160
+ .pytest_cache/
161
+ cover/
162
+
163
+ # Translations
164
+ *.mo
165
+ *.pot
166
+
167
+ # Django stuff:
168
+ *.log
169
+ local_settings.py
170
+ db.sqlite3
171
+ db.sqlite3-journal
172
+
173
+ # Flask stuff:
174
+ instance/
175
+ .webassets-cache
176
+
177
+ # Scrapy stuff:
178
+ .scrapy
179
+
180
+ # Sphinx documentation
181
+ docs/_build/
182
+
183
+ # PyBuilder
184
+ .pybuilder/
185
+ target/
186
+
187
+ # Jupyter Notebook
188
+ .ipynb_checkpoints
189
+
190
+ # IPython
191
+ profile_default/
192
+ ipython_config.py
193
+
194
+ # pyenv
195
+ # For a library or package, you might want to ignore these files since the code is
196
+ # intended to run in multiple environments; otherwise, check them in:
197
+ # .python-version
198
+
199
+ # pipenv
200
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
201
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
202
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
203
+ # install all needed dependencies.
204
+ #Pipfile.lock
205
+
206
+ # poetry
207
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
208
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
209
+ # commonly ignored for libraries.
210
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
211
+ #poetry.lock
212
+
213
+ # pdm
214
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
215
+ #pdm.lock
216
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
217
+ # in version control.
218
+ # https://pdm.fming.dev/#use-with-ide
219
+ .pdm.toml
220
+
221
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
222
+ __pypackages__/
223
+
224
+ # Celery stuff
225
+ celerybeat-schedule
226
+ celerybeat.pid
227
+
228
+ # SageMath parsed files
229
+ *.sage.py
230
+
231
+ # Environments
232
+ .env
233
+ .venv
234
+ env/
235
+ venv/
236
+ ENV/
237
+ env.bak/
238
+ venv.bak/
239
+
240
+ # Spyder project settings
241
+ .spyderproject
242
+ .spyproject
243
+
244
+ # Rope project settings
245
+ .ropeproject
246
+
247
+ # mkdocs documentation
248
+ /site
249
+
250
+ # mypy
251
+ .mypy_cache/
252
+ .dmypy.json
253
+ dmypy.json
254
+
255
+ # Pyre type checker
256
+ .pyre/
257
+
258
+ # pytype static type analyzer
259
+ .pytype/
260
+
261
+ # Cython debug symbols
262
+ cython_debug/
263
+
264
+ # PyCharm
265
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
266
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
267
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
268
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
269
+ #.idea/
270
+
271
+ ### Python Patch ###
272
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
273
+ poetry.toml
274
+
275
+ # ruff
276
+ .ruff_cache/
277
+
278
+ # LSP config files
279
+ pyrightconfig.json
280
+
281
+ ### Vim ###
282
+ # Swap
283
+ [._]*.s[a-v][a-z]
284
+ !*.svg # comment out if you don't need vector files
285
+ [._]*.sw[a-p]
286
+ [._]s[a-rt-v][a-z]
287
+ [._]ss[a-gi-z]
288
+ [._]sw[a-p]
289
+
290
+ # Session
291
+ Session.vim
292
+ Sessionx.vim
293
+
294
+ # Temporary
295
+ .netrwhist
296
+ *~
297
+ # Auto-generated tag files
298
+ tags
299
+ # Persistent undo
300
+ [._]*.un~
301
+
302
+ ### VisualStudioCode ###
303
+ .vscode/*
304
+ !.vscode/settings.json
305
+ !.vscode/tasks.json
306
+ !.vscode/launch.json
307
+ !.vscode/extensions.json
308
+ !.vscode/*.code-snippets
309
+
310
+ # Local History for Visual Studio Code
311
+ .history/
312
+
313
+ # Built Visual Studio Code Extensions
314
+ *.vsix
315
+
316
+ ### VisualStudioCode Patch ###
317
+ # Ignore all local history of files
318
+ .history
319
+ .ionide
320
+
321
+ # End of https://www.toptal.com/developers/gitignore/api/pycharm+all,visualstudiocode,python,direnv,vim
322
+ pytriton/tritonserver
323
+ docs/CHANGELOG.md
324
+ docs/CONTRIBUTING.md
325
+ docs/LICENSE.md
326
+ docs/examples.md
327
+
328
+ ### VisualStudioCode+all ##
329
+ .vscode
330
+ .devcontainer
stf/stf-api-alternative/pytriton/.pre-commit-config.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ exclude: kubernetes
15
+ repos:
16
+ - repo: https://github.com/ambv/black
17
+ rev: 23.11.0
18
+ hooks:
19
+ - id: black
20
+ - repo: https://github.com/pycqa/isort
21
+ rev: 5.12.0
22
+ hooks:
23
+ - id: isort
24
+ name: isort (python)
25
+ - repo: https://github.com/pre-commit/pre-commit-hooks
26
+ rev: v4.5.0
27
+ hooks:
28
+ - id: check-docstring-first
29
+ - id: check-executables-have-shebangs
30
+ - id: check-json
31
+ - id: check-merge-conflict
32
+ - id: detect-private-key
33
+ - id: check-shebang-scripts-are-executable
34
+ - id: check-toml
35
+ - id: check-yaml
36
+ - id: debug-statements
37
+ - id: end-of-file-fixer
38
+ types: [python]
39
+ - id: fix-byte-order-marker
40
+ - id: no-commit-to-branch
41
+ - id: requirements-txt-fixer
42
+ - id: trailing-whitespace
43
+ exclude: setup.cfg
44
+ - id: mixed-line-ending
45
+ args: [--fix=lf]
46
+ - repo: https://github.com/asottile/pyupgrade
47
+ rev: v3.15.0
48
+ hooks:
49
+ - id: pyupgrade
50
+ args: [--py36-plus]
51
+ - repo: https://github.com/pycqa/flake8
52
+ rev: 6.1.0
53
+ hooks:
54
+ - id: flake8
55
+ additional_dependencies:
56
+ - flake8-bugbear
57
+ - flake8-comprehensions
58
+ - flake8-print
59
+ - mccabe
60
+ - pep8-naming
61
+ - pycodestyle
62
+ - pyflakes
63
+ - repo: https://github.com/pycqa/pydocstyle
64
+ rev: 6.3.0
65
+ hooks:
66
+ - id: pydocstyle
67
+ name: Run pydocstyle
68
+ args:
69
+ - --convention=google
70
+ exclude: '(?:tests|examples)\/.*'
71
+ additional_dependencies: ['toml']
72
+ - repo: https://github.com/thlorenz/doctoc
73
+ rev: v2.2.0
74
+ hooks:
75
+ - id: doctoc
76
+ args: [ --github, --update-only ]
stf/stf-api-alternative/pytriton/CHANGELOG.md ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ -->
16
+
17
+ # Changelog
18
+
19
+ ## 0.4.2 (2023-12-05)
20
+
21
+ - New: You can create client from existing client instance or model configuration to avoid loading model configuration from server.
22
+ - New: Introduced warning system using the `warnings` module.
23
+ - Fix: Experimental client for decoupled models prevents sending another request, when responses from previous request are not consumed, blocks close until stream is stopped.
24
+ - Fix: Leak of ModelClient during Triton creation
25
+ - Fix: Fixed non-declared project dependencies (removed from use in code or added to package dependencies)
26
+ - Fix: Remote model is being unloaded from Triton when RemoteTriton is closed.
27
+
28
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
29
+
30
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.39.0](https://github.com/triton-inference-server/server/releases/tag/v2.39.0)
31
+
32
+ ## 0.4.1 (2023-11-09)
33
+
34
+ - New: Place where workspaces with temporary Triton model repositories and communication file sockets can be configured by `$PYTRITON_HOME` environment variable
35
+ - Fix: Recover handling `KeyboardInterrupt` in `triton.serve()`
36
+ - Fix: Remove limit for handling bytes dtype tensors
37
+ - Build scripts update
38
+ - Added support for arm64 platform builds
39
+
40
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
41
+
42
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.39.0](https://github.com/triton-inference-server/server/releases/tag/v2.39.0)
43
+
44
+ ## 0.4.0 (2023-10-20)
45
+
46
+ - New: Remote Mode - PyTriton can be used to connect to a remote Triton Inference Server
47
+ - Introduced RemoteTriton class which can be used to connect to a remote Triton Inference Server
48
+ running on the same machine, by passing triton url.
49
+ - Changed Triton lifecycle - now the Triton Inference Server is started while entering the context.
50
+ This allows to load models dynamically to the running server while calling the bind method.
51
+ It is still allowed to create Triton instance without entering the context and bind models before starting
52
+ the server (in this case the models are lazy loaded when calling run or serve method like it worked before).
53
+ - In RemoteTriton class, calling __enter__ or connect method connects to triton server, so we can safely load models
54
+ while binding inference functions (if RemoteTriton is used without context manager, models are lazy loaded
55
+ when calling connect or serve method).
56
+ - Change: `@batch` decorator raises a `ValueError` if any of the outputs have a different batch size than expected.
57
+ - fix: gevent resources leak in ``FuturesModelClient``
58
+
59
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
60
+
61
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)
62
+
63
+ ## 0.3.1 (2023-09-26)
64
+
65
+ - Change: `KeyboardInterrupt` is now handled in `triton.serve()`. PyTriton hosting scripts return an exit code of 0 instead of 130 when they receive a SIGINT signal.
66
+ - Fix: Addressed potential instability in shared memory management.
67
+
68
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
69
+
70
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)
71
+
72
+ ## 0.3.0 (2023-09-05)
73
+
74
+ - new: Support for multiple Python versions starting from 3.8+
75
+ - new: Added support for [decoupled models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/decoupled_models.md) enabling to support streaming models (alpha state)
76
+ - change: Upgraded Triton Inference Server binaries to version 2.36.0. Note that this Triton Inference Server requires glibc 2.35+ or a more recent version.
77
+
78
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
79
+
80
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.36.0](https://github.com/triton-inference-server/server/releases/tag/v2.36.0)
81
+
82
+
83
+ ## 0.2.5 (2023-08-24)
84
+
85
+ - new: Allow to execute multiple PyTriton instances in the same process and/or host
86
+ - fix: Invalid flags for Proxy Backend configuration passed to Triton
87
+
88
+
89
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
90
+
91
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
92
+
93
+ ## 0.2.4 (2023-08-10)
94
+
95
+ - new: Introduced `strict` flag in `Triton.bind` which enables data types and shapes validation of inference callable outputs
96
+ against model config
97
+ - new: `AsyncioModelClient` which works in FastAPI and other async frameworks
98
+ - fix: `FuturesModelClient` do not raise `gevent.exceptions.InvalidThreadUseError`
99
+ - fix: Do not throw TimeoutError if could not connect to server during model verification
100
+
101
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
102
+
103
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
104
+
105
+ ## 0.2.3 (2023-07-21)
106
+
107
+ - Improved verification of Proxy Backend environment when running under same Python interpreter
108
+ - Fixed pytriton.__version__ to represent currently installed version
109
+
110
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
111
+
112
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
113
+
114
+ ## 0.2.2 (2023-07-19)
115
+
116
+ - Added `inference_timeout_s` parameters to client classes
117
+ - Renamed `PyTritonClientUrlParseError` to `PyTritonClientInvalidUrlError`
118
+ - `ModelClient` and `FuturesModelClient` methods raise `PyTritonClientClosedError` when used after client is closed
119
+ - Pinned tritonclient dependency due to issues with tritonclient >= 2.34 on systems with glibc version lower than 2.34
120
+ - Added warning after Triton Server setup and teardown while using too verbose logging level as it may cause a significant performance drop in model inference
121
+
122
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
123
+
124
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
125
+
126
+ ## 0.2.1 (2023-06-28)
127
+
128
+ - Fixed handling `TritonConfig.cache_directory` option - the directory was always overwritten with the default value.
129
+ - Fixed tritonclient dependency - PyTriton need tritonclient supporting http headers and parameters
130
+ - Improved shared memory usage to match 64MB limit (default value for Docker, Kubernetes) reducing the initial size for PyTriton Proxy Backend.
131
+
132
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
133
+
134
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
135
+
136
+ ## 0.2.0 (2023-05-30)
137
+
138
+ - Added support for using custom HTTP/gRPC request headers and parameters.
139
+
140
+ This change breaks backward compatibility of the inference function signature.
141
+ The undecorated inference function now accepts a list of `Request` instances instead
142
+ of a list of dictionaries. The `Request` class contains data for inputs and parameters
143
+ for combined parameters and headers.
144
+
145
+ See [docs/custom_params.md](docs/custom_params.md) for further information
146
+
147
+ - Added `FuturesModelClient` which enables sending inference requests in a parallel manner.
148
+ - Added displaying documentation link after models are loaded.
149
+
150
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
151
+
152
+ - Version of [Triton Inference Server](https://github.com/triton-inference-server/) embedded in wheel: [2.33.0](https://github.com/triton-inference-server/server/releases/tag/v2.33.0)
153
+
154
+ ## 0.1.5 (2023-05-12)
155
+
156
+ - Improved `pytriton.decorators.group_by_values` function
157
+ - Modified the function to avoid calling the inference callable on each individual sample when grouping by string/bytes input
158
+ - Added `pad_fn` argument for easy padding and combining of the inference results
159
+ - Fixed Triton binaries search
160
+ - Improved Workspace management (remove workspace on shutdown)
161
+
162
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
163
+
164
+ - Version of external components used during testing:
165
+ - [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
166
+ - Other component versions depend on the used framework and Triton Inference Server containers versions.
167
+ Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
168
+ for a detailed summary.
169
+
170
+ ## 0.1.4 (2023-03-16)
171
+
172
+ - Add validation of the model name passed to Triton bind method.
173
+ - Add monkey patching of `InferenceServerClient.__del__` method to prevent unhandled exceptions.
174
+
175
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
176
+
177
+ - Version of external components used during testing:
178
+ - [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
179
+ - Other component versions depend on the used framework and Triton Inference Server containers versions.
180
+ Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
181
+ for a detailed summary.
182
+
183
+ ## 0.1.3 (2023-02-20)
184
+
185
+ - Fixed getting model config in `fill_optionals` decorator.
186
+
187
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
188
+
189
+ - Version of external components used during testing:
190
+ - [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
191
+ - Other component versions depend on the used framework and Triton Inference Server containers versions.
192
+ Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
193
+ for a detailed summary.
194
+
195
+ ## 0.1.2 (2023-02-14)
196
+
197
+ - Fixed wheel build to support installations on operating systems with glibc version 2.31 or higher.
198
+ - Updated the documentation on custom builds of the package.
199
+ - Change: TritonContext instance is shared across bound models and contains model_configs dictionary.
200
+ - Fixed support of binding multiple models that uses methods of the same class.
201
+
202
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
203
+
204
+ - Version of external components used during testing:
205
+ - [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
206
+ - Other component versions depend on the used framework and Triton Inference Server containers versions.
207
+ Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
208
+ for a detailed summary.
209
+
210
+ ## 0.1.1 (2023-01-31)
211
+
212
+ - Change: The `@first_value` decorator has been updated with new features:
213
+ - Renamed from `@first_values` to `@first_value`
214
+ - Added a `strict` flag to toggle the checking of equality of values on a single selected input of the request. Default is True
215
+ - Added a `squeeze_single_values` flag to toggle the squeezing of single value ND arrays to scalars. Default is True
216
+ - Fix: `@fill_optionals` now supports non-batching models
217
+ - Fix: `@first_value` fixed to work with optional inputs
218
+ - Fix: `@group_by_values` fixed to work with string inputs
219
+ - Fix: `@group_by_values` fixed to work per sample-wise
220
+
221
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
222
+
223
+ - Version of external components used during testing:
224
+ - [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
225
+ - Other component versions depend on the used framework and Triton Inference Server containers versions.
226
+ Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
227
+ for a detailed summary.
228
+
229
+ ## 0.1.0 (2023-01-12)
230
+
231
+ - Initial release of PyTriton
232
+
233
+ [//]: <> (put here on external component update with short summary what change or link to changelog)
234
+
235
+ - Version of external components used during testing:
236
+ - [Triton Inference Server](https://github.com/triton-inference-server/): 2.29.0
237
+ - Other component versions depend on the used framework and Triton Inference Server containers versions.
238
+ Refer to its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
239
+ for a detailed summary.
stf/stf-api-alternative/pytriton/CONTRIBUTING.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ -->
16
+
17
+ # Contributing
18
+
19
+ Contributions are welcome, and they are much appreciated! Every little
20
+ helps, and we will always give credit.
21
+
22
+ ## Types of Contributions
23
+
24
+ ### Report Bugs
25
+
26
+ Report bugs at [https://github.com/triton-inference-server/pytriton/issues](https://github.com/triton-inference-server/pytriton/issues).
27
+
28
+ When reporting a bug, please include the following information:
29
+
30
+ * Your operating system name and version.
31
+ * Any details about your local setup that might be helpful in troubleshooting.
32
+ * Detailed steps to reproduce the bug.
33
+
34
+ ### Fix Bugs
35
+
36
+ Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
37
+ wanted" is open to whoever wants to implement it.
38
+
39
+ ### Implement Features
40
+
41
+ Browse through the GitHub issues for features. Anything tagged with "enhancement" and "help wanted" is open to whoever wants to implement it.
42
+
43
+ ### Write Documentation
44
+
45
+ The PyTriton could always use more documentation, whether as part of
46
+ the official PyTriton docs, in docstrings, or even on the web in blog posts,
47
+ articles, and such.
48
+
49
+ ### Submit Feedback
50
+
51
+ The best way to send feedback is to file an issue at [https://github.com/triton-inference-server/pytriton/issues](https://github.com/triton-inference-server/pytriton/issues).
52
+
53
+ If you are proposing a feature:
54
+
55
+ * Explain in detail how it would work.
56
+ * Keep the scope as narrow as possible to make it easier to implement.
57
+
58
+ ## Sign your Work
59
+
60
+ We require that all contributors "sign-off" on their commits. This certifies that
61
+ the contribution is your original work, or you have the rights to submit it under
62
+ the same license or a compatible license.
63
+
64
+ Any contribution which contains commits that are not Signed-Off will not be accepted.
65
+
66
+ To sign off on a commit, simply use the `--signoff` (or `-s`) option when committing your changes:
67
+
68
+ ```shell
69
+ $ git commit -s -m "Add a cool feature."
70
+ ```
71
+
72
+ This will append the following to your commit message:
73
+
74
+ ```
75
+ Signed-off-by: Your Name <[email protected]>
76
+ ```
77
+
78
+ By doing this, you certify the following:
79
+
80
+ ```
81
+ Developer Certificate of Origin
82
+ Version 1.1
83
+
84
+ Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
85
+ 1 Letterman Drive
86
+ Suite D4700
87
+ San Francisco, CA, 94129
88
+
89
+ Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
90
+
91
+
92
+ Developer's Certificate of Origin 1.1
93
+
94
+ By making a contribution to this project, I certify that:
95
+
96
+ (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
97
+
98
+ (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
99
+
100
+ (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
101
+
102
+ (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
103
+ ```
104
+
105
+ ## Get Started!
106
+
107
+ ### Local Development
108
+
109
+ Ready to contribute? Here's how to set up the `PyTriton` for local development.
110
+
111
+ 1. Fork the `PyTriton` repo on GitHub.
112
+ 2. Clone your fork locally:
113
+
114
+ ```shell
115
+ $ git clone [email protected]:your_name_here/pytriton.git
116
+ ```
117
+
118
+ 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, here's how you set up your fork for local development:
119
+
120
+ ```shell
121
+ $ mkvirtualenv pytriton
122
+ $ cd pytriton/
123
+ ```
124
+
125
+ If you do not use the virtualenvwrapper package, you can initialize a virtual environment using the pure Python command:
126
+
127
+ ```shell
128
+ $ python -m venv pytriton
129
+ $ cd pytriton/
130
+ $ source bin/activate
131
+ ```
132
+
133
+ Once the virtualenv is activated, install the development dependencies:
134
+
135
+ ```shell
136
+ $ make install-dev
137
+ ```
138
+
139
+ 4. Extract Triton Server to your environment so you can debug PyTriton while serving some models on Triton:
140
+
141
+ ```shell
142
+ $ make extract-triton
143
+ ```
144
+
145
+ 5. Install pre-commit hooks:
146
+
147
+ ```shell
148
+ $ pre-commit install
149
+ ```
150
+
151
+ 6. Create a branch for local development:
152
+
153
+ ```shell
154
+ $ git checkout -b name-of-your-bugfix-or-feature
155
+ ```
156
+
157
+ Now you can make your changes locally.
158
+
159
+ 7. When you're done making changes, check that your changes pass linters and the
160
+ tests, including testing other Python versions with tox:
161
+
162
+ ```shell
163
+ $ make lint # will run, among others, flake8 and pytype linters
164
+ $ make test # will run a test on your current virtualenv
165
+ ```
166
+
167
+ To run a subset of tests:
168
+
169
+ ```shell
170
+ $ pytest tests.test_subset
171
+ ```
172
+
173
+ 8. Commit your changes and push your branch to GitHub:
174
+
175
+ ```shell
176
+ $ git add .
177
+ $ git commit -s -m "Your detailed description of your changes."
178
+ $ git push origin name-of-your-bugfix-or-feature
179
+ ```
180
+
181
+ 9. Submit a pull request through the GitHub website.
182
+
183
+ ### Pull Request Guidelines
184
+
185
+ Before you submit a pull request, check that it meets these guidelines:
186
+
187
+ 1. The pull request should include tests.
188
+ 2. If the pull request adds functionality, you should update the docs. Put your new functionality into a function with a docstring and add the feature to the list in README.md.
189
+
190
+
191
+ ## Documentation
192
+
193
+ Add/update docstrings as defined in [Google Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings).
194
+
195
+ ## Contributor License Agreement (CLA)
196
+
197
+ PyTriton requires that all contributors (or their corporate entity) send
198
+ a signed copy of the [Contributor License
199
+ Agreement](https://github.com/NVIDIA/triton-inference-server/blob/master/Triton-CCLA-v1.pdf)
200
201
+
202
+ *NOTE*: Contributors with no company affiliation can fill `N/A` in the
203
+ `Corporation Name` and `Corporation Address` fields.
stf/stf-api-alternative/pytriton/COPYRIGHT ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
stf/stf-api-alternative/pytriton/LICENSE ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
stf/stf-api-alternative/pytriton/Makefile ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ .PHONY: clean clean-build clean-tritonserver clean-pyc clean-docs clean-test docs lint test coverage release dist build-triton extract-triton install install-dev help
15
+ .DEFAULT_GOAL := help
16
+
17
+ define BROWSER_PYSCRIPT
18
+ import os, webbrowser, sys
19
+
20
+ from urllib.request import pathname2url
21
+
22
+ webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
23
+ endef
24
+ export BROWSER_PYSCRIPT
25
+
26
+ define PRINT_HELP_PYSCRIPT
27
+ import re, sys
28
+
29
+ for line in sys.stdin:
30
+ match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
31
+ if match:
32
+ target, help = match.groups()
33
+ print("%-20s %s" % (target, help))
34
+ endef
35
+ export PRINT_HELP_PYSCRIPT
36
+
37
+ BROWSER := python -c "$$BROWSER_PYSCRIPT"
38
+ PIP_INSTALL := pip install --extra-index-url https://pypi.ngc.nvidia.com
39
+ TRITONSERVER_IMAGE_VERSION = 23.10
40
+ TRITONSERVER_IMAGE_NAME = nvcr.io/nvidia/tritonserver:$(TRITONSERVER_IMAGE_VERSION)-pyt-python-py3
41
+ TRITONSERVER_OUTPUT_DIR = ${PWD}/pytriton/tritonserver
42
+ TRITONSERVER_BASENAME = pytriton
43
+ PYTRITON_IMAGE_NAME = $(TRITONSERVER_BASENAME):$(TRITONSERVER_IMAGE_VERSION)
44
+ # to set PLATFORM from outside, use: make PLATFORM=linux/arm64;
45
+ # correct values are: linux/amd64 (default), linux/arm64
46
+ PLATFORM=linux/amd64
47
+
48
+ help:
49
+ @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
50
+
51
+ clean: clean-build clean-pyc clean-test clean-tritonserver clean-docs ## remove all build, tritonserver, test, docs, coverage and Python artifacts
52
+
53
+ clean-build: ## remove build artifacts
54
+ rm -fr build/
55
+ rm -fr dist/
56
+ rm -fr .eggs/
57
+ find . -name '*.egg-info' -exec rm -fr {} +
58
+ find . -name '*.egg' -exec rm -f {} +
59
+
60
+ clean-tritonserver:
61
+ rm -fr pytriton/tritonserver
62
+
63
+ clean-pyc: ## remove Python file artifacts
64
+ find . -name '*.pyc' -exec rm -f {} +
65
+ find . -name '*.pyo' -exec rm -f {} +
66
+ find . -name '*~' -exec rm -f {} +
67
+ find . -name '__pycache__' -exec rm -fr {} +
68
+
69
+ clean-docs: ## remove test and coverage artifacts
70
+ rm -rf site
71
+
72
+ clean-test: ## remove test and coverage artifacts
73
+ rm -fr .tox/
74
+ rm -f .coverage
75
+ rm -fr htmlcov/
76
+ rm -fr .pytest_cache
77
+ rm -fr .pytype/
78
+
79
+ docs: clean-docs ## generate site
80
+ cp CHANGELOG.md docs
81
+ cp CONTRIBUTING.md docs
82
+ cp LICENSE docs/LICENSE.md
83
+ cp examples/README.md docs/examples.md
84
+ mkdocs build --clean
85
+
86
+ docs-serve: docs
87
+ mkdocs serve
88
+
89
+ lint: ## check style with pre-commit and pytype
90
+ tox -e pytype,pre-commit --develop
91
+
92
+ test: ## run tests on every Python version with tox
93
+ tox --develop --skip-missing-interpreters
94
+
95
+ coverage: ## check code coverage quickly with the default Python
96
+ coverage run --source pytriton -m pytest
97
+ coverage report -m
98
+ coverage html
99
+ $(BROWSER) htmlcov/index.html
100
+
101
+ dist: clean build-triton extract-triton ## builds source and wheel package
102
+ bash ./scripts/build_wheel.sh $(PLATFORM)
103
+ ls -lh dist
104
+ find ./dist -iname *-linux*.whl -type f -exec bash ./scripts/add_libs_to_wheel.sh $(PYTRITON_IMAGE_NAME) $(TRITONSERVER_OUTPUT_DIR) {} $(PLATFORM) \;
105
+ find ./dist -iname *-linux*.whl -type f -delete
106
+ ls -lh dist
107
+ twine check dist/*
108
+
109
+ build-triton: ## build Triton with Python Stubs
110
+ bash ./scripts/build_triton.sh $(TRITONSERVER_IMAGE_NAME) $(PYTRITON_IMAGE_NAME) $(PLATFORM)
111
+ echo "export PYTRITON_IMAGE_NAME=$(PYTRITON_IMAGE_NAME)" > .env
112
+
113
+ extract-triton: build-triton ## extract Triton binaries and libraries
114
+ # changing dst path, change also in clean-build and pyproject.toml
115
+ bash ./scripts/extract_triton.sh $(PYTRITON_IMAGE_NAME) $(TRITONSERVER_OUTPUT_DIR) $(PLATFORM)
116
+
117
+
118
+ install: clean extract-triton ## install the package to the active Python's site-packages
119
+ $(PIP_INSTALL) --upgrade pip
120
+ $(PIP_INSTALL) .
121
+
122
+ install-dev: clean-build clean-pyc
123
+ $(PIP_INSTALL) --upgrade pip
124
+ $(PIP_INSTALL) -e .[dev]
stf/stf-api-alternative/pytriton/README.md ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ -->
16
+
17
+ # PyTriton
18
+
19
+ PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments.
20
+ The library allows serving Machine Learning models directly from Python through
21
+ NVIDIA's [Triton Inference Server](https://github.com/triton-inference-server).
22
+
23
+ <!-- START doctoc generated TOC please keep comment here to allow auto update -->
24
+ <!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
25
+
26
+ - [Documentation](#documentation)
27
+ - [Feature matrix](#feature-matrix)
28
+ - [How it works?](#how-it-works)
29
+ - [Installation](#installation)
30
+ - [Prerequisites](#prerequisites)
31
+ - [Install from `pypi`](#install-from-pypi)
32
+ - [Setting Up Python Environment](#setting-up-python-environment)
33
+ - [Building binaries from source](#building-binaries-from-source)
34
+ - [Quick Start](#quick-start)
35
+ - [Architecture](#architecture)
36
+ - [Examples](#examples)
37
+ - [Streaming (alpha)](#streaming-alpha)
38
+ - [Profiling model](#profiling-model)
39
+ - [Version management](#version-management)
40
+ - [Useful Links](#useful-links)
41
+
42
+ <!-- END doctoc generated TOC please keep comment here to allow auto update -->
43
+
44
+ ## Documentation
45
+
46
+ Read how to customize the Triton Inference Server, load models, deploy on clusters, and the API reference
47
+ can be found in the [documentation](https://triton-inference-server.github.io/pytriton). The below sections provide
48
+ brief information about the product and quick start guide.
49
+
50
+ ## Feature matrix
51
+
52
+ | Feature | Description |
53
+ | ------- | ----------- |
54
+ | Native Python support | You can create any Python function and expose it as an HTTP/gRPC API. |
55
+ | Framework-agnostic | You can run any Python code with any framework of your choice, such as: PyTorch, TensorFlow, or JAX. |
56
+ | Performance optimization | You can benefit from dynamic batching, response cache, model pipelining, and GPU/CPU inference. |
57
+ | Easy installation and setup | You can use a simple and familiar interface based on Flask/FastAPI for easy installation and setup. |
58
+ | Model clients | You can access high-level model clients for HTTP/gRPC requests with configurable options and both synchronous and asynchronous API. |
59
+ | Streaming (alpha) | You can stream partial responses from a model by serving it in a decoupled mode. |
60
+
61
+ ## How it works?
62
+
63
+ In PyTriton, like in Flask or FastAPI, you can define any Python function that executes a Machine Learning model prediction and exposes
64
+ it through an HTTP/gRPC API. PyTriton installs Triton Inference Server in your environment and uses it for handling
65
+ HTTP/gRPC requests and responses. Our library provides a Python API that allows you to attach a Python function to Triton
66
+ and a communication layer to send/receive data between Triton and the function. The solution enables using the
67
+ performance features of Triton Inference Server, such as dynamic batching or response cache, without changing your model
68
+ environment. Thus, it improves the performance of running inference on GPU for models implemented in Python. The solution is
69
+ framework-agnostic and can be used along with frameworks like PyTorch, TensorFlow, or JAX.
70
+
71
+ ## Installation
72
+
73
+ We assume that you are comfortable with the Python programming language and familiar with Machine Learning models.
74
+ Using [Docker](https://www.docker.com/) is an option, but not mandatory.
75
+
76
+ The library can be installed in:
77
+
78
+ - system environment
79
+ - virtualenv
80
+ - [Docker](https://www.docker.com/) image
81
+
82
+ NVIDIA optimized Docker images for Python frameworks can be obtained from the [NVIDIA NGC Catalog](https://catalog.ngc.nvidia.com/containers).
83
+
84
+ If you want to use the Docker runtime, we recommend that you install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/overview.html) to
85
+ enable running model inference on NVIDIA GPU.
86
+
87
+ ### Prerequisites
88
+
89
+ Before installing the library, ensure that you meet the following requirements:
90
+
91
+ - An operating system with glibc >= `2.35`.
92
+ - Triton Inference Server and PyTriton have **only** been rigorously tested on Ubuntu 22.04.
93
+ - Other supported operating systems include Ubuntu Debian 11+, Rocky Linux 9+, and Red Hat Universal Base Image 9+.
94
+ - To check your glibc version, run `ldd --version`
95
+ - Python version >= `3.8`
96
+ - Use `pip >= `20.3`
97
+ - Install `libpython3.*.so` in the operating system (appropriate for Python version).
98
+
99
+ ### Install from `pypi`
100
+
101
+ The PyTriton can be installed from [pypi.org](https://pypi.org/project/nvidia-pytriton/) by running the following command:
102
+
103
+ ```shell
104
+ pip install -U nvidia-pytriton
105
+ ```
106
+
107
+ **Important**: The Triton Inference Server binary is installed as part of the PyTriton package.
108
+
109
+ More details about installation can be found in the [documentation](https://triton-inference-server.github.io/pytriton/latest/installation/).
110
+
111
+
112
+ ### Setting Up Python Environment
113
+
114
+ The PyTriton requires installation and linking `libpython3.*.so`. Read more in "[Setting Up Python Environment](https://triton-inference-server.github.io/pytriton/latest/installation#setting-up-python-environment)"
115
+ for additional information how to configure system for different Python versions.
116
+
117
+ ### Building binaries from source
118
+
119
+ The binary package can be built from the source, allowing access to unreleased hotfixes, the ability to modify the PyTriton code, and compatibility with various Triton Inference Server versions, including custom server builds.
120
+ For further information on building the PyTriton binary, refer to the [Building](https://triton-inference-server.github.io/pytriton/latest/building/) page of documentation.
121
+
122
+ ## Quick Start
123
+
124
+ The quick start presents how to run Python model in Triton Inference Server without need to change the current working
125
+ environment. In the example we are using a simple `Linear` PyTorch model.
126
+
127
+ The requirement for the example is to have installed PyTorch in your environment. You can do it running:
128
+
129
+ ```shell
130
+ pip install torch
131
+ ```
132
+
133
+ The integration of model requires to provide following elements:
134
+
135
+ - The model - framework or Python model or function that handle inference requests
136
+ - Inference callback - a lambda or function which handle the input data coming from Triton and return the result
137
+ - Python function connection with Triton Inference Server - a binding for communication between Triton and Python
138
+ callback
139
+
140
+ In the next step define the `Linear` model:
141
+
142
+ ```python
143
+ import torch
144
+
145
+ model = torch.nn.Linear(2, 3).to("cuda").eval()
146
+ ```
147
+
148
+ In the second step, create an inference callable as a function. The function obtains the HTTP/gRPC request data as an argument, which should be in the form of a NumPy array. The expected return object should also be a NumPy array. You can define an inference callable as a function that uses the `@batch` decorator from PyTriton. This decorator converts the input request into a more suitable format that can be directly passed to the model. You can read more about [decorators here](docs/decorators.md).
149
+
150
+ Example implementation:
151
+
152
+ <!--pytest-codeblocks:cont-->
153
+
154
+ ```python
155
+ import numpy as np
156
+ from pytriton.decorators import batch
157
+
158
+
159
+ @batch
160
+ def infer_fn(**inputs: np.ndarray):
161
+ (input1_batch,) = inputs.values()
162
+ input1_batch_tensor = torch.from_numpy(input1_batch).to("cuda")
163
+ output1_batch_tensor = model(input1_batch_tensor) # Calling the Python model inference
164
+ output1_batch = output1_batch_tensor.cpu().detach().numpy()
165
+ return [output1_batch]
166
+ ```
167
+
168
+ In the next step, you can create the binding between the inference callable and Triton Inference Server using the `bind` method from pyTriton. This method takes the model name, the inference callable, the inputs and outputs tensors, and an optional model configuration object.
169
+
170
+ <!--pytest-codeblocks:cont-->
171
+
172
+ ```python
173
+ from pytriton.model_config import ModelConfig, Tensor
174
+ from pytriton.triton import Triton
175
+
176
+ # Connecting inference callable with Triton Inference Server
177
+ with Triton() as triton:
178
+ # Load model into Triton Inference Server
179
+ triton.bind(
180
+ model_name="Linear",
181
+ infer_func=infer_fn,
182
+ inputs=[
183
+ Tensor(dtype=np.float32, shape=(-1,)),
184
+ ],
185
+ outputs=[
186
+ Tensor(dtype=np.float32, shape=(-1,)),
187
+ ],
188
+ config=ModelConfig(max_batch_size=128)
189
+ )
190
+ ...
191
+ ```
192
+
193
+ Finally, serve the model with the Triton Inference Server:
194
+
195
+ <!--pytest.mark.skip-->
196
+
197
+ ```python
198
+ from pytriton.triton import Triton
199
+
200
+ with Triton() as triton:
201
+ ... # Load models here
202
+ triton.serve()
203
+ ```
204
+
205
+ The `bind` method creates a connection between the Triton Inference Server and the `infer_fn`, which handles
206
+ the inference queries. The `inputs` and `outputs` describe the model inputs and outputs that are exposed in
207
+ Triton. The config field allows more parameters for model deployment.
208
+
209
+ The `serve` method is blocking, and at this point, the application waits for incoming HTTP/gRPC requests. From that
210
+ moment, the model is available under the name `Linear` in the Triton server. The inference queries can be sent to
211
+ `localhost:8000/v2/models/Linear/infer`, which are passed to the `infer_fn` function.
212
+
213
+ If you would like to use Triton in the background mode, use `run`. More about that can be found
214
+ in the [Deploying Models](https://triton-inference-server.github.io/pytriton/latest/initialization/) page.
215
+
216
+ Once the `serve` or `run` method is called on the `Triton` object, the server status can be obtained using:
217
+
218
+ <!--pytest.mark.skip-->
219
+
220
+ ```shell
221
+ curl -v localhost:8000/v2/health/live
222
+ ```
223
+
224
+ The model is loaded right after the server starts, and its status can be queried using:
225
+
226
+ <!--pytest.mark.skip-->
227
+
228
+ ```shell
229
+ curl -v localhost:8000/v2/models/Linear/ready
230
+ ```
231
+
232
+ Finally, you can send an inference query to the model:
233
+
234
+ <!--pytest.mark.skip-->
235
+
236
+ ```shell
237
+ curl -X POST \
238
+ -H "Content-Type: application/json" \
239
+ -d @input.json \
240
+ localhost:8000/v2/models/Linear/infer
241
+ ```
242
+
243
+ The `input.json` with sample query:
244
+
245
+ ```json
246
+ {
247
+ "id": "0",
248
+ "inputs": [
249
+ {
250
+ "name": "INPUT_1",
251
+ "shape": [1, 2],
252
+ "datatype": "FP32",
253
+ "parameters": {},
254
+ "data": [[-0.04281254857778549, 0.6738349795341492]]
255
+ }
256
+ ]
257
+ }
258
+ ```
259
+
260
+ Read more about the HTTP/gRPC interface in the Triton Inference Server
261
+ [documentation](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#httprest-and-grpc-protocols).
262
+
263
+ You can also validate the deployed model using a simple client that can perform inference requests:
264
+
265
+ <!--pytest.mark.skip-->
266
+
267
+ ```python
268
+ import torch
269
+ from pytriton.client import ModelClient
270
+
271
+ input1_data = torch.randn(128, 2).cpu().detach().numpy()
272
+
273
+ with ModelClient("localhost:8000", "Linear") as client:
274
+ result_dict = client.infer_batch(input1_data)
275
+
276
+ print(result_dict)
277
+ ```
278
+
279
+ The full example code can be found in [examples/linear_random_pytorch](examples/linear_random_pytorch).
280
+
281
+ You can learn more about client usage in the [Clients](https://triton-inference-server.github.io/pytriton/latest/clients/) document.
282
+
283
+ More information about running the server and models can be found
284
+ in [Deploying Models](https://triton-inference-server.github.io/pytriton/latest/initialization/) page of documentation.
285
+
286
+ ## Architecture
287
+
288
+ The diagram below presents the schema of how the Python models are served through Triton Inference Server using
289
+ PyTriton. The solution consists of two main components:
290
+
291
+ - Triton Inference Server: for exposing the HTTP/gRPC API and benefiting from performance features like dynamic batching
292
+ or response cache.
293
+ - Python Model Environment: your environment where the Python model is executed.
294
+
295
+ The Triton Inference Server binaries are provided as part of the PyTriton installation. The Triton Server is
296
+ installed in your current environment (system or container). The PyTriton controls the Triton Server process
297
+ through the `Triton Controller`.
298
+
299
+ Exposing the model through PyTriton requires the definition of an `Inference Callable` - a Python function that is
300
+ connected to Triton Inference Server and executes the model or ensemble for predictions. The integration layer binds
301
+ the `Inference Callable` to Triton Server and exposes it through the Triton HTTP/gRPC API under a provided `<model name>`. Once
302
+ the integration is done, the defined `Inference Callable` receives data sent to the HTTP/gRPC API endpoint
303
+ `v2/models/<model name>/infer`. Read more about HTTP/gRPC interface in Triton Inference Server
304
+ [documentation](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#httprest-and-grpc-protocols).
305
+
306
+ The HTTP/gRPC requests sent to `v2/models/<model name>/infer` are handled by Triton
307
+ Inference Server. The server batches requests and passes them to the `Proxy Backend`, which sends the batched requests to the appropriate
308
+ `Inference Callable`. The data is sent as a `numpy` array. Once the `Inference Callable` finishes execution of
309
+ the model prediction, the result is returned to the `Proxy Backend`, and a response is created by Triton Server.
310
+
311
+ ![High Level Design](docs/assets/hld.svg)
312
+
313
+
314
+
315
+
316
+ ## Examples
317
+
318
+ The [examples](examples) page presents various cases of serving models using PyTriton. You can find simple examples of
319
+ running PyTorch, TensorFlow2, JAX, and simple Python models. Additionally, we have prepared more advanced scenarios like online
320
+ learning, multi-node models, or deployment on Kubernetes using PyTriton. Each example contains instructions describing
321
+ how to build and run the example. Learn more about how to use PyTriton by reviewing our [examples](examples).
322
+
323
+ ### Streaming (alpha)
324
+
325
+ We introduced new alpha feature to PyTriton that allows to stream partial responses from a model. It is based on NVIDIA Triton Inference deocoupled models feature. Look at example in [examples/huggingface_dialogpt_streaming_pytorch](examples/huggingface_dialogpt_streaming_pytorch).
326
+
327
+ ### Profiling model
328
+
329
+ The [Perf Analyzer](https://github.com/triton-inference-server/client/blob/main/src/c++/perf_analyzer/README.md) can be
330
+ used to profile models served through PyTriton. We have prepared an example of
331
+ using the Perf Analyzer to profile the BART PyTorch model. The example code can be found
332
+ in [examples/perf_analyzer](examples/perf_analyzer).
333
+
334
+ ## Version management
335
+
336
+ PyTriton follows the [Semantic Versioning](https://semver.org/) scheme for versioning. Official releases can be found on [PyPI](https://pypi.org/project/nvidia-pytriton/) and [GitHub releases](https://github.com/triton-inference-server/pytriton/releases). The most up-to-date development version is available on the `main` branch, which may include hotfixes that have not yet been released through the standard channels. To install the latest development version, refer to the instructions in the
337
+ [building binaries from source](#building-binaries-from-source) section.
338
+
339
+ ## Useful Links
340
+
341
+ - [Changelog](CHANGELOG.md)
342
+ - [Known Issues](https://triton-inference-server.github.io/pytriton/latest/known_issues)
343
+ - [Contributing](CONTRIBUTING.md)
stf/stf-api-alternative/pytriton/build/lib/pytriton/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
15
+ from importlib.metadata import PackageNotFoundError, version
16
+
17
+ try:
18
+ __version__ = version("nvidia-pytriton")
19
+ except PackageNotFoundError:
20
+ # package is not installed
21
+ pass
22
+
23
+ from pytriton import (
24
+ client, # noqa: F401
25
+ model_config, # noqa: F401
26
+ triton, # noqa: F401
27
+ )
stf/stf-api-alternative/pytriton/build/lib/pytriton/__main__.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Pytriton check module."""
15
+
16
+ import logging
17
+ import os
18
+ import pathlib
19
+ import shutil
20
+ import tempfile
21
+ from typing import Optional
22
+
23
+ import typer
24
+ from typing_extensions import Annotated
25
+
26
+ from pytriton.check.add_sub import add_sub_example, add_sub_example_thread
27
+ from pytriton.check.env_checks import env_checks
28
+
29
+ warning_message = """
30
+ +---------------------------------------------------------------+
31
+ | WARNING |
32
+ +---------------------------------------------------------------+
33
+ | Command may collect sensitive information, please review the |
34
+ | log and the ZIP before sharing. |
35
+ +---------------------------------------------------------------+
36
+ """
37
+
38
+
39
+ app = typer.Typer(help="Pytriton check tool.\n\nThis tool is used to check the environment and run examples.")
40
+
41
+
42
+ class CheckEnvironment:
43
+ """Check environment class.
44
+
45
+ Args:
46
+ workspace_path: Path to workspace
47
+ name: Name of the sub_workspace
48
+ zip_results: Flag if results should be zipped
49
+ check_workspace_exist: Flag if workspace should be checked if exists
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ workspace_path: Optional[pathlib.Path],
55
+ name: str,
56
+ zip_results: bool = True,
57
+ check_workspace_exist: bool = True,
58
+ ):
59
+ """Initialize class."""
60
+ self.name = name
61
+ self._zip_results = zip_results
62
+ self._temp_workspace = None
63
+
64
+ self.logger = logging.getLogger(name)
65
+ if check_workspace_exist and workspace_path is not None and workspace_path.exists():
66
+ self.logger.error(f"Workspace path {workspace_path} already exists")
67
+ raise typer.Exit(code=1)
68
+ if workspace_path is None:
69
+ self._temp_workspace = tempfile.TemporaryDirectory(prefix="pytriton_workspace_")
70
+ workspace_path = pathlib.Path(self._temp_workspace.name)
71
+ else:
72
+ workspace_path.mkdir(parents=True, exist_ok=True)
73
+ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
74
+ self.logger.addHandler(logging.FileHandler(workspace_path / (name + "_log.txt")))
75
+ self.workspace_path = workspace_path
76
+ self.sub_workspace = workspace_path / name
77
+
78
+ def __enter__(self):
79
+ """Enter method."""
80
+ return self
81
+
82
+ def __exit__(self, exc_type, exc_val, exc_tb):
83
+ """Exit method zips results if required."""
84
+ self.zip_results()
85
+
86
+ def zip_results(self):
87
+ """Zip results."""
88
+ if self._zip_results:
89
+ if self.workspace_path.exists():
90
+ if self._temp_workspace is not None:
91
+ output_file_base = pathlib.Path(os.getcwd()) / self.workspace_path.name
92
+ else:
93
+ output_file_base = self.workspace_path
94
+ self.logger.info(f"Zipping {self.workspace_path} to {output_file_base}.zip")
95
+ shutil.make_archive(str(output_file_base.resolve()), "zip", str(self.workspace_path.resolve()))
96
+ else:
97
+ self.logger.error(f"Workspace path {self.workspace_path} does not exist")
98
+
99
+
100
+ @app.command("example-add-sub-script")
101
+ def example_add_sub_script(
102
+ workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
103
+ zip_results: Annotated[bool, typer.Option("--zip")] = True,
104
+ ):
105
+ """Run example using external script.
106
+
107
+ Args:
108
+ workspace: Workspace path that will be created to store testing output (should not exist)
109
+ zip_results: flag if output should be zipped
110
+ """
111
+ with CheckEnvironment(workspace, "example_add_sub_script", zip_results) as ce:
112
+ try:
113
+ add_sub_example_thread(ce.sub_workspace, ce.logger)
114
+ except Exception as e:
115
+ ce.logger.error(f"Error occurred in command: {e}")
116
+
117
+
118
+ @app.command("example-add-sub")
119
+ def example_add_sub(
120
+ workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
121
+ zip_results: Annotated[bool, typer.Option("--zip")] = True,
122
+ ):
123
+ """Run example.
124
+
125
+ Args:
126
+ workspace: Workspace path that will be created to store testing output (should not exist)
127
+ zip_results: flag if output should be zipped
128
+ """
129
+ with CheckEnvironment(workspace, "example_add_sub", zip_results) as ce:
130
+ try:
131
+ add_sub_example(ce.sub_workspace, ce.logger)
132
+ except Exception as e:
133
+ ce.logger.error(f"Error occurred in command: {e}")
134
+
135
+
136
+ @app.command("examples")
137
+ def examples(
138
+ workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
139
+ zip_results: Annotated[bool, typer.Option("--zip")] = True,
140
+ ):
141
+ """Run example in the same process.
142
+
143
+ Args:
144
+ workspace: Workspace path that will be created to store testing output (should not exist)
145
+ zip_results: flag if output should be zipped
146
+ """
147
+ with CheckEnvironment(workspace, "example_add_sub", zip_results) as ce:
148
+ try:
149
+ add_sub_example(ce.sub_workspace, ce.logger)
150
+ except Exception as e:
151
+ ce.logger.error(f"Error occurred in command: {e}")
152
+
153
+ with CheckEnvironment(workspace, "example_add_sub_script", zip_results, check_workspace_exist=False) as ce:
154
+ try:
155
+ add_sub_example_thread(ce.sub_workspace, ce.logger)
156
+ except Exception as e:
157
+ ce.logger.error(f"Error occurred in command: {e}")
158
+
159
+
160
+ @app.command("env")
161
+ def env_check(
162
+ workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
163
+ zip_results: Annotated[bool, typer.Option("--zip")] = True,
164
+ ):
165
+ """Run all environment checks.
166
+
167
+ It may collect sensitive system information in the log. Please review the log before sharing.
168
+
169
+ Args:
170
+ workspace: Workspace path that will be created to store testing output (should not exist)
171
+ zip_results: flag if output should be zipped
172
+ """
173
+ with CheckEnvironment(workspace, "env_checks", zip_results) as ce:
174
+ try:
175
+ env_checks(ce.logger)
176
+ except Exception as e:
177
+ ce.logger.error(f"Error occurred in command: {e}")
178
+
179
+
180
+ @app.command("check")
181
+ def check(
182
+ workspace: Annotated[Optional[pathlib.Path], typer.Option("--workspace", "-w")] = None,
183
+ zip_results: Annotated[bool, typer.Option("--zip")] = True,
184
+ ):
185
+ """Run all checks.
186
+
187
+ Args:
188
+ workspace: Workspace path that will be created to store testing output (should not exist)
189
+ zip_results: flag if output should be zipped
190
+ """
191
+ with CheckEnvironment(workspace, "all_checks", zip_results) as ce:
192
+ try:
193
+ ce.logger.info("Running all common checks")
194
+ env_check(ce.workspace_path / "env", False)
195
+ examples(ce.workspace_path / "examples", False)
196
+ except Exception as e:
197
+ ce.logger.error(f"Error occurred in command: {e}")
198
+
199
+
200
+ @app.callback(invoke_without_command=True)
201
+ def default_command(ctx: typer.Context):
202
+ """Default command."""
203
+ if ctx.invoked_subcommand is None:
204
+ check()
205
+
206
+
207
+ def main():
208
+ """Main function."""
209
+ logger = logging.getLogger("PyTriton-Check")
210
+ try:
211
+ logger.warning(warning_message)
212
+ app()
213
+ finally:
214
+ logger.warning(warning_message)
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/add_sub.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Add_sub example model for checking corectness of triton environment."""
16
+
17
+ import argparse
18
+ import logging
19
+ import pathlib
20
+ import signal
21
+ import sys
22
+
23
+ import numpy as np
24
+
25
+ from pytriton.check.utils import ScriptThread
26
+ from pytriton.client import ModelClient
27
+ from pytriton.decorators import batch
28
+ from pytriton.model_config import ModelConfig, Tensor
29
+ from pytriton.triton import Triton
30
+
31
+ logger = logging.getLogger("check.add_sub_example")
32
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
33
+ add_script_path = [sys.executable, "pytriton/check/add_sub.py"]
34
+
35
+
36
+ @batch
37
+ def _add_sub(**inputs):
38
+ a_batch, b_batch = inputs.values()
39
+ add_batch = a_batch + b_batch
40
+ sub_batch = a_batch - b_batch
41
+ return {"add": add_batch, "sub": sub_batch}
42
+
43
+
44
+ def prepare_triton(workspace: pathlib.Path):
45
+ """Prepare triton server with AddSub model."""
46
+ triton = Triton(workspace=str(workspace.resolve()))
47
+ triton.run()
48
+ logger.info("Loading AddSub model")
49
+ triton.bind(
50
+ model_name="AddSub",
51
+ infer_func=_add_sub,
52
+ inputs=[
53
+ Tensor(dtype=np.float32, shape=(-1,)),
54
+ Tensor(dtype=np.float32, shape=(-1,)),
55
+ ],
56
+ outputs=[
57
+ Tensor(name="add", dtype=np.float32, shape=(-1,)),
58
+ Tensor(name="sub", dtype=np.float32, shape=(-1,)),
59
+ ],
60
+ config=ModelConfig(max_batch_size=128),
61
+ strict=True,
62
+ )
63
+ return triton
64
+
65
+
66
+ def infer_add_sub_model():
67
+ """Infer AddSub model."""
68
+ batch_size = 2
69
+ a_batch = np.ones((batch_size, 1), dtype=np.float32)
70
+ b_batch = np.ones((batch_size, 1), dtype=np.float32)
71
+
72
+ logger.info(f"a: {a_batch.tolist()}")
73
+ logger.info(f"b: {b_batch.tolist()}")
74
+
75
+ with ModelClient("localhost", "AddSub") as client:
76
+ logger.info("Sending inference request")
77
+ result_batch = client.infer_batch(a_batch, b_batch)
78
+
79
+ for output_name, data_batch in result_batch.items():
80
+ logger.info(f"{output_name}: {data_batch.tolist()}")
81
+
82
+
83
+ def serve_triton(workspace: pathlib.Path):
84
+ """Serve triton server with AddSub model."""
85
+ triton = prepare_triton(workspace)
86
+ logger.info("Serving AddSub model")
87
+ triton.serve()
88
+
89
+
90
+ def add_sub_example_thread(workspace: pathlib.Path, logger: logging.Logger):
91
+ """Run example using external script.
92
+
93
+ Args:
94
+ workspace: Workspace path that will be created to store testing output (should not exist)
95
+ logger: logger instance
96
+ """
97
+ logger.info("Running example model using external script")
98
+
99
+ with ScriptThread(add_script_path + ["--workspace", str(workspace.resolve())], name="server") as server_thread:
100
+ import time
101
+
102
+ time.sleep(3)
103
+ infer_add_sub_model()
104
+
105
+ if server_thread.process:
106
+ server_thread.process.send_signal(signal.SIGINT)
107
+
108
+ server_thread.join()
109
+ logger.error(server_thread.output)
110
+ if server_thread.returncode not in [
111
+ 0,
112
+ -2,
113
+ ]:
114
+ logger.error(f"Server failed - return code {server_thread.returncode}")
115
+
116
+
117
+ def add_sub_example(workspace: pathlib.Path, logger: logging.Logger):
118
+ """Run example in the same process.
119
+
120
+ Args:
121
+ workspace: Workspace path that will be created to store testing output (should not exist)
122
+ logger: logger instance
123
+ """
124
+ logger.info("Running example model")
125
+ triton = prepare_triton(workspace)
126
+ infer_add_sub_model()
127
+ triton.stop()
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument("--workspace", help="Workspace path", type=str)
133
+ parser.add_argument("--infer", default=False, help="Infer AddSub model", action="store_true")
134
+ args = parser.parse_args()
135
+
136
+ if args.infer:
137
+ infer_add_sub_model()
138
+ else:
139
+ serve_triton(pathlib.Path(args.workspace))
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/env_checks.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Environment checks."""
15
+
16
+ import logging
17
+ import os
18
+ import pathlib
19
+ import platform
20
+ import re
21
+ import sys
22
+
23
+ import psutil
24
+
25
+ from pytriton.check.utils import ScriptThread
26
+
27
+
28
+ def nvidia_smi(logger):
29
+ """Run nvidia-smi.
30
+
31
+ Args:
32
+ logger: logger instance
33
+ """
34
+ logger.info("Running nvidia-smi")
35
+ with ScriptThread(["nvidia-smi"], name="nvidia-smi") as nvidia_smi_thread:
36
+ nvidia_smi_thread.join()
37
+ logger.info(nvidia_smi_thread.output)
38
+ if nvidia_smi_thread.returncode != 0:
39
+ logger.error("nvidia-smi failed - possible cause: no GPU available or driver not installed")
40
+ logger.error(
41
+ "If running in WSL wit sudo, make sure to add nvidia-smi folder (e.g. /usr/lib/wsl/lib) to sudoers file!"
42
+ )
43
+
44
+
45
+ def get_platform_info(logger):
46
+ """Get platform information (OS, python, etc.).
47
+
48
+ Args:
49
+ logger: logger instance
50
+ """
51
+ logger.info("Checking OS version")
52
+ logger.info("Script is running in docker:" + str(pathlib.Path("/.dockerenv").exists()))
53
+
54
+ os_release_path = pathlib.Path("/etc/os-release")
55
+ if os_release_path.exists():
56
+ with os_release_path.open() as f:
57
+ os_release = f.read()
58
+ logger.info("OS release")
59
+ logger.info(os_release)
60
+ for line in os_release.split("\n"):
61
+ if "PRETTY_NAME" in line:
62
+ os_version = line.split("=")[1].strip()
63
+ logger.info(f"OS version: {os_version}")
64
+ else:
65
+ logger.warning("OS release file not found (not available on some systems")
66
+
67
+ logger.info("Get platform info")
68
+ logger.info(f"Platform: {platform.platform()}")
69
+ logger.info(f"System: {platform.system()}")
70
+ logger.info(f"Release: {platform.release()}")
71
+ logger.info(f"Version: {platform.version()}")
72
+ logger.info(f"Machine: {platform.machine()}")
73
+ logger.info(f"Processor: {platform.processor()}")
74
+ logger.info(f"Python version: {platform.python_version()}")
75
+ logger.info(f"Python implementation: {platform.python_implementation()}")
76
+ logger.info(f"Python compiler: {platform.python_compiler()}")
77
+ logger.info(f"Python build: {platform.python_build()}")
78
+ logger.info(f"libc_ver: {platform.libc_ver()}")
79
+
80
+
81
+ def check_psutil_stats(logger):
82
+ """Check psutil stats.
83
+
84
+ Args:
85
+ logger: logger instance
86
+ """
87
+ logger.info("Checking psutil stats")
88
+ logger.info("Memory stats")
89
+ logger.info(psutil.virtual_memory())
90
+ logger.info("Swap stats")
91
+ logger.info(psutil.swap_memory())
92
+ logger.info("Disk stats")
93
+ logger.info(psutil.disk_usage("/"))
94
+ logger.info("Disk io countwers")
95
+ logger.info(psutil.disk_io_counters())
96
+ logger.info("CPU stats")
97
+ logger.info(psutil.cpu_times())
98
+ logger.info("Network stats")
99
+ logger.info(psutil.net_io_counters())
100
+
101
+
102
+ def get_listening_processes(logger):
103
+ """Get listening processes.
104
+
105
+ Args:
106
+ logger: logger instance
107
+ """
108
+ logger.info("Listening processes")
109
+ processes = {proc.pid: proc.name for proc in psutil.process_iter(["pid", "name"])}
110
+ connections = psutil.net_connections()
111
+ listening_sockets = [conn for conn in connections if conn.status == "LISTEN"]
112
+
113
+ for listening_socket in listening_sockets:
114
+ process_name = None
115
+ if listening_socket.pid is not None and listening_socket.pid in processes:
116
+ process_name = processes[listening_socket.pid]
117
+ logger.info(
118
+ f"Process ID: {listening_socket.pid}, Name: {process_name}, Local Address: {listening_socket.laddr}, Remote Address: {listening_socket.raddr}, Status: {listening_socket.status}"
119
+ )
120
+
121
+
122
+ def installed_packages(logger):
123
+ """Get installed packages.
124
+
125
+ Args:
126
+ logger: logger instance
127
+ """
128
+ logger.info("Checking installed packages")
129
+ import importlib_metadata
130
+
131
+ packages = importlib_metadata.distributions()
132
+
133
+ installed_pkg = sorted([f"{package.metadata['Name']}=={package.version} ({package._path})" for package in packages])
134
+ installed_pkg_str = "\n[\n\t" + ",\n\t".join(installed_pkg) + "\n]"
135
+ logger.info(installed_pkg_str)
136
+
137
+
138
+ def check_compiler_and_clib(logger):
139
+ """Check compiler and C libraries.
140
+
141
+ Args:
142
+ logger: logger instance
143
+ """
144
+ logger.info("Checking compiler and C libraries")
145
+ with ScriptThread(["gcc", "--version"], name="gcc_version") as gcc_version_thread:
146
+ gcc_version_thread.join()
147
+ logger.info("GCC version:")
148
+ logger.info(gcc_version_thread.output)
149
+ if gcc_version_thread.returncode != 0:
150
+ logger.error("gcc failed")
151
+
152
+ logger.info("Python version:")
153
+ logger.info(sys.version)
154
+
155
+ try:
156
+ logger.info(os.confstr("CS_GNU_LIBC_VERSION"))
157
+ except AttributeError as e:
158
+ logger.error(f"Failed to get glibc version {e}")
159
+
160
+
161
+ def log_env_variables(logger):
162
+ """Log environment variables.
163
+
164
+ Args:
165
+ logger: logger instance
166
+ """
167
+ logger.info("Environment variables")
168
+
169
+ env_vars = os.environ.items()
170
+ blacklist_patterns = [
171
+ r".*token.*",
172
+ r".*secret.*",
173
+ r".*key.*",
174
+ r".*password.*",
175
+ ]
176
+
177
+ patterns = [re.compile(pattern, re.IGNORECASE) for pattern in blacklist_patterns]
178
+ filtered_env_vars = [
179
+ f"{key}={value}"
180
+ for key, value in env_vars
181
+ if not any(pattern.search(key) or pattern.search(value) for pattern in patterns)
182
+ ]
183
+
184
+ env_vars_str = "\n".join(filtered_env_vars)
185
+ logger.info(env_vars_str)
186
+
187
+
188
+ def env_checks(logger: logging.Logger):
189
+ """Run all environment checks.
190
+
191
+ Args:
192
+ logger: logger instance
193
+ """
194
+ logger.info("Running all environment checks")
195
+ get_platform_info(logger)
196
+ nvidia_smi(logger)
197
+ installed_packages(logger)
198
+ check_psutil_stats(logger)
199
+ get_listening_processes(logger)
200
+ check_compiler_and_clib(logger)
201
+ log_env_variables(logger)
stf/stf-api-alternative/pytriton/build/lib/pytriton/check/utils.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utils."""
15
+
16
+ import contextlib
17
+ import fcntl
18
+ import logging
19
+ import os
20
+ import pathlib
21
+ import re
22
+ import select
23
+ import socket
24
+ import subprocess
25
+ import threading
26
+ import typing
27
+
28
+ LOGGER = logging.getLogger(__name__)
29
+ DEFAULT_LOG_FORMAT = "%(asctime)s - %(levelname)8s - %(process)8d - %(threadName)s - %(name)s: %(message)s"
30
+
31
+
32
+ def _read_outputs(_process, _logger, _outputs):
33
+ # Set stdout and stderr file descriptors to non-blocking mode
34
+ try:
35
+ fcntl.fcntl(_process.stdout, fcntl.F_SETFL, os.O_NONBLOCK)
36
+ fcntl.fcntl(_process.stderr, fcntl.F_SETFL, os.O_NONBLOCK)
37
+ except ValueError: # when selecting on closed files
38
+ return
39
+
40
+ buffers = {_process.stdout: "", _process.stderr: ""}
41
+ rds = [_process.stdout, _process.stderr]
42
+ while rds:
43
+ try:
44
+ readable, _, _ = select.select(rds, [], [], 1)
45
+ except ValueError: # when selecting on closed files
46
+ break
47
+
48
+ for rd in readable:
49
+ try:
50
+ data = os.read(rd.fileno(), 4096)
51
+ if not data:
52
+ rds.remove(rd)
53
+ continue
54
+
55
+ decoded_data = data.decode("utf-8")
56
+ buffers[rd] += decoded_data
57
+ lines = buffers[rd].splitlines(keepends=True)
58
+
59
+ if buffers[rd].endswith("\n"):
60
+ complete_lines = lines
61
+ buffers[rd] = ""
62
+ else:
63
+ complete_lines = lines[:-1]
64
+ buffers[rd] = lines[-1]
65
+
66
+ for line in complete_lines:
67
+ line = line.rstrip()
68
+ _logger.info(line)
69
+ _outputs.append(line)
70
+ except OSError: # Reading from an empty non-blocking file
71
+ pass
72
+
73
+
74
+ class ScriptThread(threading.Thread):
75
+ """A class that runs external script in a separate thread."""
76
+
77
+ def __init__(self, cmd, workdir=None, group=None, target=None, name=None, args=(), kwargs=None) -> None:
78
+ """Initializes the ScriptThread object."""
79
+ super().__init__(group, target, name, args, kwargs, daemon=True)
80
+ self.cmd = cmd
81
+ self.workdir = workdir
82
+ self._process_spawned_or_spawn_error_flag = None
83
+ self.active = False
84
+ self._process = None
85
+ self.returncode = None
86
+ self._output = []
87
+ self._logger = logging.getLogger(self.name)
88
+
89
+ def __enter__(self):
90
+ """Starts the script thread."""
91
+ self.start(threading.Event())
92
+ self._process_spawned_or_spawn_error_flag.wait()
93
+ return self
94
+
95
+ def __exit__(self, *args):
96
+ """Stops the script thread and waits for it to join."""
97
+ self.stop()
98
+ self.join()
99
+ self._process_spawned_or_spawn_error_flag = None
100
+
101
+ def start(self, flag: typing.Optional[threading.Event] = None) -> None:
102
+ """Starts the script thread."""
103
+ if flag is None:
104
+ flag = threading.Event()
105
+ self._logger.info(f"Starting {self.name} script with \"{' '.join(self.cmd)}\" cmd")
106
+ self._process_spawned_or_spawn_error_flag = flag
107
+ super().start()
108
+
109
+ def stop(self):
110
+ """Sets the active flag to False to stop the script thread."""
111
+ self._logger.info(f"Stopping {self.name} script")
112
+ self.active = False
113
+
114
+ def run(self):
115
+ """Runs the script in a separate process."""
116
+ import psutil
117
+
118
+ self.returncode = None
119
+ self._output = []
120
+ self._process = None
121
+
122
+ os.environ.setdefault("PYTHONUNBUFFERED", "1") # to not buffer logs
123
+ try:
124
+ with psutil.Popen(
125
+ self.cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0, cwd=self.workdir
126
+ ) as process:
127
+ self._process = process
128
+ self.active = True
129
+ if self._process_spawned_or_spawn_error_flag:
130
+ self._process_spawned_or_spawn_error_flag.set()
131
+ while self.active and process.poll() is None and process.returncode is None:
132
+ try:
133
+ _read_outputs(process, self._logger, self._output)
134
+ except KeyboardInterrupt:
135
+ self.stop()
136
+
137
+ finally:
138
+ if self._process_spawned_or_spawn_error_flag:
139
+ self._process_spawned_or_spawn_error_flag.set()
140
+ if self.process:
141
+ while self.process.poll() is None:
142
+ _read_outputs(self.process, self._logger, self._output)
143
+ _read_outputs(self.process, self._logger, self._output)
144
+ self.returncode = process.wait() # pytype: disable=name-error
145
+ self._logger.info(f"{self.name} process finished with {self.returncode}")
146
+
147
+ self.active = False
148
+ self._process = None
149
+
150
+ @property
151
+ def output(self):
152
+ """Return process stream output."""
153
+ return "\n".join(self._output)
154
+
155
+ @property
156
+ def process(self):
157
+ """Return process object."""
158
+ return self._process
159
+
160
+
161
+ def find_free_port() -> int:
162
+ """Finds a free port on the local machine."""
163
+ with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
164
+ s.bind(("", 0))
165
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
166
+ return s.getsockname()[1]
167
+
168
+
169
+ class ProcessMonitoring:
170
+ """A class that dumps the state of a process and its children.
171
+
172
+ This class uses the py-spy tool to dump the stack trace of a process and its
173
+ children recursively. It also dumps the process information such as the parent
174
+ and the command line. It allows registering custom monitors that can perform
175
+ additional actions on the process.
176
+
177
+ Attributes:
178
+ _logger (logging.Logger): The logger object to write messages.
179
+ _process (psutil.Process): The process object to monitor.
180
+ _children_processes (list[psutil.Process]): The list of child processes to monitor.
181
+ _log (logging.Logger.method): The logging method to use for messages.
182
+ _remove_color (bool): Whether to remove ANSI escape sequences from the output.
183
+ _ansi_escape (re.Pattern): The regular expression object to match ANSI escape sequences.
184
+ _custom_monitors (list[typing.Callable[[int], None]]): The list of custom monitor functions to execute on each dump cycle.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ pid: int,
190
+ logger: typing.Optional[logging.Logger] = None,
191
+ loglevel: int = logging.INFO,
192
+ remove_color: bool = False,
193
+ ):
194
+ """Initializes the ProcessMonitoring object.
195
+
196
+ Args:
197
+ pid (int): The process ID of the process to monitor.
198
+ logger (typing.Optional[logging.Logger], optional): The logger object to write messages. Defaults to None.
199
+ loglevel (int, optional): The logging level to use for messages. Defaults to logging.INFO.
200
+ remove_color (bool, optional): Whether to remove ANSI escape sequences from the output. Defaults to False.
201
+ """
202
+ import re
203
+
204
+ import psutil
205
+
206
+ self._logger = logger or logging.getLogger("monitoring")
207
+ self._process = psutil.Process(pid)
208
+ self._children_processes = list(self._process.children(recursive=True))
209
+ self._log = {
210
+ logging.DEBUG: self._logger.debug,
211
+ logging.INFO: self._logger.info,
212
+ logging.WARNING: self._logger.warning,
213
+ logging.ERROR: self._logger.error,
214
+ }[loglevel]
215
+ self._log(f"Initial list of children processes: {self._children_processes}")
216
+ self._remove_color = remove_color
217
+ pattern = r"\x1b\[.*?m"
218
+ self._ansi_escape = re.compile(pattern)
219
+ self._custom_monitors = []
220
+
221
+ def register_custom_monitor(self, custom_monitor: typing.Callable[[int], None]) -> None:
222
+ """Registers a custom monitor for the process.
223
+
224
+ This method adds a custom monitor function to the list of monitors that are
225
+ executed on each dump cycle. A custom monitor function should take an integer
226
+ as an argument (the process ID) and return None.
227
+
228
+ Args:
229
+ custom_monitor (typing.Callable[[int], None]): The custom monitor function to register.
230
+ """
231
+ self._custom_monitors.append(custom_monitor)
232
+
233
+ def dump_state(self) -> None:
234
+ """Dumps the state of the process and its children.
235
+
236
+ This method calls the _dump_processes_stacktrace and _dump_child_processes
237
+ methods to dump the stack trace and the process information of the process
238
+ and its children recursively.
239
+ """
240
+ self._dump_processes_stacktrace()
241
+ self._dump_child_processes()
242
+
243
+ def _dump_processes_stacktrace(self):
244
+ import psutil
245
+ import sh
246
+
247
+ self._log("==== Dump process stacktrace")
248
+ pyspy_cmd = sh.Command("py-spy")
249
+
250
+ for process in [self._process] + self.children:
251
+ try:
252
+ result = pyspy_cmd("dump", "-ll", "--nonblocking", "-p", str(process.pid))
253
+ if self._remove_color:
254
+ result = self._ansi_escape.sub("", str(result))
255
+ self._log(f"Dump stack trace for process (pid={process.pid}) with cmd {process.cmdline()}")
256
+ for custom_monitor in self._custom_monitors:
257
+ custom_monitor(process.pid)
258
+ self._log(result)
259
+ except psutil.NoSuchProcess as e:
260
+ self._log(f"Error during handling process: {e}")
261
+ except sh.ErrorReturnCode_1 as e:
262
+ self._log(f"Error during calling py-spy process: {e}")
263
+
264
+ def _dump_child_processes(self):
265
+ import psutil
266
+
267
+ self._log("==== Dump process info (with its children)")
268
+ for process in [self._process] + self.children:
269
+ try:
270
+ self._log(f"{process} parent={process.parent()} ")
271
+ except psutil.NoSuchProcess:
272
+ self._log(f"{process} is missing in process table")
273
+
274
+ @property
275
+ def children(self):
276
+ """Returns the list of child processes to monitor.
277
+
278
+ This property returns the list of child processes to monitor, and updates it
279
+ with any new children that are created by the process.
280
+
281
+ Returns:
282
+ list[psutil.Process]: The list of child processes to monitor.
283
+ """
284
+ import psutil
285
+
286
+ try:
287
+ children = list(self._process.children(recursive=True))
288
+ self._children_processes = list(set(self._children_processes + children))
289
+ except psutil.NoSuchProcess:
290
+ pass
291
+ return self._children_processes
292
+
293
+
294
+ def get_current_container_version():
295
+ """Returns the version of the current container."""
296
+ container_version = os.environ.get("NVIDIA_PYTORCH_VERSION") or os.environ.get("NVIDIA_TENSORFLOW_VERSION")
297
+ if container_version and "-" in container_version:
298
+ container_version = container_version.split("-")[0] # TF version has format <year_month_version>-<tf_version>
299
+ return container_version
300
+
301
+
302
+ def verify_docker_image_in_readme_same_as_tested(readme_path, image_name_with_version):
303
+ """Verify that the docker image is the same as described in the readme file."""
304
+ image_name, _image_version = image_name_with_version.split(":")
305
+ framework_name = image_name.split("/")[-1]
306
+ readme_payload = pathlib.Path(readme_path).read_text()
307
+ match_iterator = re.finditer(
308
+ rf"(?P<container_registry>[\w/.\-:]+)/{framework_name}:(?P<image_version_with_python_version>[\w.-]+)",
309
+ readme_payload,
310
+ )
311
+ for entry in match_iterator:
312
+ assert entry.group() == image_name_with_version, f"{entry.group()} != {image_name_with_version}"
313
+
314
+
315
+ def search_warning_on_too_verbose_log_level(logs: str):
316
+ """Search warnings."""
317
+ pattern = r"Triton Inference Server is running with enabled verbose logs.*It may affect inference performance."
318
+ return re.search(pattern, logs)
319
+
320
+
321
+ class ProcessMonitoringThread:
322
+ """A class that creates a thread to monitor a process.
323
+
324
+ This class uses the ProcessMonitoring class to dump the state of a process
325
+ and its children periodically. It also allows registering custom monitors
326
+ that can perform additional actions on the process.
327
+
328
+ Attributes:
329
+ _monitoring (ProcessMonitoring): The ProcessMonitoring object that handles the dumping logic.
330
+ _stop_event (threading.Event): The event object that signals the thread to stop its loop.
331
+ _thread (threading.Thread): The thread object that runs the _run method in a loop.
332
+ _interval (float): The interval in seconds between each dump cycle.
333
+ """
334
+
335
+ def __init__(self, monitoring: ProcessMonitoring, interval: float = 60):
336
+ """Initializes the ProcessMonitoringThread object.
337
+
338
+ Args:
339
+ monitoring (ProcessMonitoring): The ProcessMonitoring object that handles the dumping logic.
340
+ interval (float, optional): The interval in seconds between each dump cycle. Defaults to 60.
341
+ """
342
+ self._monitoring = monitoring
343
+ self._interval = interval
344
+
345
+ def start(self) -> None:
346
+ """Starts the monitoring thread.
347
+
348
+ This method creates a new thread that runs the _run method in a loop until
349
+ the stop method is called or an exception occurs. It also sets the stop event
350
+ object that can be used to signal the thread to stop gracefully.
351
+ """
352
+ self._stop_event = threading.Event()
353
+ self._thread = threading.Thread(target=self._run, daemon=True)
354
+ self._thread.start()
355
+
356
+ def stop(self) -> None:
357
+ """Stops the monitoring thread.
358
+
359
+ This method sets the stop event object that signals the thread to stop its loop.
360
+ It also waits for the thread to join before returning.
361
+ """
362
+ self._stop_event.set()
363
+ self._thread.join()
364
+
365
+ def __enter__(self):
366
+ """Enters the context manager for the monitoring thread."""
367
+ self.start()
368
+ return self
369
+
370
+ def __exit__(self, *args):
371
+ """Exits the context manager for the monitoring thread."""
372
+ self.stop()
373
+
374
+ def _run(self):
375
+ logging.info("Monitoring process")
376
+ self._monitoring.dump_state()
377
+ while not self._stop_event.wait(self._interval):
378
+ logging.info("Monitoring process")
379
+ self._monitoring.dump_state()
380
+
381
+
382
+ class TestMonitoringContext:
383
+ """A context manager that monitors test processes.
384
+
385
+ This context manager creates threads to monitor the test processes and dumps
386
+ their state periodically. It can extend argparse args with additional arguments.
387
+ It supports splitting log into different files. The standard output log can have one level
388
+ and the file log can have another level. It uses log rotation.
389
+ """
390
+
391
+ @staticmethod
392
+ def extend_args(parser):
393
+ """Extends argparse args with additional arguments."""
394
+ parser.add_argument(
395
+ "--verbose",
396
+ action="store_true",
397
+ help="Provide verbose logs",
398
+ )
399
+ parser.add_argument(
400
+ "--log-path",
401
+ type=str,
402
+ default=None,
403
+ help="Provide the path of external log for rotation",
404
+ )
405
+ parser.add_argument(
406
+ "--compress-logs",
407
+ action="store_true",
408
+ help="Enable logs compression",
409
+ )
410
+ parser.add_argument(
411
+ "--maximum-log-file",
412
+ type=int,
413
+ default=10 * 1024 * 1024,
414
+ help="Maximum logfile size before rotation is started",
415
+ required=False,
416
+ )
417
+ parser.add_argument(
418
+ "--enable-fault-handler",
419
+ action="store_true",
420
+ help="Enable faulthandler",
421
+ )
422
+ parser.add_argument(
423
+ "--faulthandler-interval",
424
+ type=float,
425
+ default=None,
426
+ help="Enable faulthandler after specified number of seconds with repeat",
427
+ required=False,
428
+ )
429
+ parser.add_argument(
430
+ "--process-monitoring-interval",
431
+ type=float,
432
+ default=None,
433
+ help="Enable process monitoring after specified number of seconds with repeat",
434
+ required=False,
435
+ )
436
+
437
+ def __init__(self, args):
438
+ """Initializes the TestMonitoringContext object.
439
+
440
+ Args:
441
+ args (argparse.Namespace): The argparse args object to extend with additional arguments.
442
+ """
443
+ self._args = args
444
+
445
+ def __enter__(self):
446
+ """Enters the context manager for the test monitoring."""
447
+ import faulthandler
448
+ import logging.handlers
449
+
450
+ args = self._args
451
+ self._loglevel = log_level = logging.DEBUG if args.verbose else logging.INFO
452
+ logging.basicConfig(level=logging.DEBUG, format=DEFAULT_LOG_FORMAT)
453
+ logger = logging.getLogger()
454
+
455
+ if args.log_path is not None:
456
+ # Create a rotating file handler for the file output logger
457
+ # The file name is based on the log path argument, the maximum size is 10 MB, and the maximum number of files is 500
458
+ file_handler = logging.handlers.RotatingFileHandler(
459
+ args.log_path, maxBytes=args.maximum_log_file, backupCount=500
460
+ )
461
+ file_handler.setFormatter(logging.Formatter(DEFAULT_LOG_FORMAT))
462
+ file_handler.setLevel(logging.DEBUG)
463
+ if args.compress_logs:
464
+ file_handler.namer = lambda name: name + ".gz"
465
+
466
+ def gzip_rotation(source, dest):
467
+ import gzip
468
+ import os
469
+
470
+ with open(source, "rb") as f_in:
471
+ with gzip.open(dest, "wb") as f_out:
472
+ f_out.writelines(f_in)
473
+ os.remove(source)
474
+
475
+ file_handler.rotator = gzip_rotation
476
+
477
+ # Add the file handler to the default logger
478
+ logger.addHandler(file_handler)
479
+ # Get the stream handler that was created by basicConfig
480
+
481
+ # Get the stream handler that was created by basicConfig
482
+ stream_handler = logger.handlers[0]
483
+ # Set the stream handler's level to match the log level argument
484
+ stream_handler.setLevel(log_level)
485
+
486
+ if args.enable_fault_handler:
487
+ faulthandler.enable()
488
+
489
+ if args.faulthandler_interval is not None:
490
+ faulthandler.dump_traceback_later(args.faulthandler_interval, repeat=True, exit=False)
491
+
492
+ custom_monitors = []
493
+
494
+ import os
495
+
496
+ import psutil
497
+
498
+ def monitor_ram_usage(pid=None):
499
+ if pid is None:
500
+ pid = os.getpid()
501
+
502
+ process = psutil.Process(pid)
503
+ logger.debug(f"MONITOR RAM USAGE ({pid}): {process.memory_info()}")
504
+
505
+ custom_monitors.append(monitor_ram_usage)
506
+
507
+ def monitor_file_descriptors(pid=None):
508
+ if pid is None:
509
+ pid = os.getpid()
510
+
511
+ process = psutil.Process(pid)
512
+ logger.debug(f"MONITOR FILE DESCRIPTORS ({pid}): {process.num_fds()}")
513
+
514
+ custom_monitors.append(monitor_file_descriptors)
515
+
516
+ def monitor_cpu_usage(pid=None):
517
+ if pid is None:
518
+ pid = os.getpid()
519
+
520
+ process = psutil.Process(pid)
521
+ logger.debug(f"MONITOR CPU USAGE ({pid}): {process.cpu_percent()}")
522
+
523
+ custom_monitors.append(monitor_cpu_usage)
524
+
525
+ def monitor_threads(pid=None):
526
+ if pid is None:
527
+ pid = os.getpid()
528
+
529
+ process = psutil.Process(pid)
530
+ logger.debug(f"MONITOR THREADS ({pid}): {process.num_threads()}")
531
+
532
+ custom_monitors.append(monitor_threads)
533
+
534
+ def monitor_process_dict(pid=None):
535
+ if pid is None:
536
+ pid = os.getpid()
537
+
538
+ process = psutil.Process(pid)
539
+ logger.debug(f"MONITOR PROCESS DICT ({pid}): {process.as_dict()}")
540
+
541
+ custom_monitors.append(monitor_process_dict)
542
+ if args.process_monitoring_interval is not None:
543
+ monitoring = ProcessMonitoring(os.getpid(), logger, loglevel=logging.DEBUG, remove_color=True)
544
+ for monitor in custom_monitors:
545
+ monitoring.register_custom_monitor(monitor)
546
+
547
+ self._monitor = ProcessMonitoringThread(monitoring, interval=args.process_monitoring_interval)
548
+ self._monitor.start()
549
+ return self
550
+
551
+ def __exit__(self, *args):
552
+ """Stops the monitor thread."""
553
+ if hasattr(self, "_monitor"):
554
+ self._monitor.stop()
555
+ self._monitor = None
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
15
+
16
+ from .client import (
17
+ AsyncioDecoupledModelClient, # noqa: F401
18
+ AsyncioModelClient, # noqa: F401
19
+ DecoupledModelClient, # noqa: F401
20
+ FuturesModelClient, # noqa: F401
21
+ ModelClient, # noqa: F401
22
+ )
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/asyncio_utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utility module supporting model clients."""
15
+
16
+ import asyncio
17
+ import logging
18
+ import time
19
+ from typing import Optional, Union
20
+
21
+ import aiohttp
22
+ import grpc
23
+ import tritonclient.grpc
24
+ import tritonclient.http
25
+
26
+ from pytriton.client.exceptions import PyTritonClientModelUnavailableError, PyTritonClientTimeoutError
27
+ from pytriton.client.utils import LATEST_MODEL_VERSION, ModelState, parse_grpc_response, parse_http_response
28
+ from pytriton.model_config.parser import ModelConfigParser
29
+
30
+ aio_clients = Union[tritonclient.grpc.aio.InferenceServerClient, tritonclient.http.aio.InferenceServerClient]
31
+
32
+ _LOGGER = logging.getLogger(__name__)
33
+
34
+ _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S = 60.0 # 60 seconds
35
+ _DEFAULT_ASYNC_SLEEP_FACTOR_S = 0.1 # 10% of timeout
36
+
37
+
38
+ async def asyncio_get_model_state(
39
+ client: aio_clients,
40
+ model_name: str,
41
+ model_version: Optional[str] = None,
42
+ ) -> ModelState:
43
+ """Obtains state of the model deployed in Triton Inference Server.
44
+
45
+ Typical use:
46
+
47
+ >>> import tritonclient.http.aio
48
+ ... client = tritonclient.http.aio.InferenceServerClient("localhost:8000")
49
+ ... model_state = await get_model_state(client, "MyModel", "1")
50
+
51
+ Args:
52
+ client: Triton Inference Server client to use for communication
53
+ model_name: name of the model which state we're requesting.
54
+ model_version:
55
+ version of the model which state we're requesting.
56
+ If model_version is None state of latest model is returned.
57
+ The latest versions of the model are the numerically greatest version numbers.
58
+
59
+ Returns:
60
+ Model state. ModelState.UNAVAILABLE is returned in case if model with given name and version is not found.
61
+
62
+ """
63
+ _LOGGER.debug(f"Obtaining model {model_name} state")
64
+ repository_index = await client.get_model_repository_index()
65
+ _LOGGER.debug("Model repository index obtained")
66
+ if isinstance(repository_index, list):
67
+ models_states = parse_http_response(models=repository_index)
68
+ else:
69
+ models_states = parse_grpc_response(models=repository_index.models)
70
+
71
+ if model_version is None:
72
+ requested_model_states = {
73
+ version: state for (name, version), state in models_states.items() if name == model_name
74
+ }
75
+ if not requested_model_states:
76
+ return ModelState.UNAVAILABLE
77
+ else:
78
+ requested_model_states = sorted(requested_model_states.items(), key=lambda item: int(item[0]))
79
+ latest_version, latest_version_state = requested_model_states[-1]
80
+ _LOGGER.debug(f"Model {model_name} latest version: {latest_version} state: {latest_version_state}")
81
+ return latest_version_state
82
+ else:
83
+ key = (model_name, model_version)
84
+ if key not in models_states:
85
+ return ModelState.UNAVAILABLE
86
+ else:
87
+ model_state = models_states[key]
88
+ _LOGGER.debug(f"Model {model_name} version {model_version} state: {model_state}")
89
+ return model_state
90
+
91
+
92
+ async def asyncio_get_model_config(
93
+ client: aio_clients,
94
+ model_name: str,
95
+ model_version: Optional[str] = None,
96
+ timeout_s: float = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S,
97
+ ):
98
+ """Obtain configuration of model deployed on the Triton Inference Server.
99
+
100
+ Function waits for server readiness.
101
+
102
+ Args:
103
+ client: Triton Inference Server client to use for communication
104
+ model_name: name of the model which configuration we're requesting.
105
+ model_version:
106
+ version of the model which configuration we're requesting.
107
+ If model_version is None configuration of the latest model is returned.
108
+ The latest versions of the model are the numerically greatest version numbers.
109
+ timeout_s: timeout to finish model configuration obtain.
110
+
111
+ Returns:
112
+ Configuration of requested model.
113
+
114
+ Raises:
115
+ PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
116
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
117
+ """
118
+ should_finish_before = time.time() + timeout_s
119
+ _LOGGER.debug(f"Obtaining model {model_name} config (timeout={timeout_s:0.2f})")
120
+ try:
121
+ _LOGGER.debug(f"Waiting for model {model_name} to be ready")
122
+ await asyncio.wait_for(
123
+ asyncio_wait_for_model_ready(
124
+ client, model_name=model_name, model_version=model_version, timeout_s=timeout_s
125
+ ),
126
+ timeout_s,
127
+ )
128
+
129
+ model_version = model_version or ""
130
+
131
+ timeout_s = max(0, should_finish_before - time.time())
132
+ if isinstance(client, tritonclient.grpc.aio.InferenceServerClient):
133
+ _LOGGER.debug(f"Obtaining model {model_name} config as_json=True")
134
+ response = await asyncio.wait_for(
135
+ client.get_model_config(model_name, model_version, as_json=True), timeout_s
136
+ )
137
+ model_config = response["config"]
138
+ else:
139
+ _LOGGER.debug(f"Obtaining model {model_name} config")
140
+ model_config = await asyncio.wait_for(client.get_model_config(model_name, model_version), timeout_s)
141
+ _LOGGER.debug("Model config obtained")
142
+ model_config = ModelConfigParser.from_dict(model_config)
143
+ _LOGGER.debug(f"Model config: {model_config}")
144
+ return model_config
145
+ except asyncio.TimeoutError as e:
146
+ message = f"Timeout while waiting for model {model_name} config (timeout={timeout_s:0.2f})"
147
+ _LOGGER.error(message)
148
+ raise PyTritonClientTimeoutError(message) from e
149
+
150
+
151
+ async def asyncio_wait_for_server_ready(
152
+ asyncio_client: aio_clients,
153
+ sleep_time_s: float,
154
+ ):
155
+ """Wait for Triton Inference Server readiness.
156
+
157
+ There are two functions, which check server status:
158
+ * asyncio_client.is_server_ready()
159
+ * asyncio_client.is_server_live()
160
+ Both must return true to consider server accessible to read model status.
161
+
162
+ Function contains while loop with sleep to check server status periodically.
163
+
164
+ Args:
165
+ asyncio_client: Triton Inference Server client to use for communication
166
+ sleep_time_s: time to sleep between server status checks
167
+
168
+ Raises:
169
+ PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
170
+ """
171
+ _LOGGER.debug("Waiting for server to be ready")
172
+ try:
173
+ while True:
174
+ try:
175
+ _LOGGER.debug("Waiting for server to be ready")
176
+ server_ready = await asyncio_client.is_server_ready()
177
+ _LOGGER.debug("Waiting for server to be live")
178
+ server_live = await asyncio_client.is_server_live()
179
+ except tritonclient.utils.InferenceServerException:
180
+ # Raised by tritonclient/grpc/__init__.py:75
181
+ server_live = False
182
+ server_ready = False
183
+ except aiohttp.client_exceptions.ClientConnectorError:
184
+ # This exception is raised by aiohttp/connector.py:901 in _create_direct_connection
185
+ # and it is not translated to any other error by tritonclient/http/aio/__init__.py:132 in _get method.
186
+ # res = await self._stub.get(url=req_url,
187
+ # and tritonclient/http/aio/__init__.py:242 in is_server_ready method.
188
+ # response = await self._get(request_uri=request_uri,
189
+ server_live = False
190
+ server_ready = False
191
+ except RuntimeError:
192
+ # This exception is raised by aiohttp/client.py:400 in _request
193
+ # and it is not translated to any other error by tritonclient/grpc/aio/__init__.py:151: in is_server_ready method.
194
+ # response = await self._client_stub.ServerReady(request=request,
195
+ server_live = False
196
+ server_ready = False
197
+ except grpc._cython.cygrpc.UsageError:
198
+ # This exception is raised by grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi:124
199
+ # and it is not translated to any other error by tritonclient/grpc/aio/__init__.py", line 151, in is_server_ready
200
+ # response = await self._client_stub.ServerReady(request=request,
201
+ server_live = False
202
+ server_ready = False
203
+ if server_ready and server_live:
204
+ break
205
+ _LOGGER.debug(f"Sleeping for {sleep_time_s:0.2f} seconds")
206
+ await asyncio.sleep(sleep_time_s)
207
+ except asyncio.TimeoutError as e:
208
+ # This error is caused by our timeout, not by Triton Inference Server client.
209
+ message = "Timeout while waiting for model"
210
+ _LOGGER.error(message)
211
+ raise PyTritonClientTimeoutError(message) from e
212
+ _LOGGER.debug("Server is ready")
213
+
214
+
215
+ async def asyncio_wait_for_model_status_loaded(
216
+ asyncio_client: aio_clients,
217
+ model_name: str,
218
+ sleep_time_s: float,
219
+ model_version: Optional[str] = None,
220
+ ):
221
+ """Wait for model status loaded.
222
+
223
+ Function runs the following async function to check model status:
224
+ ```python
225
+ asyncio_get_model_state(asyncio_client, model_name, model_version)
226
+ ```
227
+ If it return _ModelState.READY, then another async function can check if model is really ready:
228
+ ```python
229
+ asyncio_client.is_model_ready(model_name)
230
+ ```
231
+ This function uses the above functions to check if model is ready together
232
+ with asyncio.wait_for(...) to limit the time of waiting.
233
+
234
+ Function contains while loop with sleep to check model status periodically.
235
+
236
+ Args:
237
+ asyncio_client: Triton Inference Server client to use for communication
238
+ model_name: name of the model which configuration we're requesting.
239
+ model_version:
240
+ version of the model which configuration we're requesting.
241
+ If model_version is None configuration of the latest model is returned.
242
+ The latest versions of the model are the numerically greatest version numbers.
243
+ sleep_time_s: time interval, in seconds, between successive checks to determine if the model configuration has been completed.
244
+
245
+ Raises:
246
+ PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
247
+ """
248
+ model_version = model_version or ""
249
+ model_version_msg = model_version or LATEST_MODEL_VERSION
250
+ _LOGGER.debug(f"Waiting for model {model_name}, {model_version_msg} to be ready")
251
+ try:
252
+ while True:
253
+ _LOGGER.debug(f"Checking if model {model_name} is ready")
254
+ is_model_ready = await asyncio_client.is_model_ready(model_name, model_version)
255
+ if is_model_ready:
256
+ break
257
+ _LOGGER.debug(f"Sleeping for {sleep_time_s} seconds")
258
+ await asyncio.sleep(sleep_time_s)
259
+ except asyncio.TimeoutError as e:
260
+ message = f"Timeout while waiting for model {model_name} state (timeout={sleep_time_s:0.2f})"
261
+ _LOGGER.error(message)
262
+ raise PyTritonClientTimeoutError(message) from e
263
+ _LOGGER.debug(f"Model {model_name}, {model_version_msg} is ready")
264
+
265
+
266
+ async def asyncio_wait_for_model_ready(
267
+ asyncio_client: aio_clients,
268
+ model_name: str,
269
+ model_version: Optional[str] = None,
270
+ timeout_s: float = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S,
271
+ ):
272
+ """Wait for Triton Inference Server and deployed on it model readiness.
273
+
274
+ Args:
275
+ asyncio_client: Triton Inference Server client to use for communication
276
+ model_name: name of the model which configuration we're requesting.
277
+ model_version:
278
+ version of the model which configuration we're requesting.
279
+ If model_version is None configuration of the latest model is returned.
280
+ The latest versions of the model are the numerically greatest version numbers.
281
+ timeout_s: timeout to finish model configuration obtain.
282
+
283
+ Raises:
284
+ PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
285
+
286
+ """
287
+ _LOGGER.debug(f"Waiting for model {model_name} to be ready (timeout={timeout_s:0.2f})")
288
+ sleep_time_s = timeout_s * _DEFAULT_ASYNC_SLEEP_FACTOR_S
289
+ try:
290
+ should_finish_before = time.time() + timeout_s
291
+ await asyncio.wait_for(asyncio_wait_for_server_ready(asyncio_client, sleep_time_s), timeout_s)
292
+ _LOGGER.debug(f"Waiting for model {model_name} to be ready")
293
+ timeout_s = max(0, should_finish_before - time.time())
294
+ await asyncio.wait_for(
295
+ asyncio_wait_for_model_status_loaded(
296
+ asyncio_client, model_name=model_name, model_version=model_version, sleep_time_s=sleep_time_s
297
+ ),
298
+ timeout_s,
299
+ )
300
+ except PyTritonClientModelUnavailableError as e:
301
+ _LOGGER.error(f"Failed to obtain model {model_name} config error {e}")
302
+ raise e
303
+ except asyncio.TimeoutError as e:
304
+ _LOGGER.error(f"Failed to obtain model {model_name} config error {e}")
305
+ raise PyTritonClientTimeoutError(
306
+ f"Timeout while waiting for model {model_name} to be ready (timeout={timeout_s:0.2f})"
307
+ ) from e
308
+ _LOGGER.debug(f"Model {model_name} is ready")
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/client.py ADDED
@@ -0,0 +1,2033 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Clients for easy interaction with models deployed on the Triton Inference Server.
16
+
17
+ Typical usage example:
18
+
19
+ ```python
20
+ client = ModelClient("localhost", "MyModel")
21
+ result_dict = client.infer_sample(input_a=a, input_b=b)
22
+ client.close()
23
+ ```
24
+
25
+ Inference inputs can be provided either as positional or keyword arguments:
26
+
27
+ ```python
28
+ result_dict = client.infer_sample(input1, input2)
29
+ result_dict = client.infer_sample(a=input1, b=input2)
30
+ ```
31
+
32
+ Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError.
33
+ """
34
+
35
+ import asyncio
36
+ import contextlib
37
+ import itertools
38
+ import logging
39
+ import socket
40
+ import time
41
+ import warnings
42
+ from concurrent.futures import Future
43
+ from queue import Empty, Full, Queue
44
+ from threading import Lock, Thread
45
+ from typing import Any, Dict, Optional, Tuple, Union
46
+
47
+ import gevent
48
+ import numpy as np
49
+ import tritonclient.grpc
50
+ import tritonclient.grpc.aio
51
+ import tritonclient.http
52
+ import tritonclient.http.aio
53
+ import tritonclient.utils
54
+
55
+ from pytriton.client.asyncio_utils import asyncio_get_model_config, asyncio_wait_for_model_ready
56
+ from pytriton.client.exceptions import (
57
+ PyTritonClientClosedError,
58
+ PyTritonClientInferenceServerError,
59
+ PyTritonClientModelDoesntSupportBatchingError,
60
+ PyTritonClientQueueFullError,
61
+ PyTritonClientTimeoutError,
62
+ PyTritonClientValueError,
63
+ )
64
+ from pytriton.client.utils import (
65
+ _DEFAULT_NETWORK_TIMEOUT_S,
66
+ _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S,
67
+ TritonUrl,
68
+ get_model_config,
69
+ wait_for_model_ready,
70
+ wait_for_server_ready,
71
+ )
72
+ from pytriton.client.warnings import NotSupportedTimeoutWarning
73
+ from pytriton.model_config.triton_model_config import TritonModelConfig
74
+
75
+ _LOGGER = logging.getLogger(__name__)
76
+
77
+ _DEFAULT_SYNC_INIT_TIMEOUT_S = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S
78
+ _DEFAULT_FUTURES_INIT_TIMEOUT_S = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S
79
+ DEFAULT_INFERENCE_TIMEOUT_S = 60.0
80
+
81
+
82
+ _IOType = Union[Tuple[np.ndarray, ...], Dict[str, np.ndarray]]
83
+
84
+
85
+ def _verify_inputs_args(inputs, named_inputs):
86
+ if not inputs and not named_inputs:
87
+ raise PyTritonClientValueError("Provide input data")
88
+ if not bool(inputs) ^ bool(named_inputs):
89
+ raise PyTritonClientValueError("Use either positional either keyword method arguments convention")
90
+
91
+
92
+ def _verify_parameters(parameters_or_headers: Optional[Dict[str, Union[str, int, bool]]] = None):
93
+ if parameters_or_headers is None:
94
+ return
95
+ if not isinstance(parameters_or_headers, dict):
96
+ raise PyTritonClientValueError("Parameters and headers must be a dictionary")
97
+ for key, value in parameters_or_headers.items():
98
+ if not isinstance(key, str):
99
+ raise PyTritonClientValueError("Parameter/header key must be a string")
100
+ if not isinstance(value, (str, int, bool)):
101
+ raise PyTritonClientValueError("Parameter/header value must be a string, integer or boolean")
102
+
103
+
104
+ class BaseModelClient:
105
+ """Base client for model deployed on the Triton Inference Server."""
106
+
107
+ def __init__(
108
+ self,
109
+ url: str,
110
+ model_name: str,
111
+ model_version: Optional[str] = None,
112
+ *,
113
+ lazy_init: bool = True,
114
+ init_timeout_s: Optional[float] = None,
115
+ inference_timeout_s: Optional[float] = None,
116
+ model_config: Optional[TritonModelConfig] = None,
117
+ ensure_model_is_ready: bool = True,
118
+ ):
119
+ """Inits BaseModelClient for given model deployed on the Triton Inference Server.
120
+
121
+ Common usage:
122
+
123
+ ```python
124
+ client = ModelClient("localhost", "BERT")
125
+ result_dict = client.infer_sample(input1_sample, input2_sample)
126
+ client.close()
127
+ ```
128
+
129
+ Args:
130
+ url: The Triton Inference Server url, e.g. `grpc://localhost:8001`.
131
+ In case no scheme is provided http scheme will be used as default.
132
+ In case no port is provided default port for given scheme will be used -
133
+ 8001 for grpc scheme, 8000 for http scheme.
134
+ model_name: name of the model to interact with.
135
+ model_version: version of the model to interact with.
136
+ If model_version is None inference on latest model will be performed.
137
+ The latest versions of the model are numerically the greatest version numbers.
138
+ lazy_init: if initialization should be performed just before sending first request to inference server.
139
+ init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used.
140
+ inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used.
141
+ model_config: model configuration. If not passed, it will be read from inference server during initialization.
142
+ ensure_model_is_ready: if model should be checked if it is ready before first inference request.
143
+
144
+ Raises:
145
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
146
+ PyTritonClientTimeoutError:
147
+ if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
148
+ PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
149
+ """
150
+ self._init_timeout_s = _DEFAULT_SYNC_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s
151
+ self._inference_timeout_s = DEFAULT_INFERENCE_TIMEOUT_S if inference_timeout_s is None else inference_timeout_s
152
+ self._network_timeout_s = min(_DEFAULT_NETWORK_TIMEOUT_S, self._init_timeout_s)
153
+
154
+ self._general_client = self.create_client_from_url(url, network_timeout_s=self._network_timeout_s)
155
+ self._infer_client = self.create_client_from_url(url, network_timeout_s=self._inference_timeout_s)
156
+
157
+ self._model_name = model_name
158
+ self._model_version = model_version
159
+
160
+ self._request_id_generator = itertools.count(0)
161
+
162
+ # Monkey patch __del__ method from client to catch error in client when instance is garbage collected.
163
+ # This is needed because we are closing client in __exit__ method or in close method.
164
+ # (InferenceClient uses gevent library which does not support closing twice from different threads)
165
+ self._monkey_patch_client()
166
+
167
+ if model_config is not None:
168
+ self._model_config = model_config
169
+ self._model_ready = None if ensure_model_is_ready else True
170
+
171
+ else:
172
+ self._model_config = None
173
+ self._model_ready = None
174
+ self._lazy_init: bool = lazy_init
175
+
176
+ self._handle_lazy_init()
177
+
178
+ @classmethod
179
+ def from_existing_client(cls, existing_client: "BaseModelClient"):
180
+ """Create a new instance from an existing client using the same class.
181
+
182
+ Common usage:
183
+ ```python
184
+ client = BaseModelClient.from_existing_client(existing_client)
185
+ ```
186
+
187
+ Args:
188
+ existing_client: An instance of an already initialized subclass.
189
+
190
+ Returns:
191
+ A new instance of the same subclass with shared configuration and readiness state.
192
+ """
193
+ kwargs = {}
194
+ # Copy model configuration and readiness state if present
195
+ if hasattr(existing_client, "_model_config"):
196
+ kwargs["model_config"] = existing_client._model_config
197
+ kwargs["ensure_model_is_ready"] = False
198
+
199
+ new_client = cls(
200
+ url=existing_client._url,
201
+ model_name=existing_client._model_name,
202
+ model_version=existing_client._model_version,
203
+ init_timeout_s=existing_client._init_timeout_s,
204
+ inference_timeout_s=existing_client._inference_timeout_s,
205
+ **kwargs,
206
+ )
207
+
208
+ return new_client
209
+
210
+ def create_client_from_url(self, url: str, network_timeout_s: Optional[float] = None):
211
+ """Create Triton Inference Server client.
212
+
213
+ Args:
214
+ url: url of the server to connect to.
215
+ If url doesn't contain scheme (e.g. "localhost:8001") http scheme is added.
216
+ If url doesn't contain port (e.g. "localhost") default port for given scheme is added.
217
+ network_timeout_s: timeout for client commands. Default value is 60.0 s.
218
+
219
+ Returns:
220
+ Triton Inference Server client.
221
+
222
+ Raises:
223
+ PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
224
+ """
225
+ self._triton_url = TritonUrl.from_url(url)
226
+ self._url = self._triton_url.without_scheme
227
+ self._triton_client_lib = self.get_lib()
228
+ self._monkey_patch_client()
229
+
230
+ if self._triton_url.scheme == "grpc":
231
+ # by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout
232
+ network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s
233
+ warnings.warn(
234
+ f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.",
235
+ NotSupportedTimeoutWarning,
236
+ stacklevel=1,
237
+ )
238
+
239
+ triton_client_init_kwargs = self._get_init_extra_args()
240
+
241
+ _LOGGER.debug(
242
+ f"Creating InferenceServerClient for {self._triton_url.with_scheme} with {triton_client_init_kwargs}"
243
+ )
244
+ return self._triton_client_lib.InferenceServerClient(self._url, **triton_client_init_kwargs)
245
+
246
+ def get_lib(self):
247
+ """Returns tritonclient library for given scheme."""
248
+ raise NotImplementedError
249
+
250
+ @property
251
+ def _next_request_id(self) -> str:
252
+ # pytype complained about creating generator in __init__ method
253
+ # so we create it lazily
254
+ if getattr(self, "_request_id_generator", None) is None:
255
+ self._request_id_generator = itertools.count(0)
256
+ return str(next(self._request_id_generator))
257
+
258
+ def _get_init_extra_args(self):
259
+ timeout = self._inference_timeout_s # pytype: disable=attribute-error
260
+ # The inference timeout is used for both the HTTP and the GRPC protocols. However,
261
+ # the way the timeout is passed to the client differs depending on the protocol.
262
+ # For the HTTP protocol, the timeout is set in the ``__init__`` method as ``network_timeout``
263
+ # and ``connection_timeout``. For the GRPC protocol, the timeout
264
+ # is passed to the infer method as ``client_timeout``.
265
+ # Both protocols support timeouts correctly and will raise an exception
266
+ # if the network request or the inference process takes longer than the timeout.
267
+ # This is a design choice of the underlying tritonclient library.
268
+
269
+ if self._triton_url.scheme != "http":
270
+ return {}
271
+
272
+ kwargs = {
273
+ # This value sets the maximum time allowed for each network request in both model loading and inference process
274
+ "network_timeout": timeout,
275
+ # This value sets the maximum time allowed for establishing a connection to the server.
276
+ # We use the inference timeout here instead of the init timeout because the init timeout
277
+ # is meant for waiting for the model to be ready. The connection timeout should be shorter
278
+ # than the init timeout because it only checks if connection is established (e.g. correct port)
279
+ "connection_timeout": timeout,
280
+ }
281
+ return kwargs
282
+
283
+ def _monkey_patch_client(self):
284
+ pass
285
+
286
+ def _get_model_config_extra_args(self):
287
+ # For the GRPC protocol, the timeout must be passed to the each request as client_timeout
288
+ # model_config doesn't yet support timeout but it is planned for the future
289
+ # grpc_network_timeout_s will be used for model_config
290
+ return {}
291
+
292
+ def _handle_lazy_init(self):
293
+ raise NotImplementedError
294
+
295
+
296
+ def _run_once_per_lib(f):
297
+ def wrapper(_self):
298
+ if _self._triton_client_lib not in wrapper.patched:
299
+ wrapper.patched.add(_self._triton_client_lib)
300
+ return f(_self)
301
+
302
+ wrapper.patched = set()
303
+ return wrapper
304
+
305
+
306
+ class ModelClient(BaseModelClient):
307
+ """Synchronous client for model deployed on the Triton Inference Server."""
308
+
309
+ def __init__(
310
+ self,
311
+ url: str,
312
+ model_name: str,
313
+ model_version: Optional[str] = None,
314
+ *,
315
+ lazy_init: bool = True,
316
+ init_timeout_s: Optional[float] = None,
317
+ inference_timeout_s: Optional[float] = None,
318
+ model_config: Optional[TritonModelConfig] = None,
319
+ ensure_model_is_ready: bool = True,
320
+ ):
321
+ """Inits ModelClient for given model deployed on the Triton Inference Server.
322
+
323
+ If `lazy_init` argument is False, model configuration will be read
324
+ from inference server during initialization.
325
+
326
+ Common usage:
327
+
328
+ ```python
329
+ client = ModelClient("localhost", "BERT")
330
+ result_dict = client.infer_sample(input1_sample, input2_sample)
331
+ client.close()
332
+ ```
333
+
334
+ Client supports also context manager protocol:
335
+
336
+ ```python
337
+ with ModelClient("localhost", "BERT") as client:
338
+ result_dict = client.infer_sample(input1_sample, input2_sample)
339
+ ```
340
+
341
+ The creation of client requires connection to the server and downloading model configuration. You can create client from existing client using the same class:
342
+
343
+ ```python
344
+ client = ModelClient.from_existing_client(existing_client)
345
+ ```
346
+
347
+ Args:
348
+ url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'.
349
+ In case no scheme is provided http scheme will be used as default.
350
+ In case no port is provided default port for given scheme will be used -
351
+ 8001 for grpc scheme, 8000 for http scheme.
352
+ model_name: name of the model to interact with.
353
+ model_version: version of the model to interact with.
354
+ If model_version is None inference on latest model will be performed.
355
+ The latest versions of the model are numerically the greatest version numbers.
356
+ lazy_init: if initialization should be performed just before sending first request to inference server.
357
+ init_timeout_s: timeout for maximum waiting time in loop, which sends retry requests ask if model is ready. It is applied at initialization time only when `lazy_init` argument is False. Default is to do retry loop at first inference.
358
+ inference_timeout_s: timeout in seconds for the model inference process.
359
+ If non passed default 60 seconds timeout will be used.
360
+ For HTTP client it is not only inference timeout but any client request timeout
361
+ - get model config, is model loaded. For GRPC client it is only inference timeout.
362
+ model_config: model configuration. If not passed, it will be read from inference server during initialization.
363
+ ensure_model_is_ready: if model should be checked if it is ready before first inference request.
364
+
365
+ Raises:
366
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
367
+ PyTritonClientTimeoutError:
368
+ if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
369
+ PyTritonClientUrlParseError: In case of problems with parsing url.
370
+ """
371
+ super().__init__(
372
+ url=url,
373
+ model_name=model_name,
374
+ model_version=model_version,
375
+ lazy_init=lazy_init,
376
+ init_timeout_s=init_timeout_s,
377
+ inference_timeout_s=inference_timeout_s,
378
+ model_config=model_config,
379
+ ensure_model_is_ready=ensure_model_is_ready,
380
+ )
381
+
382
+ def get_lib(self):
383
+ """Returns tritonclient library for given scheme."""
384
+ return {"grpc": tritonclient.grpc, "http": tritonclient.http}[self._triton_url.scheme.lower()]
385
+
386
+ def __enter__(self):
387
+ """Create context for using ModelClient as a context manager."""
388
+ return self
389
+
390
+ def __exit__(self, *_):
391
+ """Close resources used by ModelClient instance when exiting from the context."""
392
+ self.close()
393
+
394
+ def load_model(self, config: Optional[str] = None, files: Optional[dict] = None):
395
+ """Load model on the Triton Inference Server.
396
+
397
+ Args:
398
+ config: str - Optional JSON representation of a model config provided for
399
+ the load request, if provided, this config will be used for
400
+ loading the model.
401
+ files: dict - Optional dictionary specifying file path (with "file:" prefix) in
402
+ the override model directory to the file content as bytes.
403
+ The files will form the model directory that the model will be
404
+ loaded from. If specified, 'config' must be provided to be
405
+ the model configuration of the override model directory.
406
+ """
407
+ self._general_client.load_model(self._model_name, config=config, files=files)
408
+
409
+ def unload_model(self):
410
+ """Unload model from the Triton Inference Server."""
411
+ self._general_client.unload_model(self._model_name)
412
+
413
+ def close(self):
414
+ """Close resources used by ModelClient.
415
+
416
+ This method closes the resources used by the ModelClient instance,
417
+ including the Triton Inference Server connections.
418
+ Once this method is called, the ModelClient instance should not be used again.
419
+ """
420
+ _LOGGER.debug("Closing ModelClient")
421
+ try:
422
+ if self._general_client is not None:
423
+ self._general_client.close()
424
+ if self._infer_client is not None:
425
+ self._infer_client.close()
426
+ self._general_client = None
427
+ self._infer_client = None
428
+ except Exception as e:
429
+ _LOGGER.error(f"Error while closing ModelClient resources: {e}")
430
+ raise e
431
+
432
+ def wait_for_model(self, timeout_s: float):
433
+ """Wait for the Triton Inference Server and the deployed model to be ready.
434
+
435
+ Args:
436
+ timeout_s: timeout in seconds to wait for the server and model to be ready.
437
+
438
+ Raises:
439
+ PyTritonClientTimeoutError: If the server and model are not ready before the given timeout.
440
+ PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
441
+ KeyboardInterrupt: If the hosting process receives SIGINT.
442
+ PyTritonClientClosedError: If the ModelClient is closed.
443
+ """
444
+ if self._general_client is None:
445
+ raise PyTritonClientClosedError("ModelClient is closed")
446
+ wait_for_model_ready(self._general_client, self._model_name, self._model_version, timeout_s=timeout_s)
447
+
448
+ @property
449
+ def is_batching_supported(self):
450
+ """Checks if model supports batching.
451
+
452
+ Also waits for server to get into readiness state.
453
+ """
454
+ return self.model_config.max_batch_size > 0
455
+
456
+ def wait_for_server(self, timeout_s: float):
457
+ """Wait for Triton Inference Server readiness.
458
+
459
+ Args:
460
+ timeout_s: timeout to server get into readiness state.
461
+
462
+ Raises:
463
+ PyTritonClientTimeoutError: If server is not in readiness state before given timeout.
464
+ KeyboardInterrupt: If hosting process receives SIGINT
465
+ """
466
+ wait_for_server_ready(self._general_client, timeout_s=timeout_s)
467
+
468
+ @property
469
+ def model_config(self) -> TritonModelConfig:
470
+ """Obtain the configuration of the model deployed on the Triton Inference Server.
471
+
472
+ This method waits for the server to get into readiness state before obtaining the model configuration.
473
+
474
+ Returns:
475
+ TritonModelConfig: configuration of the model deployed on the Triton Inference Server.
476
+
477
+ Raises:
478
+ PyTritonClientTimeoutError: If the server and model are not in readiness state before the given timeout.
479
+ PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
480
+ KeyboardInterrupt: If the hosting process receives SIGINT.
481
+ PyTritonClientClosedError: If the ModelClient is closed.
482
+ """
483
+ if not self._model_config:
484
+ if self._general_client is None:
485
+ raise PyTritonClientClosedError("ModelClient is closed")
486
+
487
+ self._model_config = get_model_config(
488
+ self._general_client, self._model_name, self._model_version, timeout_s=self._init_timeout_s
489
+ )
490
+ return self._model_config
491
+
492
+ def infer_sample(
493
+ self,
494
+ *inputs,
495
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
496
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
497
+ **named_inputs,
498
+ ) -> Dict[str, np.ndarray]:
499
+ """Run synchronous inference on a single data sample.
500
+
501
+ Typical usage:
502
+
503
+ ```python
504
+ client = ModelClient("localhost", "MyModel")
505
+ result_dict = client.infer_sample(input1, input2)
506
+ client.close()
507
+ ```
508
+
509
+ Inference inputs can be provided either as positional or keyword arguments:
510
+
511
+ ```python
512
+ result_dict = client.infer_sample(input1, input2)
513
+ result_dict = client.infer_sample(a=input1, b=input2)
514
+ ```
515
+
516
+ Args:
517
+ *inputs: Inference inputs provided as positional arguments.
518
+ parameters: Custom inference parameters.
519
+ headers: Custom inference headers.
520
+ **named_inputs: Inference inputs provided as named arguments.
521
+
522
+ Returns:
523
+ Dictionary with inference results, where dictionary keys are output names.
524
+
525
+ Raises:
526
+ PyTritonClientValueError: If mixing of positional and named arguments passing detected.
527
+ PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or
528
+ inference request time exceeds `inference_timeout_s`.
529
+ PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
530
+ PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side.
531
+ """
532
+ _verify_inputs_args(inputs, named_inputs)
533
+ _verify_parameters(parameters)
534
+ _verify_parameters(headers)
535
+
536
+ if self.is_batching_supported:
537
+ if inputs:
538
+ inputs = tuple(data[np.newaxis, ...] for data in inputs)
539
+ elif named_inputs:
540
+ named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()}
541
+
542
+ result = self._infer(inputs or named_inputs, parameters, headers)
543
+
544
+ return self._debatch_result(result)
545
+
546
+ def infer_batch(
547
+ self,
548
+ *inputs,
549
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
550
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
551
+ **named_inputs,
552
+ ) -> Dict[str, np.ndarray]:
553
+ """Run synchronous inference on batched data.
554
+
555
+ Typical usage:
556
+
557
+ ```python
558
+ client = ModelClient("localhost", "MyModel")
559
+ result_dict = client.infer_batch(input1, input2)
560
+ client.close()
561
+ ```
562
+
563
+ Inference inputs can be provided either as positional or keyword arguments:
564
+
565
+ ```python
566
+ result_dict = client.infer_batch(input1, input2)
567
+ result_dict = client.infer_batch(a=input1, b=input2)
568
+ ```
569
+
570
+ Args:
571
+ *inputs: Inference inputs provided as positional arguments.
572
+ parameters: Custom inference parameters.
573
+ headers: Custom inference headers.
574
+ **named_inputs: Inference inputs provided as named arguments.
575
+
576
+ Returns:
577
+ Dictionary with inference results, where dictionary keys are output names.
578
+
579
+ Raises:
580
+ PyTritonClientValueError: If mixing of positional and named arguments passing detected.
581
+ PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or
582
+ inference request time exceeds `inference_timeout_s`.
583
+ PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable.
584
+ PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side.
585
+ PyTritonClientModelDoesntSupportBatchingError: If the model doesn't support batching.
586
+ PyTritonClientValueError: if mixing of positional and named arguments passing detected.
587
+ PyTritonClientTimeoutError:
588
+ in case of first method call, `lazy_init` argument is False
589
+ and wait time for server and model being ready exceeds `init_timeout_s` or
590
+ inference time exceeds `inference_timeout_s` passed to `__init__`.
591
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
592
+ PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side,
593
+ """
594
+ _verify_inputs_args(inputs, named_inputs)
595
+ _verify_parameters(parameters)
596
+ _verify_parameters(headers)
597
+
598
+ if not self.is_batching_supported:
599
+ raise PyTritonClientModelDoesntSupportBatchingError(
600
+ f"Model {self.model_config.model_name} doesn't support batching - use infer_sample method instead"
601
+ )
602
+
603
+ return self._infer(inputs or named_inputs, parameters, headers)
604
+
605
+ def _wait_and_init_model_config(self, init_timeout_s: float):
606
+ if self._general_client is None:
607
+ raise PyTritonClientClosedError("ModelClient is closed")
608
+
609
+ should_finish_before_s = time.time() + init_timeout_s
610
+ self.wait_for_model(init_timeout_s)
611
+ self._model_ready = True
612
+ timeout_s = max(0.0, should_finish_before_s - time.time())
613
+ self._model_config = get_model_config(
614
+ self._general_client, self._model_name, self._model_version, timeout_s=timeout_s
615
+ )
616
+
617
+ def _create_request(self, inputs: _IOType):
618
+ if self._infer_client is None:
619
+ raise PyTritonClientClosedError("ModelClient is closed")
620
+
621
+ if not self._model_ready:
622
+ self._wait_and_init_model_config(self._init_timeout_s)
623
+
624
+ if isinstance(inputs, Tuple):
625
+ inputs = {input_spec.name: input_data for input_spec, input_data in zip(self.model_config.inputs, inputs)}
626
+
627
+ inputs_wrapped = []
628
+
629
+ # to help pytype to obtain variable type
630
+ inputs: Dict[str, np.ndarray]
631
+
632
+ for input_name, input_data in inputs.items():
633
+ if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes):
634
+ raise RuntimeError(
635
+ f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \
636
+ \\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}"
637
+ )
638
+ if input_data.dtype.type == np.str_:
639
+ raise RuntimeError(
640
+ "Unicode inputs are not supported. "
641
+ f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))."
642
+ )
643
+ triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype)
644
+ infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype)
645
+ infer_input.set_data_from_numpy(input_data)
646
+ inputs_wrapped.append(infer_input)
647
+
648
+ outputs_wrapped = [
649
+ self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in self.model_config.outputs
650
+ ]
651
+ return inputs_wrapped, outputs_wrapped
652
+
653
+ def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]:
654
+ if self.model_config.decoupled:
655
+ raise PyTritonClientInferenceServerError("Model config is decoupled. Use DecoupledModelClient instead.")
656
+
657
+ inputs_wrapped, outputs_wrapped = self._create_request(inputs)
658
+
659
+ try:
660
+ _LOGGER.debug("Sending inference request to Triton Inference Server")
661
+ response = self._infer_client.infer(
662
+ model_name=self._model_name,
663
+ model_version=self._model_version or "",
664
+ inputs=inputs_wrapped,
665
+ headers=headers,
666
+ outputs=outputs_wrapped,
667
+ request_id=self._next_request_id,
668
+ parameters=parameters,
669
+ **self._get_infer_extra_args(),
670
+ )
671
+ except tritonclient.utils.InferenceServerException as e:
672
+ # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout
673
+ if "Deadline Exceeded" in e.message():
674
+ raise PyTritonClientTimeoutError(
675
+ f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}"
676
+ ) from e
677
+
678
+ raise PyTritonClientInferenceServerError(
679
+ f"Error occurred during inference request. Message: {e.message()}"
680
+ ) from e
681
+ except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout
682
+ message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
683
+ _LOGGER.error(message)
684
+ raise PyTritonClientTimeoutError(message) from e
685
+ except OSError as e: # tritonclient.http raises socket.error for connection error
686
+ message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
687
+ _LOGGER.error(message)
688
+ raise PyTritonClientTimeoutError(message) from e
689
+
690
+ if isinstance(response, tritonclient.http.InferResult):
691
+ outputs = {
692
+ output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"]
693
+ }
694
+ else:
695
+ outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs}
696
+
697
+ return outputs
698
+
699
+ def _get_numpy_result(self, result):
700
+ if isinstance(result, tritonclient.grpc.InferResult):
701
+ result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs}
702
+ else:
703
+ result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]}
704
+ return result
705
+
706
+ def _debatch_result(self, result):
707
+ if self.is_batching_supported:
708
+ result = {name: data[0] for name, data in result.items()}
709
+ return result
710
+
711
+ def _handle_lazy_init(self):
712
+ if not self._lazy_init:
713
+ self._wait_and_init_model_config(self._init_timeout_s)
714
+
715
+ def _get_infer_extra_args(self):
716
+ if self._triton_url.scheme == "http":
717
+ return {}
718
+ # For the GRPC protocol, the timeout is passed to the infer method as client_timeout
719
+ # This timeout applies to the whole inference process and each network request
720
+
721
+ # The ``infer`` supports also timeout argument for both GRPC and HTTP.
722
+ # It is applied at server side and supported only for dynamic batching.
723
+ # However, it is not used here yet and planned for future release
724
+ kwargs = {"client_timeout": self._inference_timeout_s}
725
+ return kwargs
726
+
727
+ @_run_once_per_lib
728
+ def _monkey_patch_client(self):
729
+ """Monkey patch InferenceServerClient to catch error in __del__."""
730
+ _LOGGER.info(f"Patch ModelClient {self._triton_url.scheme}")
731
+ if not hasattr(self._triton_client_lib.InferenceServerClient, "__del__"):
732
+ return
733
+
734
+ old_del = self._triton_client_lib.InferenceServerClient.__del__
735
+
736
+ def _monkey_patched_del(self):
737
+ """Monkey patched del."""
738
+ try:
739
+ old_del(self)
740
+ except gevent.exceptions.InvalidThreadUseError:
741
+ _LOGGER.info("gevent.exceptions.InvalidThreadUseError in __del__ of InferenceServerClient")
742
+ except Exception as e:
743
+ _LOGGER.error("Exception in __del__ of InferenceServerClient: %s", e)
744
+
745
+ self._triton_client_lib.InferenceServerClient.__del__ = _monkey_patched_del
746
+
747
+
748
+ class DecoupledModelClient(ModelClient):
749
+ """Synchronous client for decoupled model deployed on the Triton Inference Server."""
750
+
751
+ def __init__(
752
+ self,
753
+ url: str,
754
+ model_name: str,
755
+ model_version: Optional[str] = None,
756
+ *,
757
+ lazy_init: bool = True,
758
+ init_timeout_s: Optional[float] = None,
759
+ inference_timeout_s: Optional[float] = None,
760
+ model_config: Optional[TritonModelConfig] = None,
761
+ ensure_model_is_ready: bool = True,
762
+ ):
763
+ """Inits DecoupledModelClient for given decoupled model deployed on the Triton Inference Server.
764
+
765
+ Common usage:
766
+
767
+ ```python
768
+ client = DecoupledModelClient("localhost", "BERT")
769
+ for response in client.infer_sample(input1_sample, input2_sample):
770
+ print(response)
771
+ client.close()
772
+ ```
773
+
774
+ Args:
775
+ url: The Triton Inference Server url, e.g. `grpc://localhost:8001`.
776
+ In case no scheme is provided http scheme will be used as default.
777
+ In case no port is provided default port for given scheme will be used -
778
+ 8001 for grpc scheme, 8000 for http scheme.
779
+ model_name: name of the model to interact with.
780
+ model_version: version of the model to interact with.
781
+ If model_version is None inference on latest model will be performed.
782
+ The latest versions of the model are numerically the greatest version numbers.
783
+ lazy_init: if initialization should be performed just before sending first request to inference server.
784
+ init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used.
785
+ inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used.
786
+ model_config: model configuration. If not passed, it will be read from inference server during initialization.
787
+ ensure_model_is_ready: if model should be checked if it is ready before first inference request.
788
+
789
+ Raises:
790
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
791
+ PyTritonClientTimeoutError:
792
+ if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
793
+ PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
794
+ """
795
+ super().__init__(
796
+ url,
797
+ model_name,
798
+ model_version,
799
+ lazy_init=lazy_init,
800
+ init_timeout_s=init_timeout_s,
801
+ inference_timeout_s=inference_timeout_s,
802
+ model_config=model_config,
803
+ ensure_model_is_ready=ensure_model_is_ready,
804
+ )
805
+ if self._triton_url.scheme == "http":
806
+ raise PyTritonClientValueError("DecoupledModelClient is only supported for grpc protocol")
807
+ self._queue = Queue()
808
+ self._lock = Lock()
809
+
810
+ def close(self):
811
+ """Close resources used by DecoupledModelClient."""
812
+ _LOGGER.debug("Closing DecoupledModelClient")
813
+ if self._lock.acquire(blocking=False):
814
+ try:
815
+ super().close()
816
+ finally:
817
+ self._lock.release()
818
+ else:
819
+ _LOGGER.warning("DecoupledModelClient is stil streaming answers")
820
+ self._infer_client.stop_stream(False)
821
+ super().close()
822
+
823
+ def _infer(self, inputs: _IOType, parameters, headers):
824
+ if not self._lock.acquire(blocking=False):
825
+ raise PyTritonClientInferenceServerError("Inference is already in progress")
826
+ if not self.model_config.decoupled:
827
+ raise PyTritonClientInferenceServerError("Model config is coupled. Use ModelClient instead.")
828
+
829
+ inputs_wrapped, outputs_wrapped = self._create_request(inputs)
830
+ if parameters is not None:
831
+ raise PyTritonClientValueError("DecoupledModelClient does not support parameters")
832
+ if headers is not None:
833
+ raise PyTritonClientValueError("DecoupledModelClient does not support headers")
834
+ try:
835
+ _LOGGER.debug("Sending inference request to Triton Inference Server")
836
+ if self._infer_client._stream is None:
837
+ self._infer_client.start_stream(callback=lambda result, error: self._response_callback(result, error))
838
+
839
+ self._infer_client.async_stream_infer(
840
+ model_name=self._model_name,
841
+ model_version=self._model_version or "",
842
+ inputs=inputs_wrapped,
843
+ outputs=outputs_wrapped,
844
+ request_id=self._next_request_id,
845
+ enable_empty_final_response=True,
846
+ **self._get_infer_extra_args(),
847
+ )
848
+ except tritonclient.utils.InferenceServerException as e:
849
+ # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout
850
+ if "Deadline Exceeded" in e.message():
851
+ raise PyTritonClientTimeoutError(
852
+ f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}"
853
+ ) from e
854
+
855
+ raise PyTritonClientInferenceServerError(
856
+ f"Error occurred during inference request. Message: {e.message()}"
857
+ ) from e
858
+ except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout
859
+ message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
860
+ _LOGGER.error(message)
861
+ raise PyTritonClientTimeoutError(message) from e
862
+ except OSError as e: # tritonclient.http raises socket.error for connection error
863
+ message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}"
864
+ _LOGGER.error(message)
865
+ raise PyTritonClientTimeoutError(message) from e
866
+ _LOGGER.debug("Returning response iterator")
867
+ return self._create_response_iterator()
868
+
869
+ def _response_callback(self, response, error):
870
+ _LOGGER.debug(f"Received response from Triton Inference Server: {response}")
871
+ if error:
872
+ _LOGGER.error(f"Error occurred during inference request. Message: {error}")
873
+ self._queue.put(error)
874
+ else:
875
+ actual_response = response.get_response()
876
+ # Check if the object is not None
877
+ triton_final_response = actual_response.parameters.get("triton_final_response")
878
+ if triton_final_response and triton_final_response.bool_param:
879
+ self._queue.put(None)
880
+ else:
881
+ result = self._get_numpy_result(response)
882
+ self._queue.put(result)
883
+
884
+ def _create_response_iterator(self):
885
+ try:
886
+ while True:
887
+ try:
888
+ item = self._queue.get(self._inference_timeout_s)
889
+ except Empty as e:
890
+ message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s"
891
+ _LOGGER.error(message)
892
+ raise PyTritonClientTimeoutError(message) from e
893
+ if isinstance(item, Exception):
894
+ message = f"Error occurred during inference request. Message: {item.message()}"
895
+ _LOGGER.error(message)
896
+ raise PyTritonClientInferenceServerError(message) from item
897
+
898
+ if item is None:
899
+ break
900
+ yield item
901
+ finally:
902
+ self._lock.release()
903
+
904
+ def _debatch_result(self, result):
905
+ if self.is_batching_supported:
906
+ result = ({name: data[0] for name, data in result_.items()} for result_ in result)
907
+ return result
908
+
909
+ def _get_infer_extra_args(self):
910
+ # kwargs = super()._get_infer_extra_args()
911
+ kwargs = {}
912
+ # kwargs["enable_empty_final_response"] = True
913
+ return kwargs
914
+
915
+
916
+ class AsyncioModelClient(BaseModelClient):
917
+ """Asyncio client for model deployed on the Triton Inference Server.
918
+
919
+ This client is based on Triton Inference Server Python clients and GRPC library:
920
+ - ``tritonclient.http.aio.InferenceServerClient``
921
+ - ``tritonclient.grpc.aio.InferenceServerClient``
922
+
923
+ It can wait for server to be ready with model loaded and then perform inference on it.
924
+ ``AsyncioModelClient`` supports asyncio context manager protocol.
925
+
926
+ Typical usage:
927
+
928
+ ```python
929
+ from pytriton.client import AsyncioModelClient
930
+ import numpy as np
931
+
932
+ input1_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
933
+ input2_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
934
+
935
+ client = AsyncioModelClient("localhost", "MyModel")
936
+ result_dict = await client.infer_sample(input1_sample, input2_sample)
937
+ print(result_dict["output_name"])
938
+ await client.close()
939
+ ```
940
+ """
941
+
942
+ def __init__(
943
+ self,
944
+ url: str,
945
+ model_name: str,
946
+ model_version: Optional[str] = None,
947
+ *,
948
+ lazy_init: bool = True,
949
+ init_timeout_s: Optional[float] = None,
950
+ inference_timeout_s: Optional[float] = None,
951
+ model_config: Optional[TritonModelConfig] = None,
952
+ ensure_model_is_ready: bool = True,
953
+ ):
954
+ """Inits ModelClient for given model deployed on the Triton Inference Server.
955
+
956
+ If `lazy_init` argument is False, model configuration will be read
957
+ from inference server during initialization.
958
+
959
+ Args:
960
+ url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'.
961
+ In case no scheme is provided http scheme will be used as default.
962
+ In case no port is provided default port for given scheme will be used -
963
+ 8001 for grpc scheme, 8000 for http scheme.
964
+ model_name: name of the model to interact with.
965
+ model_version: version of the model to interact with.
966
+ If model_version is None inference on latest model will be performed.
967
+ The latest versions of the model are numerically the greatest version numbers.
968
+ lazy_init: if initialization should be performed just before sending first request to inference server.
969
+ init_timeout_s: timeout for server and model being ready.
970
+ inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used.
971
+ model_config: model configuration. If not passed, it will be read from inference server during initialization.
972
+ ensure_model_is_ready: if model should be checked if it is ready before first inference request.
973
+
974
+ Raises:
975
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
976
+ PyTritonClientTimeoutError: if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`.
977
+ PyTritonClientUrlParseError: In case of problems with parsing url.
978
+ """
979
+ super().__init__(
980
+ url=url,
981
+ model_name=model_name,
982
+ model_version=model_version,
983
+ lazy_init=lazy_init,
984
+ init_timeout_s=init_timeout_s,
985
+ inference_timeout_s=inference_timeout_s,
986
+ model_config=model_config,
987
+ ensure_model_is_ready=ensure_model_is_ready,
988
+ )
989
+
990
+ def get_lib(self):
991
+ """Get Triton Inference Server Python client library."""
992
+ return {"grpc": tritonclient.grpc.aio, "http": tritonclient.http.aio}[self._triton_url.scheme.lower()]
993
+
994
+ async def __aenter__(self):
995
+ """Create context for use AsyncioModelClient as a context manager."""
996
+ _LOGGER.debug("Entering AsyncioModelClient context")
997
+ try:
998
+ if not self._lazy_init:
999
+ _LOGGER.debug("Waiting in AsyncioModelClient context for model to be ready")
1000
+ await self._wait_and_init_model_config(self._init_timeout_s)
1001
+ _LOGGER.debug("Model is ready in AsyncioModelClient context")
1002
+ return self
1003
+ except Exception as e:
1004
+ _LOGGER.error("Error occurred during AsyncioModelClient context initialization")
1005
+ await self.close()
1006
+ raise e
1007
+
1008
+ async def __aexit__(self, *_):
1009
+ """Close resources used by AsyncioModelClient when exiting from context."""
1010
+ await self.close()
1011
+ _LOGGER.debug("Exiting AsyncioModelClient context")
1012
+
1013
+ async def close(self):
1014
+ """Close resources used by _ModelClientBase."""
1015
+ _LOGGER.debug("Closing InferenceServerClient")
1016
+ await self._general_client.close()
1017
+ await self._infer_client.close()
1018
+ _LOGGER.debug("InferenceServerClient closed")
1019
+
1020
+ async def wait_for_model(self, timeout_s: float):
1021
+ """Asynchronous wait for Triton Inference Server and deployed on it model readiness.
1022
+
1023
+ Args:
1024
+ timeout_s: timeout to server and model get into readiness state.
1025
+
1026
+ Raises:
1027
+ PyTritonClientTimeoutError: If server and model are not in readiness state before given timeout.
1028
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
1029
+ KeyboardInterrupt: If hosting process receives SIGINT
1030
+ """
1031
+ _LOGGER.debug(f"Waiting for model {self._model_name} to be ready")
1032
+ try:
1033
+ await asyncio.wait_for(
1034
+ asyncio_wait_for_model_ready(
1035
+ self._general_client, self._model_name, self._model_version, timeout_s=timeout_s
1036
+ ),
1037
+ self._init_timeout_s,
1038
+ )
1039
+ except asyncio.TimeoutError as e:
1040
+ message = f"Timeout while waiting for model {self._model_name} to be ready for {self._init_timeout_s}s"
1041
+ _LOGGER.error(message)
1042
+ raise PyTritonClientTimeoutError(message) from e
1043
+
1044
+ @property
1045
+ async def model_config(self):
1046
+ """Obtain configuration of model deployed on the Triton Inference Server.
1047
+
1048
+ Also waits for server to get into readiness state.
1049
+ """
1050
+ try:
1051
+ if not self._model_config:
1052
+ kwargs = self._get_model_config_extra_args()
1053
+ _LOGGER.debug(f"Obtaining model config for {self._model_name}")
1054
+
1055
+ self._model_config = await asyncio.wait_for(
1056
+ asyncio_get_model_config(
1057
+ self._general_client,
1058
+ self._model_name,
1059
+ self._model_version,
1060
+ timeout_s=self._init_timeout_s,
1061
+ **kwargs,
1062
+ ),
1063
+ self._init_timeout_s,
1064
+ )
1065
+ _LOGGER.debug(f"Obtained model config for {self._model_name}")
1066
+ return self._model_config
1067
+ except asyncio.TimeoutError as e:
1068
+ message = f"Timeout while waiting for model {self._model_name} to be ready for {self._init_timeout_s}s"
1069
+ _LOGGER.error(message)
1070
+ raise PyTritonClientTimeoutError(message) from e
1071
+
1072
+ async def infer_sample(
1073
+ self,
1074
+ *inputs,
1075
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
1076
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
1077
+ **named_inputs,
1078
+ ):
1079
+ """Run asynchronous inference on single data sample.
1080
+
1081
+ Typical usage:
1082
+
1083
+ ```python
1084
+ client = AsyncioModelClient("localhost", "MyModel")
1085
+ result_dict = await client.infer_sample(input1, input2)
1086
+ await client.close()
1087
+ ```
1088
+
1089
+ Inference inputs can be provided either as positional or keyword arguments:
1090
+
1091
+ ```python
1092
+ result_dict = await client.infer_sample(input1, input2)
1093
+ result_dict = await client.infer_sample(a=input1, b=input2)
1094
+ ```
1095
+
1096
+ Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError.
1097
+
1098
+ Args:
1099
+ *inputs: inference inputs provided as positional arguments.
1100
+ parameters: custom inference parameters.
1101
+ headers: custom inference headers.
1102
+ **named_inputs: inference inputs provided as named arguments.
1103
+
1104
+ Returns:
1105
+ dictionary with inference results, where dictionary keys are output names.
1106
+
1107
+ Raises:
1108
+ PyTritonClientValueError: if mixing of positional and named arguments passing detected.
1109
+ PyTritonClientTimeoutError:
1110
+ in case of first method call, `lazy_init` argument is False
1111
+ and wait time for server and model being ready exceeds `init_timeout_s`
1112
+ or inference time exceeds `timeout_s`.
1113
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
1114
+ PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
1115
+ """
1116
+ _verify_inputs_args(inputs, named_inputs)
1117
+ _verify_parameters(parameters)
1118
+ _verify_parameters(headers)
1119
+
1120
+ _LOGGER.debug(f"Running inference for {self._model_name}")
1121
+ model_config = await self.model_config
1122
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1123
+
1124
+ model_supports_batching = model_config.max_batch_size > 0
1125
+ if model_supports_batching:
1126
+ if inputs:
1127
+ inputs = tuple(data[np.newaxis, ...] for data in inputs)
1128
+ elif named_inputs:
1129
+ named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()}
1130
+
1131
+ _LOGGER.debug(f"Running _infer for {self._model_name}")
1132
+ result = await self._infer(inputs or named_inputs, parameters, headers)
1133
+ _LOGGER.debug(f"_infer for {self._model_name} finished")
1134
+ if model_supports_batching:
1135
+ result = {name: data[0] for name, data in result.items()}
1136
+
1137
+ return result
1138
+
1139
+ async def infer_batch(
1140
+ self,
1141
+ *inputs,
1142
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
1143
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
1144
+ **named_inputs,
1145
+ ):
1146
+ """Run asynchronous inference on batched data.
1147
+
1148
+ Typical usage:
1149
+
1150
+ ```python
1151
+ client = AsyncioModelClient("localhost", "MyModel")
1152
+ result_dict = await client.infer_batch(input1, input2)
1153
+ await client.close()
1154
+ ```
1155
+
1156
+ Inference inputs can be provided either as positional or keyword arguments:
1157
+
1158
+ ```python
1159
+ result_dict = await client.infer_batch(input1, input2)
1160
+ result_dict = await client.infer_batch(a=input1, b=input2)
1161
+ ```
1162
+
1163
+ Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError.
1164
+
1165
+ Args:
1166
+ *inputs: inference inputs provided as positional arguments.
1167
+ parameters: custom inference parameters.
1168
+ headers: custom inference headers.
1169
+ **named_inputs: inference inputs provided as named arguments.
1170
+
1171
+ Returns:
1172
+ dictionary with inference results, where dictionary keys are output names.
1173
+
1174
+ Raises:
1175
+ PyTritonClientValueError: if mixing of positional and named arguments passing detected.
1176
+ PyTritonClientTimeoutError:
1177
+ in case of first method call, `lazy_init` argument is False
1178
+ and wait time for server and model being ready exceeds `init_timeout_s`
1179
+ or inference time exceeds `timeout_s`.
1180
+ PyTritonClientModelDoesntSupportBatchingError: if model doesn't support batching.
1181
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
1182
+ PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
1183
+ """
1184
+ _verify_inputs_args(inputs, named_inputs)
1185
+ _verify_parameters(parameters)
1186
+ _verify_parameters(headers)
1187
+
1188
+ _LOGGER.debug(f"Running inference for {self._model_name}")
1189
+ model_config = await self.model_config
1190
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1191
+
1192
+ model_supports_batching = model_config.max_batch_size > 0
1193
+ if not model_supports_batching:
1194
+ _LOGGER.error(f"Model {model_config.model_name} doesn't support batching")
1195
+ raise PyTritonClientModelDoesntSupportBatchingError(
1196
+ f"Model {model_config.model_name} doesn't support batching - use infer_sample method instead"
1197
+ )
1198
+
1199
+ _LOGGER.debug(f"Running _infer for {self._model_name}")
1200
+ result = await self._infer(inputs or named_inputs, parameters, headers)
1201
+ _LOGGER.debug(f"_infer for {self._model_name} finished")
1202
+ return result
1203
+
1204
+ async def _wait_and_init_model_config(self, init_timeout_s: float):
1205
+ """Asynchronous wait for model and obtain model configuration.
1206
+
1207
+ Args:
1208
+ init_timeout_s: timeout for server and model being ready.
1209
+
1210
+ Raises:
1211
+ PyTritonClientTimeoutError: if wait time for server and model being ready exceeds `init_timeout_s`
1212
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
1213
+ """
1214
+ try:
1215
+ should_finish_before_s = time.time() + init_timeout_s
1216
+ _LOGGER.debug(f"Waiting for model {self._model_name} to be ready")
1217
+
1218
+ await asyncio.wait_for(self.wait_for_model(init_timeout_s), init_timeout_s)
1219
+ _LOGGER.debug(f"Model {self._model_name} is ready")
1220
+ self._model_ready = True
1221
+
1222
+ timeout_s = max(0.0, should_finish_before_s - time.time())
1223
+ _LOGGER.debug(f"Obtaining model config for {self._model_name}")
1224
+ self._model_config = await asyncio.wait_for(
1225
+ asyncio_get_model_config(
1226
+ self._general_client, self._model_name, self._model_version, timeout_s=timeout_s
1227
+ ),
1228
+ timeout_s,
1229
+ )
1230
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1231
+ except asyncio.TimeoutError as e:
1232
+ _LOGGER.error(f"Timeout exceeded while waiting for model {self._model_name} to be ready")
1233
+ raise PyTritonClientTimeoutError(
1234
+ f"Timeout exceeded while waiting for model {self._model_name} to be ready"
1235
+ ) from e
1236
+
1237
+ def _validate_input(self, input_name, input_data):
1238
+ if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes):
1239
+ raise RuntimeError(
1240
+ f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \
1241
+ \\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}"
1242
+ )
1243
+ if input_data.dtype.type == np.str_:
1244
+ raise RuntimeError(
1245
+ "Unicode inputs are not supported. "
1246
+ f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))."
1247
+ )
1248
+
1249
+ async def _execute_infer(self, model_config, inputs_wrapped, outputs_wrapped, parameters, headers) -> Any:
1250
+ try:
1251
+ _LOGGER.debug(f"Sending InferRequest for {self._model_name}")
1252
+ kwargs = self._get_infer_extra_args()
1253
+ response = await self._infer_client.infer(
1254
+ model_name=self._model_name,
1255
+ model_version=self._model_version or "",
1256
+ inputs=inputs_wrapped,
1257
+ headers=headers,
1258
+ outputs=outputs_wrapped,
1259
+ request_id=self._next_request_id,
1260
+ parameters=parameters,
1261
+ **kwargs,
1262
+ )
1263
+ except asyncio.exceptions.TimeoutError as e:
1264
+ # HTTP aio client raises asyncio.exceptions.TimeoutError for timeout errors
1265
+ message = f"Timeout exceeded while running inference for {self._model_name}"
1266
+ _LOGGER.error(message)
1267
+ raise PyTritonClientTimeoutError(message) from e
1268
+ except tritonclient.utils.InferenceServerException as e:
1269
+ message = f"Error occurred on Triton Inference Server side:\n {e.message()}"
1270
+ _LOGGER.error(message)
1271
+ if "Deadline Exceeded" in e.message():
1272
+ # GRPC aio client raises InferenceServerException with message "Deadline Exceeded"
1273
+ # for timeout errors
1274
+ raise PyTritonClientTimeoutError(message) from e
1275
+ else:
1276
+ raise PyTritonClientInferenceServerError(message) from e
1277
+ _LOGGER.debug(f"Received InferResponse for {self._model_name}")
1278
+ outputs = {output_spec.name: response.as_numpy(output_spec.name) for output_spec in model_config.outputs}
1279
+ return outputs
1280
+
1281
+ async def _infer(self, inputs: _IOType, parameters, headers):
1282
+ if self._model_ready:
1283
+ _LOGGER.debug(f"Waiting for model {self._model_name} config")
1284
+ await self._wait_and_init_model_config(self._init_timeout_s)
1285
+ _LOGGER.debug(f"Model wait finished for {self._model_name}")
1286
+
1287
+ _LOGGER.debug(f"Obtaining config for {self._model_name}")
1288
+ model_config = await self.model_config
1289
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1290
+ if model_config.decoupled:
1291
+ raise PyTritonClientInferenceServerError(
1292
+ "Model config is decoupled. Use DecouploedAsyncioModelClient instead."
1293
+ )
1294
+
1295
+ if isinstance(inputs, Tuple):
1296
+ inputs = {input_spec.name: input_data for input_spec, input_data in zip(model_config.inputs, inputs)}
1297
+
1298
+ inputs_wrapped = []
1299
+ for input_name, input_data in inputs.items():
1300
+ if isinstance(input_data, np.ndarray):
1301
+ self._validate_input(input_name, input_data)
1302
+ triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype)
1303
+ infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype)
1304
+ infer_input.set_data_from_numpy(input_data)
1305
+ input_wrapped = infer_input
1306
+ inputs_wrapped.append(input_wrapped)
1307
+ else:
1308
+ raise PyTritonClientValueError(
1309
+ f"Input {input_name} is not a numpy array. Got {type(input_data)} instead."
1310
+ )
1311
+
1312
+ outputs_wrapped = [
1313
+ self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in model_config.outputs
1314
+ ]
1315
+ return await self._execute_infer(model_config, inputs_wrapped, outputs_wrapped, parameters, headers)
1316
+
1317
+ def _handle_lazy_init(self):
1318
+ # Asynchronous lazy initialization is done in __aenter__ method
1319
+ pass
1320
+
1321
+ def _get_init_extra_args(self):
1322
+ # The inference timeout is used for both the HTTP and the GRPC protocols. However,
1323
+ # the way the timeout is passed to the client differs depending on the protocol.
1324
+ # For the HTTP protocol, the timeout is set in the ``__init__`` method as ``conn_timeout`` for both connection and request timeouts.
1325
+ # For the GRPC protocol, the timeout
1326
+ # is passed to the infer method as ``client_timeout``.
1327
+ # Both protocols support timeouts correctly and will raise an exception
1328
+ # if the network request or the inference process takes longer than the timeout.
1329
+ # This is a design choice of the underlying tritonclient library.
1330
+
1331
+ if self._triton_url.scheme != "http":
1332
+ return {}
1333
+
1334
+ kwargs = {
1335
+ # This value sets the maximum time allowed for both connection and network requests in both model loading and inference process
1336
+ "conn_timeout": self._inference_timeout_s,
1337
+ }
1338
+ return kwargs
1339
+
1340
+ def _get_infer_extra_args(self):
1341
+ if self._triton_url.scheme == "http":
1342
+ return {}
1343
+ # For the GRPC protocol, the timeout is passed to the infer method as client_timeout
1344
+ # This timeout applies to the whole inference process and each network request
1345
+
1346
+ # The ``infer`` supports also timeout argument for both GRPC and HTTP.
1347
+ # It is applied at server side and supported only for dynamic batching.
1348
+ # However, it is not used here yet and planned for future release
1349
+ kwargs = {"client_timeout": self._inference_timeout_s}
1350
+ return kwargs
1351
+
1352
+
1353
+ class AsyncioDecoupledModelClient(AsyncioModelClient):
1354
+ """Asyncio client for model deployed on the Triton Inference Server.
1355
+
1356
+ This client is based on Triton Inference Server Python clients and GRPC library:
1357
+ * ``tritonclient.grpc.aio.InferenceServerClient``
1358
+
1359
+ It can wait for server to be ready with model loaded and then perform inference on it.
1360
+ ``AsyncioDecoupledModelClient`` supports asyncio context manager protocol.
1361
+
1362
+ The client is intended to be used with decoupled models and will raise an error if model is coupled.
1363
+
1364
+ Typical usage:
1365
+ ```python
1366
+ from pytriton.client import AsyncioDecoupledModelClient
1367
+ import numpy as np
1368
+
1369
+ input1_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
1370
+ input2_sample = np.random.rand(1, 3, 224, 224).astype(np.float32)
1371
+
1372
+ async with AsyncioDecoupledModelClient("grpc://localhost", "MyModel") as client:
1373
+ async for result_dict in client.infer_sample(input1_sample, input2_sample):
1374
+ print(result_dict["output_name"])
1375
+ ```
1376
+ """
1377
+
1378
+ async def infer_sample(
1379
+ self,
1380
+ *inputs,
1381
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
1382
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
1383
+ **named_inputs,
1384
+ ):
1385
+ """Run asynchronous inference on single data sample.
1386
+
1387
+ Typical usage:
1388
+
1389
+ ```python
1390
+ async with AsyncioDecoupledModelClient("grpc://localhost", "MyModel") as client:
1391
+ async for result_dict in client.infer_sample(input1_sample, input2_sample):
1392
+ print(result_dict["output_name"])
1393
+ ```
1394
+
1395
+ Inference inputs can be provided either as positional or keyword arguments:
1396
+
1397
+ ```python
1398
+ results_iterator = client.infer_sample(input1, input2)
1399
+ results_iterator = client.infer_sample(a=input1, b=input2)
1400
+ ```
1401
+
1402
+ Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError.
1403
+
1404
+ Args:
1405
+ *inputs: inference inputs provided as positional arguments.
1406
+ parameters: custom inference parameters.
1407
+ headers: custom inference headers.
1408
+ **named_inputs: inference inputs provided as named arguments.
1409
+
1410
+ Returns:
1411
+ Asynchronous generator, which generates dictionaries with partial inference results, where dictionary keys are output names.
1412
+
1413
+ Raises:
1414
+ PyTritonClientValueError: if mixing of positional and named arguments passing detected.
1415
+ PyTritonClientTimeoutError:
1416
+ in case of first method call, `lazy_init` argument is False
1417
+ and wait time for server and model being ready exceeds `init_timeout_s`
1418
+ or inference time exceeds `timeout_s`.
1419
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
1420
+ PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
1421
+ """
1422
+ _verify_inputs_args(inputs, named_inputs)
1423
+ _verify_parameters(parameters)
1424
+ _verify_parameters(headers)
1425
+
1426
+ _LOGGER.debug(f"Running inference for {self._model_name}")
1427
+ model_config = await self.model_config
1428
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1429
+
1430
+ model_supports_batching = model_config.max_batch_size > 0
1431
+ if model_supports_batching:
1432
+ if inputs:
1433
+ inputs = tuple(data[np.newaxis, ...] for data in inputs)
1434
+ elif named_inputs:
1435
+ named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()}
1436
+
1437
+ _LOGGER.debug(f"Running _infer for {self._model_name}")
1438
+ result = self._infer(inputs or named_inputs, parameters, headers)
1439
+ _LOGGER.debug(f"_infer for {self._model_name} finished")
1440
+
1441
+ async for item in result:
1442
+ if model_supports_batching:
1443
+ debatched_item = {name: data[0] for name, data in item.items()}
1444
+ yield debatched_item
1445
+ else:
1446
+ yield item
1447
+
1448
+ async def infer_batch(
1449
+ self,
1450
+ *inputs,
1451
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
1452
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
1453
+ **named_inputs,
1454
+ ):
1455
+ """Run asynchronous inference on batched data.
1456
+
1457
+ Typical usage:
1458
+
1459
+ ```python
1460
+ async with AsyncioDecoupledModelClient("grpc://localhost", "MyModel") as client:
1461
+ async for result_dict in client.infer_batch(input1_sample, input2_sample):
1462
+ print(result_dict["output_name"])
1463
+ ```
1464
+
1465
+ Inference inputs can be provided either as positional or keyword arguments:
1466
+
1467
+ ```python
1468
+ results_iterator = client.infer_batch(input1, input2)
1469
+ results_iterator = client.infer_batch(a=input1, b=input2)
1470
+ ```
1471
+
1472
+ Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError.
1473
+
1474
+ Args:
1475
+ *inputs: inference inputs provided as positional arguments.
1476
+ parameters: custom inference parameters.
1477
+ headers: custom inference headers.
1478
+ **named_inputs: inference inputs provided as named arguments.
1479
+
1480
+ Returns:
1481
+ Asynchronous generator, which generates dictionaries with partial inference results, where dictionary keys are output names.
1482
+
1483
+ Raises:
1484
+ PyTritonClientValueError: if mixing of positional and named arguments passing detected.
1485
+ PyTritonClientTimeoutError:
1486
+ in case of first method call, `lazy_init` argument is False
1487
+ and wait time for server and model being ready exceeds `init_timeout_s`
1488
+ or inference time exceeds `timeout_s`.
1489
+ PyTritonClientModelDoesntSupportBatchingError: if model doesn't support batching.
1490
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
1491
+ PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side.
1492
+ """
1493
+ _verify_inputs_args(inputs, named_inputs)
1494
+ _verify_parameters(parameters)
1495
+ _verify_parameters(headers)
1496
+
1497
+ _LOGGER.debug(f"Running inference for {self._model_name}")
1498
+ model_config = await self.model_config
1499
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1500
+
1501
+ model_supports_batching = model_config.max_batch_size > 0
1502
+ if not model_supports_batching:
1503
+ _LOGGER.error(f"Model {model_config.model_name} doesn't support batching")
1504
+ raise PyTritonClientModelDoesntSupportBatchingError(
1505
+ f"Model {model_config.model_name} doesn't support batching - use infer_sample method instead"
1506
+ )
1507
+
1508
+ _LOGGER.debug(f"Running _infer for {self._model_name}")
1509
+ result = self._infer(inputs or named_inputs, parameters, headers)
1510
+ _LOGGER.debug(f"_infer for {self._model_name} finished")
1511
+ async for item in result:
1512
+ yield item
1513
+
1514
+ async def _execute_infer(self, model_config, inputs_wrapped, outputs_wrapped, parameters, headers) -> Any:
1515
+ # stream_infer siletly consumes all errors raised inside async_request_iterator and raises CancelledError
1516
+ error_raised_inside_async_request_iterator = set()
1517
+ try:
1518
+ _LOGGER.debug(f"Sending InferRequest for {self._model_name}")
1519
+ kwargs = self._get_infer_extra_args()
1520
+
1521
+ async def async_request_iterator(errors):
1522
+ _LOGGER.debug(f"Begin creating InferRequestHeader for {self._model_name}")
1523
+ try:
1524
+ yield {
1525
+ "model_name": self._model_name,
1526
+ "inputs": inputs_wrapped,
1527
+ "outputs": outputs_wrapped,
1528
+ "request_id": self._next_request_id,
1529
+ "sequence_id": 0,
1530
+ "sequence_start": True,
1531
+ "sequence_end": True,
1532
+ }
1533
+ except Exception as e:
1534
+ _LOGGER.error(f"Error occurred while creating InferRequestHeader for {self._model_name}")
1535
+ errors.add(e)
1536
+ raise e
1537
+ _LOGGER.debug(f"End creating InferRequestHeader for {self._model_name}")
1538
+
1539
+ response_iterator = self._infer_client.stream_infer(
1540
+ inputs_iterator=async_request_iterator(error_raised_inside_async_request_iterator),
1541
+ headers=headers,
1542
+ **kwargs,
1543
+ )
1544
+ _LOGGER.debug(f"End preparing InferRequest for {self._model_name}")
1545
+ while True:
1546
+ try:
1547
+ try:
1548
+ response = await asyncio.wait_for(
1549
+ response_iterator.__anext__(),
1550
+ self._inference_timeout_s,
1551
+ )
1552
+ except asyncio.TimeoutError as e:
1553
+ message = f"Timeout while waiting for model {self._model_name} to return next response {self._inference_timeout_s}s"
1554
+ _LOGGER.error(message)
1555
+ raise PyTritonClientTimeoutError(message) from e
1556
+ result, error = response
1557
+ _LOGGER.debug(f"Received InferResponse for {self._model_name}")
1558
+ if error is not None:
1559
+ raise error
1560
+ else:
1561
+ partial_output = {
1562
+ output_spec.name: result.as_numpy(output_spec.name) for output_spec in model_config.outputs
1563
+ }
1564
+ yield partial_output
1565
+ except StopAsyncIteration:
1566
+ break
1567
+ _LOGGER.debug(f"End receiving InferResponse for {self._model_name}")
1568
+
1569
+ except asyncio.exceptions.TimeoutError as e:
1570
+ # HTTP aio client raises asyncio.exceptions.TimeoutError for timeout errors
1571
+ message = f"Timeout exceeded while running inference for {self._model_name}"
1572
+ _LOGGER.error(message)
1573
+ raise PyTritonClientTimeoutError(message) from e
1574
+ except tritonclient.utils.InferenceServerException as e:
1575
+ message = f"Error occurred on Triton Inference Server side:\n {e.message()}"
1576
+ _LOGGER.error(message)
1577
+ if "Deadline Exceeded" in e.message():
1578
+ # GRPC aio client raises InferenceServerException with message "Deadline Exceeded"
1579
+ # for timeout errors
1580
+ raise PyTritonClientTimeoutError(message) from e
1581
+ else:
1582
+ raise PyTritonClientInferenceServerError(message) from e
1583
+ except asyncio.exceptions.CancelledError as e:
1584
+ _LOGGER.error(f"CancelledError occurred while streaming inference for {self._model_name}")
1585
+ # stream_infer siletly consumes all errors raised inside async_request_iterator and raises CancelledError
1586
+ if len(error_raised_inside_async_request_iterator) > 0:
1587
+ _LOGGER.error(f"Re-raising error raised inside async_request_iterator for {self._model_name} ")
1588
+ raise error_raised_inside_async_request_iterator.pop() from None
1589
+ else:
1590
+ raise e
1591
+
1592
+ async def _infer(self, inputs: _IOType, parameters, headers):
1593
+ if self._model_ready:
1594
+ _LOGGER.debug(f"Waiting for model {self._model_name} config")
1595
+ await self._wait_and_init_model_config(self._init_timeout_s)
1596
+ _LOGGER.debug(f"Model wait finished for {self._model_name}")
1597
+
1598
+ _LOGGER.debug(f"Obtaining config for {self._model_name}")
1599
+ model_config = await self.model_config
1600
+ _LOGGER.debug(f"Model config for {self._model_name} obtained")
1601
+ if not model_config.decoupled:
1602
+ raise PyTritonClientInferenceServerError("Model config is coupled. Use AsyncioModelClient instead.")
1603
+
1604
+ if isinstance(inputs, Tuple):
1605
+ inputs = {input_spec.name: input_data for input_spec, input_data in zip(model_config.inputs, inputs)}
1606
+
1607
+ inputs_wrapped = []
1608
+ for input_name, input_data in inputs.items():
1609
+ if isinstance(input_data, np.ndarray):
1610
+ self._validate_input(input_name, input_data)
1611
+ triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype)
1612
+ infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype)
1613
+ infer_input.set_data_from_numpy(input_data)
1614
+ input_wrapped = infer_input
1615
+ inputs_wrapped.append(input_wrapped)
1616
+ else:
1617
+ raise PyTritonClientValueError(
1618
+ f"Input {input_name} is not a numpy array. Got {type(input_data)} instead."
1619
+ )
1620
+
1621
+ outputs_wrapped = [
1622
+ self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in model_config.outputs
1623
+ ]
1624
+ result = self._execute_infer(model_config, inputs_wrapped, outputs_wrapped, parameters, headers)
1625
+ async for item in result:
1626
+ yield item
1627
+
1628
+ def _get_infer_extra_args(self):
1629
+ if self._triton_url.scheme == "http":
1630
+ raise PyTritonClientValueError("AsyncioDecoupledModelClient is only supported for grpc protocol")
1631
+ warnings.warn(
1632
+ f"tritonclient.aio.grpc doesn't support client_timeout parameter {self._inference_timeout_s} for infer_stream",
1633
+ NotSupportedTimeoutWarning,
1634
+ stacklevel=1,
1635
+ )
1636
+ return {}
1637
+
1638
+
1639
+ @contextlib.contextmanager
1640
+ def _hub_context():
1641
+ hub = gevent.get_hub()
1642
+ try:
1643
+ yield hub
1644
+ finally:
1645
+ hub.destroy()
1646
+
1647
+
1648
+ _INIT = "init"
1649
+ _WAIT_FOR_MODEL = "wait_for_model"
1650
+ _MODEL_CONFIG = "model_config"
1651
+ _INFER_BATCH = "infer_batch"
1652
+ _INFER_SAMPLE = "infer_sample"
1653
+ _CLOSE = "close"
1654
+
1655
+
1656
+ class FuturesModelClient:
1657
+ """A client for interacting with a model deployed on the Triton Inference Server using concurrent.futures.
1658
+
1659
+ This client allows asynchronous inference requests using a thread pool executor. It can be used to perform inference
1660
+ on a model by providing input data and receiving the corresponding output data. The client can be used in a `with`
1661
+ statement to ensure proper resource management.
1662
+
1663
+ Example usage with context manager:
1664
+
1665
+ ```python
1666
+ with FuturesModelClient("localhost", "MyModel") as client:
1667
+ result_future = client.infer_sample(input1=input1_data, input2=input2_data)
1668
+ # do something else
1669
+ print(result_future.result())
1670
+ ```
1671
+
1672
+ Usage without context manager:
1673
+
1674
+ ```python
1675
+ client = FuturesModelClient("localhost", "MyModel")
1676
+ result_future = client.infer_sample(input1=input1_data, input2=input2_data)
1677
+ # do something else
1678
+ print(result_future.result())
1679
+ client.close()
1680
+ ```
1681
+ """
1682
+
1683
+ def __init__(
1684
+ self,
1685
+ url: str,
1686
+ model_name: str,
1687
+ model_version: Optional[str] = None,
1688
+ *,
1689
+ max_workers: int = 128,
1690
+ max_queue_size: int = 128,
1691
+ non_blocking: bool = False,
1692
+ init_timeout_s: Optional[float] = None,
1693
+ inference_timeout_s: Optional[float] = None,
1694
+ ):
1695
+ """Initializes the FuturesModelClient for a given model.
1696
+
1697
+ Args:
1698
+ url: The Triton Inference Server url, e.g. `grpc://localhost:8001`.
1699
+ model_name: The name of the model to interact with.
1700
+ model_version: The version of the model to interact with. If None, the latest version will be used.
1701
+ max_workers: The maximum number of threads that can be used to execute the given calls. If None, there is not limit on the number of threads.
1702
+ max_queue_size: The maximum number of requests that can be queued. If None, there is not limit on the number of requests.
1703
+ non_blocking: If True, the client will raise a PyTritonClientQueueFullError if the queue is full. If False, the client will block until the queue is not full.
1704
+ init_timeout_s: Timeout in seconds for server and model being ready. If non passed default 60 seconds timeout will be used.
1705
+ inference_timeout_s: Timeout in seconds for the single model inference request. If non passed default 60 seconds timeout will be used.
1706
+ """
1707
+ self._url = url
1708
+ self._model_name = model_name
1709
+ self._model_version = model_version
1710
+ self._threads = []
1711
+ self._max_workers = max_workers
1712
+ self._max_queue_size = max_queue_size
1713
+ self._non_blocking = non_blocking
1714
+
1715
+ if self._max_workers is not None and self._max_workers <= 0:
1716
+ raise ValueError("max_workers must be greater than 0")
1717
+ if self._max_queue_size is not None and self._max_queue_size <= 0:
1718
+ raise ValueError("max_queue_size must be greater than 0")
1719
+
1720
+ kwargs = {}
1721
+ if self._max_queue_size is not None:
1722
+ kwargs["maxsize"] = self._max_queue_size
1723
+ self._queue = Queue(**kwargs)
1724
+ self._queue.put((_INIT, None, None))
1725
+ self._init_timeout_s = _DEFAULT_FUTURES_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s
1726
+ self._inference_timeout_s = inference_timeout_s
1727
+ self._closed = False
1728
+ self._lock = Lock()
1729
+ self._existing_client = None
1730
+
1731
+ def __enter__(self):
1732
+ """Create context for using FuturesModelClient as a context manager."""
1733
+ return self
1734
+
1735
+ def __exit__(self, exc_type, exc_value, traceback):
1736
+ """Close resources used by FuturesModelClient instance when exiting from the context."""
1737
+ self.close()
1738
+
1739
+ def close(self, wait=True):
1740
+ """Close resources used by FuturesModelClient.
1741
+
1742
+ This method closes the resources used by the FuturesModelClient instance, including the Triton Inference Server connections.
1743
+ Once this method is called, the FuturesModelClient instance should not be used again.
1744
+
1745
+ Args:
1746
+ wait: If True, then shutdown will not return until all running futures have finished executing.
1747
+ """
1748
+ if self._closed:
1749
+ return
1750
+ _LOGGER.debug("Closing FuturesModelClient.")
1751
+
1752
+ self._closed = True
1753
+ for _ in range(len(self._threads)):
1754
+ self._queue.put((_CLOSE, None, None))
1755
+
1756
+ if wait:
1757
+ _LOGGER.debug("Waiting for futures to finish.")
1758
+ for thread in self._threads:
1759
+ thread.join()
1760
+
1761
+ def wait_for_model(self, timeout_s: float) -> Future:
1762
+ """Returns a Future object which result will be None when the model is ready.
1763
+
1764
+ Typical usage:
1765
+
1766
+ ```python
1767
+ with FuturesModelClient("localhost", "BERT") as client
1768
+ future = client.wait_for_model(300.)
1769
+ # do something else
1770
+ future.result() # wait rest of timeout_s time
1771
+ # till return None if model is ready
1772
+ # or raise PyTritonClientTimeutError
1773
+ ```
1774
+
1775
+ Args:
1776
+ timeout_s: The maximum amount of time to wait for the model to be ready, in seconds.
1777
+
1778
+ Returns:
1779
+ A Future object which result is None when the model is ready.
1780
+ """
1781
+ return self._execute(
1782
+ name=_WAIT_FOR_MODEL,
1783
+ request=timeout_s,
1784
+ )
1785
+
1786
+ def model_config(self) -> Future:
1787
+ """Obtain the configuration of the model deployed on the Triton Inference Server.
1788
+
1789
+ This method returns a Future object that will contain the TritonModelConfig object when it is ready.
1790
+ Client will wait init_timeout_s for the server to get into readiness state before obtaining the model configuration.
1791
+
1792
+ Returns:
1793
+ A Future object that will contain the TritonModelConfig object when it is ready.
1794
+
1795
+ Raises:
1796
+ PyTritonClientClosedError: If the FuturesModelClient is closed.
1797
+ """
1798
+ return self._execute(name=_MODEL_CONFIG)
1799
+
1800
+ def infer_sample(
1801
+ self,
1802
+ *inputs,
1803
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
1804
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
1805
+ **named_inputs,
1806
+ ) -> Future:
1807
+ """Run asynchronous inference on a single data sample and return a Future object.
1808
+
1809
+ This method allows the user to perform inference on a single data sample by providing input data and receiving the
1810
+ corresponding output data. The method returns a Future object that wraps a dictionary of inference results, where dictionary keys are output names.
1811
+
1812
+ Example usage:
1813
+
1814
+ ```python
1815
+ with FuturesModelClient("localhost", "BERT") as client:
1816
+ result_future = client.infer_sample(input1=input1_data, input2=input2_data)
1817
+ # do something else
1818
+ print(result_future.result())
1819
+ ```
1820
+
1821
+ Inference inputs can be provided either as positional or keyword arguments:
1822
+
1823
+ ```python
1824
+ future = client.infer_sample(input1, input2)
1825
+ future = client.infer_sample(a=input1, b=input2)
1826
+ ```
1827
+
1828
+ Args:
1829
+ *inputs: Inference inputs provided as positional arguments.
1830
+ parameters: Optional dictionary of inference parameters.
1831
+ headers: Optional dictionary of HTTP headers for the inference request.
1832
+ **named_inputs: Inference inputs provided as named arguments.
1833
+
1834
+ Returns:
1835
+ A Future object wrapping a dictionary of inference results, where dictionary keys are output names.
1836
+
1837
+ Raises:
1838
+ PyTritonClientClosedError: If the FuturesModelClient is closed.
1839
+ """
1840
+ return self._execute(
1841
+ name=_INFER_SAMPLE,
1842
+ request=(inputs, parameters, headers, named_inputs),
1843
+ )
1844
+
1845
+ def infer_batch(
1846
+ self,
1847
+ *inputs,
1848
+ parameters: Optional[Dict[str, Union[str, int, bool]]] = None,
1849
+ headers: Optional[Dict[str, Union[str, int, bool]]] = None,
1850
+ **named_inputs,
1851
+ ) -> Future:
1852
+ """Run asynchronous inference on batched data and return a Future object.
1853
+
1854
+ This method allows the user to perform inference on batched data by providing input data and receiving the corresponding output data.
1855
+ The method returns a Future object that wraps a dictionary of inference results, where dictionary keys are output names.
1856
+
1857
+ Example usage:
1858
+
1859
+ ```python
1860
+ with FuturesModelClient("localhost", "BERT") as client:
1861
+ future = client.infer_batch(input1_sample, input2_sample)
1862
+ # do something else
1863
+ print(future.result())
1864
+ ```
1865
+
1866
+ Inference inputs can be provided either as positional or keyword arguments:
1867
+
1868
+ ```python
1869
+ future = client.infer_batch(input1, input2)
1870
+ future = client.infer_batch(a=input1, b=input2)
1871
+ ```
1872
+
1873
+ Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError.
1874
+
1875
+ Args:
1876
+ *inputs: Inference inputs provided as positional arguments.
1877
+ parameters: Optional dictionary of inference parameters.
1878
+ headers: Optional dictionary of HTTP headers for the inference request.
1879
+ **named_inputs: Inference inputs provided as named arguments.
1880
+
1881
+ Returns:
1882
+ A Future object wrapping a dictionary of inference results, where dictionary keys are output names.
1883
+
1884
+ Raises:
1885
+ PyTritonClientClosedError: If the FuturesModelClient is closed.
1886
+ """
1887
+ return self._execute(name=_INFER_BATCH, request=(inputs, parameters, headers, named_inputs))
1888
+
1889
+ def _execute(self, name, request=None):
1890
+ if self._closed:
1891
+ raise PyTritonClientClosedError("FutureModelClient is already closed")
1892
+ self._extend_thread_pool()
1893
+ future = Future()
1894
+ if self._non_blocking:
1895
+ try:
1896
+ self._queue.put_nowait((future, request, name))
1897
+ except Full as e:
1898
+ raise PyTritonClientQueueFullError("Queue is full") from e
1899
+ else:
1900
+ kwargs = {}
1901
+ if self._inference_timeout_s is not None:
1902
+ kwargs["timeout"] = self._inference_timeout_s
1903
+ try:
1904
+ self._queue.put((future, request, name), **kwargs)
1905
+ except Full as e:
1906
+ raise PyTritonClientQueueFullError("Queue is full") from e
1907
+ return future
1908
+
1909
+ def _extend_thread_pool(self):
1910
+ if self._closed:
1911
+ return
1912
+
1913
+ with self._lock:
1914
+ if not self._queue.empty() and (self._max_workers is None or len(self._threads) < self._max_workers):
1915
+ _LOGGER.debug("Create new thread")
1916
+ thread = Thread(target=self._worker)
1917
+ self._threads.append(thread)
1918
+ thread.start()
1919
+ else:
1920
+ _LOGGER.debug("No need to create new thread")
1921
+
1922
+ def _client_request_executor(self, client, request, name):
1923
+ _LOGGER.debug(f"Running {name} for {self._model_name}")
1924
+ if name == _INFER_SAMPLE:
1925
+ inputs, parameters, headers, named_inputs = request
1926
+ result = client.infer_sample(
1927
+ *inputs,
1928
+ parameters=parameters,
1929
+ headers=headers,
1930
+ **named_inputs,
1931
+ )
1932
+ elif name == _INFER_BATCH:
1933
+ inputs, parameters, headers, named_inputs = request
1934
+ result = client.infer_batch(
1935
+ *inputs,
1936
+ parameters=parameters,
1937
+ headers=headers,
1938
+ **named_inputs,
1939
+ )
1940
+ elif name == _MODEL_CONFIG:
1941
+ result = client.model_config
1942
+ elif name == _WAIT_FOR_MODEL:
1943
+ timeout_s = request
1944
+ result = client.wait_for_model(timeout_s)
1945
+ else:
1946
+ raise PyTritonClientValueError(f"Unknown request name {name}")
1947
+ self._set_existing_client(client)
1948
+ return result
1949
+
1950
+ def _create_client(self, lazy_init):
1951
+ _LOGGER.debug(f"Creating ModelClient lazy_init={lazy_init}")
1952
+ return ModelClient(
1953
+ self._url,
1954
+ self._model_name,
1955
+ self._model_version,
1956
+ lazy_init=lazy_init,
1957
+ init_timeout_s=self._init_timeout_s,
1958
+ inference_timeout_s=self._inference_timeout_s,
1959
+ )
1960
+
1961
+ def _set_existing_client(self, client):
1962
+ if client._model_config is not None:
1963
+ with self._lock:
1964
+ if self._existing_client is None:
1965
+ _LOGGER.debug("Setting existing client")
1966
+ self._existing_client = client
1967
+
1968
+ def _remove_existing_client(self, client):
1969
+ if client is not None:
1970
+ with self._lock:
1971
+ if self._existing_client is not None:
1972
+ if self._existing_client is client:
1973
+ _LOGGER.debug("Resetting existing client")
1974
+ self._existing_client = None
1975
+
1976
+ def _worker(self):
1977
+ _LOGGER.debug("Starting worker thread")
1978
+ client = None
1979
+ # Work around for AttributeError: '_Threadlocal' object has no attribute 'hub'
1980
+ # gevent/_hub_local.py", line 77, in gevent._gevent_c_hub_local.get_hub_noargs
1981
+ with _hub_context():
1982
+ while True:
1983
+ future, request, name = self._queue.get()
1984
+ if future == _CLOSE:
1985
+ _LOGGER.debug("Closing thread")
1986
+ self._queue.task_done()
1987
+ break
1988
+ if future == _INIT:
1989
+ with self._lock:
1990
+ if self._existing_client is None:
1991
+ try:
1992
+ _LOGGER.debug("Initial client creation")
1993
+ client = self._create_client(False)
1994
+ _LOGGER.debug("Setting existing client")
1995
+ self._existing_client = client
1996
+ except Exception as e:
1997
+ _LOGGER.warning(f"Error {e} occurred during init for {self._model_name}")
1998
+ continue
1999
+ try:
2000
+ if client is None:
2001
+ with self._lock:
2002
+ if self._existing_client is not None:
2003
+ _LOGGER.debug("Creating new client from existing client")
2004
+ client = ModelClient.from_existing_client(self._existing_client)
2005
+ if client is None:
2006
+ _LOGGER.debug("Creating new client")
2007
+ client = self._create_client(name == _WAIT_FOR_MODEL)
2008
+ with client:
2009
+ self._set_existing_client(client)
2010
+ while True:
2011
+ try:
2012
+ result = self._client_request_executor(client, request, name)
2013
+ _LOGGER.debug(f"Finished {name} for {self._model_name}")
2014
+ future.set_result(result)
2015
+ self._queue.task_done()
2016
+ except Exception as e:
2017
+ _LOGGER.error(f"Error {e} occurred during {name} for {self._model_name}")
2018
+ future.set_exception(e)
2019
+ self._queue.task_done()
2020
+ break
2021
+ future, request, name = self._queue.get()
2022
+ if future == _CLOSE:
2023
+ _LOGGER.debug("Closing thread")
2024
+ self._queue.task_done()
2025
+ return
2026
+ except Exception as e:
2027
+ _LOGGER.error(f"Error {e} occurred during {name} for {self._model_name}")
2028
+ future.set_exception(e)
2029
+ self._queue.task_done()
2030
+ finally:
2031
+ self._remove_existing_client(client)
2032
+ client = None
2033
+ _LOGGER.debug("Finishing worker thread")
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/exceptions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Exceptions thrown in pytriton.client module."""
15
+
16
+
17
+ class PyTritonClientError(Exception):
18
+ """Generic pytriton client exception."""
19
+
20
+ def __init__(self, message: str):
21
+ """Initialize exception with message.
22
+
23
+ Args:
24
+ message: Error message
25
+ """
26
+ self._message = message
27
+
28
+ def __str__(self) -> str:
29
+ """String representation of error.
30
+
31
+ Returns:
32
+ Message content
33
+ """
34
+ return self._message
35
+
36
+ @property
37
+ def message(self):
38
+ """Get the exception message.
39
+
40
+ Returns:
41
+ The message associated with this exception, or None if no message.
42
+
43
+ """
44
+ return self._message
45
+
46
+
47
+ class PyTritonClientValueError(PyTritonClientError):
48
+ """Generic error raised in case of incorrect values are provided into API."""
49
+
50
+ pass
51
+
52
+
53
+ class PyTritonClientInvalidUrlError(PyTritonClientValueError):
54
+ """Error raised when provided Triton Inference Server url is invalid."""
55
+
56
+ pass
57
+
58
+
59
+ class PyTritonClientTimeoutError(PyTritonClientError):
60
+ """Timeout occurred during communication with the Triton Inference Server."""
61
+
62
+ pass
63
+
64
+
65
+ class PyTritonClientModelUnavailableError(PyTritonClientError):
66
+ """Model with given name and version is unavailable on the given Triton Inference Server."""
67
+
68
+ pass
69
+
70
+
71
+ class PyTritonClientClosedError(PyTritonClientError):
72
+ """Error raised in case of trying to use closed client."""
73
+
74
+ pass
75
+
76
+
77
+ class PyTritonClientModelDoesntSupportBatchingError(PyTritonClientError):
78
+ """Error raised in case of trying to infer batch on model not supporting batching."""
79
+
80
+ pass
81
+
82
+
83
+ class PyTritonClientInferenceServerError(PyTritonClientError):
84
+ """Error raised in case of error on inference callable or Triton Inference Server side."""
85
+
86
+ pass
87
+
88
+
89
+ class PyTritonClientQueueFullError(PyTritonClientError):
90
+ """Error raised in case of trying to push request to full queue."""
91
+
92
+ pass
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/utils.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utility module supporting model clients."""
15
+
16
+ import dataclasses
17
+ import enum
18
+ import logging
19
+ import socket
20
+ import sys
21
+ import time
22
+ import urllib
23
+ import warnings
24
+ from typing import Optional, Union
25
+
26
+ import tritonclient.grpc
27
+ import tritonclient.http
28
+ import tritonclient.http.aio
29
+ from grpc import RpcError
30
+ from tritonclient.utils import InferenceServerException
31
+
32
+ from pytriton.client.exceptions import PyTritonClientInvalidUrlError, PyTritonClientTimeoutError
33
+ from pytriton.client.warnings import NotSupportedTimeoutWarning
34
+ from pytriton.constants import DEFAULT_GRPC_PORT, DEFAULT_HTTP_PORT
35
+ from pytriton.model_config.parser import ModelConfigParser
36
+
37
+ _LOGGER = logging.getLogger(__name__)
38
+
39
+ _TritonSyncClientType = Union[tritonclient.grpc.InferenceServerClient, tritonclient.http.InferenceServerClient]
40
+
41
+ _DEFAULT_NETWORK_TIMEOUT_S = 60.0 # 1min
42
+ _DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S = 60.0 # 1min
43
+ _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S = 300.0 # 5min
44
+
45
+ LATEST_MODEL_VERSION = "<latest>"
46
+
47
+
48
+ # Special value for model_version argument. If model_version is None, the latest version of the model is returned.
49
+
50
+
51
+ class ModelState(enum.Enum):
52
+ """Describe model state in Triton.
53
+
54
+ Attributes:
55
+ LOADING: Loading of model
56
+ UNLOADING: Unloading of model
57
+ UNAVAILABLE: Model is missing or could not be loaded
58
+ READY: Model is ready for inference
59
+ """
60
+
61
+ LOADING = "LOADING"
62
+ UNLOADING = "UNLOADING"
63
+ UNAVAILABLE = "UNAVAILABLE"
64
+ READY = "READY"
65
+
66
+
67
+ def parse_http_response(models):
68
+ """Parse model repository index response from Triton Inference Server for HTTP."""
69
+ models_states = {}
70
+ _LOGGER.debug("Parsing model repository index entries:")
71
+ for model in models:
72
+ _LOGGER.debug(f" name={model.get('name')} version={model.get('version')} state={model.get('state')}")
73
+ if not model.get("version"):
74
+ continue
75
+
76
+ model_state = ModelState(model["state"]) if model.get("state") else ModelState.LOADING
77
+ models_states[(model["name"], model["version"])] = model_state
78
+
79
+ return models_states
80
+
81
+
82
+ def parse_grpc_response(models):
83
+ """Parse model repository index response from Triton Inference Server for GRCP."""
84
+ models_states = {}
85
+ _LOGGER.debug("Parsing model repository index entries:")
86
+ for model in models:
87
+ _LOGGER.debug(f" name={model.name} version={model.version} state={model.state}")
88
+ if not model.version:
89
+ continue
90
+
91
+ model_state = ModelState(model.state) if model.state else ModelState.LOADING
92
+ models_states[(model.name, model.version)] = model_state
93
+
94
+ return models_states
95
+
96
+
97
+ def get_model_state(
98
+ client: _TritonSyncClientType,
99
+ model_name: str,
100
+ model_version: Optional[str] = None,
101
+ ) -> ModelState:
102
+ """Obtains state of the model deployed in Triton Inference Server.
103
+
104
+ Args:
105
+ client: Triton Inference Server client to use for communication
106
+ model_name: name of the model which state we're requesting.
107
+ model_version:
108
+ version of the model which state we're requesting.
109
+ If model_version is None state of latest model is returned.
110
+ The latest versions of the model are the numerically greatest version numbers.
111
+
112
+ Returns:
113
+ Model state. _ModelState.UNAVAILABLE is returned in case if model with given name and version is not found.
114
+
115
+ """
116
+ repository_index = client.get_model_repository_index()
117
+ if isinstance(repository_index, list):
118
+ models_states = parse_http_response(models=repository_index)
119
+ else:
120
+ models_states = parse_grpc_response(models=repository_index.models)
121
+
122
+ if model_version is None:
123
+ requested_model_states = {
124
+ version: state for (name, version), state in models_states.items() if name == model_name
125
+ }
126
+ if not requested_model_states:
127
+ return ModelState.UNAVAILABLE
128
+ else:
129
+ requested_model_states = sorted(requested_model_states.items(), key=lambda item: int(item[0]))
130
+ _latest_version, latest_version_state = requested_model_states[-1]
131
+ return latest_version_state
132
+ else:
133
+ state = models_states.get((model_name, model_version), ModelState.UNAVAILABLE)
134
+ return state
135
+
136
+
137
+ def get_model_config(
138
+ client: _TritonSyncClientType,
139
+ model_name: str,
140
+ model_version: Optional[str] = None,
141
+ timeout_s: Optional[float] = None,
142
+ ):
143
+ """Obtain configuration of model deployed on the Triton Inference Server.
144
+
145
+ Function waits for server readiness.
146
+
147
+ Typical use:
148
+
149
+ client = tritonclient.grpc.Client("localhost:8001")
150
+ model_config = get_model_config(client, "MyModel", "1", 60.0)
151
+ model_config = get_model_config(client, "MyModel")
152
+
153
+ Args:
154
+ client: Triton Inference Server client to use for communication
155
+ model_name: name of the model which configuration we're requesting.
156
+ model_version:
157
+ version of the model which configuration we're requesting.
158
+ If model_version is None configuration of the latest model is returned.
159
+ The latest versions of the model are the numerically greatest version numbers.
160
+ timeout_s: timeout to finish model configuration obtain. Default value is 300.0 s.
161
+
162
+ Returns:
163
+ Configuration of requested model.
164
+
165
+ Raises:
166
+ PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
167
+ PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable.
168
+ """
169
+ wait_for_model_ready(client, model_name=model_name, model_version=model_version, timeout_s=timeout_s)
170
+
171
+ model_version = model_version or ""
172
+
173
+ _LOGGER.debug(f"Obtaining model {model_name} config")
174
+ if isinstance(client, tritonclient.grpc.InferenceServerClient):
175
+ response = client.get_model_config(model_name, model_version, as_json=True)
176
+ model_config = response["config"]
177
+ else:
178
+ model_config = client.get_model_config(model_name, model_version)
179
+ model_config = ModelConfigParser.from_dict(model_config)
180
+ _LOGGER.debug(f"Model config: {model_config}")
181
+ return model_config
182
+
183
+
184
+ def _warn_on_too_big_network_timeout(client: _TritonSyncClientType, timeout_s: float):
185
+ if isinstance(client, tritonclient.http.InferenceServerClient):
186
+ connection_pool = client._client_stub._connection_pool
187
+ network_reldiff_s = (connection_pool.network_timeout - timeout_s) / timeout_s
188
+ connection_reldiff_s = (connection_pool.connection_timeout - timeout_s) / timeout_s
189
+ rtol = 0.001
190
+ if network_reldiff_s > rtol or connection_reldiff_s > rtol:
191
+ warnings.warn(
192
+ "Client network and/or connection timeout is smaller than requested timeout_s. This may cause unexpected behavior. "
193
+ f"network_timeout={connection_pool.network_timeout} "
194
+ f"connection_timeout={connection_pool.connection_timeout} "
195
+ f"timeout_s={timeout_s}",
196
+ NotSupportedTimeoutWarning,
197
+ stacklevel=1,
198
+ )
199
+
200
+
201
+ def wait_for_server_ready(
202
+ client: _TritonSyncClientType,
203
+ timeout_s: Optional[float] = None,
204
+ ):
205
+ """Waits for Triton Inference Server to be ready.
206
+
207
+ Typical use:
208
+
209
+ client = tritonclient.http.Client("localhost:8001")
210
+ wait_for_server_ready(client, timeout_s=600.0)
211
+
212
+ Args:
213
+ client: Triton Inference Server client to use for communication
214
+ timeout_s: timeout to server get into readiness state. Default value is 60.0 s.
215
+
216
+ Raises:
217
+ PyTritonClientTimeoutError: If obtain of model configuration didn't finish before given timeout.
218
+ """
219
+ timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_SERVER_READY_TIMEOUT_S
220
+ should_finish_before_s = time.time() + timeout_s
221
+ _warn_on_too_big_network_timeout(client, timeout_s)
222
+
223
+ def _is_server_ready():
224
+ try:
225
+ return client.is_server_ready() and client.is_server_live()
226
+ except InferenceServerException:
227
+ return False
228
+ except (RpcError, ConnectionError, socket.gaierror): # GRPC and HTTP clients raises these errors
229
+ return False
230
+ except Exception as e:
231
+ _LOGGER.exception(f"Exception while checking server readiness: {e}")
232
+ raise e
233
+
234
+ timeout_s = max(0.0, should_finish_before_s - time.time())
235
+ _LOGGER.debug(f"Waiting for server to be ready (timeout={timeout_s})")
236
+ is_server_ready = _is_server_ready()
237
+ while not is_server_ready:
238
+ time.sleep(min(1.0, timeout_s))
239
+ is_server_ready = _is_server_ready()
240
+ if not is_server_ready and time.time() >= should_finish_before_s:
241
+ raise PyTritonClientTimeoutError("Waiting for server to be ready timed out.")
242
+
243
+
244
+ def wait_for_model_ready(
245
+ client: _TritonSyncClientType,
246
+ model_name: str,
247
+ model_version: Optional[str] = None,
248
+ timeout_s: Optional[float] = None,
249
+ ):
250
+ """Wait for Triton Inference Server to be ready.
251
+
252
+ Args:
253
+ client: Triton Inference Server client to use for communication.
254
+ model_name: name of the model to wait for readiness.
255
+ model_version:
256
+ version of the model to wait for readiness.
257
+ If model_version is None waiting for latest version of the model.
258
+ The latest versions of the model are the numerically greatest version numbers.
259
+ timeout_s: timeout to server and model get into readiness state. Default value is 300.0 s.
260
+
261
+ Raises:
262
+ PyTritonClientTimeoutError: If server readiness didn't finish before given timeout.
263
+ """
264
+ model_version = model_version or ""
265
+ model_version_msg = model_version or LATEST_MODEL_VERSION
266
+ timeout_s = timeout_s if timeout_s is not None else _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S
267
+ should_finish_before_s = time.time() + timeout_s
268
+
269
+ wait_for_server_ready(client, timeout_s=timeout_s)
270
+ timeout_s = max(0.0, should_finish_before_s - time.time())
271
+ _LOGGER.debug(f"Waiting for model {model_name}/{model_version_msg} to be ready (timeout={timeout_s})")
272
+ is_model_ready = client.is_model_ready(model_name, model_version)
273
+ while not is_model_ready:
274
+ time.sleep(min(1.0, timeout_s))
275
+ is_model_ready = client.is_model_ready(model_name, model_version)
276
+
277
+ if not is_model_ready and time.time() >= should_finish_before_s:
278
+ raise PyTritonClientTimeoutError(
279
+ f"Waiting for model {model_name}/{model_version_msg} to be ready timed out."
280
+ )
281
+
282
+
283
+ def create_client_from_url(url: str, network_timeout_s: Optional[float] = None) -> _TritonSyncClientType: # type: ignore
284
+ """Create Triton Inference Server client.
285
+
286
+ Args:
287
+ url: url of the server to connect to.
288
+ If url doesn't contain scheme (e.g. "localhost:8001") http scheme is added.
289
+ If url doesn't contain port (e.g. "localhost") default port for given scheme is added.
290
+ network_timeout_s: timeout for client commands. Default value is 60.0 s.
291
+
292
+ Returns:
293
+ Triton Inference Server client.
294
+
295
+ Raises:
296
+ PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid.
297
+ """
298
+ url = TritonUrl.from_url(url)
299
+ triton_client_lib = {"grpc": tritonclient.grpc, "http": tritonclient.http}[url.scheme]
300
+
301
+ if url.scheme == "grpc":
302
+ # by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout
303
+ network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s
304
+ warnings.warn(
305
+ f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.",
306
+ NotSupportedTimeoutWarning,
307
+ stacklevel=1,
308
+ )
309
+
310
+ triton_client_init_kwargs = {}
311
+ if network_timeout_s is not None:
312
+ triton_client_init_kwargs.update(
313
+ **{
314
+ "grpc": {},
315
+ "http": {"connection_timeout": network_timeout_s, "network_timeout": network_timeout_s},
316
+ }[url.scheme]
317
+ )
318
+
319
+ _LOGGER.debug(f"Creating InferenceServerClient for {url.with_scheme} with {triton_client_init_kwargs}")
320
+ return triton_client_lib.InferenceServerClient(url.without_scheme, **triton_client_init_kwargs)
321
+
322
+
323
+ @dataclasses.dataclass
324
+ class TritonUrl:
325
+ """TritonUrl class for parsing Triton Inference Server url.
326
+
327
+ Attributes:
328
+ scheme: scheme of the url (http or grpc)
329
+ hostname: hostname of the url
330
+ port: port of the url
331
+
332
+ Examples:
333
+ triton_url = TritonUrl.from_url("localhost:8000")
334
+ triton_url.with_scheme
335
+ >>> "http://localhost:8000"
336
+ triton_url.without_scheme
337
+ >>> "localhost:8000"
338
+ triton_url.scheme, triton_url.hostname, triton_url.port
339
+ >>> ("http", "localhost", 8000)
340
+ """
341
+
342
+ scheme: str
343
+ hostname: str
344
+ port: int
345
+
346
+ @classmethod
347
+ def from_url(cls, url):
348
+ """Parse triton url and create TritonUrl instance.
349
+
350
+ Returns:
351
+ TritonUrl object with scheme, hostname and port.
352
+ """
353
+ if not isinstance(url, str):
354
+ raise PyTritonClientInvalidUrlError(f"Invalid url {url}. Url must be a string.")
355
+ try:
356
+ parsed_url = urllib.parse.urlparse(url)
357
+ # change in py3.9+
358
+ # https://github.com/python/cpython/commit/5a88d50ff013a64fbdb25b877c87644a9034c969
359
+ if sys.version_info < (3, 9) and not parsed_url.scheme and "://" in parsed_url.path:
360
+ raise ValueError(f"Invalid url {url}. Only grpc and http are supported.")
361
+ if (not parsed_url.scheme and "://" not in parsed_url.path) or (
362
+ sys.version_info >= (3, 9) and parsed_url.scheme and not parsed_url.netloc
363
+ ):
364
+ _LOGGER.debug(f"Adding http scheme to {url}")
365
+ parsed_url = urllib.parse.urlparse(f"http://{url}")
366
+
367
+ scheme = parsed_url.scheme.lower()
368
+ if scheme not in ["grpc", "http"]:
369
+ raise ValueError(f"Invalid scheme {scheme}. Only grpc and http are supported.")
370
+
371
+ port = parsed_url.port or {"grpc": DEFAULT_GRPC_PORT, "http": DEFAULT_HTTP_PORT}[scheme]
372
+ except ValueError as e:
373
+ raise PyTritonClientInvalidUrlError(f"Invalid url {url}") from e
374
+ return cls(scheme, parsed_url.hostname, port)
375
+
376
+ @property
377
+ def with_scheme(self):
378
+ """Get Triton Inference Server url with scheme."""
379
+ return f"{self.scheme}://{self.hostname}:{self.port}"
380
+
381
+ @property
382
+ def without_scheme(self):
383
+ """Get Triton Inference Server url without scheme."""
384
+ return f"{self.hostname}:{self.port}"
stf/stf-api-alternative/pytriton/build/lib/pytriton/client/warnings.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Warnings for pytriton module."""
15
+
16
+
17
+ class PyTritonWarning(UserWarning):
18
+ """Base warning for pytriton module."""
19
+
20
+ pass
21
+
22
+
23
+ class NotSupportedTimeoutWarning(PyTritonWarning):
24
+ """A warning for client, which doesn't support timeout configuration."""
25
+
26
+ pass
stf/stf-api-alternative/pytriton/build/lib/pytriton/constants.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
15
+ """Constants for pytriton."""
16
+
17
+ import os
18
+ import pathlib
19
+
20
+ DEFAULT_HTTP_PORT = 8000
21
+ DEFAULT_GRPC_PORT = 8001
22
+ DEFAULT_METRICS_PORT = 8002
23
+ TRITON_LOCAL_IP = "127.0.0.1"
24
+ TRITON_CONTEXT_FIELD_NAME = "triton_context"
25
+ TRITON_PYTHON_BACKEND_INTERPRETER_DIRNAME = "python_backend_interpreter"
26
+ DEFAULT_TRITON_STARTUP_TIMEOUT_S = 120
27
+ CREATE_TRITON_CLIENT_TIMEOUT_S = 10
28
+
29
+ __DEFAULT_PYTRITON_HOME = os.path.join(os.getenv("XDG_CACHE_HOME", "$HOME/.cache"), "pytriton")
30
+ __PYTRITON_HOME = os.path.expanduser(os.path.expandvars(os.getenv("PYTRITON_HOME", __DEFAULT_PYTRITON_HOME)))
31
+ PYTRITON_HOME = pathlib.Path(__PYTRITON_HOME).resolve().absolute()
stf/stf-api-alternative/pytriton/build/lib/pytriton/decorators.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Inference callable decorators."""
15
+
16
+ import collections
17
+ import dataclasses
18
+ import inspect
19
+ import itertools
20
+ import operator
21
+ import typing
22
+ from bisect import bisect_left
23
+ from collections.abc import MutableMapping
24
+ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+ import wrapt
28
+
29
+ from pytriton.constants import TRITON_CONTEXT_FIELD_NAME
30
+ from pytriton.exceptions import PyTritonBadParameterError, PyTritonRuntimeError, PyTritonValidationError
31
+ from pytriton.model_config.triton_model_config import TritonModelConfig
32
+ from pytriton.proxy.data import _serialize_byte_tensor
33
+ from pytriton.proxy.telemetry import start_span_from_span
34
+
35
+
36
+ class _WrappedWithWrapper(NamedTuple):
37
+ wrapped: Optional[Callable]
38
+ wrapper: Optional[Callable]
39
+
40
+
41
+ InputNames = typing.List[str]
42
+ InferenceRequest = typing.Dict[str, np.ndarray]
43
+ InferenceRequests = typing.Union[typing.List[InferenceRequest], typing.Tuple[InferenceRequest, ...]]
44
+ InferenceResult = typing.Dict[str, np.ndarray]
45
+ InferenceResults = typing.Union[typing.List[InferenceResult], typing.Tuple[InferenceResult, ...]]
46
+
47
+
48
+ def get_inference_request_batch_size(inference_request: InferenceRequest) -> int:
49
+ """Get batch size from triton request.
50
+
51
+ Args:
52
+ inference_request (InferenceRequest): Triton request.
53
+
54
+ Returns:
55
+ int: Batch size.
56
+ """
57
+ first_input_value = next(iter(inference_request.values()))
58
+ batch_size, *_dims = first_input_value.shape
59
+ return batch_size
60
+
61
+
62
+ def _get_wrapt_stack(wrapped) -> List[_WrappedWithWrapper]:
63
+ """Returns stack of wrapped functions with wrappers applied to inference callable."""
64
+ stack = []
65
+ infer_callable = wrapped
66
+ while infer_callable is not None:
67
+ stack.append(_WrappedWithWrapper(infer_callable, getattr(infer_callable, "_self_wrapper", None)))
68
+ infer_callable = getattr(infer_callable, "__wrapped__", None)
69
+
70
+ return stack
71
+
72
+
73
+ class ModelConfigDict(MutableMapping):
74
+ """Dictionary for storing model configs for inference callable."""
75
+
76
+ def __init__(self):
77
+ """Create ModelConfigDict object."""
78
+ self._data: Dict[str, TritonModelConfig] = {}
79
+ self._keys: List[Callable] = []
80
+
81
+ def __getitem__(self, infer_callable: Callable) -> TritonModelConfig:
82
+ """Get model config for inference callable."""
83
+ key = self._get_model_config_key(infer_callable)
84
+ return self._data[key]
85
+
86
+ def __setitem__(self, infer_callable: Callable, item: TritonModelConfig):
87
+ """Set model config for inference callable."""
88
+ self._keys.append(infer_callable)
89
+ key = self._get_model_config_key(infer_callable)
90
+ self._data[key] = item
91
+
92
+ def __delitem__(self, infer_callable: Callable):
93
+ """Delete model config for inference callable."""
94
+ key = self._get_model_config_key(infer_callable)
95
+ del self._data[key]
96
+
97
+ def __len__(self):
98
+ """Get number of inference callable keys."""
99
+ return len(self._data)
100
+
101
+ def __iter__(self):
102
+ """Iterate over inference callable keys."""
103
+ return iter(self._keys)
104
+
105
+ @staticmethod
106
+ def _get_model_config_key(infer_callable: Callable) -> str:
107
+ """Prepares TritonModelConfig dictionary key for function/callable."""
108
+ dict_key = infer_callable
109
+ if inspect.ismethod(dict_key) and dict_key.__name__ == "__call__":
110
+ dict_key = dict_key.__self__
111
+ return str(dict_key)
112
+
113
+
114
+ @dataclasses.dataclass
115
+ class TritonContext:
116
+ """Triton context definition class."""
117
+
118
+ model_configs: ModelConfigDict = dataclasses.field(default_factory=ModelConfigDict)
119
+
120
+
121
+ def get_triton_context(wrapped, instance) -> TritonContext:
122
+ """Retrieves triton context from callable.
123
+
124
+ It is used in @triton_context to get triton context registered by triton binding in inference callable.
125
+ If you use @triton_context decorator you do not need this function.
126
+ """
127
+ caller = instance or wrapped
128
+ if not hasattr(caller, "__triton_context__"):
129
+ raise PyTritonValidationError("Wrapped function or object must bound with triton to get __triton_context__")
130
+ return caller.__triton_context__
131
+
132
+
133
+ def get_model_config(wrapped, instance) -> TritonModelConfig:
134
+ """Retrieves instance of TritonModelConfig from callable.
135
+
136
+ It is internally used in convert_output function to get output list from model.
137
+ You can use this in custom decorators if you need access to model_config information.
138
+ If you use @triton_context decorator you do not need this function (you can get model_config directly
139
+ from triton_context passing function/callable to dictionary getter).
140
+ """
141
+ return get_triton_context(wrapped, instance).model_configs[wrapped]
142
+
143
+
144
+ def convert_output(
145
+ outputs: Union[Dict, List, Tuple], wrapped=None, instance=None, model_config: Optional[TritonModelConfig] = None
146
+ ):
147
+ """Converts output from tuple ot list to dictionary.
148
+
149
+ It is utility function useful for mapping output list into dictionary of outputs.
150
+ Currently, it is used in @sample and @batch decorators (we assume that user can return list or tuple of outputs
151
+ instead of dictionary if this list matches output list in model config (size and order).
152
+ """
153
+ if isinstance(outputs, dict):
154
+ return outputs
155
+ elif isinstance(outputs, (list, tuple)):
156
+ if model_config is None:
157
+ model_config = get_model_config(wrapped, instance)
158
+ if len(outputs) != len(model_config.outputs):
159
+ raise PyTritonValidationError("Outputs length different than config outputs length")
160
+ outputs = {config_output.name: output for config_output, output in zip(model_config.outputs, outputs)}
161
+ return outputs
162
+ else:
163
+ raise PyTritonValidationError(f"Unsupported output type {type(outputs)}.")
164
+
165
+
166
+ @wrapt.decorator
167
+ def sample(wrapped, instance, args, kwargs):
168
+ """Decorator is used for non-batched inputs to convert from one element list of requests to request kwargs.
169
+
170
+ Decorator takes first request and convert it into named inputs.
171
+ Useful with non-batching models - instead of one element list of request, we will get named inputs - `kwargs`.
172
+ """
173
+ kwargs.update(args[0][0])
174
+ outputs = wrapped(*args[1:], **kwargs)
175
+ outputs = convert_output(outputs, wrapped, instance)
176
+ return [outputs]
177
+
178
+
179
+ @wrapt.decorator
180
+ def batch(wrapped, instance, args, kwargs):
181
+ """Decorator for converting list of request dicts to dict of input batches.
182
+
183
+ Converts list of request dicts to dict of input batches.
184
+ It passes **kwargs to inference callable where each named input contains numpy array with batch of requests
185
+ received by Triton server.
186
+ We assume that each request has the same set of keys (you can use group_by_keys decorator before
187
+ using @batch decorator if your requests may have different set of keys).
188
+
189
+ Raises:
190
+ PyTritonValidationError: If the requests have different set of keys.
191
+ ValueError: If the output tensors have different than expected batch sizes. Expected batch size is
192
+ calculated as a sum of batch sizes of all requests.
193
+ """
194
+ telemetry_name = "pytriton-batch-decorator-span"
195
+
196
+ req_list = args[0]
197
+ input_names = req_list[0].keys()
198
+
199
+ for req_dict2 in req_list[1:]:
200
+ if input_names != req_dict2.keys():
201
+ raise PyTritonValidationError("Cannot batch requests with different set of inputs keys")
202
+
203
+ inputs = {}
204
+ for model_input in input_names:
205
+ concatenated_input_data = np.concatenate([req[model_input] for req in req_list])
206
+ inputs[model_input] = concatenated_input_data
207
+
208
+ args = args[1:]
209
+ new_kwargs = dict(kwargs)
210
+ new_kwargs.update(inputs)
211
+ spans = [start_span_from_span(request.span, telemetry_name) for request in req_list if request.span is not None]
212
+ try:
213
+ outputs = wrapped(*args, **new_kwargs)
214
+ finally:
215
+ for span in spans:
216
+ span.end()
217
+
218
+ def _split_result(_result):
219
+ outputs = convert_output(_result, wrapped, instance)
220
+ output_names = outputs.keys()
221
+
222
+ requests_total_batch_size = sum(get_inference_request_batch_size(req) for req in req_list)
223
+ not_matching_tensors_shapes = {
224
+ output_name: output_tensor.shape
225
+ for output_name, output_tensor in outputs.items()
226
+ if output_tensor.shape[0] != requests_total_batch_size
227
+ }
228
+ if not_matching_tensors_shapes:
229
+ raise ValueError(
230
+ f"Received output tensors with different batch sizes: {', '.join(': '.join(map(str, item)) for item in not_matching_tensors_shapes.items())}. "
231
+ f"Expected batch size: {requests_total_batch_size}. "
232
+ )
233
+
234
+ out_list = []
235
+ start_idx = 0
236
+ for request in req_list:
237
+ # get batch_size of first input for each request - assume that all inputs have same batch_size
238
+ request_batch_size = get_inference_request_batch_size(request)
239
+ req_output_dict = {}
240
+ for _output_ind, output_name in enumerate(output_names):
241
+ req_output = outputs[output_name][start_idx : start_idx + request_batch_size, ...]
242
+ req_output_dict[output_name] = req_output
243
+ out_list.append(req_output_dict)
244
+ start_idx += request_batch_size
245
+ return out_list
246
+
247
+ if inspect.isgenerator(outputs):
248
+ return (_split_result(_result) for _result in outputs)
249
+ else:
250
+ return _split_result(outputs)
251
+
252
+
253
+ def group_by_values(*keys, pad_fn: typing.Optional[typing.Callable[[InferenceRequests], InferenceRequests]] = None):
254
+ """Decorator for grouping requests by values of selected keys.
255
+
256
+ This function splits a batch into multiple sub-batches based on the specified keys values and
257
+ calls the decorated function with each sub-batch. This is particularly useful when working with models
258
+ that require dynamic parameters sent by the user.
259
+
260
+ For example, given an input of the form:
261
+
262
+ ```python
263
+ {"sentences": [b"Sentence1", b"Sentence2", b"Sentence3"], "param1": [1, 1, 2], "param2": [1, 1, 1]}
264
+ ```
265
+
266
+ Using @group_by_values("param1", "param2") will split the batch into two sub-batches:
267
+
268
+ ```python
269
+ [
270
+ {"sentences": [b"Sentence1", b"Sentence2"], "param1": [1, 1], "param2": [1, 1]},
271
+ {"sentences": [b"Sentence3"], "param1": [2], "param2": [1]}
272
+ ]
273
+ ```
274
+
275
+ This decorator should be used after the @batch decorator.
276
+
277
+ Example usage:
278
+ ```python
279
+ @batch
280
+ @group_by_values("param1", "param2")
281
+ def infer_fun(**inputs):
282
+ ...
283
+ return outputs
284
+ ```
285
+
286
+ Args:
287
+ *keys: List of keys to group by.
288
+ pad_fn: Optional function to pad the batch to the same size before merging again to a single batch.
289
+
290
+ Returns:
291
+ The decorator function.
292
+ """
293
+
294
+ def value_to_key(value):
295
+ if isinstance(value, np.ndarray):
296
+ if value.dtype == np.object_ or value.dtype.type == np.bytes_:
297
+ return _serialize_byte_tensor(value)
298
+ else:
299
+ return value.tobytes()
300
+ return value
301
+
302
+ def _get_sort_key_for_sample(_request, _sample_idx: int):
303
+ return tuple(value_to_key(_request[_key][_sample_idx]) for _key in keys)
304
+
305
+ def _group_request(_request: InferenceRequest, _batch_size: int):
306
+ idx_inputs = [(sample_idx, _get_sort_key_for_sample(_request, sample_idx)) for sample_idx in range(_batch_size)]
307
+ idx_inputs.sort(key=operator.itemgetter(1))
308
+ for _, group in itertools.groupby(idx_inputs, key=operator.itemgetter(1)):
309
+ _samples_idxes, _ = zip(*group)
310
+ grouped_request = {input_name: value[_samples_idxes, ...] for input_name, value in _request.items()}
311
+ yield _samples_idxes, grouped_request
312
+
313
+ @wrapt.decorator
314
+ def _wrapper(wrapped, instance, args, kwargs):
315
+ wrappers_stack = [
316
+ callable_with_wrapper.wrapper
317
+ for callable_with_wrapper in _get_wrapt_stack(wrapped)
318
+ if callable_with_wrapper.wrapper is not None
319
+ ]
320
+ if batch in wrappers_stack:
321
+ raise PyTritonRuntimeError("The @group_by_values decorator must be used after the @batch decorator.")
322
+
323
+ request = {k: v for k, v in kwargs.items() if k not in _SPECIAL_KEYS}
324
+ other_kwargs = {k: v for k, v in kwargs.items() if k in _SPECIAL_KEYS}
325
+
326
+ batch_size = get_inference_request_batch_size(request)
327
+ sample_indices_with_interim_result = []
328
+ for sample_indices, _grouped_sub_request in _group_request(request, batch_size):
329
+ interim_result = wrapped(*args, **_grouped_sub_request, **other_kwargs)
330
+ sample_indices_with_interim_result.append((sample_indices, interim_result))
331
+
332
+ if pad_fn is not None:
333
+ indices, results = tuple(map(tuple, zip(*sample_indices_with_interim_result)))
334
+ results = pad_fn(results)
335
+ sample_indices_with_interim_result = tuple(zip(indices, results))
336
+
337
+ _, first_result_data = sample_indices_with_interim_result[0]
338
+ result = {
339
+ output_name: np.zeros((batch_size,) + data.shape[1:], dtype=data.dtype)
340
+ for output_name, data in first_result_data.items()
341
+ }
342
+ for indices, results in sample_indices_with_interim_result:
343
+ for output_name, data in results.items():
344
+ result[output_name][indices, ...] = data
345
+
346
+ return result
347
+
348
+ return _wrapper
349
+
350
+
351
+ class ConstantPadder:
352
+ """Padder that pads the given batches with a constant value."""
353
+
354
+ def __init__(self, pad_value=0):
355
+ """Initialize the padder.
356
+
357
+ Args:
358
+ pad_value (int, optional): Padding value. Defaults to 0.
359
+ """
360
+ self.pad_value = pad_value
361
+
362
+ def __call__(self, batches_list: InferenceResults) -> InferenceResults:
363
+ """Pad the given batches with the specified value to pad size enabling further batching to single arrays.
364
+
365
+ Args:
366
+ batches_list (List[Dict[str, np.ndarray]]): List of batches to pad.
367
+
368
+ Returns:
369
+ List[Dict[str, np.ndarray]]: List of padded batches.
370
+
371
+ Raises:
372
+ PyTritonRuntimeError: If the input arrays for a given input name have different dtypes.
373
+ """
374
+
375
+ def _get_padded_shape(_batches: List[np.ndarray]) -> Tuple[int, ...]:
376
+ """Get the shape of the padded array without batch axis."""
377
+ return tuple(np.max([batch.shape[1:] for batch in _batches if batch is not None], axis=0))
378
+
379
+ def _get_padded_dtype(_batches: List[np.ndarray]) -> np.dtype:
380
+ dtypes = [batch.dtype for batch in _batches if batch is not None]
381
+ result_dtype = dtypes[0]
382
+
383
+ if not all(dtype.kind == result_dtype.kind for dtype in dtypes):
384
+ raise PyTritonRuntimeError("All input arrays for given input name must have the same dtype.")
385
+
386
+ # for bytes (encoded string) or unicode string need to obtain the max length
387
+ if result_dtype.kind in "SU":
388
+ order_and_kind = result_dtype.str[:2]
389
+ max_len = max([int(dtype.str[2:]) for dtype in dtypes])
390
+ result_dtype = f"{order_and_kind}{max_len}"
391
+ else:
392
+ if not all(dtype == result_dtype for dtype in dtypes):
393
+ raise PyTritonRuntimeError("All input arrays for given input name must have the same dtype.")
394
+
395
+ return np.dtype(result_dtype)
396
+
397
+ input_names = list(
398
+ collections.OrderedDict.fromkeys(input_name for batch in batches_list for input_name in batch.keys())
399
+ )
400
+ batches_by_name = {input_name: [batch.get(input_name) for batch in batches_list] for input_name in input_names}
401
+ for input_batches in batches_by_name.values():
402
+ result_shape, result_dtype = _get_padded_shape(input_batches), _get_padded_dtype(input_batches)
403
+ for batch_idx, batch in enumerate(input_batches):
404
+ if batch is not None:
405
+ input_batches[batch_idx] = np.pad(
406
+ batch,
407
+ [(0, 0)] + [(0, b - a) for a, b in zip(batch.shape[1:], result_shape)],
408
+ mode="constant",
409
+ constant_values=self.pad_value if result_dtype.kind not in ["S", "U", "O"] else b"",
410
+ ).astype(result_dtype)
411
+
412
+ return [
413
+ {name: batches[batch_idx] for name, batches in batches_by_name.items() if batches[batch_idx] is not None}
414
+ for batch_idx in range(len(batches_list))
415
+ ]
416
+
417
+
418
+ @wrapt.decorator
419
+ def group_by_keys(wrapped, instance, args, kwargs):
420
+ """Group by keys.
421
+
422
+ Decorator prepares groups of requests with the same set of keys and calls wrapped function
423
+ for each group separately (it is convenient to use this decorator before batching, because the batching decorator
424
+ requires consistent set of inputs as it stacks them into batches).
425
+ """
426
+ inputs = args[0]
427
+ idx_inputs = [(idx, tuple(sorted(input.keys())), input) for idx, input in enumerate(inputs)]
428
+ idx_inputs.sort(key=operator.itemgetter(1))
429
+ idx_groups_res = []
430
+ for _, group in itertools.groupby(idx_inputs, key=operator.itemgetter(1)):
431
+ idx, _key, sample_list = zip(*group)
432
+ args = (list(sample_list),) + args[1:]
433
+ out = wrapped(*args, **kwargs)
434
+ idx_groups_res.extend(zip(idx, out))
435
+
436
+ idx_groups_res.sort(key=operator.itemgetter(0))
437
+ res_flat = [r[1] for r in idx_groups_res]
438
+ return res_flat
439
+
440
+
441
+ def fill_optionals(**defaults):
442
+ """This decorator ensures that any missing inputs in requests are filled with default values specified by the user.
443
+
444
+ Default values should be NumPy arrays without batch axis.
445
+
446
+ If you plan to group requests ex. with
447
+ [@group_by_keys][pytriton.decorators.group_by_keys] or
448
+ [@group_by_vales][pytriton.decorators.group_by_values] decorators
449
+ provide default values for optional parameters at the beginning of decorators stack.
450
+ The other decorators can then group requests into bigger batches resulting in a better model performance.
451
+
452
+ Typical use:
453
+ ```python
454
+ @fill_optionals()
455
+ @group_by_keys()
456
+ @batch
457
+ def infer_fun(**inputs):
458
+ ...
459
+ return outputs
460
+ ```
461
+
462
+ Args:
463
+ defaults: keyword arguments containing default values for missing inputs
464
+
465
+
466
+ If you have default values for some optional parameter it is good idea to provide them at the very beginning,
467
+ so the other decorators (e.g. @group_by_keys) can make bigger consistent groups.
468
+ """
469
+
470
+ def _verify_defaults(model_config: TritonModelConfig):
471
+ inputs = {spec.name: spec for spec in model_config.inputs}
472
+ not_matching_default_names = sorted(set(defaults) - set(inputs))
473
+ if not_matching_default_names:
474
+ raise PyTritonBadParameterError(f"Could not found {', '.join(not_matching_default_names)} inputs")
475
+
476
+ non_numpy_items = {k: v for k, v in defaults.items() if not isinstance(v, np.ndarray)}
477
+ if non_numpy_items:
478
+ raise PyTritonBadParameterError(
479
+ f"Could not use {', '.join([f'{k}={v}' for k, v in non_numpy_items.items()])} defaults "
480
+ "as they are not NumPy arrays"
481
+ )
482
+
483
+ not_matching_dtypes = {k: (v.dtype, inputs[k].dtype) for k, v in defaults.items() if v.dtype != inputs[k].dtype}
484
+ if not_matching_dtypes:
485
+ non_matching_dtypes_str_list = [
486
+ f"{name}: dtype={have_dtype} expected_dtype={expected_dtype}"
487
+ for name, (have_dtype, expected_dtype) in not_matching_dtypes.items()
488
+ ]
489
+ raise PyTritonBadParameterError(
490
+ f"Could not use {', '.join(non_matching_dtypes_str_list)} "
491
+ f"defaults as they have different than input signature dtypes"
492
+ )
493
+
494
+ def _shape_match(_have_shape, _expected_shape):
495
+ return len(_have_shape) == len(_expected_shape) and all(
496
+ e == -1 or h == e for h, e in zip(_have_shape, _expected_shape)
497
+ )
498
+
499
+ not_matching_shapes = {
500
+ k: (v.shape, inputs[k].shape) for k, v in defaults.items() if not _shape_match(v.shape, inputs[k].shape)
501
+ }
502
+ if not_matching_shapes:
503
+ non_matching_shapes_str_list = [
504
+ f"{name}: shape={have_shape} expected_shape={expected_shape}"
505
+ for name, (have_shape, expected_shape) in not_matching_shapes.items()
506
+ ]
507
+ raise PyTritonBadParameterError(
508
+ f"Could not use {', '.join(non_matching_shapes_str_list)} "
509
+ f"defaults as they have different than input signature shapes"
510
+ )
511
+
512
+ @wrapt.decorator
513
+ def _wrapper(wrapped, instance, args, kwargs):
514
+ model_config = get_model_config(wrapped, instance)
515
+ _verify_defaults(model_config)
516
+ # verification if not after group wrappers is in group wrappers
517
+
518
+ (requests,) = args
519
+
520
+ model_supports_batching = model_config.batching
521
+ for request in requests:
522
+ batch_size = get_inference_request_batch_size(request) if model_supports_batching else None
523
+ for default_key, default_value in defaults.items():
524
+ if default_key in request:
525
+ continue
526
+
527
+ if model_supports_batching:
528
+ ones_reps = (1,) * default_value.ndim # repeat once default_value on each axis
529
+ axis_reps = (batch_size,) + ones_reps # ... except on batch axis. we repeat it batch_size times
530
+ default_value = np.tile(default_value, axis_reps)
531
+
532
+ request[default_key] = default_value
533
+ return wrapped(*args, **kwargs)
534
+
535
+ return _wrapper
536
+
537
+
538
+ @wrapt.decorator
539
+ def triton_context(wrapped, instance, args, kwargs):
540
+ """Adds triton context.
541
+
542
+ It gives you additional argument passed to the function in **kwargs called 'triton_context'.
543
+ You can read model config from it and in the future possibly have some interaction with triton.
544
+ """
545
+ kwargs[TRITON_CONTEXT_FIELD_NAME] = get_triton_context(wrapped, instance)
546
+ return wrapped(*args, **kwargs)
547
+
548
+
549
+ @wrapt.decorator
550
+ def pad_batch(wrapped, instance, args, kwargs):
551
+ """Add padding to the inputs batches.
552
+
553
+ Decorator appends last rows to the inputs multiple times to get desired batch size (preferred batch size or
554
+ max batch size from model config whatever is closer to current input size).
555
+ """
556
+ inputs = {k: v for k, v in kwargs.items() if k != "__triton_context__"}
557
+ first_input = next(iter(inputs.values()))
558
+ config = get_model_config(wrapped, instance)
559
+ batch_sizes = (
560
+ []
561
+ if (config.batcher is None or config.batcher.preferred_batch_size is None)
562
+ else sorted(config.batcher.preferred_batch_size)
563
+ )
564
+ batch_sizes.append(config.max_batch_size)
565
+ batch_size = batch_sizes[bisect_left(batch_sizes, first_input.shape[0])]
566
+
567
+ new_inputs = {
568
+ input_name: np.repeat(
569
+ input_array,
570
+ np.concatenate([
571
+ np.ones(input_array.shape[0] - 1),
572
+ np.array([batch_size - input_array.shape[0] + 1]),
573
+ ]).astype(np.int64),
574
+ axis=0,
575
+ )
576
+ for input_name, input_array in inputs.items()
577
+ }
578
+
579
+ kwargs.update(new_inputs)
580
+ return wrapped(*args, **kwargs)
581
+
582
+
583
+ _SPECIAL_KEYS = ["__triton_context__"]
584
+
585
+
586
+ def first_value(*keys: str, squeeze_single_values=True, strict: bool = True):
587
+ """This decorator overwrites selected inputs with first element of the given input.
588
+
589
+ It can be used in two ways:
590
+
591
+ 1. Wrapping a single request inference callable by chaining with @batch decorator:
592
+ ```python
593
+ @batch
594
+ @first_value("temperature")
595
+ def infer_fn(**inputs):
596
+ ...
597
+ return result
598
+ ```
599
+
600
+ 2. Wrapping a multiple requests inference callable:
601
+ ```python
602
+ @first_value("temperature")
603
+ def infer_fn(requests):
604
+ ...
605
+ return results
606
+ ```
607
+
608
+ By default, the decorator squeezes single value arrays to scalars.
609
+ This behavior can be disabled by setting the `squeeze_single_values` flag to False.
610
+
611
+ By default, the decorator checks the equality of the values on selected values.
612
+ This behavior can be disabled by setting the `strict` flag to False.
613
+
614
+ Wrapper can only be used with models that support batching.
615
+
616
+ Args:
617
+ keys: The input keys selected for conversion.
618
+ squeeze_single_values: squeeze single value ND array to scalar values. Defaults to True.
619
+ strict: enable checking if all values on single selected input of request are equal. Defaults to True.
620
+
621
+ Raises:
622
+ PyTritonRuntimeError: if not all values on a single selected input of the request are equal
623
+ and the strict flag is set to True. Additionally, if the decorator is used with a model that doesn't support batching,
624
+ PyTritonBadParameterError: if any of the keys passed to the decorator are not allowed.
625
+ """
626
+ if any(k in _SPECIAL_KEYS for k in keys):
627
+ not_allowed_keys = [key for key in keys if key in _SPECIAL_KEYS]
628
+ raise PyTritonBadParameterError(
629
+ f"The keys {', '.join(not_allowed_keys)} are not allowed as keys for @first_value wrapper. "
630
+ f"The set of not allowed keys are {', '.join(_SPECIAL_KEYS)}"
631
+ )
632
+
633
+ @wrapt.decorator
634
+ def wrapper(wrapped, instance, args, kwargs):
635
+ model_config = get_model_config(wrapped, instance)
636
+ if not model_config.batching:
637
+ raise PyTritonRuntimeError("The @first_value decorator can only be used with models that support batching.")
638
+
639
+ def _replace_inputs_with_first_value(_request):
640
+ for input_name in keys:
641
+ if input_name not in _request:
642
+ continue
643
+
644
+ values = _request[input_name]
645
+ if strict:
646
+ # do not set axis for arrays with strings (object) or models not supporting batching
647
+ axis_of_uniqueness = None if values.dtype == object else 0
648
+ unique_values = np.unique(values, axis=axis_of_uniqueness)
649
+ if len(unique_values) > 1:
650
+ raise PyTritonRuntimeError(
651
+ f"The values on the {input_name!r} input are not equal. "
652
+ "To proceed, either disable strict mode in @first_value wrapper "
653
+ "or ensure that the values always are consistent. "
654
+ f"The current values of {input_name!r} are {_request[input_name]!r}."
655
+ )
656
+
657
+ _first_value = values[0]
658
+ if (
659
+ squeeze_single_values
660
+ and not np.isscalar(_first_value)
661
+ and all(dim == 1 for dim in _first_value.shape)
662
+ ):
663
+ _dim_0_array = np.squeeze(_first_value)
664
+ _first_value = _dim_0_array[()] # obtain scalar from 0-dim array with numpy type
665
+
666
+ _request[input_name] = _first_value
667
+ return _request
668
+
669
+ inputs_names = set(kwargs) - set(_SPECIAL_KEYS)
670
+ if inputs_names:
671
+ kwargs = _replace_inputs_with_first_value(kwargs)
672
+ return wrapped(*args, **kwargs)
673
+ else:
674
+ requests, *other_args = args
675
+ requests = [_replace_inputs_with_first_value(request) for request in requests]
676
+ return wrapped(requests, *other_args, **kwargs)
677
+
678
+ return wrapper
stf/stf-api-alternative/pytriton/build/lib/pytriton/exceptions.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTriton exceptions definition."""
15
+
16
+
17
+ class PyTritonError(Exception):
18
+ """Generic PyTriton exception."""
19
+
20
+ def __init__(self, message: str):
21
+ """Initialize exception with message.
22
+
23
+ Args:
24
+ message: Error message
25
+ """
26
+ self._message = message
27
+
28
+ def __str__(self) -> str:
29
+ """Return exception as a string.
30
+
31
+ Returns:
32
+ Message content
33
+ """
34
+ return self._message
35
+
36
+ @property
37
+ def message(self):
38
+ """Get the exception message.
39
+
40
+ Returns:
41
+ The message associated with this exception, or None if no message.
42
+
43
+ """
44
+ return self._message
45
+
46
+
47
+ class PyTritonValidationError(PyTritonError):
48
+ """PyTriton configuration validation exception."""
49
+
50
+ pass
51
+
52
+
53
+ class PyTritonInvalidOperationError(PyTritonError):
54
+ """PyTriton invalid operation exception."""
55
+
56
+ pass
57
+
58
+
59
+ class PyTritonBadParameterError(PyTritonError):
60
+ """PyTriton invalid parameter exception."""
61
+
62
+ pass
63
+
64
+
65
+ class PyTritonModelConfigError(PyTritonError):
66
+ """PyTriton invalid model config exception."""
67
+
68
+ pass
69
+
70
+
71
+ class PyTritonUnrecoverableError(PyTritonError):
72
+ """Unrecoverable error occurred in inference callable, thus no further inferences possible."""
73
+
74
+ pass
75
+
76
+
77
+ class PyTritonRuntimeError(PyTritonError):
78
+ """Raised when an error is detected that doesn’t fall in any of the other categories."""
79
+
80
+ pass
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
15
+ from .common import DeviceKind, DynamicBatcher, QueuePolicy, TimeoutAction # noqa: F401
16
+ from .model_config import ModelConfig # noqa: F401
17
+ from .tensor import Tensor # noqa: F401
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/common.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Common structures for internal and external ModelConfig."""
15
+
16
+ import dataclasses
17
+ import enum
18
+ from typing import Dict, Optional
19
+
20
+
21
+ class DeviceKind(enum.Enum):
22
+ """Device kind for model deployment.
23
+
24
+ Args:
25
+ KIND_AUTO: Automatically select the device for model deployment.
26
+ KIND_CPU: Model is deployed on CPU.
27
+ KIND_GPU: Model is deployed on GPU.
28
+ """
29
+
30
+ KIND_AUTO = "KIND_AUTO"
31
+ KIND_CPU = "KIND_CPU"
32
+ KIND_GPU = "KIND_GPU"
33
+
34
+
35
+ class TimeoutAction(enum.Enum):
36
+ """Timeout action definition for timeout_action QueuePolicy field.
37
+
38
+ Args:
39
+ REJECT: Reject the request and return error message accordingly.
40
+ DELAY: Delay the request until all other requests at the same (or higher) priority levels
41
+ that have not reached their timeouts are processed.
42
+ """
43
+
44
+ REJECT = "REJECT"
45
+ DELAY = "DELAY"
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class QueuePolicy:
50
+ """Model queue policy configuration.
51
+
52
+ More in Triton Inference Server [documentation]
53
+ [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1037
54
+
55
+ Args:
56
+ timeout_action: The action applied to timed-out request.
57
+ default_timeout_microseconds: The default timeout for every request, in microseconds.
58
+ allow_timeout_override: Whether individual request can override the default timeout value.
59
+ max_queue_size: The maximum queue size for holding requests.
60
+ """
61
+
62
+ timeout_action: TimeoutAction = TimeoutAction.REJECT
63
+ default_timeout_microseconds: int = 0
64
+ allow_timeout_override: bool = False
65
+ max_queue_size: int = 0
66
+
67
+
68
+ @dataclasses.dataclass
69
+ class DynamicBatcher:
70
+ """Dynamic batcher configuration.
71
+
72
+ More in Triton Inference Server [documentation]
73
+ [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1104
74
+
75
+ Args:
76
+ max_queue_delay_microseconds: The maximum time, in microseconds, a request will be delayed in
77
+ the scheduling queue to wait for additional requests for batching.
78
+ preferred_batch_size: Preferred batch sizes for dynamic batching.
79
+ preserve_ordering : Should the dynamic batcher preserve the ordering of responses to
80
+ match the order of requests received by the scheduler.
81
+ priority_levels: The number of priority levels to be enabled for the model.
82
+ default_priority_level: The priority level used for requests that don't specify their priority.
83
+ default_queue_policy: The default queue policy used for requests.
84
+ priority_queue_policy: Specify the queue policy for the priority level.
85
+ """
86
+
87
+ max_queue_delay_microseconds: int = 0
88
+ preferred_batch_size: Optional[list] = None
89
+ preserve_ordering: bool = False
90
+ priority_levels: int = 0
91
+ default_priority_level: int = 0
92
+ default_queue_policy: Optional[QueuePolicy] = None
93
+ priority_queue_policy: Optional[Dict[int, QueuePolicy]] = None
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/generator.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Generator class for creating Triton model config.
15
+
16
+ The class consume the TritonModelConfig object as a constructor argument and produce the Triton model config in form of
17
+ dict or file.
18
+
19
+ Typical usage example:
20
+
21
+ model_config = TritonModelConfig(model_name="simple")
22
+ generator = ModelConfigGenerator(model_config)
23
+ generator.to_file("/path/to/config.pbtxt")
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import pathlib
29
+ from typing import Dict, Union
30
+
31
+ import numpy as np
32
+ from google.protobuf import json_format, text_format # pytype: disable=pyi-error
33
+
34
+ from pytriton.exceptions import PyTritonBadParameterError
35
+
36
+ from .triton_model_config import DynamicBatcher, TensorSpec, TritonModelConfig
37
+
38
+ try:
39
+ import tritonclient.grpc as grpc_client
40
+ from tritonclient import utils as client_utils # noqa: F401
41
+ except ImportError:
42
+ try:
43
+ import tritonclientutils as client_utils # noqa: F401
44
+ import tritongrpcclient as grpc_client
45
+ except ImportError:
46
+ client_utils = None
47
+ grpc_client = None
48
+
49
+ LOGGER = logging.getLogger(__name__)
50
+
51
+
52
+ class ModelConfigGenerator:
53
+ """Generate the protobuf config from ModelConfig object."""
54
+
55
+ def __init__(self, config: TritonModelConfig):
56
+ """Initialize generator.
57
+
58
+ Args:
59
+ config: model config object
60
+ """
61
+ self._config = config
62
+
63
+ def to_file(self, config_path: Union[str, pathlib.Path]) -> str:
64
+ """Serialize ModelConfig to prototxt and save to config_path directory.
65
+
66
+ Args:
67
+ config_path: path to configuration file
68
+
69
+ Returns:
70
+ A string with generated model configuration
71
+ """
72
+ from tritonclient.grpc import model_config_pb2 # pytype: disable=import-error
73
+
74
+ # https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
75
+ model_config = self.get_config()
76
+ LOGGER.debug(f"Generated Triton config:\n{json.dumps(model_config, indent=4)}")
77
+
78
+ config_payload = json_format.ParseDict(model_config, model_config_pb2.ModelConfig())
79
+ LOGGER.debug(f"Generated Triton config payload:\n{config_payload}")
80
+
81
+ config_path = pathlib.Path(config_path)
82
+ config_path.parent.mkdir(parents=True, exist_ok=True)
83
+
84
+ model_config_bytes = text_format.MessageToBytes(config_payload)
85
+
86
+ # WAR: triton requires max_batch_size = 0 to be explicit written
87
+ # while this is not stored in payload during MessageToBytes
88
+ if model_config["max_batch_size"] == 0:
89
+ model_config_bytes += b"max_batch_size: 0\n"
90
+
91
+ with config_path.open("wb") as cfg:
92
+ cfg.write(model_config_bytes)
93
+
94
+ LOGGER.debug(f"Generated config stored in {config_path}")
95
+
96
+ return config_payload
97
+
98
+ def get_config(self) -> Dict:
99
+ """Create a Triton model config from ModelConfig object.
100
+
101
+ Returns:
102
+ Dict with model configuration data
103
+ """
104
+ model_config = {"name": self._config.model_name, "backend": self._config.backend}
105
+ self._set_batching(model_config)
106
+ self._set_model_signature(model_config)
107
+ self._set_instance_group(model_config)
108
+ self._set_model_transaction_policy(model_config)
109
+ self._set_backend_parameters(model_config)
110
+ self._set_response_cache(model_config)
111
+ return model_config
112
+
113
+ def _set_batching(self, model_config: Dict) -> None:
114
+ """Configure batching for model deployment on Triton Inference Server.
115
+
116
+ Args:
117
+ model_config: Dict with model config for Triton Inference Server
118
+ """
119
+ if not self._config.batching:
120
+ model_config["max_batch_size"] = 0
121
+ LOGGER.debug("Batching for model is disabled. The `max_batch_size` field value set to 0.")
122
+ return
123
+ elif self._config.max_batch_size < 1:
124
+ raise PyTritonBadParameterError("The `max_batch_size` must be greater or equal to 1.")
125
+
126
+ model_config["max_batch_size"] = self._config.max_batch_size
127
+ if isinstance(self._config.batcher, DynamicBatcher):
128
+ dynamic_batching_config = {}
129
+ if self._config.batcher.max_queue_delay_microseconds > 0:
130
+ dynamic_batching_config["maxQueueDelayMicroseconds"] = int(
131
+ self._config.batcher.max_queue_delay_microseconds
132
+ )
133
+
134
+ if self._config.batcher.preferred_batch_size:
135
+ dynamic_batching_config["preferredBatchSize"] = [
136
+ int(bs) for bs in self._config.batcher.preferred_batch_size
137
+ ]
138
+
139
+ if self._config.batcher.preserve_ordering:
140
+ dynamic_batching_config["preserveOrdering"] = self._config.batcher.preserve_ordering
141
+
142
+ if self._config.batcher.priority_levels:
143
+ dynamic_batching_config["priorityLevels"] = self._config.batcher.priority_levels
144
+
145
+ if self._config.batcher.default_priority_level:
146
+ if self._config.batcher.default_priority_level > self._config.batcher.priority_levels:
147
+ raise PyTritonBadParameterError(
148
+ "The `default_priority_level` must be between 1 and " f"{self._config.batcher.priority_levels}."
149
+ )
150
+ dynamic_batching_config["defaultPriorityLevel"] = self._config.batcher.default_priority_level
151
+
152
+ if self._config.batcher.default_queue_policy:
153
+ priority_queue_policy_config = {
154
+ "timeoutAction": self._config.batcher.default_queue_policy.timeout_action.value,
155
+ "defaultTimeoutMicroseconds": int(
156
+ self._config.batcher.default_queue_policy.default_timeout_microseconds
157
+ ),
158
+ "allowTimeoutOverride": self._config.batcher.default_queue_policy.allow_timeout_override,
159
+ "maxQueueSize": int(self._config.batcher.default_queue_policy.max_queue_size),
160
+ }
161
+ dynamic_batching_config["defaultQueuePolicy"] = priority_queue_policy_config
162
+
163
+ if self._config.batcher.priority_queue_policy:
164
+ if not self._config.batcher.priority_levels:
165
+ raise PyTritonBadParameterError(
166
+ "Provide the `priority_levels` if you want to define `priority_queue_policy` "
167
+ "for Dynamic Batching."
168
+ )
169
+
170
+ priority_queue_policy_config = {}
171
+ for priority, queue_policy in self._config.batcher.priority_queue_policy.items():
172
+ if priority < 0 or priority > self._config.batcher.priority_levels:
173
+ raise PyTritonBadParameterError(
174
+ f"Invalid `priority`={priority} provided. The value must be between "
175
+ f"1 and {self._config.batcher.priority_levels}."
176
+ )
177
+
178
+ priority_queue_policy_config[priority] = {
179
+ "timeoutAction": queue_policy.timeout_action.value,
180
+ "defaultTimeoutMicroseconds": int(queue_policy.default_timeout_microseconds),
181
+ "allowTimeoutOverride": queue_policy.allow_timeout_override,
182
+ "maxQueueSize": int(queue_policy.max_queue_size),
183
+ }
184
+
185
+ dynamic_batching_config["priorityQueuePolicy"] = priority_queue_policy_config
186
+
187
+ model_config["dynamic_batching"] = dynamic_batching_config
188
+ else:
189
+ LOGGER.debug("Default batching used")
190
+
191
+ def _set_instance_group(self, model_config: Dict) -> None:
192
+ """Configure instance group for model deployment on Triton Inference Server.
193
+
194
+ Args:
195
+ model_config: Dict with model config for Triton Inference Server
196
+ """
197
+ instance_groups = []
198
+ for device_kind, count in self._config.instance_group.items():
199
+ instance_groups.append({
200
+ "count": count,
201
+ "kind": device_kind.value,
202
+ })
203
+
204
+ if instance_groups:
205
+ model_config["instance_group"] = instance_groups
206
+
207
+ def _set_model_transaction_policy(self, model_config: Dict) -> None:
208
+ """Configure model transaction policy for model deployment on Triton Inference Server.
209
+
210
+ Args:
211
+ model_config: Dict with model config for Triton Inference Server
212
+ """
213
+ if self._config.decoupled:
214
+ model_config["model_transaction_policy"] = {"decoupled": True}
215
+
216
+ def _set_backend_parameters(self, model_config: Dict) -> None:
217
+ """Configure backend parameters for model deployment on Triton Inference Server.
218
+
219
+ Args:
220
+ model_config: Dict with model config for Triton Inference Server
221
+ """
222
+ parameters = {}
223
+ for key, value in self._config.backend_parameters.items():
224
+ parameters[key] = {
225
+ "string_value": str(value),
226
+ }
227
+
228
+ if parameters:
229
+ model_config["parameters"] = parameters
230
+
231
+ def _set_model_signature(self, model_config: Dict) -> None:
232
+ """Configure model signature for model deployment on Triton Inference Server.
233
+
234
+ Args:
235
+ model_config: Dict with model config for Triton Inference Server
236
+
237
+ """
238
+
239
+ def _rewrite_io_spec(spec_: TensorSpec) -> Dict:
240
+ if spec_.dtype in [np.object_, object, bytes, np.bytes_]:
241
+ dtype = "TYPE_STRING"
242
+ else:
243
+ # pytype: disable=attribute-error
244
+ dtype = spec_.dtype().dtype
245
+ # pytype: enable=attribute-error
246
+ dtype = f"TYPE_{client_utils.np_to_triton_dtype(dtype)}"
247
+
248
+ dims = spec_.shape
249
+
250
+ item = {
251
+ "name": spec_.name,
252
+ "dims": list(dims),
253
+ "data_type": dtype,
254
+ }
255
+
256
+ if spec_.optional:
257
+ item["optional"] = True
258
+
259
+ return item
260
+
261
+ if self._config.inputs:
262
+ model_config["input"] = [_rewrite_io_spec(spec) for spec in self._config.inputs]
263
+
264
+ if self._config.outputs:
265
+ outputs = [_rewrite_io_spec(spec) for spec in self._config.outputs]
266
+ if outputs:
267
+ optional_outputs = [o for o in outputs if o.get("optional")]
268
+ if optional_outputs:
269
+ raise PyTritonBadParameterError(
270
+ "Optional flag for outputs is not supported. "
271
+ f"Outputs marked as optional: {', '.join([o['name'] for o in optional_outputs])}."
272
+ )
273
+ model_config["output"] = outputs
274
+
275
+ def _set_response_cache(self, model_config: Dict):
276
+ """Configure response cache for model.
277
+
278
+ Args:
279
+ model_config: Dictionary where configuration is attached.
280
+ """
281
+ if self._config.response_cache:
282
+ model_config["response_cache"] = {
283
+ "enable": self._config.response_cache.enable,
284
+ }
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/model_config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Model configurations.
15
+
16
+ Dataclasses with specialized deployment paths for models on Triton. The purpose of this module is to provide clear options
17
+ to configure models of given types.
18
+
19
+ The dataclasses are exposed in the user API.
20
+ """
21
+
22
+ import dataclasses
23
+
24
+ from pytriton.model_config import DynamicBatcher
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class ModelConfig:
29
+ """Additional model configuration for running model through Triton Inference Server.
30
+
31
+ Args:
32
+ batching: Flag to enable/disable batching for model.
33
+ max_batch_size: The maximal batch size that would be handled by model.
34
+ batcher: Configuration of Dynamic Batching for the model.
35
+ response_cache: Flag to enable/disable response cache for the model
36
+ decoupled: Flag to enable/disable decoupled from requests execution
37
+ """
38
+
39
+ batching: bool = True
40
+ max_batch_size: int = 4
41
+ batcher: DynamicBatcher = dataclasses.field(default_factory=DynamicBatcher)
42
+ response_cache: bool = False
43
+ decoupled: bool = False
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/parser.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ModelConfigParser class definition.
15
+
16
+ Provide functionality to parse the Triton model configuration stored in file or form of dictionary into the object of
17
+ class ModelConfig.
18
+
19
+ Examples of use:
20
+
21
+ # Parse from dict
22
+ model_config = ModelConfigParser.from_dict(model_config_dict)
23
+
24
+ # Parse from file
25
+ model_config = ModelConfigParser.from_file("/path/to/config.pbtxt")
26
+
27
+ """
28
+
29
+ import json
30
+ import logging
31
+ import pathlib
32
+ from typing import Dict
33
+
34
+ import numpy as np
35
+ from google.protobuf import json_format, text_format # pytype: disable=pyi-error
36
+
37
+ from pytriton.exceptions import PyTritonModelConfigError
38
+
39
+ from .common import QueuePolicy, TimeoutAction
40
+ from .triton_model_config import DeviceKind, DynamicBatcher, ResponseCache, TensorSpec, TritonModelConfig
41
+
42
+ try:
43
+ import tritonclient.grpc as grpc_client
44
+ from tritonclient import utils as client_utils # noqa: F401
45
+ except ImportError:
46
+ try:
47
+ import tritonclientutils as client_utils # noqa: F401
48
+ import tritongrpcclient as grpc_client
49
+ except ImportError:
50
+ client_utils = None
51
+ grpc_client = None
52
+
53
+ LOGGER = logging.getLogger(__name__)
54
+
55
+
56
+ class ModelConfigParser:
57
+ """Provide functionality to parse dictionary or file to ModelConfig object."""
58
+
59
+ @classmethod
60
+ def from_dict(cls, model_config_dict: Dict) -> TritonModelConfig:
61
+ """Create ModelConfig from configuration stored in dictionary.
62
+
63
+ Args:
64
+ model_config_dict: Dictionary with model config
65
+
66
+ Returns:
67
+ A ModelConfig object with data parsed from the dictionary
68
+ """
69
+ LOGGER.debug(f"Parsing Triton config model from dict: \n{json.dumps(model_config_dict, indent=4)}")
70
+
71
+ if model_config_dict.get("max_batch_size", 0) > 0:
72
+ batching = True
73
+ else:
74
+ batching = False
75
+
76
+ dynamic_batcher_config = model_config_dict.get("dynamic_batching")
77
+ if dynamic_batcher_config is not None:
78
+ batcher = cls._parse_dynamic_batching(dynamic_batcher_config)
79
+ else:
80
+ batcher = None
81
+
82
+ instance_group = {
83
+ DeviceKind(entry["kind"]): entry.get("count") for entry in model_config_dict.get("instance_group", [])
84
+ }
85
+
86
+ decoupled = model_config_dict.get("model_transaction_policy", {}).get("decoupled", False)
87
+
88
+ backend_parameters_config = model_config_dict.get("parameters", [])
89
+ if isinstance(backend_parameters_config, list):
90
+ # If the backend_parameters_config is a list of strings, use them as keys with empty values
91
+ LOGGER.debug(f"backend_parameters_config is a list of strings: {backend_parameters_config}")
92
+ backend_parameters = {name: "" for name in backend_parameters_config}
93
+ elif isinstance(backend_parameters_config, dict):
94
+ # If the backend_parameters_config is a dictionary, use the key and "string_value" fields as key-value pairs
95
+ LOGGER.debug(f"backend_parameters_config is a dictionary: {backend_parameters_config}")
96
+ backend_parameters = {
97
+ name: backend_parameters_config[name]["string_value"] for name in backend_parameters_config
98
+ }
99
+ else:
100
+ # Otherwise, raise an error
101
+ LOGGER.error(
102
+ f"Invalid type {type(backend_parameters_config)} for backend_parameters_config: {backend_parameters_config}"
103
+ )
104
+ raise TypeError(f"Invalid type for backend_parameters_config: {type(backend_parameters_config)}")
105
+
106
+ inputs = [
107
+ cls.rewrite_io_spec(item, "input", idx) for idx, item in enumerate(model_config_dict.get("input", []))
108
+ ] or None
109
+ outputs = [
110
+ cls.rewrite_io_spec(item, "output", idx) for idx, item in enumerate(model_config_dict.get("output", []))
111
+ ] or None
112
+
113
+ response_cache_config = model_config_dict.get("response_cache")
114
+ if response_cache_config:
115
+ response_cache = cls._parse_response_cache(response_cache_config)
116
+ else:
117
+ response_cache = None
118
+
119
+ return TritonModelConfig(
120
+ model_name=model_config_dict["name"],
121
+ batching=batching,
122
+ max_batch_size=model_config_dict.get("max_batch_size", 0),
123
+ batcher=batcher,
124
+ inputs=inputs,
125
+ outputs=outputs,
126
+ instance_group=instance_group,
127
+ decoupled=decoupled,
128
+ backend_parameters=backend_parameters,
129
+ response_cache=response_cache,
130
+ )
131
+
132
+ @classmethod
133
+ def from_file(cls, *, config_path: pathlib.Path) -> TritonModelConfig:
134
+ """Create ModelConfig from configuration stored in file.
135
+
136
+ Args:
137
+ config_path: location of file with model config
138
+
139
+ Returns:
140
+ A ModelConfig object with data parsed from the file
141
+ """
142
+ from tritonclient.grpc import model_config_pb2 # pytype: disable=import-error
143
+
144
+ LOGGER.debug(f"Parsing Triton config model config_path={config_path}")
145
+
146
+ with config_path.open("r") as config_file:
147
+ payload = config_file.read()
148
+ model_config_proto = text_format.Parse(payload, model_config_pb2.ModelConfig())
149
+
150
+ model_config_dict = json_format.MessageToDict(model_config_proto, preserving_proto_field_name=True)
151
+ return ModelConfigParser.from_dict(model_config_dict=model_config_dict)
152
+
153
+ @classmethod
154
+ def rewrite_io_spec(cls, item: Dict, io_type: str, idx: int) -> TensorSpec:
155
+ """Rewrite the IO Spec provided in form of dictionary to TensorSpec.
156
+
157
+ Args:
158
+ item: IO data for input
159
+ io_type: Type of the IO (input or output)
160
+ idx: Index of IO
161
+
162
+ Returns:
163
+ TensorSpec with input or output data
164
+ """
165
+ name = item.get("name")
166
+ if not name:
167
+ raise PyTritonModelConfigError(f"Name for {io_type} at index {idx} not provided.")
168
+
169
+ data_type = item.get("data_type")
170
+ if not data_type:
171
+ raise PyTritonModelConfigError(f"Data type for {io_type} with name `{name}` not defined.")
172
+
173
+ data_type_val = data_type.split("_")
174
+ if len(data_type_val) != 2:
175
+ raise PyTritonModelConfigError(
176
+ f"Invalid data type `{data_type}` for {io_type} with name `{name}` not defined. "
177
+ "The expected name is TYPE_{type}."
178
+ )
179
+
180
+ data_type = data_type_val[1]
181
+ if data_type == "STRING":
182
+ dtype = np.bytes_
183
+ else:
184
+ dtype = client_utils.triton_to_np_dtype(data_type)
185
+ if dtype is None:
186
+ raise PyTritonModelConfigError(f"Unsupported data type `{data_type}` for {io_type} with name `{name}`")
187
+
188
+ dtype = np.dtype("bool") if dtype is bool else dtype
189
+
190
+ dims = item.get("dims", [])
191
+ if not dims:
192
+ raise PyTritonModelConfigError(f"Dimension for {io_type} with name `{name}` not defined.")
193
+
194
+ shape = tuple(int(s) for s in dims)
195
+
196
+ optional = item.get("optional", False)
197
+ return TensorSpec(name=item["name"], shape=shape, dtype=dtype, optional=optional)
198
+
199
+ @classmethod
200
+ def _parse_dynamic_batching(cls, dynamic_batching_config: Dict) -> DynamicBatcher:
201
+ """Parse config to create DynamicBatcher object.
202
+
203
+ Args:
204
+ dynamic_batching_config: Configuration of dynamic batcher from config
205
+
206
+ Returns:
207
+ DynamicBatcher object with configuration
208
+ """
209
+ default_queue_policy = None
210
+ default_queue_policy_config = dynamic_batching_config.get("default_queue_policy")
211
+ if default_queue_policy_config:
212
+ default_queue_policy = QueuePolicy(
213
+ timeout_action=TimeoutAction(
214
+ default_queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value)
215
+ ),
216
+ default_timeout_microseconds=int(default_queue_policy_config.get("default_timeout_microseconds", 0)),
217
+ allow_timeout_override=bool(default_queue_policy_config.get("allow_timeout_override", False)),
218
+ max_queue_size=int(default_queue_policy_config.get("max_queue_size", 0)),
219
+ )
220
+
221
+ priority_queue_policy = None
222
+ priority_queue_policy_config = dynamic_batching_config.get("priority_queue_policy")
223
+ if priority_queue_policy_config:
224
+ priority_queue_policy = {}
225
+ for priority, queue_policy_config in priority_queue_policy_config.items():
226
+ queue_policy = QueuePolicy(
227
+ timeout_action=TimeoutAction(queue_policy_config.get("timeout_action", TimeoutAction.REJECT.value)),
228
+ default_timeout_microseconds=int(queue_policy_config.get("default_timeout_microseconds", 0)),
229
+ allow_timeout_override=bool(queue_policy_config.get("allow_timeout_override", False)),
230
+ max_queue_size=int(queue_policy_config.get("max_queue_size", 0)),
231
+ )
232
+ priority_queue_policy[int(priority)] = queue_policy
233
+
234
+ batcher = DynamicBatcher(
235
+ preferred_batch_size=dynamic_batching_config.get("preferred_batch_size"),
236
+ max_queue_delay_microseconds=int(dynamic_batching_config.get("max_queue_delay_microseconds", 0)),
237
+ preserve_ordering=bool(dynamic_batching_config.get("preserve_ordering", False)),
238
+ priority_levels=int(dynamic_batching_config.get("priority_levels", 0)),
239
+ default_priority_level=int(dynamic_batching_config.get("default_priority_level", 0)),
240
+ default_queue_policy=default_queue_policy,
241
+ priority_queue_policy=priority_queue_policy,
242
+ )
243
+ return batcher
244
+
245
+ @classmethod
246
+ def _parse_response_cache(cls, response_cache_config: Dict) -> ResponseCache:
247
+ """Parse config for response cache.
248
+
249
+ Args:
250
+ response_cache_config: response cache configuration
251
+
252
+ Returns:
253
+ ResponseCache object with configuration
254
+ """
255
+ response_cache = ResponseCache(
256
+ enable=bool(response_cache_config["enable"]),
257
+ )
258
+ return response_cache
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/tensor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tensor object definition.
15
+
16
+ Describe the model input or output.
17
+
18
+ Examples of use:
19
+
20
+ # Minimal constructors
21
+ tensor = Tensor(dtype=np.bytes_, shape=(-1,))
22
+ tensor = Tensor(dtype=np.float32, shape=(-1,))
23
+
24
+ # Type definition from existing object
25
+ a = np.array([1, 2, 3, 4])
26
+ tensor = Tensor(dtype=a.dtype, shape=(-1,))
27
+
28
+ # Custom name
29
+ tensor = Tensor(name="data", dtype=np.float32, shape=(16,))
30
+ """
31
+
32
+ import dataclasses
33
+ from typing import Optional, Type, Union
34
+
35
+ import numpy as np
36
+
37
+
38
+ @dataclasses.dataclass(frozen=True)
39
+ class Tensor:
40
+ """Model input and output definition for Triton deployment.
41
+
42
+ Args:
43
+ shape: Shape of the input/output tensor.
44
+ dtype: Data type of the input/output tensor.
45
+ name: Name of the input/output of model.
46
+ optional: Flag to mark if input is optional.
47
+ """
48
+
49
+ shape: tuple
50
+ dtype: Union[np.dtype, Type[np.dtype], Type[object]]
51
+ name: Optional[str] = None
52
+ optional: Optional[bool] = False
53
+
54
+ def __post_init__(self):
55
+ """Override object values on post init or field override."""
56
+ if isinstance(self.dtype, np.dtype):
57
+ object.__setattr__(self, "dtype", self.dtype.type) # pytype: disable=attribute-error
stf/stf-api-alternative/pytriton/build/lib/pytriton/model_config/triton_model_config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ModelConfig related objects."""
15
+
16
+ import dataclasses
17
+ from typing import Dict, Optional, Sequence, Type, Union
18
+
19
+ import numpy as np
20
+
21
+ from .common import DeviceKind, DynamicBatcher
22
+
23
+
24
+ @dataclasses.dataclass
25
+ class ResponseCache:
26
+ """Model response cache configuration.
27
+
28
+ More in Triton Inference Server [documentation]
29
+ [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto#L1765
30
+ """
31
+
32
+ enable: bool
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class TensorSpec:
37
+ """Stores specification of single tensor. This includes name, shape and dtype."""
38
+
39
+ name: str
40
+ shape: tuple
41
+ dtype: Union[Type[np.dtype], Type[object]]
42
+ optional: Optional[bool] = False
43
+
44
+
45
+ @dataclasses.dataclass
46
+ class TritonModelConfig:
47
+ """Triton Model Config dataclass for simplification and specialization of protobuf config generation.
48
+
49
+ More in Triton Inference Server [documentation]
50
+ [documentation]: https://github.com/triton-inference-server/common/blob/main/protobuf/model_config.proto
51
+ """
52
+
53
+ model_name: str
54
+ model_version: int = 1
55
+ max_batch_size: int = 4
56
+ batching: bool = True
57
+ batcher: Optional[DynamicBatcher] = None
58
+ instance_group: Dict[DeviceKind, Optional[int]] = dataclasses.field(default_factory=lambda: {})
59
+ decoupled: bool = False
60
+ backend_parameters: Dict[str, str] = dataclasses.field(default_factory=lambda: {})
61
+ inputs: Optional[Sequence[TensorSpec]] = None
62
+ outputs: Optional[Sequence[TensorSpec]] = None
63
+ response_cache: Optional[ResponseCache] = None
64
+
65
+ @property
66
+ def backend(self) -> str:
67
+ """Return backend parameter."""
68
+ return "python"
stf/stf-api-alternative/pytriton/build/lib/pytriton/models/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
stf/stf-api-alternative/pytriton/build/lib/pytriton/models/manager.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ModelManager class.
15
+
16
+ The ModelManager is responsible for maintaining the models that has to be server on Triton Inference Server.
17
+
18
+ Examples of use:
19
+ manager = ModelManager(model_repository)
20
+ manager.add_model(model)
21
+
22
+ manager.create_models()
23
+ """
24
+
25
+ import contextlib
26
+ import json
27
+ import logging
28
+ import pathlib
29
+ import socket
30
+ from typing import Dict, Iterable, Optional, Tuple
31
+
32
+ from tritonclient.grpc import InferenceServerException
33
+
34
+ from pytriton.client import ModelClient
35
+ from pytriton.client.utils import create_client_from_url, wait_for_server_ready
36
+ from pytriton.constants import CREATE_TRITON_CLIENT_TIMEOUT_S, DEFAULT_TRITON_STARTUP_TIMEOUT_S
37
+ from pytriton.exceptions import PyTritonInvalidOperationError
38
+ from pytriton.models.model import Model
39
+
40
+ LOGGER = logging.getLogger(__name__)
41
+
42
+
43
+ class ModelManager:
44
+ """ModelManager class for maintaining Triton models."""
45
+
46
+ def __init__(
47
+ self,
48
+ triton_url: str,
49
+ model_store_path: Optional[pathlib.Path] = None,
50
+ ):
51
+ """Create ModelManager object.
52
+
53
+ Args:
54
+ triton_url: Triton server URL
55
+ model_store_path: Path to local model store
56
+ """
57
+ self._triton_url = triton_url
58
+ self._models: Dict[Tuple[str, int], Model] = {}
59
+ self._model_store_path = model_store_path
60
+
61
+ @property
62
+ def models(self) -> Iterable[Model]:
63
+ """List models added to manage.
64
+
65
+ Returns:
66
+ List with models added to ModelManager.
67
+ """
68
+ return self._models.values()
69
+
70
+ def add_model(self, model: Model, load_model: bool = False) -> None:
71
+ """Add model to manage.
72
+
73
+ Args:
74
+ model: Model instance
75
+ load_model: If True, model will be loaded to Triton server.
76
+ """
77
+ key = self._format_key(model)
78
+ if key in self._models:
79
+ raise PyTritonInvalidOperationError("Cannot add model with the same name twice.")
80
+
81
+ LOGGER.debug(f"Adding {model.model_name} ({model.model_version}) to registry under {key}.")
82
+ self._models[key] = model
83
+
84
+ _is_model_store_local = self._model_store_path is not None
85
+ if _is_model_store_local:
86
+ model.generate_model(self._model_store_path)
87
+
88
+ if load_model:
89
+ self._load_model(model, _is_model_store_local)
90
+ model.setup()
91
+
92
+ def load_models(self) -> None:
93
+ """Load bound models to Triton server and setup loaded models."""
94
+ for model in self._models.values():
95
+ if not model.is_alive():
96
+ self._load_model(model)
97
+ model.setup()
98
+
99
+ def setup_models(self) -> None:
100
+ """Setup loaded models."""
101
+ for model in self._models.values():
102
+ if not model.is_alive():
103
+ model.setup()
104
+
105
+ def clean(self) -> None:
106
+ """Clean the model and internal registry."""
107
+ with contextlib.closing(
108
+ create_client_from_url(self._triton_url, network_timeout_s=CREATE_TRITON_CLIENT_TIMEOUT_S)
109
+ ) as client:
110
+ server_live = False
111
+ try:
112
+ server_live = client.is_server_live()
113
+ # TimeoutError and ConnectionRefusedError are derived from OSError so they are redundant here
114
+ # OSError is raised from gevent/_socketcommon.py:590 sometimes, when server is not ready
115
+ except (socket.timeout, OSError, InferenceServerException):
116
+ pass
117
+ except Exception as ex:
118
+ LOGGER.error(f"Unexpected exception during server live check: {ex}")
119
+ raise ex
120
+
121
+ for name, model in self._models.items():
122
+ LOGGER.debug(f"Clean model {name}.")
123
+ model.clean()
124
+ if server_live:
125
+ client.unload_model(model.model_name)
126
+
127
+ if server_live:
128
+ # after unload there is a short period of time when server is not ready
129
+ wait_for_server_ready(client, timeout_s=DEFAULT_TRITON_STARTUP_TIMEOUT_S)
130
+
131
+ self._models.clear()
132
+
133
+ def _format_key(self, model: Model) -> Tuple[str, int]:
134
+ key = (model.model_name.lower(), model.model_version)
135
+ return key
136
+
137
+ def _load_model(self, model: Model, local_model_store=False):
138
+ """Prepare model config and required files dict and load model to Triton server."""
139
+ LOGGER.debug(f"Creating model {model.model_name} with version {model.model_version}.")
140
+ config = None if local_model_store else json.dumps(model.get_model_config())
141
+ files = None if local_model_store else model.get_proxy_model_files()
142
+ with ModelClient(
143
+ url=self._triton_url, model_name=model.model_name, model_version=str(model.model_version)
144
+ ) as client:
145
+ client.wait_for_server(timeout_s=DEFAULT_TRITON_STARTUP_TIMEOUT_S)
146
+ client.load_model(config=config, files=files)
147
+ LOGGER.debug("Done.")
stf/stf-api-alternative/pytriton/build/lib/pytriton/models/model.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Model base class."""
15
+
16
+ import base64
17
+ import copy
18
+ import enum
19
+ import json
20
+ import logging
21
+ import os
22
+ import pathlib
23
+ import shutil
24
+ import threading
25
+ import typing
26
+ from typing import Callable, List, Optional, Sequence, Union
27
+
28
+ from pytriton.decorators import TritonContext
29
+ from pytriton.exceptions import PyTritonValidationError
30
+ from pytriton.model_config.generator import ModelConfigGenerator
31
+ from pytriton.model_config.model_config import ModelConfig
32
+ from pytriton.model_config.tensor import Tensor
33
+ from pytriton.model_config.triton_model_config import DeviceKind, ResponseCache, TensorSpec, TritonModelConfig
34
+ from pytriton.proxy.communication import get_config_from_handshake_server
35
+ from pytriton.proxy.data import Base64SerializerDeserializer, TensorStoreSerializerDeserializer
36
+ from pytriton.proxy.inference import InferenceHandler, InferenceHandlerEvent, RequestsResponsesConnector
37
+ from pytriton.proxy.validators import TritonResultsValidator
38
+ from pytriton.utils.workspace import Workspace
39
+
40
+ LOGGER = logging.getLogger(__name__)
41
+
42
+
43
+ class ModelEvent(enum.Enum):
44
+ """Represents model event."""
45
+
46
+ RUNTIME_TERMINATING = "runtime-terminating"
47
+ RUNTIME_TERMINATED = "runtime-terminated"
48
+
49
+
50
+ ModelEventsHandler = typing.Callable[["Model", ModelEvent, typing.Optional[typing.Any]], None]
51
+
52
+
53
+ def _inject_triton_context(triton_context: TritonContext, model_callable: Callable) -> Callable:
54
+ """Inject triton context into callable.
55
+
56
+ Args:
57
+ triton_context: Triton context
58
+ model_callable: Callable to inject triton context
59
+
60
+ Returns:
61
+ Callable with injected triton context
62
+ """
63
+ if hasattr(model_callable, "__self__"):
64
+ model_callable.__self__.__triton_context__ = triton_context
65
+ else:
66
+ model_callable.__triton_context__ = triton_context
67
+ return model_callable
68
+
69
+
70
+ class Model:
71
+ """Model definition."""
72
+
73
+ SCRIPT_FILES_TO_COPY = ["communication.py", "data.py", "model.py", "types.py", "telemetry.py"]
74
+
75
+ def __init__(
76
+ self,
77
+ model_name: str,
78
+ model_version: int,
79
+ inference_fn: Union[Callable, Sequence[Callable]],
80
+ inputs: Sequence[Tensor],
81
+ outputs: Sequence[Tensor],
82
+ config: ModelConfig,
83
+ workspace: Workspace,
84
+ triton_context: TritonContext,
85
+ strict: bool,
86
+ trace_config: Optional[List[str]] = None,
87
+ ):
88
+ """Create Python model with required data.
89
+
90
+ Args:
91
+ model_name: Model name
92
+ model_version: Model version
93
+ inference_fn: Inference handler (function or lambda)
94
+ inputs: Model inputs definition
95
+ outputs: Model outputs definition
96
+ config: model configuration parameters
97
+ workspace: workspace for storing artifacts
98
+ triton_context: Triton context
99
+ strict: Enable strict validation of model outputs
100
+ trace_config: List of trace config parameters
101
+
102
+ Raises:
103
+ PyTritonValidationError if one or more of provided values are incorrect.
104
+ """
105
+ self.triton_context = triton_context
106
+ self.model_name = model_name
107
+ self.model_version = model_version
108
+ self._inference_handlers_lock = threading.Lock()
109
+ self._inference_handlers = []
110
+ self._requests_respones_connectors = []
111
+ self._observers_lock = threading.Lock()
112
+ self._strict = strict
113
+ self._trace_config = trace_config
114
+
115
+ self.infer_functions = [inference_fn] if isinstance(inference_fn, Callable) else inference_fn
116
+ if not isinstance(self.infer_functions, (Sequence, Callable)):
117
+ raise PyTritonValidationError("inference_fn has to be either callable or sequence of callables")
118
+
119
+ self.inputs = inputs
120
+ self.outputs = outputs
121
+
122
+ if any(output.optional for output in self.outputs):
123
+ raise PyTritonValidationError("Output tensors cannot be optional.")
124
+
125
+ self.config = config
126
+ self._workspace = workspace
127
+ if os.environ.get("PYTRITON_NO_TENSORSTORE"):
128
+ self._serializer_deserializer = Base64SerializerDeserializer()
129
+ else:
130
+ self._serializer_deserializer = TensorStoreSerializerDeserializer()
131
+ self._triton_model_config: Optional[TritonModelConfig] = None
132
+ self._model_events_observers: typing.List[ModelEventsHandler] = []
133
+
134
+ def get_model_config(self) -> dict:
135
+ """Get model config.
136
+
137
+ Returns:
138
+ Dictionary with model config
139
+ """
140
+ triton_model_config = self._get_triton_model_config()
141
+ generator = ModelConfigGenerator(config=triton_model_config)
142
+ return generator.get_config()
143
+
144
+ def get_proxy_model_files(self) -> typing.Dict[str, bytes]:
145
+ """Get proxy model files.
146
+
147
+ Returns:
148
+ Dictionary with model files to be copied to Triton model store on server side:
149
+ key: file path in following format - 'file:{model_version}/{file_name}'
150
+ value: file content as bytes
151
+ """
152
+ proxy_model_files_dict = {}
153
+ proxy_path = pathlib.Path(__file__).parent.parent / "proxy"
154
+ for file_to_copy in self.SCRIPT_FILES_TO_COPY:
155
+ src_file_path = proxy_path / file_to_copy
156
+ with open(src_file_path, "rb") as f:
157
+ src_file = f.read()
158
+ proxy_model_files_dict[f"file:{self.model_version}/{file_to_copy}"] = src_file
159
+
160
+ return proxy_model_files_dict
161
+
162
+ def generate_model(self, model_repository: pathlib.Path) -> None:
163
+ """Generate model and its config in the model repository.
164
+
165
+ Args:
166
+ model_repository: Path to Triton model repository
167
+
168
+ Raises:
169
+ OSError: when model repository not exists
170
+ """
171
+ LOGGER.debug(
172
+ f"Generating model and config for {self.model_name} and {self.model_version} to {model_repository}"
173
+ )
174
+
175
+ model_catalog = model_repository / self.model_name
176
+
177
+ config_file_path = model_catalog / "config.pbtxt"
178
+ if config_file_path.exists():
179
+ LOGGER.warning(f"The config file {config_file_path} is going to be overridden.")
180
+
181
+ triton_model_config = self._get_triton_model_config()
182
+ generator = ModelConfigGenerator(config=triton_model_config)
183
+ generator.to_file(config_file_path)
184
+
185
+ model_version_catalog = model_catalog / str(self.model_version)
186
+ model_version_catalog.mkdir(exist_ok=True, parents=True)
187
+
188
+ proxy_path = pathlib.Path(__file__).parent.parent / "proxy"
189
+
190
+ for script_file in self.SCRIPT_FILES_TO_COPY:
191
+ src_file_path = proxy_path / script_file
192
+ dst_file_path = model_version_catalog / script_file
193
+ shutil.copy(src_file_path, dst_file_path)
194
+
195
+ def setup(self) -> None:
196
+ """Create deployments and bindings to Triton Inference Server."""
197
+ with self._inference_handlers_lock:
198
+ if not self._inference_handlers:
199
+ triton_model_config = self._get_triton_model_config()
200
+ workspace_path = pathlib.Path(triton_model_config.backend_parameters["workspace-path"])
201
+ validator = TritonResultsValidator(triton_model_config, self._strict)
202
+
203
+ inference_handler_config_path = workspace_path / f"{self.model_name}-config.sock"
204
+ inference_handler_config = get_config_from_handshake_server(inference_handler_config_path)
205
+
206
+ data_socket = pathlib.Path(inference_handler_config["data_socket"])
207
+ authkey = base64.decodebytes(inference_handler_config["authkey"].encode("ascii"))
208
+ self._serializer_deserializer.connect(data_socket.as_posix(), authkey)
209
+
210
+ for i, infer_function in enumerate(self.infer_functions):
211
+ self.triton_context.model_configs[infer_function] = copy.deepcopy(triton_model_config)
212
+ _inject_triton_context(self.triton_context, infer_function)
213
+
214
+ request_server_socket = workspace_path / f"{self.model_name}_0_{i}-server.sock"
215
+ request_server_socket = f"ipc://{request_server_socket.as_posix()}"
216
+
217
+ requests_respones_connector = RequestsResponsesConnector(
218
+ url=request_server_socket,
219
+ serializer_deserializer=self._serializer_deserializer,
220
+ )
221
+ requests_respones_connector.start()
222
+ self._requests_respones_connectors.append(requests_respones_connector)
223
+ inference_handler = InferenceHandler(
224
+ model_callable=infer_function,
225
+ requests_responses_connector=requests_respones_connector,
226
+ validator=validator,
227
+ name=f"inference_handler-{i}",
228
+ )
229
+ inference_handler.on_inference_handler_event(self._on_inference_handler_event)
230
+ inference_handler.start()
231
+ self._inference_handlers.append(inference_handler)
232
+
233
+ def clean(self) -> None:
234
+ """Post unload actions to perform on model."""
235
+ with self._observers_lock:
236
+ LOGGER.debug("Clearing model events observers")
237
+ self._model_events_observers.clear()
238
+ LOGGER.debug("Socket closed. Waiting for inference handler and communication threads to shut down")
239
+ with self._inference_handlers_lock:
240
+ for inference_handler in self._inference_handlers:
241
+ inference_handler.stop()
242
+ for inference_handler in self._inference_handlers:
243
+ inference_handler.join()
244
+ self._inference_handlers.clear()
245
+ for requests_responses_connector in self._requests_respones_connectors:
246
+ requests_responses_connector.close()
247
+ for requests_responses_connector in self._requests_respones_connectors:
248
+ requests_responses_connector.join()
249
+ self._requests_respones_connectors.clear()
250
+ self._serializer_deserializer.close()
251
+
252
+ def is_alive(self) -> bool:
253
+ """Validate if model is working on Triton.
254
+
255
+ If model is fully loaded by Triton, return True. Otherwise, perform a custom verification.
256
+
257
+ Returns:
258
+ True if model is working, False otherwise
259
+ """
260
+ with self._inference_handlers_lock:
261
+ return (
262
+ bool(self._inference_handlers)
263
+ and bool(self._requests_respones_connectors)
264
+ and all(inference_handler.is_alive() for inference_handler in self._inference_handlers)
265
+ and all(
266
+ requests_responses_connector.is_alive()
267
+ for requests_responses_connector in self._requests_respones_connectors
268
+ )
269
+ )
270
+
271
+ def _get_triton_model_config(self) -> TritonModelConfig:
272
+ """Generate ModelConfig from descriptor and custom arguments for Python model.
273
+
274
+ Returns:
275
+ ModelConfig object with configuration for Python model deployment
276
+ """
277
+ if not self._triton_model_config:
278
+ backend_parameters = {"workspace-path": self._workspace.path.as_posix()}
279
+ if self._trace_config:
280
+ backend_parameters["trace-config"] = base64.b64encode(json.dumps(self._trace_config).encode()).decode()
281
+ triton_model_config = TritonModelConfig(
282
+ model_name=self.model_name,
283
+ model_version=self.model_version,
284
+ batching=self.config.batching,
285
+ batcher=self.config.batcher,
286
+ max_batch_size=self.config.max_batch_size,
287
+ decoupled=self.config.decoupled,
288
+ backend_parameters=backend_parameters,
289
+ instance_group={DeviceKind.KIND_CPU: len(self.infer_functions)},
290
+ )
291
+ inputs = []
292
+ for idx, input_spec in enumerate(self.inputs, start=1):
293
+ input_name = input_spec.name if input_spec.name else f"INPUT_{idx}"
294
+ tensor = TensorSpec(
295
+ name=input_name, dtype=input_spec.dtype, shape=input_spec.shape, optional=input_spec.optional
296
+ )
297
+ inputs.append(tensor)
298
+
299
+ outputs = []
300
+ for idx, output_spec in enumerate(self.outputs, start=1):
301
+ output_name = output_spec.name if output_spec.name else f"OUTPUT_{idx}"
302
+ tensor = TensorSpec(name=output_name, dtype=output_spec.dtype, shape=output_spec.shape)
303
+ outputs.append(tensor)
304
+
305
+ triton_model_config.inputs = inputs
306
+ triton_model_config.outputs = outputs
307
+
308
+ if self.config.response_cache:
309
+ triton_model_config.response_cache = ResponseCache(enable=True)
310
+
311
+ self._triton_model_config = triton_model_config
312
+
313
+ return self._triton_model_config
314
+
315
+ def on_model_event(self, model_event_handle_fn: ModelEventsHandler):
316
+ """Register ModelEventsHandler callable.
317
+
318
+ Args:
319
+ model_event_handle_fn: function to be called when model events arises
320
+ """
321
+ with self._observers_lock:
322
+ self._model_events_observers.append(model_event_handle_fn)
323
+
324
+ def _notify_model_events_observers(self, event: ModelEvent, context: typing.Any):
325
+ with self._observers_lock:
326
+ for model_event_handle_fn in self._model_events_observers:
327
+ model_event_handle_fn(self, event, context)
328
+
329
+ def _on_inference_handler_event(
330
+ self, proxy_backend: InferenceHandler, event: InferenceHandlerEvent, context: typing.Optional[typing.Any] = None
331
+ ):
332
+ if event in [InferenceHandlerEvent.CLOSING, InferenceHandlerEvent.UNRECOVERABLE_ERROR]:
333
+ self._notify_model_events_observers(ModelEvent.RUNTIME_TERMINATING, context)
334
+ elif event == InferenceHandlerEvent.CLOSED:
335
+ self._notify_model_events_observers(ModelEvent.RUNTIME_TERMINATED, context)
stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # noqa: D104
stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/communication.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Module handling communication between RequestsServer and RequestsServerClients."""
15
+
16
+ import asyncio
17
+ import enum
18
+ import functools
19
+ import json
20
+ import logging
21
+ import pathlib
22
+ import socket
23
+ import threading
24
+ import time
25
+ import traceback
26
+ import typing
27
+ import uuid
28
+ from concurrent.futures import Future as ConcurrentFuture
29
+
30
+ import zmq # pytype: disable=import-error
31
+ import zmq.asyncio # pytype: disable=import-error
32
+
33
+ LOGGER = logging.getLogger(__name__)
34
+ SERVER_LOGGER = LOGGER.getChild("server")
35
+ CLIENT_LOGGER = LOGGER.getChild("client")
36
+
37
+ _STARTUP_TIMEOUT_S = 1.0
38
+
39
+
40
+ class PyTritonResponseFlags(enum.IntFlag):
41
+ """Response flags for PyTritonInferenceHandler."""
42
+
43
+ EOS = enum.auto() # End Of Stream
44
+ ERROR = enum.auto()
45
+
46
+
47
+ class _RequestsServerState(enum.Enum):
48
+ STOPPED = enum.auto()
49
+ STARTING = enum.auto()
50
+ STARTED = enum.auto()
51
+ STOPPING = enum.auto()
52
+
53
+
54
+ def _set_current_task_name(name: str):
55
+ current_task = asyncio.current_task()
56
+ if current_task is not None:
57
+ current_task.set_name(name)
58
+
59
+
60
+ _RequestScope = typing.Dict[str, typing.Any]
61
+ _HandleRequestsCoro = typing.Callable[[_RequestScope, bytes, zmq.asyncio.Socket], typing.Awaitable[typing.Any]]
62
+ HandleResponsesCoro = typing.Callable[[_RequestScope, asyncio.Queue, ConcurrentFuture], typing.Awaitable[typing.Any]]
63
+
64
+
65
+ class RequestsServer:
66
+ """Class for serving available inference requests and passing inference responses."""
67
+
68
+ def __init__(self, url: str, handle_responses_fn: HandleResponsesCoro):
69
+ """Initialize RequestsServer.
70
+
71
+ Args:
72
+ url: url to bind socket
73
+ handle_responses_fn: couroutine handling responses from InferenceHandler
74
+ """
75
+ self._url = url
76
+ self._handle_responses_fn = handle_responses_fn
77
+ self._state = _RequestsServerState.STOPPED
78
+ self._state_condition = threading.Condition()
79
+ self._shutdown_event = asyncio.Event() # TODO: is it still required having condition?
80
+ self._server_loop = None
81
+
82
+ # requests_id -> results asyncio.Queue map
83
+ self._responses_queues: typing.Dict[bytes, asyncio.Queue] = {}
84
+ self._handle_responses_tasks: typing.Dict[bytes, asyncio.Task] = {}
85
+
86
+ def run(self):
87
+ """Run RequestsServer.
88
+
89
+ It stops when handle_messages coroutine finishes.
90
+
91
+ Raises:
92
+ RuntimeError: if RequestsServer is already running
93
+ """
94
+ with self._state_condition:
95
+ if self._state != _RequestsServerState.STOPPED:
96
+ raise RuntimeError(f"Cannot run {type(self).__name__} as it is already running")
97
+
98
+ self._state = _RequestsServerState.STARTING
99
+ self._state_condition.notify_all()
100
+
101
+ assert len(self._responses_queues) == 0
102
+ assert len(self._handle_responses_tasks) == 0
103
+
104
+ asyncio.run(self.handle_messages())
105
+
106
+ @property
107
+ def server_loop(self) -> typing.Optional[asyncio.AbstractEventLoop]:
108
+ """Get asyncio loop for RequestsServer.
109
+
110
+ Returns:
111
+ asyncio.AbstractEventLoop: asyncio loop for RequestsServer or None if server is not started yet
112
+ """
113
+ return self._server_loop
114
+
115
+ def wait_till_running(self):
116
+ """Wait till RequestsServer is running.
117
+
118
+ Raises:
119
+ RuntimeError: if RequestsServer is shutting down or not launched yet
120
+ """
121
+ with self._state_condition:
122
+ if self._state == _RequestsServerState.STARTING:
123
+ self._state_condition.wait_for(
124
+ lambda: self._state == _RequestsServerState.STARTED, timeout=_STARTUP_TIMEOUT_S
125
+ )
126
+ elif self._state == _RequestsServerState.STOPPED:
127
+ raise RuntimeError("Cannot push requests before RequestsServer is started")
128
+ elif self._state == _RequestsServerState.STOPPING:
129
+ raise RuntimeError(f"Cannot push requests while {type(self).__name__} is shutting down")
130
+
131
+ async def handle_messages(self):
132
+ """Coroutine for handling messages from InferenceHandler."""
133
+ self._server_loop = asyncio.get_running_loop()
134
+ try:
135
+ SERVER_LOGGER.debug(f"Binding socket to url='{self._url}'")
136
+ self._zmq_context = zmq.asyncio.Context()
137
+ self._socket = self._zmq_context.socket(zmq.DEALER)
138
+ self._socket.bind(self._url)
139
+ except (TypeError, zmq.error.ZMQError) as e:
140
+ raise ValueError(
141
+ f"Error occurred during binding socket to url='{self._url}' (e: {e})." "RequestsServer will be closed."
142
+ ) from e
143
+
144
+ _set_current_task_name("handle_messages")
145
+
146
+ with self._state_condition:
147
+ if self._state != _RequestsServerState.STARTING:
148
+ self._state = _RequestsServerState.STOPPED
149
+ self._state_condition.notify_all()
150
+ raise RuntimeError(f"Cannot start {type(self).__name__} as it is not in STARTING state")
151
+
152
+ self._state = _RequestsServerState.STARTED
153
+ self._state_condition.notify_all()
154
+
155
+ def _all_responses_processed():
156
+ return not any([self._handle_responses_tasks, self._responses_queues])
157
+
158
+ try:
159
+ flag_check_interval_s = 1.0
160
+ # have to receive mssages untill all requestss to be processed, despite shutdown event is set
161
+ while not self._shutdown_event.is_set() or not _all_responses_processed():
162
+ requests_id = b"<unknown>"
163
+ try:
164
+ requests_id, flags, responses_payload = await asyncio.wait_for(
165
+ self._socket.recv_multipart(), flag_check_interval_s
166
+ )
167
+ flags = int.from_bytes(flags, byteorder="big")
168
+ responses_queue = self._responses_queues[requests_id]
169
+ responses_queue.put_nowait((flags, responses_payload)) # queue have no max_size
170
+ except asyncio.TimeoutError:
171
+ continue
172
+ except KeyError:
173
+ SERVER_LOGGER.warning(f"Received response for unknown requests {requests_id.hex()}. Ignoring it.")
174
+ except asyncio.CancelledError:
175
+ SERVER_LOGGER.info("Received CancelledError")
176
+ self._shutdown_event.set()
177
+ finally:
178
+ # Received all responses, close socket
179
+ SERVER_LOGGER.debug("Closing socket")
180
+ try:
181
+ if self._socket is not None:
182
+ self._socket.close(linger=0)
183
+ self._socket = None
184
+ except zmq.error.ZMQError as e:
185
+ SERVER_LOGGER.error(f"Error occurred during closing socket (e: {e}).")
186
+
187
+ try:
188
+ if self._zmq_context is not None:
189
+ self._zmq_context.term()
190
+ self._zmq_context = None
191
+ except zmq.error.ZMQError as e:
192
+ SERVER_LOGGER.error(f"Error occurred during closing zmq context (e: {e}).")
193
+
194
+ self._server_loop = None
195
+
196
+ with self._state_condition:
197
+ self._state = _RequestsServerState.STOPPED
198
+ self._state_condition.notify_all()
199
+
200
+ SERVER_LOGGER.debug("Socket for handle_messages task closed")
201
+ self._shutdown_event.clear()
202
+ SERVER_LOGGER.debug(f"Leaving handle_messages task from {type(self).__name__}")
203
+
204
+ def shutdown(self):
205
+ """Close RequestsServer.
206
+
207
+ Don't wait for handle_messages coroutine to finish.
208
+ """
209
+ SERVER_LOGGER.debug("Closing RequestsServer")
210
+ with self._state_condition:
211
+ self._state = _RequestsServerState.STOPPING
212
+ self._state_condition.notify_all()
213
+ self._shutdown_event.set()
214
+
215
+ async def send_requests(
216
+ self, requests_id: bytes, requests_payload: bytes, responses_future: ConcurrentFuture
217
+ ) -> asyncio.Task:
218
+ """Send requests to InferenceHandler.
219
+
220
+ Args:
221
+ requests_id: id of requests
222
+ requests_payload: payload of requests
223
+ responses_future: future for waiting in another thread
224
+
225
+ Returns:
226
+ asyncio.Task: task handling responses from InferenceHandler
227
+
228
+ Raises:
229
+ RuntimeError: if RequestsServer is shutting down or requests_id is already pending
230
+ """
231
+ if self._shutdown_event.is_set():
232
+ SERVER_LOGGER.debug(f"Cannot send requests while {type(self).__name__} is {self._state.name}")
233
+ raise RuntimeError(f"Cannot send requests while {type(self).__name__} is {self._state.name}")
234
+
235
+ if requests_id in self._responses_queues or requests_id in self._handle_responses_tasks:
236
+ SERVER_LOGGER.debug(f"Cannot send requests with id {requests_id.hex()} as such id is already pending")
237
+ raise RuntimeError(f"Cannot send requests with id {requests_id.hex()} as such id is already pending")
238
+
239
+ _set_current_task_name(f"send_requests-{requests_id.hex()}")
240
+
241
+ self._responses_queues[requests_id] = asyncio.Queue()
242
+ scope = {"requests_id": requests_id}
243
+ handle_responses_task = self._server_loop.create_task(
244
+ self._handle_responses(scope, self._responses_queues[requests_id], responses_future),
245
+ name=f"handle_responses-{requests_id.hex()}",
246
+ )
247
+ self._handle_responses_tasks[requests_id] = handle_responses_task
248
+
249
+ # FIXME: check if can not copy buffers; in case copy=False send_multipart returns MessageTracker
250
+ # https://pyzmq.readthedocs.io/en/latest/api/zmq.html#zmq.Socket.send_multipart
251
+ # consider send_pyobject|send_serialized (but it is not multipart)
252
+
253
+ # sending in same loop, thus thread as handle_messages
254
+ # send_multipart doesn't return anything, as it copies requests_payload
255
+ await self._socket.send_multipart([requests_id, requests_payload])
256
+
257
+ return handle_responses_task
258
+
259
+ async def _handle_responses(self, scope, responses_queue: asyncio.Queue, responses_future: ConcurrentFuture):
260
+ """Handle responses from InferenceHandler.
261
+
262
+ Args:
263
+ scope: scope for handling responses
264
+ responses_queue: queue with responses payload from InferenceHandler
265
+ responses_future: future for waiting in another thread
266
+ """
267
+ requests_id = scope["requests_id"]
268
+ try:
269
+ return await self._handle_responses_fn(scope, responses_queue, responses_future)
270
+ finally:
271
+ self._responses_queues.pop(requests_id)
272
+ self._handle_responses_tasks.pop(requests_id)
273
+
274
+
275
+ class RequestsServerClient:
276
+ """RequestsServer client for handling requests from RequestsServer and sending back responses."""
277
+
278
+ def __init__(self, url: str, handle_requests_fn: _HandleRequestsCoro, name: typing.Optional[str] = None):
279
+ """Initialize RequestsServerClient.
280
+
281
+ Args:
282
+ url: url to connect socket
283
+ handle_requests_fn: couroutine handling requests from InferenceHandler
284
+ name: name of RequestsServerClient
285
+ """
286
+ self._shutdown_event = asyncio.Event()
287
+ self._url = url
288
+ self._handle_requests_fn = handle_requests_fn
289
+ self._handle_requests_tasks: typing.Dict[bytes, asyncio.Task] = {}
290
+ self._handle_requests_tasks_condition = asyncio.Condition()
291
+ self._name = name or f"requests_server_client-{uuid.uuid4().hex[-4:]}"
292
+ self._loop = None
293
+
294
+ def run(self):
295
+ """Run RequestsServerClient.
296
+
297
+ It stops when handle_requests coroutine finishes.
298
+ """
299
+ asyncio.run(self.handle_requests())
300
+
301
+ def shutdown(self) -> None:
302
+ """Close RequestsServerClient.
303
+
304
+ Don't wait for handle_requests coroutine to finish.
305
+ """
306
+ CLIENT_LOGGER.debug(f"Closing {type(self).__name__} {self._name}")
307
+ self._shutdown_event.set()
308
+
309
+ async def handle_requests(self):
310
+ """Coroutine for handling requests from RequestsServer."""
311
+ name = self._name
312
+ _set_current_task_name(name)
313
+
314
+ zmq_context = None
315
+ socket = None
316
+ self._loop = asyncio.get_running_loop()
317
+ try:
318
+ CLIENT_LOGGER.debug(f"Connecting {name} to server listening on {self._url}")
319
+ zmq_context = zmq.asyncio.Context()
320
+ socket = zmq_context.socket(zmq.DEALER)
321
+ socket.connect(self._url)
322
+
323
+ send = functools.partial(self._send, socket)
324
+
325
+ flag_check_interval_s = 1.0
326
+ while True:
327
+ try:
328
+ requests_id, requests_payloads = await asyncio.wait_for(
329
+ socket.recv_multipart(), flag_check_interval_s
330
+ )
331
+ scope = {"requests_id": requests_id}
332
+ CLIENT_LOGGER.debug(f"{requests_id.hex()} received requests")
333
+ handle_requests_task = self._loop.create_task(self._handle_requests(scope, requests_payloads, send))
334
+ self._handle_requests_tasks[requests_id] = handle_requests_task
335
+ handle_requests_task.set_name(f"handle_requests-{requests_id.hex()}")
336
+ except asyncio.TimeoutError:
337
+ if self._shutdown_event.is_set():
338
+ break
339
+ continue
340
+
341
+ CLIENT_LOGGER.debug("Waiting for handle_requests tasks to finish")
342
+ async with self._handle_requests_tasks_condition:
343
+ await self._handle_requests_tasks_condition.wait_for(lambda: len(self._handle_requests_tasks) == 0)
344
+ CLIENT_LOGGER.debug("All handle_requests tasks finished")
345
+
346
+ except zmq.error.ZMQError:
347
+ CLIENT_LOGGER.exception(
348
+ "Connection error occurred during reading requests. " f"{type(self).__name__} will be closed."
349
+ )
350
+ self._shutdown_event.set()
351
+ except Exception:
352
+ CLIENT_LOGGER.exception(f"Internal {type(self).__name__}. " f"{type(self).__name__} will be closed.")
353
+ self._shutdown_event.set()
354
+ finally:
355
+ try:
356
+ socket_close_timeout_ms = 0 # immediate close (drop not sent messages)
357
+ if socket is not None:
358
+ socket.close(linger=socket_close_timeout_ms)
359
+ except zmq.error.ZMQError as e:
360
+ CLIENT_LOGGER.error(f"Error occurred during closing socket (e: {e}).")
361
+
362
+ try:
363
+ if zmq_context is not None:
364
+ zmq_context.term()
365
+ except zmq.error.ZMQError as e:
366
+ CLIENT_LOGGER.error(f"Error occurred during closing zmq context (e: {e}).")
367
+
368
+ CLIENT_LOGGER.debug(f"Socket for {name} closed")
369
+ self._shutdown_event.clear()
370
+ self._loop = None
371
+ CLIENT_LOGGER.debug(f"Leaving {name}")
372
+
373
+ @property
374
+ def name(self) -> str:
375
+ """Get name of RequestsServerClient.
376
+
377
+ Returns:
378
+ name of RequestsServerClient
379
+ """
380
+ return self._name
381
+
382
+ @property
383
+ def loop(self) -> asyncio.AbstractEventLoop:
384
+ """Get asyncio loop for RequestsServerClient.
385
+
386
+ Returns:
387
+ asyncio.AbstractEventLoop: asyncio loop for RequestsServerClient
388
+ """
389
+ return self._loop
390
+
391
+ async def _handle_requests(self, scope, requests_payload, send):
392
+ try:
393
+ await self._handle_requests_fn(scope, requests_payload, send)
394
+ # except PyTritonUnrecoverableError:
395
+ # error = traceback.format_exc()
396
+ # responses = InferenceHandlerResponses(error=error)
397
+ # CLIENT_LOGGER.error(
398
+ # "Unrecoverable error thrown during calling model callable. "
399
+ # "Shutting down Triton Inference Server. "
400
+ # f"{error}"
401
+ # )
402
+ # self.stopped = True
403
+ # self._notify_proxy_backend_observers(InferenceHandlerEvent.UNRECOVERABLE_ERROR, error)
404
+ # CLIENT_LOGGER.debug(f"Send response to proxy model for {model_name}.")
405
+ # send(responses.as_bytes())
406
+ except Exception:
407
+ error = traceback.format_exc()
408
+ flags = PyTritonResponseFlags.ERROR | PyTritonResponseFlags.EOS
409
+ await send(scope, flags, error.encode())
410
+ CLIENT_LOGGER.error(f"Error occurred during handling requests {scope['requests_id'].hex()}\n{error}")
411
+ finally:
412
+ async with self._handle_requests_tasks_condition:
413
+ self._handle_requests_tasks.pop(scope["requests_id"], None)
414
+ self._handle_requests_tasks_condition.notify()
415
+ CLIENT_LOGGER.debug(f"Finished handling requests {scope['requests_id'].hex()}")
416
+
417
+ async def _send(self, socket, scope, flags, requests_payload):
418
+ """Send requests to RequestsServer.
419
+
420
+ Args:
421
+ socket: socket for sending requests
422
+ scope: scope for sending requests
423
+ flags: flags for sending requests
424
+ requests_payload: payload of requests
425
+ """
426
+ flags = flags.to_bytes(1, "big")
427
+ await socket.send_multipart([scope["requests_id"], flags, requests_payload])
428
+
429
+
430
+ class HandshakeServer(threading.Thread):
431
+ """Handshake server for passing config."""
432
+
433
+ def __init__(self, socket_path: pathlib.Path, inference_handler_config) -> None:
434
+ """Initialize HandshakeServer.
435
+
436
+ Args:
437
+ socket_path: path to socket
438
+ inference_handler_config: config for InferenceHandler
439
+ """
440
+ super().__init__(daemon=True, name="handshake-server")
441
+ self._socket_path = socket_path
442
+ try:
443
+ self._config_payload = json.dumps(inference_handler_config).encode()
444
+ except TypeError:
445
+ raise ValueError(f"InferenceHandler config is not serializable: {inference_handler_config}") from None
446
+
447
+ self._server = None
448
+ self._error_from_thread = None
449
+
450
+ def start(self):
451
+ """Start HandshakeServer.
452
+
453
+ Raises:
454
+ RuntimeError: if HandshakeServer is already running or error occurred during starting
455
+ """
456
+ if self._server:
457
+ raise RuntimeError("HandshakeServer is already running")
458
+
459
+ super().start()
460
+ while self._server is None and not self._error_from_thread:
461
+ time.sleep(0.001)
462
+ if self._error_from_thread is not None:
463
+ raise self._error_from_thread
464
+
465
+ def run(self):
466
+ """Run HandshakeServer."""
467
+ asyncio.run(self._run())
468
+
469
+ async def _run(self):
470
+ try:
471
+ self._server = await asyncio.start_unix_server(self._handle_request, self._socket_path)
472
+ async with self._server:
473
+ try:
474
+ await self._server.serve_forever()
475
+ except asyncio.CancelledError:
476
+ pass
477
+ except Exception as e:
478
+ SERVER_LOGGER.error(f"Error occurred during running handshake server (e: {e})")
479
+ self._error_from_thread = e
480
+
481
+ def close(self):
482
+ """Close HandshakeServer."""
483
+ loop = self._server.get_loop()
484
+ loop_tasks = asyncio.all_tasks(loop=loop)
485
+ for task in loop_tasks:
486
+ loop.call_soon_threadsafe(task.cancel)
487
+
488
+ self.join()
489
+ SERVER_LOGGER.debug("Closed handshake server")
490
+
491
+ async def _handle_request(self, reader, writer):
492
+ peername = writer.get_extra_info("peername")
493
+ try:
494
+ request_name = await asyncio.wait_for(reader.readuntil(b"\n"), timeout=1.0)
495
+
496
+ if request_name == b"get_config\n":
497
+ writer.write(len(self._config_payload).to_bytes(4, "big"))
498
+ writer.write(self._config_payload)
499
+ await writer.drain()
500
+ else:
501
+ SERVER_LOGGER.warning(f"Unknown request {request_name} from {peername}")
502
+
503
+ except asyncio.TimeoutError:
504
+ SERVER_LOGGER.debug(f"Timeout occurred during handling request from {peername}")
505
+ except Exception as e:
506
+ SERVER_LOGGER.error(f"Error occurred during handling request from {peername} (e: {e})")
507
+ finally:
508
+ writer.close()
509
+ await writer.wait_closed()
510
+
511
+
512
+ def get_config_from_handshake_server(socket_path: pathlib.Path, timeout_s: float = 1.0) -> dict:
513
+ """Get config from handshake server.
514
+
515
+ Args:
516
+ socket_path: path to socket
517
+ timeout_s: timeout for waiting for the response
518
+
519
+ Returns:
520
+ config from handshake server
521
+
522
+ Raises:
523
+ TimeoutError: if timeout occurred while waiting for the response
524
+ ValueError: if invalid JSON response from the server
525
+ """
526
+ should_stop_before_s = time.time() + timeout_s
527
+ sock = None
528
+ try:
529
+ LOGGER.debug(f"Waiting for config file {socket_path}")
530
+ while not socket_path.exists() and time.time() < should_stop_before_s:
531
+ time.sleep(0.001)
532
+
533
+ if not socket_path.exists():
534
+ raise TimeoutError(f"Timeout occurred while waiting for config file {socket_path}")
535
+
536
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
537
+ sock.settimeout(max(0.0, should_stop_before_s - time.time()))
538
+ sock.connect(socket_path.as_posix())
539
+ sock.sendall(b"get_config\n")
540
+
541
+ sock.settimeout(max(0.0, should_stop_before_s - time.time()))
542
+ payload_size = sock.recv(4)
543
+ payload_size = int.from_bytes(payload_size, "big")
544
+
545
+ sock.settimeout(max(0.0, should_stop_before_s - time.time()))
546
+ config_payload = sock.recv(payload_size)
547
+ config = json.loads(config_payload)
548
+ return config
549
+ except socket.timeout as e:
550
+ raise TimeoutError(f"Timeout occurred while waiting for config file {socket_path}") from e
551
+ except json.JSONDecodeError as e:
552
+ raise ValueError("Invalid JSON response from the server.") from e
553
+ finally:
554
+ if sock is not None:
555
+ sock.close()
stf/stf-api-alternative/pytriton/build/lib/pytriton/proxy/data.py ADDED
@@ -0,0 +1,1133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Communication utility module.
15
+
16
+ It is used for interaction between model and proxy_backend.
17
+ """
18
+
19
+ import abc
20
+ import atexit
21
+ import base64
22
+ import ctypes
23
+ import ctypes.util
24
+ import dataclasses
25
+ import fcntl
26
+ import gc
27
+ import json
28
+ import logging
29
+ import math
30
+ import multiprocessing.managers
31
+ import multiprocessing.popen_spawn_posix
32
+ import multiprocessing.shared_memory
33
+ import os
34
+ import pathlib
35
+ import signal
36
+ import struct
37
+ import threading
38
+ import time
39
+ import uuid
40
+ import weakref
41
+ from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
42
+
43
+ import numpy as np
44
+
45
+ from .telemetry import get_span_dict, start_span_from_remote
46
+ from .types import Request, Requests, Response, Responses
47
+
48
+ LOGGER = logging.getLogger(__name__)
49
+
50
+ PROTOCOL_VERSION = "3"
51
+
52
+
53
+ # copy from
54
+ # https://github.com/triton-inference-server/python_backend/blob/main/src/resources/triton_python_backend_utils.py
55
+
56
+
57
+ def _serialize_byte_tensor(tensor) -> bytes:
58
+ """Serializes a bytes tensor into a flat numpy array of length prepended bytes.
59
+
60
+ The numpy array should use dtype of np.object_. For np.bytes_,
61
+ numpy will remove trailing zeros at the end of byte sequence and because
62
+ of this it should be avoided.
63
+
64
+ Args:
65
+ tensor: The bytes tensor to serialize.
66
+
67
+ Returns:
68
+ serialized array as bytes buffer.
69
+
70
+ Raises:
71
+ UnicodeEncodeErrors: raised when try to cast to string of non-bytes items fails
72
+ """
73
+ if tensor.size == 0:
74
+ return b""
75
+
76
+ # If the input is a tensor of string/bytes objects, then must flatten those
77
+ # into a 1-dimensional array containing the 4-byte byte size followed by the
78
+ # actual element bytes. All elements are concatenated together in "C" order.
79
+ assert (tensor.dtype == np.object_) or (tensor.dtype.type == np.bytes_)
80
+ flattened_ls = []
81
+ total_len = 0
82
+ for obj in np.nditer(tensor, flags=["refs_ok"], order="C"):
83
+ # If directly passing bytes to BYTES type,
84
+ # don't convert it to str as Python will encode the
85
+ # bytes which may distort the meaning
86
+ if tensor.dtype == np.object_ and not isinstance(obj.item(), bytes):
87
+ s = str(obj.item()).encode("utf-8")
88
+ else:
89
+ s = obj.item()
90
+ item_len = len(s)
91
+ flattened_ls.append(struct.pack("<I", item_len))
92
+ flattened_ls.append(s)
93
+ total_len += struct.calcsize("<I") + item_len
94
+ flattened_ls.insert(0, struct.pack("<I", total_len))
95
+ flattened = b"".join(flattened_ls)
96
+ return flattened
97
+
98
+
99
+ # copy from
100
+ # https://github.com/triton-inference-server/python_backend/blob/main/src/resources/triton_python_backend_utils.py
101
+ def _deserialize_bytes_tensor(encoded_tensor, dtype, order: Literal["C", "F"] = "C") -> np.ndarray:
102
+ """Deserializes an encoded bytes tensor into an numpy array of dtype of python objects.
103
+
104
+ Args:
105
+ encoded_tensor : The encoded bytes tensor where each element has its length in
106
+ first 4 bytes followed by the content
107
+ dtype: The dtype of the numpy array to deserialize to.
108
+ order: The order of the numpy array to deserialize to.
109
+
110
+ Returns:
111
+ The 1-D numpy array of type object containing the deserialized bytes in 'C' order.
112
+ """
113
+ strs = []
114
+ offset = 0
115
+ val_buf = encoded_tensor
116
+ val_len = struct.unpack_from("<I", val_buf, offset)[0] + 4
117
+ offset += 4
118
+ while offset < val_len:
119
+ item_length = struct.unpack_from("<I", val_buf, offset)[0]
120
+ offset += 4
121
+ item = struct.unpack_from(f"<{item_length}s", val_buf, offset)[0]
122
+ offset += item_length
123
+ strs.append(item)
124
+ return np.array(strs, dtype=dtype, order=order)
125
+
126
+
127
+ _MAX_DTYPE_DESCR = 16 # up to 16 chars in dtype descr; |S2147483647 (2^31-1) with margin
128
+ _PARTIAL_HEADER_FORMAT = f"<{_MAX_DTYPE_DESCR}scH"
129
+
130
+
131
+ def _pack_header(shape: Tuple[int, ...], dtype: np.dtype, order: Literal["C", "F"] = "C") -> bytes:
132
+ header_format = _PARTIAL_HEADER_FORMAT + "Q" * len(shape)
133
+ dtype_descr = np.lib.format.dtype_to_descr(dtype)
134
+ assert (
135
+ len(dtype_descr) <= _MAX_DTYPE_DESCR
136
+ ), f"dtype descr is too long; dtype_descr={dtype_descr} max={_MAX_DTYPE_DESCR}"
137
+ return struct.pack(header_format, dtype_descr.encode("utf-8"), order.encode("ascii"), len(shape), *shape)
138
+
139
+
140
+ def _unpack_header(header: bytes) -> Tuple[Tuple[int, ...], np.dtype, Literal["C", "F"]]:
141
+ shape_offset = struct.calcsize(_PARTIAL_HEADER_FORMAT)
142
+ dtype_descr, order, ndim = struct.unpack_from(_PARTIAL_HEADER_FORMAT, header, offset=0)
143
+ shape = struct.unpack_from("Q" * ndim, header, offset=shape_offset)
144
+ dtype = np.lib.format.descr_to_dtype(dtype_descr.decode("utf-8").rstrip("\x00"))
145
+ order = order.decode("ascii")
146
+ return shape, dtype, order
147
+
148
+
149
+ def serialize_numpy_with_struct_header(tensor: np.ndarray) -> List[Union[bytes, memoryview]]:
150
+ """Serialize numpy array to list of bytes and memoryviews.
151
+
152
+ Args:
153
+ tensor: numpy array to serialize
154
+
155
+ Returns:
156
+ List of data frames in form of bytes and memoryviews
157
+ """
158
+ if tensor.dtype.hasobject:
159
+ data = _serialize_byte_tensor(tensor.ravel())
160
+ order = "C" # as _serialize_byte_tensor returns C-ordered array
161
+ else:
162
+ if not tensor.data.contiguous:
163
+ tensor = np.ascontiguousarray(tensor)
164
+ data = tensor.data
165
+ order = "C" if tensor.flags.c_contiguous else "F"
166
+
167
+ header = _pack_header(tensor.shape, tensor.dtype, order)
168
+ frames = [header, data]
169
+ return frames
170
+
171
+
172
+ def deserialize_numpy_with_struct_header(frames: List[Union[bytes, memoryview]]) -> np.ndarray:
173
+ """Deserialize numpy array from list of bytes and memoryviews.
174
+
175
+ Args:
176
+ frames: List of data frames in form of bytes and memoryviews
177
+
178
+ Returns:
179
+ numpy array
180
+ """
181
+ header, data = frames
182
+ shape, dtype, order = _unpack_header(header)
183
+ if dtype.hasobject:
184
+ tensor = _deserialize_bytes_tensor(data, dtype).reshape(shape)
185
+ else:
186
+ tensor = np.ndarray(shape, dtype=dtype, buffer=data, order=order)
187
+ return tensor
188
+
189
+
190
+ def calc_serialized_size_of_numpy_with_struct_header(tensor: np.ndarray) -> List[int]:
191
+ """Calculate size of serialized numpy array.
192
+
193
+ Args:
194
+ tensor: numpy array to serialize
195
+
196
+ Returns:
197
+ List of sizes of data frames
198
+ """
199
+ header_size = struct.calcsize(_PARTIAL_HEADER_FORMAT) + struct.calcsize("Q") * len(tensor.shape)
200
+ if tensor.dtype.hasobject:
201
+ items_sizes = []
202
+ order = "C" if tensor.flags.c_contiguous else "F"
203
+ for obj in np.nditer(tensor, flags=["refs_ok"], order=order):
204
+ if tensor.dtype == np.object_ and not isinstance(obj.item(), bytes):
205
+ s = str(obj.item()).encode("utf-8")
206
+ else:
207
+ s = obj.item()
208
+ items_sizes.append(len(s))
209
+
210
+ # total_size + for size of each item + each item
211
+ data_size = struct.calcsize("<I") + struct.calcsize("<I") * len(items_sizes) + sum(items_sizes)
212
+ else:
213
+ data_size = tensor.nbytes
214
+
215
+ return [header_size, data_size]
216
+
217
+
218
+ @dataclasses.dataclass
219
+ class BlockDescriptor:
220
+ """Descriptor of block in shared memory."""
221
+
222
+ shm_name: str
223
+ offset: int
224
+ size: Optional[int] = None
225
+
226
+ def __post_init__(self):
227
+ """Initialize other attributes."""
228
+ self.id = f"{self.shm_name}:{self.offset}"
229
+
230
+ @classmethod
231
+ def from_id(cls, tensor_id: str):
232
+ """Create BlockDescriptor from dict."""
233
+ shm_name, offset = tensor_id.split(":")
234
+ return cls(shm_name, int(offset))
235
+
236
+
237
+ class _SharedMemorySegment:
238
+ def __init__(self, size):
239
+ self.shared_memory = multiprocessing.shared_memory.SharedMemory(create=True, size=size)
240
+ multiprocessing.util.debug(f"Created {self.shared_memory.name} of size {self.shared_memory.size}")
241
+ self.used_blocks: List[BlockDescriptor] = []
242
+ self.used_blocks_lock = threading.RLock()
243
+ self.free_blocks = [BlockDescriptor(self.shared_memory.name, offset=0, size=size)]
244
+ self.max_free_block_size = size
245
+
246
+ def _update_free_blocks(self):
247
+ total_size = self.shared_memory.size
248
+ free_blocks = []
249
+ offset = 0
250
+
251
+ with self.used_blocks_lock:
252
+ # find holes between used blocks
253
+ for used_block in self.used_blocks:
254
+ if used_block.offset > offset:
255
+ free_blocks.append(
256
+ BlockDescriptor(self.shared_memory.name, offset=offset, size=used_block.offset - offset)
257
+ )
258
+ offset = used_block.offset + used_block.size
259
+ # if tail is free
260
+ if offset < total_size:
261
+ free_blocks.append(BlockDescriptor(self.shared_memory.name, offset=offset, size=total_size - offset))
262
+
263
+ self.free_blocks = free_blocks
264
+ self.max_free_block_size = max(block.size for block in self.free_blocks) if self.free_blocks else 0
265
+
266
+ def __contains__(self, block_id: str) -> bool:
267
+ with self.used_blocks_lock:
268
+ return any(block_id == block.id for block in self.used_blocks) # pytype: disable=attribute-error
269
+
270
+ def __getitem__(self, block_id: str) -> BlockDescriptor:
271
+ with self.used_blocks_lock:
272
+ for block in self.used_blocks:
273
+ if block.id == block_id: # pytype: disable=attribute-error
274
+ return block
275
+ raise KeyError(f"Block with id {block_id} not found in segment {self.shared_memory.name}")
276
+
277
+ def allocate(self, offset, byte_size):
278
+ block = BlockDescriptor(self.shared_memory.name, offset=offset, size=byte_size)
279
+ with self.used_blocks_lock:
280
+ self.used_blocks.append(block)
281
+ self.used_blocks.sort(key=lambda block: block.offset)
282
+ self._update_free_blocks()
283
+ return block
284
+
285
+ def release(self, block: BlockDescriptor):
286
+ with self.used_blocks_lock:
287
+ self.used_blocks.remove(block)
288
+ self._update_free_blocks()
289
+
290
+
291
+ class _DataBlocksServer:
292
+ _instance = None
293
+ _cnt = 0
294
+ _minimal_segment_size = 4096 # 4KB
295
+
296
+ def __new__(cls):
297
+ if cls._instance is None:
298
+ cls._instance = super().__new__(cls)
299
+ return cls._instance
300
+
301
+ def __init__(self):
302
+ # WAR: for some reason, the __init__ is called on each create of proxy object
303
+ if self._cnt == 1:
304
+ return
305
+ self._cnt += 1
306
+ self._id = uuid.uuid4() # to verify that it is singleton across processes
307
+ self._segments = []
308
+ self._segments_lock = threading.RLock()
309
+ atexit.register(self.close)
310
+
311
+ def get_free_blocks(self, bytes_sizes: Sequence[int]) -> Sequence[str]:
312
+ tensors_ids = []
313
+ with self._segments_lock:
314
+ for byte_size in bytes_sizes:
315
+ for segment in self._segments:
316
+ if segment.max_free_block_size >= byte_size:
317
+ for free_block in segment.free_blocks:
318
+ if free_block.size >= byte_size:
319
+ block = self._allocate_block(segment, free_block.offset, byte_size)
320
+ tensors_ids.append(block.id) # pytype: disable=attribute-error
321
+ break
322
+ else:
323
+ continue # If no suitable block was found, try the next segment
324
+ break # If a suitable block was found, don't try any more segments
325
+ else: # If no suitable block was found in any segment
326
+ new_segment_size = int(
327
+ max(self._minimal_segment_size, math.pow(2, math.ceil(math.log2(byte_size))))
328
+ )
329
+ block = self._allocate_block(
330
+ self._create_new_segment(new_segment_size), offset=0, byte_size=byte_size
331
+ )
332
+ tensors_ids.append(block.id) # pytype: disable=attribute-error
333
+ return tensors_ids
334
+
335
+ def release_block(self, block_id: str):
336
+ with self._segments_lock:
337
+ for segment in self._segments:
338
+ try:
339
+ block = segment[block_id]
340
+ segment.release(block)
341
+ return
342
+ except KeyError:
343
+ pass
344
+ raise KeyError(f"Block with id {block_id} not found in server")
345
+
346
+ def _allocate_block(self, segment: _SharedMemorySegment, offset: int, byte_size: int) -> BlockDescriptor:
347
+ return segment.allocate(offset, byte_size)
348
+
349
+ def _create_new_segment(self, segment_size):
350
+ segment = _SharedMemorySegment(segment_size)
351
+ self._segments.append(segment)
352
+ return segment
353
+
354
+ def get_debug_status(self):
355
+ return {
356
+ "server_id": str(self._id),
357
+ "host_pid": multiprocessing.current_process().pid,
358
+ "segments": [
359
+ {
360
+ "shared_memory": segment.shared_memory.name,
361
+ "used_blocks": [str(block) for block in segment.used_blocks],
362
+ }
363
+ for segment in self._segments
364
+ ],
365
+ }
366
+
367
+ def close(self):
368
+ multiprocessing.util.debug(f"Closing server {self._id}")
369
+ with self._segments_lock:
370
+ while self._segments:
371
+ segment = self._segments.pop()
372
+ multiprocessing.util.debug(f"Closing and delete segment {segment.shared_memory.name}")
373
+ segment.shared_memory.close()
374
+ segment.shared_memory.unlink()
375
+
376
+
377
+ class BlocksStoreManager(multiprocessing.managers.BaseManager):
378
+ """Remote block store for storing and retrieving numpy arrays in/from shared memory."""
379
+
380
+ @classmethod
381
+ def _run_server(cls, registry, address, authkey, serializer, writer, initializer=None, initargs=()):
382
+ PR_SET_PDEATHSIG = 1 # noqa
383
+ libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
384
+ libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM) # terminate process when parent **thread** dies
385
+
386
+ if bool(os.environ.get("PYTRITON_VIZTRACER")):
387
+ from viztracer import VizTracer # type: ignore # pytype: disable=import-error
388
+
389
+ cls._tracer = VizTracer(log_async=True, log_gc=True, tracer_entries=10000000, pid_suffix=True)
390
+ cls._tracer.register_exit()
391
+ cls._tracer.start()
392
+
393
+ super()._run_server(
394
+ registry, address, authkey, serializer, writer, initializer, initargs
395
+ ) # pytype: disable=attribute-error
396
+
397
+
398
+ class _DataBlocksServerProxy(multiprocessing.managers.BaseProxy):
399
+ def release_block(self, /, *args, **kwargs):
400
+ return self._callmethod("release_block", args, kwargs)
401
+
402
+ def get_free_blocks(self, /, *args, **kwargs):
403
+ return self._callmethod("get_free_blocks", args, kwargs)
404
+
405
+ def _get_debug_status(self, /, *args, **kwargs):
406
+ return self._callmethod("get_debug_status", args, kwargs)
407
+
408
+ def close(self, /, *args, **kwargs):
409
+ return self._callmethod("close", args, kwargs)
410
+
411
+
412
+ BlocksStoreManager.register("blocks", _DataBlocksServer, proxytype=_DataBlocksServerProxy)
413
+
414
+
415
+ class _FileLock:
416
+ _locks = {}
417
+
418
+ def __new__(cls, file_path):
419
+ if file_path not in cls._locks:
420
+ cls._locks[file_path] = super().__new__(cls)
421
+ return cls._locks[file_path]
422
+
423
+ def __init__(self, file_path):
424
+ if hasattr(self, "_file_path"):
425
+ return
426
+ self._file_path = pathlib.Path(file_path)
427
+ self._file_lock = None
428
+ self._lock = threading.RLock()
429
+ atexit.register(self._clean)
430
+
431
+ def __enter__(self):
432
+ self._file_lock = self._file_path.open("a")
433
+ fcntl.flock(self._file_lock.fileno(), fcntl.LOCK_EX)
434
+ self._lock.acquire()
435
+
436
+ def __exit__(self, exc_type, exc_value, traceback):
437
+ fcntl.flock(self._file_lock.fileno(), fcntl.LOCK_UN)
438
+ self._lock.release()
439
+
440
+ def _clean(self):
441
+ if self._file_lock is not None:
442
+ self._file_lock.close()
443
+ try:
444
+ self._file_path.unlink(missing_ok=True)
445
+ except OSError as e:
446
+ LOGGER.warning(f"Could not remove lock file {self._file_path}; {e}")
447
+
448
+
449
+ class _Popen(multiprocessing.popen_spawn_posix.Popen):
450
+ def _launch(self, process_obj):
451
+ # Modified version of multiprocessing.popen_spawn_posix.Popen._launch
452
+ import io
453
+ import os
454
+ from multiprocessing import context, resource_tracker, spawn, util
455
+
456
+ tracker_fd = resource_tracker.getfd()
457
+ self._fds.append(tracker_fd) # pytype: disable=attribute-error
458
+
459
+ # get prep_data + remove init_main_from* as they are not required for TensorStore process
460
+ prep_data = spawn.get_preparation_data(process_obj._name)
461
+ prep_data.pop("init_main_from_module", None)
462
+ prep_data.pop("init_main_from_path", None)
463
+
464
+ fp = io.BytesIO()
465
+ context.set_spawning_popen(self)
466
+ try:
467
+ context.reduction.dump(prep_data, fp) # pytype: disable=module-attr
468
+ context.reduction.dump(process_obj, fp) # pytype: disable=module-attr
469
+ finally:
470
+ context.set_spawning_popen(None)
471
+
472
+ parent_r = child_w = child_r = parent_w = None
473
+ try:
474
+ parent_r, child_w = os.pipe()
475
+ child_r, parent_w = os.pipe()
476
+ cmd = spawn.get_command_line(tracker_fd=tracker_fd, pipe_handle=child_r)
477
+ self._fds.extend([child_r, child_w]) # pytype: disable=attribute-error
478
+ self.pid = util.spawnv_passfds(
479
+ spawn.get_executable(),
480
+ cmd,
481
+ self._fds, # pytype: disable=attribute-error,wrong-arg-types
482
+ )
483
+ self.sentinel = parent_r
484
+ with open(parent_w, "wb", closefd=False) as f:
485
+ f.write(fp.getbuffer())
486
+ finally:
487
+ fds_to_close = []
488
+ for fd in (parent_r, parent_w):
489
+ if fd is not None:
490
+ fds_to_close.append(fd)
491
+ self.finalizer = util.Finalize(self, util.close_fds, fds_to_close) # pytype: disable=module-attr
492
+
493
+ for fd in (child_r, child_w):
494
+ if fd is not None:
495
+ os.close(fd)
496
+
497
+
498
+ class _SpawnProcess(multiprocessing.process.BaseProcess):
499
+ _start_method = "spawn"
500
+
501
+ @staticmethod
502
+ def _Popen(process_obj): # noqa N802
503
+ return _Popen(process_obj)
504
+
505
+
506
+ class _SpawnContext(multiprocessing.context.BaseContext):
507
+ _name = "spawn"
508
+ Process = _SpawnProcess
509
+
510
+
511
+ class TensorStore:
512
+ """Tensor store for storing and retrieving numpy arrays in/from shared memory."""
513
+
514
+ _SOCKET_EXISTANCE_CHECK_INTERVAL_S = 0.1
515
+ _instances = {}
516
+
517
+ def __new__(cls, *args, **kwargs):
518
+ """Create TensorStore object. If object with given address already exists, return it."""
519
+ if args:
520
+ address = args[0]
521
+ elif "address" in kwargs:
522
+ address = kwargs["address"]
523
+ else:
524
+ raise TypeError("TensorStore() missing 1 required positional argument: 'address'")
525
+
526
+ address = address.as_posix() if isinstance(address, pathlib.Path) else address
527
+
528
+ if address not in cls._instances:
529
+ cls._instances[address] = super().__new__(cls)
530
+
531
+ return cls._instances[address]
532
+
533
+ def __init__(self, address: Union[str, pathlib.Path], auth_key: Optional[bytes] = None):
534
+ """Initialize TensorStore object.
535
+
536
+ Args:
537
+ address: address of data store
538
+ auth_key: authentication key required to setup connection. If not provided, current process authkey will be used
539
+ """
540
+ if not hasattr(self, "_remote_blocks_store_manager"):
541
+ address = address.as_posix() if isinstance(address, pathlib.Path) else address
542
+ self._remote_blocks_store_manager = BlocksStoreManager(address, authkey=auth_key, ctx=_SpawnContext())
543
+ self._remote_blocks_store = None
544
+ self._manager_start_stop_filelock = _FileLock(f"{address}.lock")
545
+
546
+ # container for keeping map between tensor_id and numpy array weak ref
547
+ self._handled_blocks: Dict[str, weakref.ReferenceType] = {}
548
+ self._handled_blocks_lock = threading.RLock()
549
+
550
+ self._shm_segments: Dict[str, multiprocessing.shared_memory.SharedMemory] = {}
551
+ self._shm_segments_lock = threading.RLock()
552
+
553
+ self.serialize = serialize_numpy_with_struct_header
554
+ self.deserialize = deserialize_numpy_with_struct_header
555
+ self._calc_serialized_tensor_size = calc_serialized_size_of_numpy_with_struct_header
556
+
557
+ @property
558
+ def address(self) -> str:
559
+ """Return address of remote block store."""
560
+ return self._remote_blocks_store_manager.address
561
+
562
+ def start(self):
563
+ """Start remote block store."""
564
+ with self._manager_start_stop_filelock:
565
+ if self._remote_blocks_store is not None:
566
+ raise RuntimeError("Remote block store is already started/connected")
567
+
568
+ self._remote_blocks_store_manager.start()
569
+ self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error
570
+
571
+ address = pathlib.Path(self._remote_blocks_store_manager.address)
572
+ self._wait_for_address(address)
573
+ LOGGER.debug(
574
+ f"Started remote block store at {address} (pid={self._remote_blocks_store_manager._process.pid})" # pytype: disable=attribute-error
575
+ )
576
+
577
+ def connect(self, timeout_s: Optional[float] = None):
578
+ """Connect to remote block store."""
579
+ if self._remote_blocks_store is None:
580
+ address = pathlib.Path(self._remote_blocks_store_manager.address)
581
+
582
+ self._wait_for_address(address, timeout_s)
583
+ self._remote_blocks_store_manager.connect()
584
+ self._remote_blocks_store = self._remote_blocks_store_manager.blocks() # pytype: disable=attribute-error
585
+ LOGGER.debug(f"Connected to remote block store at {address})")
586
+ else:
587
+ LOGGER.debug(f"Already connectd to remote block store at {self.address}")
588
+
589
+ def _wait_for_address(self, address, timeout_s: Optional[float] = None):
590
+ should_stop_at = time.time() + timeout_s if timeout_s is not None else None
591
+ if timeout_s is not None and self._SOCKET_EXISTANCE_CHECK_INTERVAL_S > timeout_s:
592
+ socket_existance_check_interval = timeout_s
593
+ else:
594
+ socket_existance_check_interval = self._SOCKET_EXISTANCE_CHECK_INTERVAL_S
595
+
596
+ while not address.exists():
597
+ if should_stop_at is not None and time.time() >= should_stop_at:
598
+ raise TimeoutError(f"Timeout while waiting for {address} to be created")
599
+ time.sleep(socket_existance_check_interval)
600
+
601
+ def _calc_serialized_size(self, tensor: np.ndarray) -> int:
602
+ # frames payload sum + total size + frames sizes
603
+ # assume 2 frames: header with tensor description + data
604
+ return sum(self._calc_serialized_tensor_size(tensor)) + struct.calcsize("<I") + 2 * struct.calcsize("<I")
605
+
606
+ def put(self, tensors: Sequence[np.ndarray]) -> Sequence[str]:
607
+ """Append tensor to shared memory buffer.
608
+
609
+ Args:
610
+ tensors: numpy arrays to store
611
+
612
+ Returns:
613
+ List of ids of stored tensors
614
+ """
615
+ byte_size_of_frames_containers = [self._calc_serialized_size(tensor) for tensor in tensors]
616
+ tensors_ids = self._remote_blocks_store.get_free_blocks(byte_size_of_frames_containers)
617
+ blocks = [BlockDescriptor.from_id(tensor_id) for tensor_id in tensors_ids]
618
+
619
+ for tensor, block in zip(tensors, blocks):
620
+ with self._shm_segments_lock:
621
+ shm = self._shm_segments.get(block.shm_name)
622
+ if shm is None:
623
+ shm = multiprocessing.shared_memory.SharedMemory(block.shm_name, create=False)
624
+ self._shm_segments[block.shm_name] = shm
625
+
626
+ frames = self.serialize(tensor)
627
+ self._copy_frames(frames, shm, block.offset)
628
+
629
+ return tensors_ids
630
+
631
+ def get(self, tensor_id: str) -> np.ndarray:
632
+ """Get numpy array from tensor store.
633
+
634
+ Args:
635
+ tensor_id: id of of tenosr to get
636
+
637
+ Returns:
638
+ numpy array
639
+ """
640
+ tensor = None
641
+ # try to handle already handled tensor from weakref
642
+ with self._handled_blocks_lock:
643
+ tensor_ref = self._handled_blocks.get(tensor_id)
644
+ if tensor_ref is not None:
645
+ tensor = tensor_ref()
646
+
647
+ if tensor is None: # if tensor was not handled yet or weakref is already empty
648
+ block = BlockDescriptor.from_id(tensor_id)
649
+
650
+ # check if shm segment is already opened
651
+ with self._shm_segments_lock:
652
+ shm = self._shm_segments.get(block.shm_name)
653
+
654
+ # if not open it and put into cache
655
+ if shm is None:
656
+ shm = multiprocessing.shared_memory.SharedMemory(block.shm_name, create=False)
657
+ with self._shm_segments_lock:
658
+ shm = self._shm_segments.setdefault(block.shm_name, shm) # in meantime other thread could create it
659
+
660
+ frames = self._handle_frames(shm, block.offset)
661
+ tensor = self.deserialize(frames)
662
+
663
+ # store tensor in weakref to be able to release shared memory when tensor will be garbage collected
664
+ with self._handled_blocks_lock:
665
+ tensor_ref = self._handled_blocks.setdefault(tensor_id, weakref.ref(tensor))
666
+ tensor = tensor_ref()
667
+
668
+ return tensor # pytype: disable=bad-return-type
669
+
670
+ def release_block(self, tensor_id: str):
671
+ """Release shared memory block.
672
+
673
+ Args:
674
+ tensor_id: id of tensor to release
675
+ """
676
+ tensor_ref = None
677
+ with self._handled_blocks_lock:
678
+ tensor_ref = self._handled_blocks.pop(tensor_id, None)
679
+
680
+ try:
681
+ if tensor_ref is not None:
682
+ self._remote_blocks_store.release_block(tensor_id)
683
+ except OSError: # thrown when remote process is already closed
684
+ LOGGER.warning(
685
+ f"Failed to release block {tensor_id} on remote process at {self.address}. Probably remote process is already closed"
686
+ )
687
+
688
+ def _copy_frames(
689
+ self,
690
+ frames: List[Union[bytes, memoryview]],
691
+ shm: multiprocessing.shared_memory.SharedMemory,
692
+ offset: int,
693
+ ) -> int:
694
+ total_size = struct.calcsize("<I") # start after total_size; max 4GB for all frames
695
+ for frame in frames:
696
+ if isinstance(frame, bytes):
697
+ frame = memoryview(frame)
698
+
699
+ assert frame.contiguous, "Only contiguous arrays are supported"
700
+ struct.pack_into("<I", shm.buf, offset + total_size, frame.nbytes) # pytype: disable=wrong-arg-types
701
+ total_size += struct.calcsize("<I")
702
+ shm.buf[offset + total_size : offset + total_size + frame.nbytes] = frame.cast("B")
703
+
704
+ total_size += frame.nbytes
705
+
706
+ struct.pack_into("<I", shm.buf, offset, total_size) # pytype: disable=wrong-arg-types
707
+ return total_size
708
+
709
+ def _handle_frames(self, shm: multiprocessing.shared_memory.SharedMemory, block_offset: int) -> List[memoryview]:
710
+ frames = []
711
+ (total_size,) = struct.unpack_from("<I", shm.buf, block_offset) # pytype: disable=wrong-arg-types
712
+ offset = struct.calcsize("<I")
713
+ while offset < total_size:
714
+ (frame_size,) = struct.unpack_from("<I", shm.buf, block_offset + offset) # pytype: disable=wrong-arg-types
715
+ offset += struct.calcsize("<I")
716
+ frame = shm.buf[block_offset + offset : block_offset + offset + frame_size]
717
+ offset += frame_size
718
+ frames.append(frame)
719
+ return frames
720
+
721
+ def close(self):
722
+ """Free resources used by TensorStore object."""
723
+ from multiprocessing.resource_tracker import register, unregister
724
+
725
+ LOGGER.debug(f"TensorStore is being closed (is_started={self.is_started()})")
726
+
727
+ gc.collect()
728
+ with self._handled_blocks_lock:
729
+ tensors_ids = list(self._handled_blocks)
730
+ for tensor_id in tensors_ids:
731
+ self.release_block(tensor_id)
732
+
733
+ with self._shm_segments_lock:
734
+ while self._shm_segments:
735
+ _, shm = self._shm_segments.popitem()
736
+ LOGGER.debug(f"Closing shared memory {shm.name}")
737
+ try:
738
+ shm.close()
739
+ except Exception as e:
740
+ LOGGER.warning(f"Failed to close shared memory {shm.name}: {e}")
741
+ finally:
742
+ if not self.is_started():
743
+ register(shm._name, "shared_memory") # pytype: disable=attribute-error
744
+ unregister(shm._name, "shared_memory") # pytype: disable=attribute-error
745
+
746
+ if self.is_started():
747
+ if self._remote_blocks_store is not None:
748
+ LOGGER.debug(f"Releasing all resources on remote process at {self.address}")
749
+ try:
750
+ self._remote_blocks_store.close()
751
+ except FileNotFoundError: # thrown when remote process is already closed
752
+ pass
753
+ self._remote_blocks_store = None
754
+ LOGGER.debug(f"Shutting down side process of data store at {self.address}")
755
+ self._remote_blocks_store_manager.shutdown()
756
+ LOGGER.debug(f"TensorStore at {self.address} closed")
757
+
758
+ def is_started(self) -> bool:
759
+ """Check if remote block store was started by this instance.
760
+
761
+ Returns:
762
+ True if remote block store was started by this instance, False otherwise
763
+ """
764
+ return hasattr(self._remote_blocks_store_manager, "shutdown")
765
+
766
+
767
+ def get_debug_status(tensor_store: TensorStore) -> dict:
768
+ """Get debug status of remote block store.
769
+
770
+ Args:
771
+ tensor_store: TensorStore object
772
+
773
+ Returns:
774
+ Debug status of remote block store
775
+ """
776
+ if tensor_store._remote_blocks_store is None:
777
+ raise RuntimeError("Remote block store is not initialized")
778
+
779
+ return tensor_store._remote_blocks_store._get_debug_status()
780
+
781
+
782
+ class BaseRequestsResponsesSerializerDeserializer(abc.ABC):
783
+ """Base class for requests/responses serializer/deserializer."""
784
+
785
+ @abc.abstractmethod
786
+ def serialize_requests(self, requests: Requests) -> bytes:
787
+ """Serialize requests.
788
+
789
+ Args:
790
+ requests: list of requests to serialize
791
+
792
+ Returns:
793
+ Serialized requests
794
+ """
795
+ pass
796
+
797
+ @abc.abstractmethod
798
+ def deserialize_requests(self, requests_payload: bytes) -> Requests:
799
+ """Deserialize requests.
800
+
801
+ Args:
802
+ requests_payload: serialized requests
803
+
804
+ Returns:
805
+ List of deserialized requests
806
+ """
807
+ pass
808
+
809
+ @abc.abstractmethod
810
+ def free_requests_resources(self, requests_payload: bytes):
811
+ """Free resources used by requests."""
812
+ pass
813
+
814
+ @abc.abstractmethod
815
+ def serialize_responses(self, responses: Responses) -> bytes:
816
+ """Serialize responses.
817
+
818
+ Args:
819
+ responses: list of responses to serialize
820
+
821
+ Returns:
822
+ Serialized responses
823
+ """
824
+ pass
825
+
826
+ @abc.abstractmethod
827
+ def deserialize_responses(self, responses_payload: bytes) -> Responses:
828
+ """Deserialize responses.
829
+
830
+ Args:
831
+ responses_payload: serialized responses
832
+
833
+ Returns:
834
+ List of deserialized responses
835
+ """
836
+ pass
837
+
838
+ @abc.abstractmethod
839
+ def free_responses_resources(self, responses_payload: bytes):
840
+ """Free resources used by responses."""
841
+ pass
842
+
843
+
844
+ class Base64SerializerDeserializer(BaseRequestsResponsesSerializerDeserializer):
845
+ """Serializer/deserializer for requests/responses using base64 implementation."""
846
+
847
+ def serialize_requests(self, requests: Requests) -> bytes:
848
+ """Serialize requests.
849
+
850
+ Args:
851
+ requests: list of requests to serialize
852
+
853
+ Returns:
854
+ Serialized requests
855
+ """
856
+ serialized_requests = self._serialize_named_tensors_lists(requests)
857
+ requests_list = []
858
+ for request, serialized_request in zip(requests, serialized_requests):
859
+ serialized_request = {"data": serialized_request, "parameters": request.parameters}
860
+ if request.span is not None:
861
+ serialized_request["span"] = get_span_dict(request.span)
862
+ requests_list.append(serialized_request)
863
+
864
+ requests = {"requests": requests_list}
865
+ requests = json.dumps(requests).encode("utf-8")
866
+ return requests
867
+
868
+ def deserialize_requests(self, requests_payload: bytes) -> Requests:
869
+ """Deserialize requests.
870
+
871
+ Args:
872
+ requests_payload: serialized requests
873
+
874
+ Returns:
875
+ List of deserialized requests
876
+ """
877
+ requests = json.loads(requests_payload)
878
+ requests_data = [request["data"] for request in requests["requests"]]
879
+ requests_data = self._deserialized_named_tensors_lists(requests_data)
880
+
881
+ deserialized_requests = []
882
+ for request, request_data in zip(requests["requests"], requests_data):
883
+ kwargs = {"data": request_data, "parameters": request.get("parameters")}
884
+ # FIXME: move span creation above just after json.loads
885
+ if "span" in request:
886
+ span_dict = request["span"]
887
+ span = start_span_from_remote(span_dict, "proxy_inference_callable")
888
+ kwargs["span"] = span
889
+ request_wrapped = Request(**kwargs)
890
+ deserialized_requests.append(request_wrapped)
891
+
892
+ return deserialized_requests
893
+
894
+ def free_requests_resources(self, requests_payload: bytes):
895
+ """Free resources used by requests."""
896
+ pass
897
+
898
+ def serialize_responses(self, responses: Responses) -> bytes:
899
+ """Serialize responses.
900
+
901
+ Args:
902
+ responses: list of responses to serialize
903
+
904
+ Returns:
905
+ Serialized responses
906
+ """
907
+ responses = self._serialize_named_tensors_lists(responses)
908
+ responses = {"responses": [{"data": response} for response in responses]}
909
+ return json.dumps(responses).encode("utf-8")
910
+
911
+ def deserialize_responses(self, responses_payload: bytes) -> Responses:
912
+ """Deserialize responses.
913
+
914
+ Args:
915
+ responses_payload: serialized responses
916
+
917
+ Returns:
918
+ List of deserialized responses
919
+ """
920
+ if responses_payload:
921
+ responses = json.loads(responses_payload)
922
+ responses = [response["data"] for response in responses["responses"]]
923
+ responses = self._deserialized_named_tensors_lists(responses)
924
+ return [Response(data=response) for response in responses]
925
+ else:
926
+ return []
927
+
928
+ def free_responses_resources(self, responses_payload: bytes):
929
+ """Free resources used by responses."""
930
+ pass
931
+
932
+ def _serialize_named_tensors_lists(self, named_tensors_lists):
933
+ def _encode(_tensor):
934
+ frames = serialize_numpy_with_struct_header(_tensor)
935
+ return [base64.b64encode(frame).decode("utf-8") for frame in frames]
936
+
937
+ return [
938
+ {tensor_name: _encode(tensor) for tensor_name, tensor in tensors.items()} for tensors in named_tensors_lists
939
+ ]
940
+
941
+ def _deserialized_named_tensors_lists(self, named_tensors_lists):
942
+ def _decode(decoded_tensor):
943
+ frames = [base64.b64decode(frame.encode("utf-8")) for frame in decoded_tensor]
944
+ return deserialize_numpy_with_struct_header(frames)
945
+
946
+ return [
947
+ {tensor_name: _decode(encoded_tensor) for tensor_name, encoded_tensor in tensors.items()}
948
+ for tensors in named_tensors_lists
949
+ ]
950
+
951
+ def start(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
952
+ """Start Dummy implementation.
953
+
954
+ Args:
955
+ url: address of data store
956
+ authkey: authentication key required to setup connection. If not provided, current process authkey will be used
957
+ """
958
+ pass
959
+
960
+ def connect(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
961
+ """Connect to Dummy implementation.
962
+
963
+ Args:
964
+ url: address of data store
965
+ authkey: authentication key required to setup connection. If not provided, current process authkey will be used
966
+ """
967
+ pass
968
+
969
+ def close(self):
970
+ """Close Dummy implementation."""
971
+ pass
972
+
973
+
974
+ class TensorStoreSerializerDeserializer(BaseRequestsResponsesSerializerDeserializer):
975
+ """Serializer/deserializer for requests/responses using TensorStore."""
976
+
977
+ def __init__(self):
978
+ """Initialize TensorStoreSerializerDeserializer object."""
979
+ self._tensor_store = None
980
+
981
+ def serialize_requests(self, requests: Requests) -> bytes:
982
+ """Serialize requests.
983
+
984
+ Args:
985
+ requests: list of requests to serialize
986
+
987
+ Returns:
988
+ Serialized requests
989
+ """
990
+ serialized_requests = self._serialize_named_tensors_lists(requests)
991
+ requests_list = []
992
+ for request, serialized_request in zip(requests, serialized_requests):
993
+ serialized_request = {"data": serialized_request, "parameters": request.parameters}
994
+ if request.span is not None:
995
+ serialized_request["span"] = get_span_dict(request.span)
996
+ requests_list.append(serialized_request)
997
+
998
+ requests = {"requests": requests_list}
999
+ return json.dumps(requests).encode("utf-8")
1000
+
1001
+ def deserialize_requests(self, requests_payload: bytes) -> Requests:
1002
+ """Deserialize requests.
1003
+
1004
+ Args:
1005
+ requests_payload: serialized requests
1006
+
1007
+ Returns:
1008
+ List of deserialized requests
1009
+ """
1010
+ requests = json.loads(requests_payload)
1011
+ deserialized_requests = []
1012
+ for request in requests["requests"]:
1013
+ kwargs = {}
1014
+ if "span" in request:
1015
+ span_dict = request["span"]
1016
+ span = start_span_from_remote(span_dict, "proxy_inference_callable")
1017
+ kwargs["span"] = span
1018
+ request_data = {
1019
+ input_name: self._tensor_store.get(tensor_id)
1020
+ for input_name, tensor_id in request.get("data", {}).items()
1021
+ }
1022
+ kwargs["data"] = request_data
1023
+ kwargs["parameters"] = request.get("parameters")
1024
+ request_wrapped = Request(**kwargs)
1025
+ deserialized_requests.append(request_wrapped)
1026
+
1027
+ return deserialized_requests
1028
+
1029
+ def free_requests_resources(self, requests_payload: bytes):
1030
+ """Free resources used by requests."""
1031
+ if requests_payload:
1032
+ requests = json.loads(requests_payload)
1033
+ for response in requests["requests"]:
1034
+ for _, tensor_id in response.get("data", {}).items():
1035
+ self._tensor_store.release_block(tensor_id)
1036
+
1037
+ def serialize_responses(self, responses: Responses) -> bytes:
1038
+ """Serialize responses.
1039
+
1040
+ Args:
1041
+ responses: list of responses to serialize
1042
+
1043
+ Returns:
1044
+ Serialized responses
1045
+ """
1046
+ responses = self._serialize_named_tensors_lists(responses)
1047
+ responses = {"responses": [{"data": response} for response in responses]}
1048
+ return json.dumps(responses).encode("utf-8")
1049
+
1050
+ def deserialize_responses(self, responses_payload: bytes) -> Responses:
1051
+ """Deserialize responses.
1052
+
1053
+ Args:
1054
+ responses_payload: serialized responses
1055
+
1056
+ Returns:
1057
+ List of deserialized responses
1058
+ """
1059
+ if responses_payload:
1060
+ responses = json.loads(responses_payload)
1061
+ return [
1062
+ Response(
1063
+ data={
1064
+ input_name: self._tensor_store.get(tensor_id)
1065
+ for input_name, tensor_id in response.get("data", {}).items()
1066
+ }
1067
+ )
1068
+ for response in responses["responses"]
1069
+ ]
1070
+ else:
1071
+ return []
1072
+
1073
+ def free_responses_resources(self, responses_payload: bytes):
1074
+ """Free resources used by responses."""
1075
+ if responses_payload:
1076
+ responses = json.loads(responses_payload)
1077
+ for response in responses["responses"]:
1078
+ for _, tensor_id in response.get("data", {}).items():
1079
+ self._tensor_store.release_block(tensor_id)
1080
+
1081
+ def _serialize_named_tensors_lists(self, named_tensors_lists):
1082
+ values_with_coords = [
1083
+ (idx, tensor_name, tensor)
1084
+ for idx, tensors in enumerate(named_tensors_lists)
1085
+ for tensor_name, tensor in tensors.items()
1086
+ ]
1087
+ tensor_ids = self._tensor_store.put([tensor for _, _, tensor in values_with_coords])
1088
+ named_tensors_lists = [{} for _ in range(len(named_tensors_lists))]
1089
+ for (idx, tensor_name, _), tensor_id in zip(values_with_coords, tensor_ids):
1090
+ named_tensors_lists[idx][tensor_name] = tensor_id
1091
+
1092
+ return named_tensors_lists
1093
+
1094
+ def start(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
1095
+ """Start TensorStore.
1096
+
1097
+ Args:
1098
+ url: address of data store
1099
+ authkey: authentication key required to setup connection. If not provided, current process authkey will be used
1100
+ """
1101
+ self._tensor_store = self._create(url, authkey)
1102
+ self._tensor_store.start()
1103
+
1104
+ def connect(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
1105
+ """Connect to TensorStore.
1106
+
1107
+ Args:
1108
+ url: address of data store
1109
+ authkey: authentication key required to setup connection. If not provided, current process authkey will be used
1110
+ """
1111
+ self._tensor_store = self._create(url, authkey)
1112
+ self._tensor_store.connect()
1113
+
1114
+ def _create(self, url: Union[str, pathlib.Path], authkey: Optional[bytes] = None):
1115
+ authkey = authkey or multiprocessing.current_process().authkey
1116
+ return TensorStore(url, authkey)
1117
+
1118
+ def close(self):
1119
+ """Close TensorStore."""
1120
+ if self._tensor_store:
1121
+ # check if run by this serializer/deserializer
1122
+ if self._tensor_store.is_started():
1123
+ debug_status = get_debug_status(self._tensor_store)
1124
+ used_blocks = [block for segment in debug_status["segments"] for block in segment["used_blocks"]]
1125
+ if used_blocks:
1126
+ LOGGER.debug(f"TensorStore used blocks while closing: {used_blocks}")
1127
+ # raise RuntimeError(
1128
+ # f"TensorStore at {self._tensor_store.address} is still running. Used blocks: {used_blocks}"
1129
+ # )
1130
+ LOGGER.debug(f"Closing TensorStore process at {self._tensor_store.address}")
1131
+
1132
+ self._tensor_store.close()
1133
+ self._tensor_store = None