Skip to content

Commit

Permalink
Bring back Python backend based PyTorch backend (#117)
Browse files Browse the repository at this point in the history
* Add Python backend based PyTorch runtime

* Add exec env build

* Add note for adding .pt2 model support

* Do not specify pytorch cuda version

* Do not install Python runtime on non x86

* Remove legacy comment

* User to build PyTorch env

* Add docs

* Update copyright

* Clarify model layout between PyTorch and TorchScript

* Fix header size
  • Loading branch information
kthui authored Jan 11, 2024
1 parent d900538 commit 7468381
Show file tree
Hide file tree
Showing 4 changed files with 500 additions and 2 deletions.
9 changes: 8 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -502,6 +502,13 @@ install(
${INSTALL_CONFIGDIR}
)

install(
FILES
src/model.py
DESTINATION
${CMAKE_INSTALL_PREFIX}/backends/pytorch
)

include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/TritonPyTorchBackendConfig.cmake.in
Expand Down
126 changes: 125 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<!--
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -243,3 +243,127 @@ instance in the
[model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#instance-groups)
to ensure that the model instance and the tensors used for inference are
assigned to the same GPU device as on which the model was traced.

# PyTorch 2.0 Backend \[Experimental\]

> [!WARNING]
> *This feature is subject to change and removal.*
Starting from 24.01, PyTorch models can be served directly via
[Python runtime](src/model.py). By default, Triton will use the
[LibTorch runtime](#pytorch-libtorch-backend) for PyTorch models. To use Python
runtime, provide the following
[runtime setting](https://github.com/triton-inference-server/backend/blob/main/README.md#backend-shared-library)
in the model configuration:

```
runtime: "model.py"
```

## Dependencies

### Python backend dependency

This feature depends on
[Python backend](https://github.com/triton-inference-server/python_backend),
see
[Python-based Backends](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md)
for more details.

### PyTorch dependency

This feature will take advantage of the
[`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)
optimization, make sure the
[PyTorch 2.0+ pip package](https://pypi.org/project/torch) is available in the
same Python environment.

Alternatively, a [Python Execution Environment](#using-custom-python-execution-environments)
with the PyTorch dependency may be used. It can be created with the
[provided script](tools/gen_pb_exec_env.sh). The resulting
`pb_exec_env_model.py.tar.gz` file should be placed at the same
[backend shared library](https://github.com/triton-inference-server/backend/blob/main/README.md#backend-shared-library)
directory as the [Python runtime](src/model.py).

## Model Layout

### PyTorch 2.0 models

The model repository should look like:

```
model_repository/
`-- model_directory
|-- 1
| |-- model.py
| `-- [model.pt]
`-- config.pbtxt
```

The `model.py` contains the class definition of the PyTorch model. The class
should extend the
[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module).
The `model.pt` may be optionally provided which contains the saved
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference)
of the model.

### TorchScript models

The model repository should look like:

```
model_repository/
`-- model_directory
|-- 1
| `-- model.pt
`-- config.pbtxt
```

The `model.pt` is the TorchScript model file.

## Customization

The following PyTorch settings may be customized by setting parameters on the
`config.pbtxt`.

[`torch.set_num_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads)
- Key: NUM_THREADS
- Value: The number of threads used for intraop parallelism on CPU.

[`torch.set_num_interop_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads.html#torch.set_num_interop_threads)
- Key: NUM_INTEROP_THREADS
- Value: The number of threads used for interop parallelism (e.g. in JIT
interpreter) on CPU.

[`torch.compile()` parameters](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)
- Key: TORCH_COMPILE_OPTIONAL_PARAMETERS
- Value: Any of following parameter(s) encoded as a JSON object.
- fullgraph (*bool*): Whether it is ok to break model into several subgraphs.
- dynamic (*bool*): Use dynamic shape tracing.
- backend (*str*): The backend to be used.
- mode (*str*): Can be either "default", "reduce-overhead" or "max-autotune".
- options (*dict*): A dictionary of options to pass to the backend.
- disable (*bool*): Turn `torch.compile()` into a no-op for testing.

For example:
```
parameters: {
key: "NUM_THREADS"
value: { string_value: "4" }
}
parameters: {
key: "TORCH_COMPILE_OPTIONAL_PARAMETERS"
value: { string_value: "{\"disable\": true}" }
}
```

## Limitations

Following are few known limitations of this feature:
- Python functions optimizable by `torch.compile` may not be served directly in
the `model.py` file, they need to be enclosed by a class extending the
[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module).
- Model weights cannot be shared across multiple instances on the same GPU
device.
- When using `KIND_MODEL` as model instance kind, the default device of the
first parameter on the model is used.
Loading

0 comments on commit 7468381

Please sign in to comment.