Skip to content

Commit

Permalink
'The Cleanening', part 1
Browse files Browse the repository at this point in the history
- Update existing documentation to be relevant to v2.0 changes
- Add missing documentation and doctests/examples
- Remove unused code
- Replace most pointer usage with `NonNull`
- Use slightly safer pointer `.cast()`s
- Disallow extracting string tensors with `extract_raw_tensor` (TODO: owned variant of this function)
- Add lifetime bound to `ModelMetadata` to fix a potential use-after-free
- Add per-run tag to `RunOptions`
- Replace `SessionInputKey` with `Cow<str>`
- Register execution providers inside `SessionBuilder::with_execution_providers`, not on commit
  • Loading branch information
decahedron1 committed Feb 21, 2024
1 parent 75aa918 commit 735284c
Show file tree
Hide file tree
Showing 27 changed files with 987 additions and 674 deletions.
96 changes: 33 additions & 63 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ pub(crate) struct Environment {
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut();

debug!("Releasing environment");

let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut();

assert_ne!(env_ptr, std::ptr::null_mut());
ortsys![unsafe ReleaseEnv(env_ptr)];
}
Expand All @@ -55,16 +55,7 @@ pub struct EnvironmentGlobalThreadPoolOptions {
pub intra_op_thread_affinity: Option<String>
}

/// Struct used to build an environment [`Environment`].
///
/// This is ONNX Runtime's main entry point. An environment _must_ be created as the first step. An [`Environment`] can
/// only be built using `EnvironmentBuilder` to configure it.
///
/// Libraries using `ort` should **not** create an environment, as only one is allowed per process. Instead, allow the
/// user to pass their own environment to the library.
///
/// **NOTE**: If the same configuration method (for example [`EnvironmentBuilder::with_name()`] is called multiple
/// times, the last value will have precedence.
/// Struct used to build an `Environment`.
pub struct EnvironmentBuilder {
name: String,
execution_providers: Vec<ExecutionProviderDispatch>,
Expand All @@ -82,55 +73,26 @@ impl Default for EnvironmentBuilder {
}

impl EnvironmentBuilder {
/// Configure the environment with a given name
///
/// **NOTE**: Since ONNX can only define one environment per process, creating multiple environments using multiple
/// [`EnvironmentBuilder`]s will end up re-using the same environment internally; a new one will _not_ be created.
/// New parameters will be ignored.
pub fn with_name<S>(mut self, name: S) -> EnvironmentBuilder
/// Configure the environment with a given name for logging purposes.
#[must_use]
pub fn with_name<S>(mut self, name: S) -> Self
where
S: Into<String>
{
self.name = name.into();
self
}

/// Configures a list of execution providers sessions created under this environment will use by default. Sessions
/// may override these via
/// [`SessionBuilder::with_execution_providers`](crate::SessionBuilder::with_execution_providers).
/// Sets a list of execution providers which all sessions created in this environment will register.
///
/// Execution providers are loaded in the order they are provided until a suitable execution provider is found. Most
/// execution providers will silently fail if they are unavailable or misconfigured (see notes below), however, some
/// may log to the console, which is sadly unavoidable. The CPU execution provider is always available, so always
/// put it last in the list (though it is not required).
/// If a session is created in this environment with [`crate::SessionBuilder::with_execution_providers`], those EPs
/// will take precedence over the environment's EPs.
///
/// Execution providers will only work if the corresponding `onnxep-*` feature is enabled and ONNX Runtime was built
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
/// feature enabled are currently ignored.
///
/// Execution provider options can be specified in the second argument. Refer to ONNX Runtime's
/// [execution provider docs](https://onnxruntime.ai/docs/execution-providers/) for configuration options. In most
/// cases, passing `None` to configure with no options is suitable.
///
/// It is recommended to enable the `cuda` EP for x86 platforms and the `acl` EP for ARM platforms for the best
/// performance, though this does mean you'll have to build ONNX Runtime for these targets. Microsoft's prebuilt
/// binaries are built with CUDA and TensorRT support, if you built `ort` with the `onnxep-cuda` or
/// `onnxep-tensorrt` features enabled.
///
/// Supported execution providers:
/// - `cpu`: Default CPU/MLAS execution provider. Available on all platforms.
/// - `acl`: Arm Compute Library
/// - `cuda`: NVIDIA CUDA/cuDNN
/// - `tensorrt`: NVIDIA TensorRT
///
/// ## Notes
///
/// - Using the CUDA/TensorRT execution providers **can terminate the process if the CUDA/TensorRT installation is
/// misconfigured**. Configuring the execution provider will seem to work, but when you attempt to run a session,
/// it will hard crash the process with a "stack buffer overrun" error. This can occur when CUDA/TensorRT is
/// missing a DLL such as `zlibwapi.dll`. To prevent your app from crashing, you can check to see if you can load
/// `zlibwapi.dll` before enabling the CUDA/TensorRT execution providers.
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> EnvironmentBuilder {
/// feature enabled will emit a warning.
#[must_use]
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self {
self.execution_providers = execution_providers.as_ref().to_vec();
self
}
Expand All @@ -139,12 +101,13 @@ impl EnvironmentBuilder {
///
/// Sessions will only use the global thread pool if they are created with
/// [`SessionBuilder::with_disable_per_session_threads`](crate::SessionBuilder::with_disable_per_session_threads).
pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> EnvironmentBuilder {
#[must_use]
pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self {
self.global_thread_pool_options = Some(options);
self
}

/// Commit the configuration to a new [`Environment`].
/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<()> {
// drop global reference to previous environment
drop(unsafe { (*G_ENV.cell.get()).take() });
Expand All @@ -164,7 +127,7 @@ impl EnvironmentBuilder {
ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism) -> Error::CreateEnvironment];
}
if let Some(spin_control) = global_thread_pool.spin_control {
ortsys![unsafe SetGlobalSpinControl(thread_options, if spin_control { 1 } else { 0 }) -> Error::CreateEnvironment];
ortsys![unsafe SetGlobalSpinControl(thread_options, i32::from(spin_control)) -> Error::CreateEnvironment];
}
if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity {
let cstr = CString::new(intra_op_thread_affinity).unwrap();
Expand Down Expand Up @@ -196,7 +159,7 @@ impl EnvironmentBuilder {
) -> Error::CreateEnvironment; nonNull(env_ptr)];
env_ptr
};
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");

unsafe {
*G_ENV.cell.get() = Some(Arc::new(Environment {
Expand All @@ -211,21 +174,28 @@ impl EnvironmentBuilder {

/// Creates an ONNX Runtime environment.
///
/// If this is not called, a default environment will be created.
///
/// In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
/// # Notes
/// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a
/// default environment will be created.
/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it.
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
#[must_use]
pub fn init() -> EnvironmentBuilder {
EnvironmentBuilder::default()
}

/// Creates an ONNX Runtime environment, using the ONNX Runtime dynamic library specified by `path`.
/// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`)
/// specified by `path`.
///
/// If this is not called, a default environment will be created.
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
///
/// In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
/// # Notes
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
#[cfg(feature = "load-dynamic")]
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
#[must_use]
pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
EnvironmentBuilder::default()
Expand Down
27 changes: 19 additions & 8 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ pub enum Error {
/// Error occurred when unterminating run options.
#[error("Failed to unterminate run options: {0}")]
RunOptionsUnsetTerminate(ErrorInternal),
/// Error occurred when setting run tag.
#[error("Failed to set run tag: {0}")]
RunOptionsSetTag(ErrorInternal),
/// Error occurred when converting data to a String
#[error("Data was not UTF-8: {0}")]
StringFromUtf8Error(#[from] string::FromUtf8Error),
Expand Down Expand Up @@ -150,10 +153,10 @@ pub enum Error {
WideFfiStringNull(#[from] widestring::error::ContainsNul<u16>),
#[error("`{0}` should be a null pointer")]
/// ORT pointer should have been null
PointerShouldBeNull(String),
PointerShouldBeNull(&'static str),
/// ORT pointer should not have been null
#[error("`{0}` should not be a null pointer")]
PointerShouldNotBeNull(String),
PointerShouldNotBeNull(&'static str),
/// The runtime type was undefined.
#[error("Undefined tensor element type")]
UndefinedTensorElementType,
Expand Down Expand Up @@ -254,6 +257,16 @@ pub enum ErrorInternal {
IntoStringError(std::ffi::IntoStringError)
}

impl ErrorInternal {
#[must_use]
pub fn as_str(&self) -> Option<&str> {
match self {
ErrorInternal::Msg(msg) => Some(msg.as_str()),
ErrorInternal::IntoStringError(_) => None
}
}
}

/// Error from downloading pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models).
#[non_exhaustive]
#[derive(Error, Debug)]
Expand Down Expand Up @@ -290,14 +303,12 @@ impl From<*mut ort_sys::OrtStatus> for OrtStatusWrapper {
}
}

pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
ptr.is_null().then_some(()).ok_or_else(|| Error::PointerShouldBeNull(name.to_owned()))
pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &'static str) -> Result<()> {
ptr.is_null().then_some(()).ok_or_else(|| Error::PointerShouldBeNull(name))
}

pub(crate) fn assert_non_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
(!ptr.is_null())
.then_some(())
.ok_or_else(|| Error::PointerShouldNotBeNull(name.to_owned()))
pub(crate) fn assert_non_null_pointer<T>(ptr: *const T, name: &'static str) -> Result<()> {
(!ptr.is_null()).then_some(()).ok_or_else(|| Error::PointerShouldNotBeNull(name))
}

impl From<OrtStatusWrapper> for Result<(), ErrorInternal> {
Expand Down
4 changes: 3 additions & 1 deletion src/execution_providers/acl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ pub struct ACLExecutionProvider {
}

impl ACLExecutionProvider {
#[must_use]
pub fn with_arena_allocator(mut self) -> Self {
self.use_arena = true;
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
}
Expand All @@ -43,7 +45,7 @@ impl ExecutionProvider for ACLExecutionProvider {
{
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ACL(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
return crate::error::status_to_result(unsafe {
OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.session_options_ptr, self.use_arena.into())
OrtSessionOptionsAppendExecutionProvider_ACL(session_builder.session_options_ptr.as_ptr(), self.use_arena.into())
})
.map_err(Error::ExecutionProvider);
}
Expand Down
4 changes: 3 additions & 1 deletion src/execution_providers/armnn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ pub struct ArmNNExecutionProvider {
}

impl ArmNNExecutionProvider {
#[must_use]
pub fn with_arena_allocator(mut self) -> Self {
self.use_arena = true;
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
}
Expand All @@ -43,7 +45,7 @@ impl ExecutionProvider for ArmNNExecutionProvider {
{
super::get_ep_register!(OrtSessionOptionsAppendExecutionProvider_ArmNN(options: *mut ort_sys::OrtSessionOptions, use_arena: std::os::raw::c_int) -> ort_sys::OrtStatusPtr);
return crate::error::status_to_result(unsafe {
OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.session_options_ptr, self.use_arena.into())
OrtSessionOptionsAppendExecutionProvider_ArmNN(session_builder.session_options_ptr.as_ptr(), self.use_arena.into())
})
.map_err(Error::ExecutionProvider);
}
Expand Down
11 changes: 10 additions & 1 deletion src/execution_providers/cann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,45 +36,52 @@ pub struct CANNExecutionProvider {
}

impl CANNExecutionProvider {
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.device_id = Some(device_id);
self
}

/// Configure the size limit of the device memory arena in bytes. This size limit is only for the execution
/// provider’s arena. The total device memory usage may be higher.
#[must_use]
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.npu_mem_limit = Some(limit);
self
}

/// Configure the strategy for extending the device's memory arena.
#[must_use]
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
self.arena_extend_strategy = Some(strategy);
self
}

/// Configure whether to use the graph inference engine to speed up performance. The recommended and default setting
/// is true. If false, it will fall back to the single-operator inference engine.
#[must_use]
pub fn with_cann_graph(mut self, enable: bool) -> Self {
self.enable_cann_graph = Some(enable);
self
}

/// Configure whether to dump the subgraph into ONNX format for analysis of subgraph segmentation.
#[must_use]
pub fn with_dump_graphs(mut self) -> Self {
self.dump_graphs = Some(true);
self
}

/// Set the precision mode of the operator. See [`CANNExecutionProviderPrecisionMode`].
#[must_use]
pub fn with_precision_mode(mut self, mode: CANNExecutionProviderPrecisionMode) -> Self {
self.precision_mode = Some(mode);
self
}

/// Configure the implementation mode for operators. Some CANN operators can have both high-precision and
/// high-performance implementations.
#[must_use]
pub fn with_implementation_mode(mut self, mode: CANNExecutionProviderImplementationMode) -> Self {
self.op_select_impl_mode = Some(mode);
self
Expand All @@ -88,11 +95,13 @@ impl CANNExecutionProvider {
/// - `SoftmaxV2`
/// - `LRN`
/// - `ROIAlign`
#[must_use]
pub fn with_implementation_mode_oplist(mut self, list: impl ToString) -> Self {
self.optypelist_for_impl_mode = Some(list.to_string());
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
}
Expand Down Expand Up @@ -150,7 +159,7 @@ impl ExecutionProvider for CANNExecutionProvider {
return Err(e);
}

let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_builder.session_options_ptr, cann_options)];
let status = crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_builder.session_options_ptr.as_ptr(), cann_options)];
crate::ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
std::mem::drop((keys, values));
return crate::error::status_to_result(status).map_err(Error::ExecutionProvider);
Expand Down
10 changes: 8 additions & 2 deletions src/execution_providers/coreml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ pub struct CoreMLExecutionProvider {
impl CoreMLExecutionProvider {
/// Limit CoreML to running on CPU only. This may decrease the performance but will provide reference output value
/// without precision loss, which is useful for validation.
#[must_use]
pub fn with_cpu_only(mut self) -> Self {
self.use_cpu_only = true;
self
}

/// Enable CoreML EP to run on a subgraph in the body of a control flow operator (i.e. a Loop, Scan or If operator).
#[must_use]
pub fn with_subgraphs(mut self) -> Self {
self.enable_on_subgraph = true;
self
Expand All @@ -30,11 +32,13 @@ impl CoreMLExecutionProvider {
/// By default the CoreML EP will be enabled for all compatible Apple devices. Setting this option will only enable
/// CoreML EP for Apple devices with a compatible Apple Neural Engine (ANE). Note, enabling this option does not
/// guarantee the entire model to be executed using ANE only.
#[must_use]
pub fn with_ane_only(mut self) -> Self {
self.only_enable_device_with_ane = true;
self
}

#[must_use]
pub fn build(self) -> ExecutionProviderDispatch {
self.into()
}
Expand Down Expand Up @@ -70,8 +74,10 @@ impl ExecutionProvider for CoreMLExecutionProvider {
if self.only_enable_device_with_ane {
flags |= 0x004;
}
return crate::error::status_to_result(unsafe { OrtSessionOptionsAppendExecutionProvider_CoreML(session_builder.session_options_ptr, flags) })
.map_err(Error::ExecutionProvider);
return crate::error::status_to_result(unsafe {
OrtSessionOptionsAppendExecutionProvider_CoreML(session_builder.session_options_ptr.as_ptr(), flags)
})
.map_err(Error::ExecutionProvider);
}

Err(Error::ExecutionProviderNotRegistered(self.as_str()))
Expand Down
Loading

0 comments on commit 735284c

Please sign in to comment.