Skip to content

Commit

Permalink
test: Add BF16 test for python backend (#7483)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmccorm4 authored Jul 30, 2024
1 parent fb056b1 commit 69d768d
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 0 deletions.
30 changes: 30 additions & 0 deletions qa/L0_backend_python/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,36 @@ def test_bool(self):
self.assertIsNotNone(output0)
self.assertTrue(np.all(output0 == input_data))

def test_bf16(self):
model_name = "identity_bf16"
shape = [2, 2]
with self._shm_leak_detector.Probe() as shm_probe:
with httpclient.InferenceServerClient(
f"{_tritonserver_ipaddr}:8000"
) as client:
# NOTE: Client will truncate FP32 to BF16 internally
# since numpy has no built-in BF16 representation.
np_input = np.ones(shape, dtype=np.float32)
inputs = [
httpclient.InferInput(
"INPUT0", np_input.shape, "BF16"
).set_data_from_numpy(np_input)
]
result = client.infer(model_name, inputs)

# Assert that Triton correctly returned a BF16 tensor.
response = result.get_response()
triton_output = response["outputs"][0]
triton_dtype = triton_output["datatype"]
self.assertEqual(triton_dtype, "BF16")

np_output = result.as_numpy("OUTPUT0")
self.assertIsNotNone(np_output)
# BF16 tensors are held in FP32 when converted to numpy due to
# lack of native BF16 support in numpy, so verify that.
self.assertEqual(np_output.dtype, np.float32)
self.assertTrue(np.allclose(np_output, np_input))

def test_infer_pytorch(self):
# FIXME: This model requires torch. Because windows tests are not run in a docker
# environment with torch installed, we need to think about how we want to install
Expand Down
3 changes: 3 additions & 0 deletions qa/L0_backend_python/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ fi
mkdir -p models/identity_fp32/1/
cp ../python_models/identity_fp32/model.py ./models/identity_fp32/1/model.py
cp ../python_models/identity_fp32/config.pbtxt ./models/identity_fp32/config.pbtxt
mkdir -p models/identity_bf16/1/
cp ../python_models/identity_bf16/model.py ./models/identity_bf16/1/model.py
cp ../python_models/identity_bf16/config.pbtxt ./models/identity_bf16/config.pbtxt
RET=0

cp -r ./models/identity_fp32 ./models/identity_uint8
Expand Down
51 changes: 51 additions & 0 deletions qa/python_models/identity_bf16/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

backend: "python"
max_batch_size: 64

input [
{
name: "INPUT0"
data_type: TYPE_BF16
dims: [ -1 ]
}
]

output [
{
name: "OUTPUT0"
data_type: TYPE_BF16
dims: [ -1 ]
}
]

instance_group [
{
count: 1
kind : KIND_CPU
}
]
88 changes: 88 additions & 0 deletions qa/python_models/identity_bf16/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json

import torch
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
# You must parse model_config. JSON string is not parsed here
self.model_config = json.loads(args["model_config"])

# Get tensor configurations for testing/validation
self.input0_config = pb_utils.get_input_config_by_name(
self.model_config, "INPUT0"
)
self.output0_config = pb_utils.get_output_config_by_name(
self.model_config, "OUTPUT0"
)

def validate_bf16_tensor(self, tensor, tensor_config):
# I/O datatypes can be queried from the model config if needed
dtype = tensor_config["data_type"]
if dtype != "TYPE_BF16":
raise Exception(f"Expected a BF16 tensor, but got {dtype} instead.")

# Converting BF16 tensors to numpy is not supported, and DLPack
# should be used instead via to_dlpack and from_dlpack.
try:
_ = tensor.as_numpy()
except pb_utils.TritonModelException as e:
expected_error = "tensor dtype is bf16 and cannot be converted to numpy"
assert expected_error in str(e).lower()
else:
raise Exception("Expected BF16 conversion to numpy to fail")

def execute(self, requests):
"""
Identity model in Python backend with example BF16 and PyTorch usage.
"""
responses = []
for request in requests:
input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0")

# Numpy does not support BF16, so use DLPack instead.
bf16_dlpack = input_tensor.to_dlpack()

# OPTIONAL: The tensor can be converted to other dlpack-compatible
# frameworks like PyTorch and TensorFlow with their dlpack utilities.
torch_tensor = torch.utils.dlpack.from_dlpack(bf16_dlpack)

# When complete, convert back to a pb_utils.Tensor via DLPack.
output_tensor = pb_utils.Tensor.from_dlpack(
"OUTPUT0", torch.utils.dlpack.to_dlpack(torch_tensor)
)
responses.append(pb_utils.InferenceResponse([output_tensor]))

# NOTE: The following helper function is for testing and example
# purposes only, you should remove this in practice.
self.validate_bf16_tensor(input_tensor, self.input0_config)
self.validate_bf16_tensor(output_tensor, self.output0_config)

return responses

0 comments on commit 69d768d

Please sign in to comment.