From d6f6d01499d84d0186114262a5bd190aa09e99fa Mon Sep 17 00:00:00 2001 From: Michael Butler Date: Thu, 21 Apr 2022 16:40:09 -0700 Subject: [PATCH] Handle case where NN AIDL callback is null in IDevice::prepareModel* Prior to this change, if IDevice::prepareModel* was passed a null callback, the code would still attempt to call "notify" on the callback to return the error to the client. This CL ensures the "notify" method will not be invoked if the callback is null. Bug: N/A Test: mma Test: presubmit Change-Id: I4a15d02c4879a0261ec26cc0e7a47d0a4da86b8b --- .../utils/adapter/aidl/src/Device.cpp | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/neuralnetworks/utils/adapter/aidl/src/Device.cpp b/neuralnetworks/utils/adapter/aidl/src/Device.cpp index 1b90a1ab4b..453ec9b7a9 100644 --- a/neuralnetworks/utils/adapter/aidl/src/Device.cpp +++ b/neuralnetworks/utils/adapter/aidl/src/Device.cpp @@ -135,16 +135,26 @@ std::shared_ptr adaptPreparedModel(nn::SharedPreparedModel prepar return ndk::SharedRefBase::make(std::move(preparedModel)); } +void notify(IPreparedModelCallback* callback, ErrorStatus status, + const std::shared_ptr& preparedModel) { + if (callback != nullptr) { + const auto ret = callback->notify(status, preparedModel); + if (!ret.isOk()) { + LOG(ERROR) << "IPreparedModelCallback::notify failed with " << ret.getDescription(); + } + } +} + void notify(IPreparedModelCallback* callback, PrepareModelResult result) { if (!result.has_value()) { const auto& [message, status] = result.error(); LOG(ERROR) << message; const auto aidlCode = utils::convert(status).value_or(ErrorStatus::GENERAL_FAILURE); - callback->notify(aidlCode, nullptr); + notify(callback, aidlCode, nullptr); } else { auto preparedModel = std::move(result).value(); auto aidlPreparedModel = adaptPreparedModel(std::move(preparedModel)); - callback->notify(ErrorStatus::NONE, std::move(aidlPreparedModel)); + notify(callback, ErrorStatus::NONE, std::move(aidlPreparedModel)); } } @@ -284,7 +294,7 @@ ndk::ScopedAStatus Device::prepareModel(const Model& model, ExecutionPreference if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); - callback->notify(aidlCode, nullptr); + notify(callback.get(), aidlCode, nullptr); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); } @@ -300,7 +310,7 @@ ndk::ScopedAStatus Device::prepareModelFromCache( if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); - callback->notify(aidlCode, nullptr); + notify(callback.get(), aidlCode, nullptr); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); } @@ -317,7 +327,7 @@ ndk::ScopedAStatus Device::prepareModelWithConfig( if (!result.has_value()) { const auto& [message, code] = result.error(); const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE); - callback->notify(aidlCode, nullptr); + notify(callback.get(), aidlCode, nullptr); return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage( static_cast(aidlCode), message.c_str()); }