Skip to content

Commit

Permalink
feat: device-allocated Values
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Mar 9, 2024
1 parent 1ee8b21 commit e57e08b
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 80 deletions.
6 changes: 5 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,11 @@ pub enum Error {
#[error("{0}")]
CustomError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("String tensors cannot be borrowed as mutable")]
StringTensorNotMutable
StringTensorNotMutable,
#[error("Could't get `MemoryInfo` from allocator: {0}")]
AllocatorGetInfo(ErrorInternal),
#[error("Could't get `MemoryType` from memory info: {0}")]
GetMemoryType(ErrorInternal)
}

impl From<Infallible> for Error {
Expand Down
47 changes: 1 addition & 46 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub use self::error::FetchModelError;
pub use self::error::{Error, ErrorInternal, Result};
pub use self::execution_providers::*;
pub use self::io_binding::IoBinding;
pub use self::memory::{AllocationDevice, Allocator, MemoryInfo};
pub use self::memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType};
pub use self::metadata::ModelMetadata;
pub use self::operator::{
io::{OperatorInput, OperatorOutput},
Expand Down Expand Up @@ -380,51 +380,6 @@ impl From<GraphOptimizationLevel> for ort_sys::GraphOptimizationLevel {
}
}

/// Execution provider allocator type.
#[derive(Debug, Copy, Clone)]
pub enum AllocatorType {
/// Default device-specific allocator.
Device,
/// Arena allocator.
Arena
}

impl From<AllocatorType> for ort_sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
match val {
AllocatorType::Device => ort_sys::OrtAllocatorType::OrtDeviceAllocator,
AllocatorType::Arena => ort_sys::OrtAllocatorType::OrtArenaAllocator
}
}
}

/// Memory types for allocated memory.
#[derive(Default, Debug, Copy, Clone)]
pub enum MemoryType {
/// Any CPU memory used by non-CPU execution provider.
CPUInput,
/// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED.
CPUOutput,
/// The default allocator for an execution provider.
#[default]
Default
}

impl MemoryType {
/// Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. `CUDA_PINNED`.
pub const CPU: MemoryType = MemoryType::CPUOutput;
}

impl From<MemoryType> for ort_sys::OrtMemType {
fn from(val: MemoryType) -> Self {
match val {
MemoryType::CPUInput => ort_sys::OrtMemType::OrtMemTypeCPUInput,
MemoryType::CPUOutput => ort_sys::OrtMemType::OrtMemTypeCPUOutput,
MemoryType::Default => ort_sys::OrtMemType::OrtMemTypeDefault
}
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
105 changes: 93 additions & 12 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,55 @@ use std::{

use super::{
error::{Error, Result},
ortsys, AllocatorType, MemoryType
ortsys
};
use crate::{char_p_to_string, error::status_to_result};
use crate::{char_p_to_string, error::status_to_result, Session};

/// An ONNX Runtime allocator, used to manage the allocation of [`crate::Value`]s.
#[derive(Debug)]
pub struct Allocator {
pub(crate) ptr: NonNull<ort_sys::OrtAllocator>,
is_default: bool
is_default: bool,
_info: Option<MemoryInfo>
}

impl Allocator {
pub(crate) unsafe fn from_raw_unchecked(ptr: *mut ort_sys::OrtAllocator) -> Allocator {
Allocator {
ptr: NonNull::new_unchecked(ptr),
is_default: false
is_default: false,
_info: None
}
}

pub(crate) unsafe fn free<T>(&self, ptr: *mut T) {
self.ptr.as_ref().Free.unwrap_unchecked()(self.ptr.as_ptr(), ptr.cast());
}

/// Creates a new [`Allocator`] for the given session, to allocate memory on the device described in the
/// [`MemoryInfo`].
///
/// For example, to create an allocator to allocate pinned memory for CUDA:
/// ```no_run
/// # use ort::{Allocator, Session, MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// # let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?;
/// let allocator = Allocator::new(
/// &session,
/// MemoryInfo::new(AllocationDevice::CUDAPinned, 0, AllocatorType::Device, MemoryType::CPUInput)?
/// )?;
/// # Ok(())
/// # }
/// ```
pub fn new(session: &Session, memory_info: MemoryInfo) -> Result<Self> {
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe CreateAllocator(session.ptr(), memory_info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
is_default: false,
_info: Some(memory_info)
})
}
}

impl Default for Allocator {
Expand All @@ -35,7 +62,8 @@ impl Default for Allocator {
status_to_result(ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr); nonNull(allocator_ptr)]).unwrap();
Self {
ptr: unsafe { NonNull::new_unchecked(allocator_ptr) },
is_default: true
is_default: true,
_info: None
}
}
}
Expand Down Expand Up @@ -107,10 +135,64 @@ impl TryFrom<String> for AllocationDevice {
}
}

/// Execution provider allocator type.
#[derive(Debug, Copy, Clone)]
pub enum AllocatorType {
/// Default device-specific allocator.
Device,
/// Arena allocator.
Arena
}

impl From<AllocatorType> for ort_sys::OrtAllocatorType {
fn from(val: AllocatorType) -> Self {
match val {
AllocatorType::Device => ort_sys::OrtAllocatorType::OrtDeviceAllocator,
AllocatorType::Arena => ort_sys::OrtAllocatorType::OrtArenaAllocator
}
}
}

/// Memory types for allocated memory.
#[derive(Default, Debug, Copy, Clone)]
pub enum MemoryType {
/// Any CPU memory used by non-CPU execution provider.
CPUInput,
/// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED.
CPUOutput,
/// The default allocator for an execution provider.
#[default]
Default
}

impl MemoryType {
/// Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. `CUDA_PINNED`.
pub const CPU: MemoryType = MemoryType::CPUOutput;
}

impl From<MemoryType> for ort_sys::OrtMemType {
fn from(val: MemoryType) -> Self {
match val {
MemoryType::CPUInput => ort_sys::OrtMemType::OrtMemTypeCPUInput,
MemoryType::CPUOutput => ort_sys::OrtMemType::OrtMemTypeCPUOutput,
MemoryType::Default => ort_sys::OrtMemType::OrtMemTypeDefault
}
}
}

impl From<ort_sys::OrtMemType> for MemoryType {
fn from(value: ort_sys::OrtMemType) -> Self {
match value {
ort_sys::OrtMemType::OrtMemTypeCPUInput => MemoryType::CPUInput,
ort_sys::OrtMemType::OrtMemTypeCPUOutput => MemoryType::CPUOutput,
ort_sys::OrtMemType::OrtMemTypeDefault => MemoryType::Default
}
}
}

#[derive(Debug)]
pub struct MemoryInfo {
pub(crate) ptr: NonNull<ort_sys::OrtMemoryInfo>,
memory_type: MemoryType,
should_release: bool
}

Expand All @@ -124,7 +206,6 @@ impl MemoryInfo {
];
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) },
memory_type,
should_release: true
})
}
Expand All @@ -140,15 +221,15 @@ impl MemoryInfo {
];
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(memory_info_ptr) },
memory_type,
should_release: true
})
}

/// Returns the [`MemoryType`] this struct was created with.
#[must_use]
pub fn memory_type(&self) -> MemoryType {
self.memory_type
/// Returns the [`MemoryType`] described by this struct.
pub fn memory_type(&self) -> Result<MemoryType> {
let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault;
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type) -> Error::GetMemoryType];
Ok(MemoryType::from(raw_type))
}

/// Returns the [`AllocationDevice`] this struct was created with.
Expand Down
Loading

0 comments on commit e57e08b

Please sign in to comment.