Skip to content

Commit

Permalink
Support BF16 data type in TensorRT backend (#90)
Browse files Browse the repository at this point in the history
* Adding support for TensorRT 10 APIs in the backend. Keep TRT 8 support as well (#88)

* Replace binding index-based methods with name-based alternatives

* Remove unused variables

* Remove unused variables

* Remove allInput*Specified()

* Delete TRTV1Interface

* Replace getProfileShapeValues() with getProfileTensorValues()

* Remove buffer_bindings_

* Enhancements

* Replace isExecutionBinding()

* Add INT64 support

* Remove hasImplicitBatchDimension()

* Update Copyright

* Remove unused variables

* Undo copyright

* Undo Copyright

* Undo copyright

* Fix the handling in INT64 shape tensors output

* Fix data dependent output shapes

* Fix pre commit errors

* Update copyright

* Resolve review comments

* Include source for building on TRT 8 (#86) (#87)

* Include source for building on TRT 8

* Apply suggestions from code review



---------

Co-authored-by: Misha Chornyi <[email protected]>

* Fix envvar access in CMake

---------

Co-authored-by: Sai Kiran Polisetty <[email protected]>
Co-authored-by: Misha Chornyi <[email protected]>

* Add support for kBF16

---------

Co-authored-by: Tanmay Verma <[email protected]>
Co-authored-by: Misha Chornyi <[email protected]>
  • Loading branch information
3 people authored Jun 5, 2024
1 parent 75d3361 commit fd867b7
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/tensorrt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ ConvertTrtTypeToDataType(nvinfer1::DataType trt_type)
return TRITONSERVER_TYPE_INT8;
case nvinfer1::DataType::kUINT8:
return TRITONSERVER_TYPE_UINT8;
case nvinfer1::DataType::kBF16:
return TRITONSERVER_TYPE_BF16;
case nvinfer1::DataType::kINT32:
return TRITONSERVER_TYPE_INT32;
case nvinfer1::DataType::kINT64:
Expand All @@ -67,6 +69,8 @@ ConvertTrtTypeToConfigDataType(nvinfer1::DataType trt_type)
return "TYPE_INT8";
case nvinfer1::DataType::kUINT8:
return "TYPE_UINT8";
case nvinfer1::DataType::kBF16:
return "TYPE_BF16";
case nvinfer1::DataType::kINT32:
return "TYPE_INT32";
case nvinfer1::DataType::kINT64:
Expand Down Expand Up @@ -116,6 +120,9 @@ ConvertDataTypeToTrtType(const TRITONSERVER_DataType& dtype)
case TRITONSERVER_TYPE_UINT8:
trt_type = nvinfer1::DataType::kUINT8;
break;
case TRITONSERVER_TYPE_BF16:
trt_type = nvinfer1::DataType::kBF16;
break;
case TRITONSERVER_TYPE_INT32:
trt_type = nvinfer1::DataType::kINT32;
break;
Expand Down

0 comments on commit fd867b7

Please sign in to comment.