Skip to content

Commit

Permalink
feat: expose environment, session, & value pointers
Browse files Browse the repository at this point in the history
for libs which might want to use unsafe ort-sys functions that don't currently have a safe implementation.
  • Loading branch information
decahedron1 committed Feb 22, 2024
1 parent a5d4b80 commit c2074eb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
25 changes: 21 additions & 4 deletions src/environment.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
use std::{cell::UnsafeCell, ffi::CString, sync::atomic::AtomicPtr, sync::Arc};
use std::{
cell::UnsafeCell,
ffi::CString,
sync::{
atomic::{AtomicPtr, Ordering},
Arc
}
};

use tracing::debug;

Expand All @@ -19,11 +26,18 @@ unsafe impl Sync for EnvironmentSingleton {}
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };

#[derive(Debug)]
pub(crate) struct Environment {
pub struct Environment {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
pub(crate) env_ptr: AtomicPtr<ort_sys::OrtEnv>
}

impl Environment {
/// Loads the underlying [`ort_sys::OrtEnv`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
self.env_ptr.load(Ordering::Relaxed)
}
}

impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
Expand All @@ -36,7 +50,8 @@ impl Drop for Environment {
}
}

pub(crate) fn get_environment() -> Result<&'static Arc<Environment>> {
/// Gets a reference to the global environment, creating one if an environment has been committed yet.
pub fn get_environment() -> Result<&'static Arc<Environment>> {
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
Ok(c)
} else {
Expand Down Expand Up @@ -110,7 +125,9 @@ impl EnvironmentBuilder {
/// Commit the environment configuration and set the global environment.
pub fn commit(self) -> Result<()> {
// drop global reference to previous environment
drop(unsafe { (*G_ENV.cell.get()).take() });
if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } {
drop(env_arc);
}

let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use tracing::Level;

#[cfg(feature = "load-dynamic")]
pub use self::environment::init_from;
pub use self::environment::{init, EnvironmentBuilder, EnvironmentGlobalThreadPoolOptions};
pub use self::environment::{get_environment, init, Environment, EnvironmentBuilder, EnvironmentGlobalThreadPoolOptions};
#[cfg(feature = "fetch-models")]
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
pub use self::error::FetchModelError;
Expand Down
14 changes: 13 additions & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,13 @@ pub struct SharedSessionInner {
_environment: Arc<Environment>
}

impl SharedSessionInner {
/// Returns the underlying [`ort_sys::OrtSession`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtSession {
self.session_ptr.as_ptr()
}
}

unsafe impl Send for SharedSessionInner {}
unsafe impl Sync for SharedSessionInner {}

Expand Down Expand Up @@ -626,7 +633,12 @@ impl Session {
IoBinding::new(self)
}

/// Get an shared ([`Arc`]'d) reference to the underlying [`SharedSessionInner`], which holds the
/// Returns the underlying [`ort_sys::OrtSession`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtSession {
self.inner.ptr()
}

/// Get a shared ([`Arc`]'d) reference to the underlying [`SharedSessionInner`], which holds the
/// [`ort_sys::OrtSession`] pointer and the session allocator.
#[must_use]
pub fn inner(&self) -> Arc<SharedSessionInner> {
Expand Down
3 changes: 2 additions & 1 deletion src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ impl Value {
}
}

pub(crate) fn ptr(&self) -> *mut ort_sys::OrtValue {
/// Returns the underlying [`ort_sys::OrtValue`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtValue {
match &self.inner {
ValueInner::CppOwnedRef { ptr } | ValueInner::CppOwned { ptr, .. } | ValueInner::RustOwned { ptr, .. } => ptr.as_ptr()
}
Expand Down

0 comments on commit c2074eb

Please sign in to comment.