diff --git a/tiledb/api/src/vfs.rs b/tiledb/api/src/vfs.rs index f097248a..4bb303ca 100644 --- a/tiledb/api/src/vfs.rs +++ b/tiledb/api/src/vfs.rs @@ -6,6 +6,12 @@ use crate::config::Config; use crate::context::Context; use crate::Result as TileDBResult; +pub enum VFSLsStatus { + Continue, + Stop, + Error, +} + pub(crate) enum RawVFS { Owned(*mut ffi::tiledb_vfs_t), } @@ -415,19 +421,26 @@ impl<'ctx> VFS<'ctx> { } } - /// # Safety - /// This function is unsafe because of the data pointer being passed. - pub unsafe fn ls( - &self, - uri: &str, - callback: ffi::LSCallback, - data: *mut ::std::os::raw::c_void, - ) -> TileDBResult<()> { + pub fn ls(&self, uri: &str, mut callback: F) -> TileDBResult<()> + where + F: FnMut(&str) -> VFSLsStatus, + { let c_ctx = self.context.as_mut_ptr(); let c_vfs = *self.raw; let c_uri = cstring!(uri); + + // See the StackOverflow link on vfs_ls_cb_handler + let mut cb: &mut dyn FnMut(&str) -> VFSLsStatus = &mut callback; + let cb = &mut cb; + let res = unsafe { - ffi::tiledb_vfs_ls(c_ctx, c_vfs, c_uri.as_ptr(), callback, data) + ffi::tiledb_vfs_ls( + c_ctx, + c_vfs, + c_uri.as_ptr(), + Some(vfs_ls_cb_handler), + cb as *mut _ as *mut std::ffi::c_void, + ) }; if res == ffi::TILEDB_OK { @@ -437,24 +450,29 @@ impl<'ctx> VFS<'ctx> { } } - /// # Safety - /// This function is unsafe because of the data pointer being passed. - pub unsafe fn ls_recursive( + pub fn ls_recursive( &self, uri: &str, - callback: ffi::LSRecursiveCallback, - data: *mut ::std::os::raw::c_void, - ) -> TileDBResult<()> { + mut callback: F, + ) -> TileDBResult<()> + where + F: FnMut(&str, u64) -> VFSLsStatus, + { let c_ctx = self.context.as_mut_ptr(); let c_vfs = *self.raw; let c_uri = cstring!(uri); + + // See the StackOverflow link on vfs_ls_recursive_cb_handler + let mut cb: &mut dyn FnMut(&str, u64) -> VFSLsStatus = &mut callback; + let cb = &mut cb; + let res = unsafe { ffi::tiledb_vfs_ls_recursive( c_ctx, c_vfs, c_uri.as_ptr(), - callback, - data, + Some(vfs_ls_recursive_cb_handler), + cb as *mut _ as *mut std::ffi::c_void, ) }; @@ -466,6 +484,73 @@ impl<'ctx> VFS<'ctx> { } } +// This bit of complexity is based on the StackOverflow answer here: +// https://stackoverflow.com/a/32270215 +extern "C" fn vfs_ls_cb_handler( + path: *const ::std::os::raw::c_char, + callback_data: *mut ::std::os::raw::c_void, +) -> std::ffi::c_int { + let closure: &mut &mut dyn FnMut(&str) -> VFSLsStatus = unsafe { + std::mem::transmute( + // This complicated cast is brought to you by clippy. The original + // did not require this, but the original is also two years old. + &mut *(callback_data + as *mut &mut dyn for<'a> std::ops::FnMut( + &'a str, + ) + -> VFSLsStatus), + ) + }; + + let c_str: &std::ffi::CStr = unsafe { std::ffi::CStr::from_ptr(path) }; + let slice = c_str.to_str(); + + if slice.is_err() { + return -1; + } + + match closure(slice.unwrap()) { + VFSLsStatus::Continue => 1, + VFSLsStatus::Stop => 0, + VFSLsStatus::Error => -1, + } +} + +// This bit of complexity is based on the StackOverflow answer here: +// https://stackoverflow.com/a/32270215 +extern "C" fn vfs_ls_recursive_cb_handler( + path: *const ::std::os::raw::c_uchar, + path_len: usize, + object_size: u64, + callback_data: *mut ::std::os::raw::c_void, +) -> std::ffi::c_int { + let closure: &mut &mut dyn FnMut(&str, u64) -> VFSLsStatus = unsafe { + std::mem::transmute( + // This complicated cast is brought to you by clippy. The original + // did not require this, but the original is also two years old. + &mut *(callback_data + as *mut &mut dyn for<'a> std::ops::FnMut( + &'a str, + &'a u64, + ) + -> VFSLsStatus), + ) + }; + + let path_slice: &[u8] = + unsafe { std::slice::from_raw_parts(path, path_len) }; + let c_str = std::str::from_utf8(path_slice); + if c_str.is_err() { + return -1; + } + + match closure(c_str.unwrap(), object_size) { + VFSLsStatus::Continue => 1, + VFSLsStatus::Stop => 0, + VFSLsStatus::Error => -1, + } +} + impl<'ctx> VFSHandle<'ctx> { pub fn is_closed(&self) -> TileDBResult { let c_ctx = self.context.as_mut_ptr(); @@ -767,14 +852,6 @@ mod tests { Ok(()) } - unsafe extern "C" fn ls_callback( - _: *const std::os::raw::c_char, - count: *mut std::os::raw::c_void, - ) -> i32 { - *(count as *mut u64) += 1; - 1 - } - #[test] fn vfs_ls() -> TileDBResult<()> { let ctx = Context::new()?; @@ -792,14 +869,12 @@ mod tests { let tmp_uri = tmp_dir.path().to_str().expect("Error getting temp dir"); let mut count: u64 = 0; - unsafe { - vfs.ls( - tmp_uri, - Some(ls_callback), - &mut count as *mut std::ffi::c_ulonglong - as *mut std::ffi::c_void, - )?; - } + let cb = |_: &str| -> VFSLsStatus { + count += 1; + VFSLsStatus::Continue + }; + + vfs.ls(tmp_uri, cb)?; // ls only sees the three directories. assert_eq!(count, 3); @@ -807,16 +882,6 @@ mod tests { Ok(()) } - unsafe extern "C" fn ls_recursive_callback( - _: *const std::os::raw::c_char, - _: usize, - _: u64, - count: *mut std::os::raw::c_void, - ) -> i32 { - *(count as *mut u64) += 1; - 1 - } - #[test] fn vfs_ls_recursive_old() -> TileDBResult<()> { // Recursive ls over the Posix backend doesn't exist before 2.21 @@ -838,15 +903,11 @@ mod tests { let tmp_uri = tmp_dir.path().to_str().expect("Error getting tmp_uri"); let mut count: u64 = 0; - assert!(unsafe { - vfs.ls_recursive( - tmp_uri, - Some(ls_recursive_callback), - &mut count as *mut std::ffi::c_ulonglong - as *mut std::ffi::c_void, - ) - .is_err() - }); + let cb = |_: &str, _: u64| -> VFSLsStatus { + count += 1; + VFSLsStatus::Continue + }; + assert!(vfs.ls_recursive(tmp_uri, cb).is_err()); Ok(()) } @@ -875,14 +936,11 @@ mod tests { let tmp_uri = tmp_dir.path().to_str().expect("Error getting temp dir"); let mut count: u64 = 0; - unsafe { - vfs.ls_recursive( - tmp_uri, - Some(ls_recursive_callback), - &mut count as *mut std::ffi::c_ulonglong - as *mut std::ffi::c_void, - )?; - } + let cb = |_: &str, _: u64| -> VFSLsStatus { + count += 1; + VFSLsStatus::Continue + }; + vfs.ls_recursive(tmp_uri, cb)?; // ls_recursive sees three directories and one file. assert_eq!(count, 4); diff --git a/tiledb/sys/src/vfs.rs b/tiledb/sys/src/vfs.rs index 69885a64..83aee51e 100644 --- a/tiledb/sys/src/vfs.rs +++ b/tiledb/sys/src/vfs.rs @@ -18,7 +18,7 @@ pub type LSCallback = ::std::option::Option< pub type LSRecursiveCallback = ::std::option::Option< unsafe extern "C" fn( - path: *const std::os::raw::c_char, + path: *const std::os::raw::c_uchar, path_len: usize, object_size: u64, callback_data: *mut ::std::os::raw::c_void,