Skip to content

Commit

Permalink
feature/merge_parse_model (#2651)
Browse files Browse the repository at this point in the history
* Fix pep8, strict formatting

* merge parse_model_http and parse_model_grpc
  • Loading branch information
kimdwkimdw authored Apr 6, 2021
1 parent 8fd536e commit b8997d1
Showing 1 changed file with 40 additions and 115 deletions.
155 changes: 40 additions & 115 deletions python/examples/image_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import numpy as np
from PIL import Image
import sys
from functools import partial
import os
import sys

from PIL import Image
import numpy as np
from attrdict import AttrDict

import tritonclient.grpc as grpcclient
import tritonclient.grpc.model_config_pb2 as mc
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
from tritonclient.utils import InferenceServerException
from tritonclient.utils import triton_to_np_dtype

if sys.version_info >= (3, 0):
import queue
Expand All @@ -59,7 +61,7 @@ def completion_callback(user_data, result, error):
FLAGS = None


def parse_model_grpc(model_metadata, model_config):
def parse_model(model_metadata, model_config):
"""
Check the configuration of a model to make sure it meets the
requirements for an image classification network (as expected by
Expand Down Expand Up @@ -107,11 +109,15 @@ def parse_model_grpc(model_metadata, model_config):
if len(input_metadata.shape) != expected_input_dims:
raise Exception(
"expecting input to have {} dimensions, model '{}' input has {}".
format(expected_input_dims, model_metadata.name,
len(input_metadata.shape)))
format(expected_input_dims, model_metadata.name,
len(input_metadata.shape)))

if type(input_config.format) == str:
FORMAT_ENUM_TO_INT = dict(mc.ModelInput.Format.items())
input_config.format = FORMAT_ENUM_TO_INT[input_config.format]

if ((input_config.format != mc.ModelInput.FORMAT_NCHW) and
(input_config.format != mc.ModelInput.FORMAT_NHWC)):
(input_config.format != mc.ModelInput.FORMAT_NHWC)):
raise Exception("unexpected input format " +
mc.ModelInput.Format.Name(input_config.format) +
", expecting " +
Expand All @@ -133,79 +139,6 @@ def parse_model_grpc(model_metadata, model_config):
input_metadata.datatype)


def parse_model_http(model_metadata, model_config):
"""
Check the configuration of a model to make sure it meets the
requirements for an image classification network (as expected by
this client)
"""
if len(model_metadata['inputs']) != 1:
raise Exception("expecting 1 input, got {}".format(
len(model_metadata['inputs'])))
if len(model_metadata['outputs']) != 1:
raise Exception("expecting 1 output, got {}".format(
len(model_metadata['outputs'])))

if len(model_config['input']) != 1:
raise Exception(
"expecting 1 input in model configuration, got {}".format(
len(model_config['input'])))

input_metadata = model_metadata['inputs'][0]
input_config = model_config['input'][0]
output_metadata = model_metadata['outputs'][0]

max_batch_size = 0
if 'max_batch_size' in model_config:
max_batch_size = model_config['max_batch_size']

if output_metadata['datatype'] != "FP32":
raise Exception("expecting output datatype to be FP32, model '" +
model_metadata['name'] + "' output type is " +
output_metadata['datatype'])

# Output is expected to be a vector. But allow any number of
# dimensions as long as all but 1 is size 1 (e.g. { 10 }, { 1, 10
# }, { 10, 1, 1 } are all ok). Ignore the batch dimension if there
# is one.
output_batch_dim = (max_batch_size > 0)
non_one_cnt = 0
for dim in output_metadata['shape']:
if output_batch_dim:
output_batch_dim = False
elif dim > 1:
non_one_cnt += 1
if non_one_cnt > 1:
raise Exception("expecting model output to be a vector")

# Model input must have 3 dims (not counting the batch dimension),
# either CHW or HWC
input_batch_dim = (max_batch_size > 0)
expected_input_dims = 3 + (1 if input_batch_dim else 0)
if len(input_metadata['shape']) != expected_input_dims:
raise Exception(
"expecting input to have {} dimensions, model '{}' input has {}".
format(expected_input_dims, model_metadata['name'],
len(input_metadata['shape'])))

if ((input_config['format'] != "FORMAT_NCHW") and
(input_config['format'] != "FORMAT_NHWC")):
raise Exception("unexpected input format " + input_config['format'] +
", expecting FORMAT_NCHW or FORMAT_NHWC")

if input_config['format'] == "FORMAT_NHWC":
h = input_metadata['shape'][1 if input_batch_dim else 0]
w = input_metadata['shape'][2 if input_batch_dim else 1]
c = input_metadata['shape'][3 if input_batch_dim else 2]
else:
c = input_metadata['shape'][1 if input_batch_dim else 0]
h = input_metadata['shape'][2 if input_batch_dim else 1]
w = input_metadata['shape'][3 if input_batch_dim else 2]

return (max_batch_size, input_metadata['name'], output_metadata['name'], c,
h, w, input_config['format'], input_metadata['datatype'])


def preprocess(img, format, dtype, c, h, w, scaling, protocol):
"""
Pre-process an image to meet the size, type and format
Expand Down Expand Up @@ -237,16 +170,10 @@ def preprocess(img, format, dtype, c, h, w, scaling, protocol):
scaled = typed

# Swap to CHW if necessary
if protocol == "grpc":
if format == mc.ModelInput.FORMAT_NCHW:
ordered = np.transpose(scaled, (2, 0, 1))
else:
ordered = scaled
if format == mc.ModelInput.FORMAT_NCHW:
ordered = np.transpose(scaled, (2, 0, 1))
else:
if format == "FORMAT_NCHW":
ordered = np.transpose(scaled, (2, 0, 1))
else:
ordered = scaled
ordered = scaled

# Channels are in RGB order. Currently model configuration data
# doesn't provide any information as to other channel orderings
Expand Down Expand Up @@ -277,32 +204,29 @@ def postprocess(results, output_name, batch_size, batching):


def requestGenerator(batched_image_data, input_name, output_name, dtype, FLAGS):
protocol = FLAGS.protocol.lower()

# Set the input data
inputs = []
if FLAGS.protocol.lower() == "grpc":
inputs.append(
grpcclient.InferInput(input_name, batched_image_data.shape, dtype))
inputs[0].set_data_from_numpy(batched_image_data)
if protocol == "grpc":
client = grpcclient
else:
inputs.append(
httpclient.InferInput(input_name, batched_image_data.shape, dtype))
inputs[0].set_data_from_numpy(batched_image_data, binary_data=True)
client = httpclient

outputs = []
if FLAGS.protocol.lower() == "grpc":
outputs.append(
grpcclient.InferRequestedOutput(output_name,
class_count=FLAGS.classes))
else:
outputs.append(
httpclient.InferRequestedOutput(output_name,
binary_data=True,
class_count=FLAGS.classes))
# Set the input data
inputs = [client.InferInput(input_name, batched_image_data.shape, dtype)]
inputs[0].set_data_from_numpy(batched_image_data)

outputs = [client.InferRequestedOutput(output_name, class_count=FLAGS.classes)]

yield inputs, outputs, FLAGS.model_name, FLAGS.model_version


def convert_http_metadata_config(_metadata, _config):
_model_metadata = AttrDict(_metadata)
_model_config = AttrDict(_config)

return _model_metadata, _model_config


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-v',
Expand All @@ -323,7 +247,7 @@ def requestGenerator(batched_image_data, input_name, output_name, dtype, FLAGS):
required=False,
default=False,
help='Use streaming inference API. ' +
'The flag is only available with gRPC protocol.')
'The flag is only available with gRPC protocol.')
parser.add_argument('-m',
'--model-name',
type=str,
Expand Down Expand Up @@ -368,7 +292,7 @@ def requestGenerator(batched_image_data, input_name, output_name, dtype, FLAGS):
required=False,
default='HTTP',
help='Protocol (HTTP/gRPC) used to communicate with ' +
'the inference service. Default is HTTP.')
'the inference service. Default is HTTP.')
parser.add_argument('image_filename',
type=str,
nargs='?',
Expand Down Expand Up @@ -411,11 +335,12 @@ def requestGenerator(batched_image_data, input_name, output_name, dtype, FLAGS):
sys.exit(1)

if FLAGS.protocol.lower() == "grpc":
max_batch_size, input_name, output_name, c, h, w, format, dtype = parse_model_grpc(
model_metadata, model_config.config)
model_config = model_config.config
else:
max_batch_size, input_name, output_name, c, h, w, format, dtype = parse_model_http(
model_metadata, model_config)
model_metadata, model_config = convert_http_metadata_config(model_metadata, model_config)

max_batch_size, input_name, output_name, c, h, w, format, dtype = parse_model(
model_metadata, model_config)

filenames = []
if os.path.isdir(FLAGS.image_filename):
Expand Down

0 comments on commit b8997d1

Please sign in to comment.