Skip to content

Commit

Permalink
Make InferenceCalculatorDarwinn support float and int32 as input data…
Browse files Browse the repository at this point in the history
… type.

Before, input data type is default to float for Darwinn. Now some tflite models require int32 as input, thus this support is added.

PiperOrigin-RevId: 602450883
  • Loading branch information
aaronndx authored and copybara-github committed Jan 29, 2024
1 parent 850da4e commit 82060a1
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mediapipe/calculators/tensor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ cc_library(
deps = [
":inference_calculator_cc_proto",
"//mediapipe/framework:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:common",
] + select({
"//conditions:default": [
"//mediapipe/util:cpu_util",
Expand All @@ -582,6 +588,20 @@ cc_library(
alwayslink = 1,
)

cc_test(
name = "inference_calculator_utils_test",
srcs = ["inference_calculator_utils_test.cc"],
deps = [
":inference_calculator_utils",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:common",
],
)

cc_library(
name = "inference_calculator_xnnpack",
srcs = [
Expand Down
91 changes: 91 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@

#include "mediapipe/calculators/tensor/inference_calculator_utils.h"

#include <cstring>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port.h" // NOLINT: provides MEDIAPIPE_ANDROID/IOS
#include "mediapipe/framework/port/status_macros.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"

#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
#include "mediapipe/util/cpu_util.h"
Expand All @@ -37,6 +46,58 @@ int GetXnnpackDefaultNumThreads() {
#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__
}

// Checks if a MediaPipe Tensor's type matches a TfLite's data type.
bool DoTypesMatch(Tensor::ElementType tensor_type, TfLiteType tflite_type) {
switch (tensor_type) {
// Do these two match?
case Tensor::ElementType::kNone:
return tflite_type == TfLiteType::kTfLiteNoType;
case Tensor::ElementType::kFloat16:
return tflite_type == TfLiteType::kTfLiteFloat16;
case Tensor::ElementType::kFloat32:
return tflite_type == TfLiteType::kTfLiteFloat32;
case Tensor::ElementType::kUInt8:
return tflite_type == TfLiteType::kTfLiteUInt8;
case Tensor::ElementType::kInt8:
return tflite_type == TfLiteType::kTfLiteInt8;
case Tensor::ElementType::kInt32:
return tflite_type == TfLiteType::kTfLiteInt32;
case Tensor::ElementType::kBool:
return tflite_type == TfLiteType::kTfLiteBool;
// Seems like TfLite does not have a char type support?
default:
return false;
}
}

template <typename T>
absl::Status CopyTensorBufferToInterpreter(const Tensor& input_tensor,
tflite::Interpreter& interpreter,
int input_tensor_index) {
auto input_tensor_view = input_tensor.GetCpuReadView();
const T* input_tensor_buffer = input_tensor_view.buffer<T>();
if (input_tensor_buffer == nullptr) {
return absl::InternalError("Input tensor buffer is null.");
}
T* local_tensor_buffer =
interpreter.typed_input_tensor<T>(input_tensor_index);
if (local_tensor_buffer == nullptr) {
return absl::InvalidArgumentError(
"Interpreter's input tensor buffer is null, may because it does not "
"support the input type specified.");
}
const TfLiteTensor* local_tensor =
interpreter.input_tensor(input_tensor_index);
if (local_tensor->bytes != input_tensor.bytes()) {
return absl::InvalidArgumentError(
absl::StrCat("Interpreter's input size do not match the input tensor's "
"size for index ",
input_tensor_index, "."));
}
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
return absl::OkStatus();
}

} // namespace

int GetXnnpackNumThreads(
Expand All @@ -50,4 +111,34 @@ int GetXnnpackNumThreads(
return GetXnnpackDefaultNumThreads();
}

absl::Status CopyCpuInputIntoInterpreterTensor(const Tensor& input_tensor,
tflite::Interpreter& interpreter,
int input_tensor_index) {
const TfLiteType interpreter_tensor_type =
interpreter.tensor(interpreter.inputs()[input_tensor_index])->type;
const Tensor::ElementType input_tensor_type = input_tensor.element_type();
if (!DoTypesMatch(input_tensor_type, interpreter_tensor_type)) {
return absl::InvalidArgumentError(absl::StrCat(
"Input and interpreter tensor type mismatch: ", input_tensor_type,
" vs. ", interpreter_tensor_type));
}
switch (interpreter_tensor_type) {
case TfLiteType::kTfLiteFloat16:
case TfLiteType::kTfLiteFloat32: {
MP_RETURN_IF_ERROR(CopyTensorBufferToInterpreter<float>(
input_tensor, interpreter, input_tensor_index));
break;
}
case TfLiteType::kTfLiteInt32: {
MP_RETURN_IF_ERROR(CopyTensorBufferToInterpreter<int>(
input_tensor, interpreter, input_tensor_index));
break;
}
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported input data type: ", input_tensor_type));
}
return absl::OkStatus();
}

} // namespace mediapipe
7 changes: 7 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_

#include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "tensorflow/lite/interpreter.h"

namespace mediapipe {

Expand All @@ -26,6 +29,10 @@ int GetXnnpackNumThreads(
const bool opts_has_delegate,
const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate);

absl::Status CopyCpuInputIntoInterpreterTensor(const Tensor& input_tensor,
tflite::Interpreter& interpreter,
int input_tensor_index);

} // namespace mediapipe

#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_
173 changes: 173 additions & 0 deletions mediapipe/calculators/tensor/inference_calculator_utils_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mediapipe/calculators/tensor/inference_calculator_utils.h"

#include <cstdint>
#include <cstring>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"

namespace mediapipe {
namespace {

using ElementType = ::mediapipe::Tensor::ElementType;
using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
using ::tflite::Interpreter;

// Adds a tensor of certain type and size inside the interpreter, and update
// the tensor index.
void AddInterpreterInput(TfLiteType type, int size, int& tensor_index,
bool allocate_tensor, Interpreter& interpreter) {
ABSL_CHECK_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk);
TfLiteQuantizationParams quant;
interpreter.SetTensorParametersReadWrite(tensor_index, type, "", {size},
quant);
interpreter.SetInputs({tensor_index});
ABSL_CHECK_EQ(interpreter.tensor(interpreter.inputs()[tensor_index])->type,
type);
if (allocate_tensor) {
ABSL_CHECK_EQ(interpreter.AllocateTensors(), kTfLiteOk);
}
}

template <typename T>
std::vector<T> TfLiteTensorData(const Interpreter& interpreter,
int tensor_index) {
const TfLiteTensor* tensor =
interpreter.tensor(interpreter.inputs()[tensor_index]);
const T* tensor_ptr = reinterpret_cast<T*>(tensor->data.data);
ABSL_CHECK_NE(tensor_ptr, nullptr);
size_t tensor_size = tensor->bytes / sizeof(T);
return std::vector<T>(tensor_ptr, tensor_ptr + tensor_size);
}

TEST(InferenceCalculatorUtilsTest,
CopyCpuInputIntoInterpreterTensorWorksCorrectlyForInt32) {
tflite::Interpreter interpreter;
int tensor_index, tensor_len = 4;
AddInterpreterInput(kTfLiteInt32, tensor_len, tensor_index,
/*allocate_tensor=*/true, interpreter);
std::vector<int32_t> values{1, 2, 3, 4};
int values_len = values.size();
Tensor tensor(ElementType::kInt32, Tensor::Shape({values_len}));
std::memcpy(tensor.GetCpuWriteView().buffer<int32_t>(), values.data(),
values_len * sizeof(int32_t));
MP_EXPECT_OK(
CopyCpuInputIntoInterpreterTensor(tensor, interpreter, tensor_index));
EXPECT_THAT(TfLiteTensorData<int32_t>(interpreter, tensor_index),
ElementsAreArray(values));
}

TEST(InferenceCalculatorUtilsTest,
CopyCpuInputIntoInterpreterTensorWorksCorrectlyForFloat32) {
tflite::Interpreter interpreter;
int tensor_index, tensor_len = 4;
AddInterpreterInput(kTfLiteFloat32, tensor_len, tensor_index,
/*allocate_tensor=*/true, interpreter);
std::vector<float> values{1.0f, 2.0f, 3.0f, 4.0f};
int values_len = values.size();
Tensor tensor(ElementType::kFloat32, Tensor::Shape({values_len}));
std::memcpy(tensor.GetCpuWriteView().buffer<float>(), values.data(),
values_len * sizeof(float));
MP_EXPECT_OK(
CopyCpuInputIntoInterpreterTensor(tensor, interpreter, tensor_index));
EXPECT_THAT(TfLiteTensorData<float>(interpreter, tensor_index),
ElementsAreArray(values));
}

TEST(InferenceCalculatorUtilsTest,
CopyCpuInputIntoInterpreterTensorTypeMismatch) {
tflite::Interpreter interpreter;
int tensor_index, tensor_len = 4;
AddInterpreterInput(kTfLiteInt32, tensor_len, tensor_index,
/*allocate_tensor=*/true, interpreter);
std::vector<float> values{1.0f, 2.0f, 3.0f, 4.0f};
int values_len = values.size();
Tensor tensor(ElementType::kFloat32, Tensor::Shape({values_len}));
std::memcpy(tensor.GetCpuWriteView().buffer<float>(), values.data(),
values_len * sizeof(float));
absl::Status status =
CopyCpuInputIntoInterpreterTensor(tensor, interpreter, tensor_index);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("Input and interpreter tensor type mismatch:"));
}

TEST(InferenceCalculatorUtilsTest,
CopyCpuInputIntoInterpreterTensorSizeMismatch) {
tflite::Interpreter interpreter;
int tensor_index, tensor_len = 5;
AddInterpreterInput(kTfLiteFloat32, tensor_len, tensor_index,
/*allocate_tensor=*/true, interpreter);
std::vector<float> values{1.0f, 2.0f, 3.0f, 4.0f};
int values_len = values.size();
Tensor tensor(ElementType::kFloat32, Tensor::Shape({values_len}));
std::memcpy(tensor.GetCpuWriteView().buffer<float>(), values.data(),
values_len * sizeof(float));
absl::Status status =
CopyCpuInputIntoInterpreterTensor(tensor, interpreter, tensor_index);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("Interpreter's input size do not match the input "
"tensor's size for index"));
}

TEST(InferenceCalculatorUtilsTest,
CopyCpuInputIntoInterpreterTensorNullBuffer) {
tflite::Interpreter interpreter;
int tensor_index, tensor_len = 4;
// Make TFLite interpreter's buffer null.
AddInterpreterInput(kTfLiteFloat32, tensor_len, tensor_index,
/*allocate_tensor=*/false, interpreter);
std::vector<float> values{1.0f, 2.0f, 3.0f, 4.0f};
int values_len = values.size();
Tensor tensor(ElementType::kFloat32, Tensor::Shape({values_len}));
std::memcpy(tensor.GetCpuWriteView().buffer<float>(), values.data(),
values_len * sizeof(float));
absl::Status status =
CopyCpuInputIntoInterpreterTensor(tensor, interpreter, tensor_index);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.message(),
HasSubstr("Interpreter's input tensor buffer is null"));
}

TEST(InferenceCalculatorUtilsTest,
CopyCpuInputIntoInterpreterTensorUnsupportedType) {
tflite::Interpreter interpreter;
int tensor_index, tensor_len = 4;
AddInterpreterInput(kTfLiteUInt8, tensor_len, tensor_index,
/*allocate_tensor=*/true, interpreter);
std::vector<uint8_t> values{1, 2, 3, 4};
int values_len = values.size();
Tensor tensor(ElementType::kUInt8, Tensor::Shape({values_len}));
std::memcpy(tensor.GetCpuWriteView().buffer<uint8_t>(), values.data(),
values_len * sizeof(uint8_t));
absl::Status status =
CopyCpuInputIntoInterpreterTensor(tensor, interpreter, tensor_index);
EXPECT_FALSE(status.ok());
EXPECT_THAT(status.message(), HasSubstr("Unsupported input data type:"));
}

} // namespace
} // namespace mediapipe
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
// Reallocation is needed for memory sanity.
if (resized_tensor_shapes) interpreter_->AllocateTensors();

// TODO: Replace this using the util function in
// inference_calculator_utils.
for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteType input_tensor_type =
interpreter_->tensor(interpreter_->inputs()[i])->type;
Expand Down

0 comments on commit 82060a1

Please sign in to comment.