Skip to content

Commit

Permalink
Avoid double allocation when passing strings via IntoParam (#1713)
Browse files Browse the repository at this point in the history
* Use an iterator instead of double allocation for strings

* Add a missing trailing new line

* Use `T::default` instead of `Default::default` for clarity

* Rename `from_iter` to `string_from_iter` for clarity

* Run `cargo fmt`

* Make string_from_iter from memory safe

* Encoder is fused because Chain is fused

* Remove unsafe from string_from_iter as it is memory-safe

* Incorporate ryancerium and rylev's suggestions

* Correct documentation error
  • Loading branch information
AronParker authored Apr 28, 2022
1 parent 9deec0e commit e518d03
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 14 deletions.
34 changes: 23 additions & 11 deletions crates/libs/windows/src/core/heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,32 @@ pub unsafe fn heap_free(ptr: RawPtr) {
}
}

/// Copy a slice of `T` into a freshly allocated buffer with an additional default `T` at the end.
/// Copy len elements of an iterator of type `T` into a freshly allocated buffer.
///
/// Returns a pointer to the beginning of the buffer
/// Returns a pointer to the beginning of the buffer. This pointer must be freed when done using `heap_free`.
///
/// # Panics
///
/// This function panics if the heap allocation fails or if the pointer returned from
/// the heap allocation is not properly aligned to `T`.
pub fn heap_string<T: Copy + Default + Sized>(slice: &[T]) -> *const T {
unsafe {
let buffer = heap_alloc((slice.len() + 1) * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;
assert!(buffer.align_offset(std::mem::align_of::<T>()) == 0, "heap allocated buffer is not properly aligned");
buffer.copy_from_nonoverlapping(slice.as_ptr(), slice.len());
buffer.add(slice.len()).write(T::default());
buffer
/// This function panics if the heap allocation fails, the alignment requirements of 'T' surpass
/// 8 (HeapAlloc's alignment).
pub fn alloc_from_iter<I, T>(iter: I, len: usize) -> *const T
where
I: Iterator<Item = T>,
T: Copy,
{
// alignment of memory returned by HeapAlloc is at least 8
// Source: https://docs.microsoft.com/en-us/windows/win32/api/heapapi/nf-heapapi-heapalloc
// Ensure that T has sufficient alignment requirements
assert!(std::mem::align_of::<T>() <= 8, "T alignment surpasses HeapAlloc alignment");

let ptr = heap_alloc(len * std::mem::size_of::<T>()).expect("could not allocate string") as *mut T;

for (offset, c) in iter.take(len).enumerate() {
// SAFETY: ptr points to an allocation object of size `len`, indices accessed are always lower than `len`
unsafe {
ptr.add(offset).write(c);
}
}

ptr
}
2 changes: 1 addition & 1 deletion crates/libs/windows/src/core/pcstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ unsafe impl Abi for PCSTR {
#[cfg(feature = "alloc")]
impl<'a> IntoParam<'a, PCSTR> for &str {
fn into_param(self) -> Param<'a, PCSTR> {
Param::Boxed(PCSTR(heap_string(self.as_bytes())))
Param::Boxed(PCSTR(alloc_from_iter(self.as_bytes().iter().copied().chain(core::iter::once(0)), self.len() + 1)))
}
}
#[cfg(feature = "alloc")]
Expand Down
4 changes: 2 additions & 2 deletions crates/libs/windows/src/core/pcwstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ unsafe impl Abi for PCWSTR {
#[cfg(feature = "alloc")]
impl<'a> IntoParam<'a, PCWSTR> for &str {
fn into_param(self) -> Param<'a, PCWSTR> {
Param::Boxed(PCWSTR(heap_string(&self.encode_utf16().collect::<alloc::vec::Vec<u16>>())))
Param::Boxed(PCWSTR(alloc_from_iter(self.encode_utf16().chain(core::iter::once(0)), self.len() + 1)))
}
}
#[cfg(feature = "alloc")]
Expand All @@ -58,7 +58,7 @@ impl<'a> IntoParam<'a, PCWSTR> for alloc::string::String {
impl<'a> IntoParam<'a, PCWSTR> for &::std::ffi::OsStr {
fn into_param(self) -> Param<'a, PCWSTR> {
use ::std::os::windows::ffi::OsStrExt;
Param::Boxed(PCWSTR(heap_string(&self.encode_wide().collect::<alloc::vec::Vec<u16>>())))
Param::Boxed(PCWSTR(alloc_from_iter(self.encode_wide().chain(core::iter::once(0)), self.len() + 1)))
}
}
#[cfg(feature = "alloc")]
Expand Down

0 comments on commit e518d03

Please sign in to comment.