-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bfloat to client #4521
Add bfloat to client #4521
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,18 +31,25 @@ | |
import sys | ||
import requests as httpreq | ||
from builtins import range | ||
import tritongrpcclient as grpcclient | ||
import tritonhttpclient as httpclient | ||
from tritonclientutils import np_to_triton_dtype | ||
import tritonclient.grpc as grpcclient | ||
import tritonclient.http as httpclient | ||
from tritonclient.utils import np_to_triton_dtype | ||
|
||
FLAGS = None | ||
|
||
def test_bf16_raw_http(shape): | ||
def test_bf16_http(shape): | ||
if ("tensorflow" not in sys.modules): | ||
from bfloat16 import bfloat16 | ||
else: | ||
# known incompatability issue here: | ||
# https://github.com/GreenWaves-Technologies/bfloat16/issues/2 | ||
# Can solve when numpy officially supports bfloat16 | ||
# https://github.com/numpy/numpy/issues/19808 | ||
print("error: tensorflow is included in module. This module cannot " \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this an error/check that should be done by the client as well? The comment is helpful, but users may run into issues if this results in incompatibility. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes more sense to
|
||
"co-exist with pypi version of bfloat16") | ||
sys.exit(1) | ||
model = "identity_bf16" | ||
# Using fp16 data as a WAR since it is same byte_size as bf16 | ||
# and is supported by numpy for ease-of-use. Since this is an | ||
# identity model, it's OK that the bytes are interpreted differently | ||
input_data = (16384 * np.random.randn(*shape)).astype(np.float16) | ||
input_data = (np.random.randn(*shape)).astype(bfloat16) | ||
input_bytes = input_data.tobytes() | ||
headers = {'Inference-Header-Content-Length': '0'} | ||
r = httpreq.post("http://localhost:8000/v2/models/{}/infer".format(model), | ||
|
@@ -264,7 +271,5 @@ def test_bf16_raw_http(shape): | |
print("error: expected 'param2' == False") | ||
sys.exit(1) | ||
|
||
# FIXME: Use identity_bf16 model in test above once proper python client | ||
# support is added, and remove this raw HTTP test. See DLIS-3720. | ||
test_bf16_raw_http([2, 2]) | ||
test_bf16_http([2, 2]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper was only meant to be a temp workaround until python client bf16 support was added.
Can we instead add to the existing test below now that there is bf16 support in the client? I believe there's a
identity_bf16
model that will work now, or if not could make a very simple python identity model.This would:
triton_to_np_dtype
,np_to_triton_dtype
, etc.