Skip to content

Commit

Permalink
Return error status instead of silently erroring out during TensorRT …
Browse files Browse the repository at this point in the history
…weight conversion
  • Loading branch information
jhalakpatel committed Nov 6, 2024
1 parent 8cf2586 commit 7254cbe
Showing 1 changed file with 16 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -520,41 +520,41 @@ static void packNonSplatInt4Tensor(ElementsAttr values, int64_t count,
}
}

static void serializeSplatElements(DenseIntOrFPElementsAttr values,
std::vector<int8_t> &data) {
static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values,
std::vector<int8_t> &data) {
assert(values.isSplat() && "expected SplatElementsAttr");

auto rtt = cast<RankedTensorType>(values.getType());
if (rtt.getElementType().isInteger(32)) {
std::fill_n(reinterpret_cast<int32_t *>(data.data()),
values.getNumElements(), values.getSplatValue<int32_t>());
return;
return llvm::success();
}
if (rtt.getElementType().isInteger(8)) {
std::fill_n(reinterpret_cast<int8_t *>(data.data()),
values.getNumElements(), values.getSplatValue<int8_t>());
return;
return llvm::success();
}
if (rtt.getElementType().isF32()) {
std::fill_n(reinterpret_cast<float *>(data.data()), values.getNumElements(),
values.getSplatValue<float>());
return;
return llvm::success();
}
if (rtt.getElementType().isF16() || rtt.getElementType().isBF16()) {
APInt tmp = values.getSplatValue<APFloat>().bitcastToAPInt();
assert(tmp.getBitWidth() == 16 && "unexpected bitwidth");
uint16_t fillValue = *reinterpret_cast<const uint16_t *>(tmp.getRawData());
std::fill_n(reinterpret_cast<uint16_t *>(data.data()),
values.getNumElements(), fillValue);
return;
return llvm::success();
}
if (rtt.getElementType().isFloat8E4M3FN()) {
APInt tmp = values.getSplatValue<APFloat>().bitcastToAPInt();
assert(tmp.getBitWidth() == 8 && "unexpected bitwidth");
uint8_t fillValue = *reinterpret_cast<const uint8_t *>(tmp.getRawData());
std::fill_n(reinterpret_cast<uint8_t *>(data.data()),
values.getNumElements(), fillValue);
return;
return llvm::success();
}
if (rtt.getElementType().isInteger(4)) {
APInt tmp = values.getSplatValue<APInt>();
Expand All @@ -566,11 +566,12 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values,
packed |= ((value & 0x0F) << 4);
// Fill `data` vector with `packed`
std::fill_n(reinterpret_cast<uint8_t *>(data.data()), data.size(), packed);
return;
return llvm::success();
}

llvm_unreachable("unsupported data type to convert MLIR splat attribute to "
"TensorRT weights!");
llvm::errs() << "Error: "
<< "unsupported data type to convert MLIR splat attribute to "
"TensorRT weights!";
return llvm::failure();
}

FailureOr<nvinfer1::Weights>
Expand Down Expand Up @@ -615,8 +616,10 @@ NvInferNetworkEncoder::getNvInferWeights(ElementsAttr values) {
weights.values = data.data();

if (values.isSplat() && isa<DenseIntOrFPElementsAttr>(values)) {
serializeSplatElements(cast<DenseIntOrFPElementsAttr>(values),
weightsMap[values]);
LogicalResult status = serializeSplatElements(
cast<DenseIntOrFPElementsAttr>(values), weightsMap[values]);
if (failed(status))
return failure();
return weights;
}

Expand Down

0 comments on commit 7254cbe

Please sign in to comment.