From 906d5071aded7863fe754a8e0508e6cfe17e6c67 Mon Sep 17 00:00:00 2001 From: "Paul J. Davis" Date: Fri, 20 Sep 2024 15:08:18 -0500 Subject: [PATCH 01/42] Replace buffer management with Arrow buffers This implements a second version of the Query interface using Arrow Arrays. The core idea here is that we can share Arrow arrays freely and then convert them to mutable buffers as long as there are no external references to them. This is all done safely and returns an error if any buffer is externally referenced. --- tiledb/api/Cargo.toml | 10 +- .../examples/multi_range_subarray_arrow.rs | 154 ++ .../examples/query_condition_dense_arrow.rs | 242 +++ .../examples/query_condition_sparse_arrow.rs | 233 +++ tiledb/api/examples/quickstart_dense_arrow.rs | 176 ++ .../quickstart_sparse_string_arrow.rs | 139 ++ .../api/examples/reading_incomplete_arrow.rs | 179 ++ tiledb/api/src/array/attribute/mod.rs | 1 - tiledb/api/src/array/dimension/mod.rs | 1 - tiledb/api/src/array/mod.rs | 4 - tiledb/api/src/array/schema/mod.rs | 5 - tiledb/api/src/array/schema/strategy.rs | 9 + tiledb/api/src/array/strategy.rs | 17 +- tiledb/api/src/datatype/mod.rs | 1 - tiledb/api/src/filter/mod.rs | 1 - tiledb/api/src/lib.rs | 2 +- tiledb/api/src/query/buffer/mod.rs | 1 - tiledb/api/src/query/read/output/mod.rs | 1 - tiledb/api/src/query/write/input/mod.rs | 1 - tiledb/api/src/query_arrow/arrow.rs | 520 ++++++ tiledb/api/src/query_arrow/buffers.rs | 1645 +++++++++++++++++ tiledb/api/src/query_arrow/fields.rs | 171 ++ tiledb/api/src/query_arrow/mod.rs | 578 ++++++ tiledb/api/src/query_arrow/subarray.rs | 64 + tiledb/api/src/range.rs | 6 + 25 files changed, 4139 insertions(+), 22 deletions(-) create mode 100644 tiledb/api/examples/multi_range_subarray_arrow.rs create mode 100644 tiledb/api/examples/query_condition_dense_arrow.rs create mode 100644 tiledb/api/examples/query_condition_sparse_arrow.rs create mode 100644 tiledb/api/examples/quickstart_dense_arrow.rs create mode 100644 tiledb/api/examples/quickstart_sparse_string_arrow.rs create mode 100644 tiledb/api/examples/reading_incomplete_arrow.rs create mode 100644 tiledb/api/src/query_arrow/arrow.rs create mode 100644 tiledb/api/src/query_arrow/buffers.rs create mode 100644 tiledb/api/src/query_arrow/fields.rs create mode 100644 tiledb/api/src/query_arrow/mod.rs create mode 100644 tiledb/api/src/query_arrow/subarray.rs diff --git a/tiledb/api/Cargo.toml b/tiledb/api/Cargo.toml index 840629f3..3b7cfd38 100644 --- a/tiledb/api/Cargo.toml +++ b/tiledb/api/Cargo.toml @@ -9,7 +9,7 @@ path = "src/lib.rs" [dependencies] anyhow = { workspace = true } -arrow = { version = "52.0.0", features = ["prettyprint"], optional = true } +arrow = { version = "52.0.0", features = ["prettyprint"] } itertools = "0" num-traits = { version = "0.2", optional = true } paste = "1.0" @@ -35,5 +35,9 @@ tiledb-utils = { workspace = true } [features] default = [] -proptest-strategies = ["dep:num-traits", "dep:proptest", "dep:proptest-derive", "dep:tiledb-test-utils"] -arrow = ["dep:arrow"] +proptest-strategies = [ + "dep:num-traits", + "dep:proptest", + "dep:proptest-derive", + "dep:tiledb-test-utils", +] diff --git a/tiledb/api/examples/multi_range_subarray_arrow.rs b/tiledb/api/examples/multi_range_subarray_arrow.rs new file mode 100644 index 00000000..1019a5d1 --- /dev/null +++ b/tiledb/api/examples/multi_range_subarray_arrow.rs @@ -0,0 +1,154 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::array as aa; +use itertools::izip; + +use tiledb::array::{ + Array, ArrayType, AttributeData, CellOrder, DimensionData, DomainData, + SchemaData, TileOrder, +}; +use tiledb::context::Context; +use tiledb::error::Error as TileDBError; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryStatus, QueryType}; +use tiledb::Result as TileDBResult; +use tiledb::{Datatype, Factory}; + +const ARRAY_URI: &str = "multi_range_slicing"; + +/// This example creates a 4x4 dense array with the contents: +/// +/// Col: 1 2 3 4 +/// Row: 1 1 2 3 4 +/// 2 5 6 7 8 +/// 3 9 10 11 12 +/// 4 13 14 15 16 +/// +/// The query run restricts rows to [1, 2, 4] and returns all columns which +/// should produce these rows: +/// +/// Row Col Value +/// 1 1 1 +/// 1 2 2 +/// 1 3 3 +/// 1 4 4 +/// 2 1 5 +/// 2 2 6 +/// 2 3 7 +/// 2 4 8 +/// 4 1 13 +/// 4 2 14 +/// 4 3 15 +/// 4 4 16 +fn main() -> TileDBResult<()> { + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + let _ = std::env::set_current_dir( + PathBuf::from(manifest_dir).join("examples").join("output"), + ); + } + + let ctx = Context::new()?; + if Array::exists(&ctx, ARRAY_URI)? { + Array::delete(&ctx, ARRAY_URI)?; + } + + create_array(&ctx)?; + write_array(&ctx)?; + + let array = Array::open(&ctx, ARRAY_URI, tiledb::array::Mode::Read)?; + let mut query = QueryBuilder::new(array, QueryType::Read) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field("rows") + .field("cols") + .field("a") + .end_fields() + .start_subarray() + .add_range("rows", &[1, 2]) + .add_range("rows", &[4, 4]) + .add_range("cols", &[1, 4]) + .end_subarray() + .build() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + let status = query + .submit() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + if !matches!(status, QueryStatus::Completed) { + return Err(TileDBError::Other("Make this better.".to_string())); + } + + let buffers = query + .buffers() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + let rows = buffers.get::("rows").unwrap(); + let cols = buffers.get::("cols").unwrap(); + let attr = buffers.get::("a").unwrap(); + + for (r, c, a) in izip!(rows.values(), cols.values(), attr.values()) { + println!("{} {} {}", r, c, a); + } + + Ok(()) +} + +fn create_array(ctx: &Context) -> TileDBResult<()> { + let schema = SchemaData { + array_type: ArrayType::Dense, + domain: DomainData { + dimension: vec![ + DimensionData { + name: "rows".to_owned(), + datatype: Datatype::Int32, + constraints: ([1i32, 4], 4i32).into(), + filters: None, + }, + DimensionData { + name: "cols".to_owned(), + datatype: Datatype::Int32, + constraints: ([1i32, 4], 4i32).into(), + filters: None, + }, + ], + }, + attributes: vec![AttributeData { + name: "a".to_owned(), + datatype: Datatype::Int32, + ..Default::default() + }], + tile_order: Some(TileOrder::RowMajor), + cell_order: Some(CellOrder::RowMajor), + + ..Default::default() + }; + + let schema = schema.create(ctx)?; + Array::create(ctx, ARRAY_URI, schema)?; + Ok(()) +} + +fn write_array(ctx: &Context) -> TileDBResult<()> { + let data = Arc::new(aa::Int32Array::from(vec![ + 1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ])); + + let array = + tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Write)?; + + let mut query = QueryBuilder::new(array, QueryType::Write) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field_with_buffer("a", data) + .end_fields() + .build() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + let (_, _) = query + .submit() + .and_then(|_| query.finalize()) + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + Ok(()) +} diff --git a/tiledb/api/examples/query_condition_dense_arrow.rs b/tiledb/api/examples/query_condition_dense_arrow.rs new file mode 100644 index 00000000..5cef4450 --- /dev/null +++ b/tiledb/api/examples/query_condition_dense_arrow.rs @@ -0,0 +1,242 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::array as aa; +use itertools::izip; + +use tiledb::array::{ + Array, ArrayType, AttributeBuilder, DimensionBuilder, DomainBuilder, + SchemaBuilder, +}; +use tiledb::error::Error as TileDBError; +use tiledb::query::conditions::QueryConditionExpr as QC; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; +use tiledb::{Context, Datatype, Result as TileDBResult}; + +const ARRAY_URI: &str = "query_condition_dense"; +const NUM_ELEMS: i32 = 10; +const C_FILL_VALUE: i32 = -1; +const D_FILL_VALUE: f32 = 0.0; + +/// Demonstrate reading dense arrays with query conditions. +fn main() -> TileDBResult<()> { + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + let _ = std::env::set_current_dir( + PathBuf::from(manifest_dir).join("examples").join("output"), + ); + } + + let ctx = Context::new()?; + if Array::exists(&ctx, ARRAY_URI)? { + Array::delete(&ctx, ARRAY_URI)?; + } + + create_array(&ctx)?; + write_array(&ctx)?; + + println!("Reading the entire array:"); + read_array(&ctx, None)?; + + println!("Reading: a is null"); + let qc = QC::field("a").is_null(); + read_array(&ctx, Some(qc))?; + + println!("Reading: b < \"eve\""); + let qc = QC::field("b").lt("eve"); + read_array(&ctx, Some(qc))?; + + println!("Reading: c >= 1"); + let qc = QC::field("c").ge(1i32); + read_array(&ctx, Some(qc))?; + + println!("Reading: 3.0 <= d <= 4.0"); + let qc = QC::field("d").ge(3.0f32) & QC::field("d").le(4.0f32); + read_array(&ctx, Some(qc))?; + + println!("Reading: (a is not null) && (b < \"eve\") && (3.0 <= d <= 4.0)"); + let qc = QC::field("a").not_null() + & QC::field("b").lt("eve") + & QC::field("d").ge(3.0f32) + & QC::field("d").le(4.0f32); + read_array(&ctx, Some(qc))?; + + Ok(()) +} + +/// Read the array with the optional query condition and print the results +/// to stdout. +fn read_array(ctx: &Context, qc: Option) -> TileDBResult<()> { + let array = tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Read)?; + let mut query = QueryBuilder::new(array, QueryType::Read) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field("index") + .field("a") + .field("b") + .field("c") + .field("d") + .end_fields() + .start_subarray() + .add_range("index", &[0i32, NUM_ELEMS - 1]) + .end_subarray(); + + query = if let Some(qc) = qc { + query.with_query_condition(qc) + } else { + query + }; + + let mut query = query + .build() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + let status = query + .submit() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + if !status.is_complete() { + return Err(TileDBError::Other("Query incomplete.".to_string())); + } + + let buffers = query.buffers().map_err(|e| { + TileDBError::Other(format!("Error getting buffers: {e}")) + })?; + + let index = buffers.get::("index").unwrap(); + let a = buffers.get::("a").unwrap(); + let b = buffers.get::("b").unwrap(); + let c = buffers.get::("c").unwrap(); + let d = buffers.get::("d").unwrap(); + + for (index, a, b, c, d) in izip!(index, a, b, c, d) { + if a.is_some() { + println!( + "{}: '{}' '{}' '{}' '{}'", + index.unwrap(), + a.unwrap(), + b.unwrap(), + c.unwrap(), + d.unwrap() + ); + } else { + println!( + "{}: null '{}' '{}' '{}'", + index.unwrap(), + b.unwrap(), + c.unwrap(), + d.unwrap() + ); + } + } + + Ok(()) +} + +/// Function to create the TileDB array used in this example. +/// The array will be 1D with size 1 with dimension "index". +/// The bounds on the index will be 0 through 9, inclusive. +/// +/// The array has four attributes. The four attributes are +/// - "a" (type i32) +/// - "b" (type String) +/// - "c" (type i32) +/// - "d" (type f32) +fn create_array(ctx: &Context) -> TileDBResult<()> { + let domain = { + let index = DimensionBuilder::new( + ctx, + "index", + Datatype::Int32, + ([0, NUM_ELEMS - 1], 4), + )? + .build(); + + DomainBuilder::new(ctx)?.add_dimension(index)?.build() + }; + + let attr_a = AttributeBuilder::new(ctx, "a", Datatype::Int32)? + .nullability(true)? + .build(); + + let attr_b = AttributeBuilder::new(ctx, "b", Datatype::StringAscii)? + .var_sized()? + .build(); + + let attr_c = AttributeBuilder::new(ctx, "c", Datatype::Int32)? + .fill_value(C_FILL_VALUE)? + .build(); + + let attr_d = AttributeBuilder::new(ctx, "d", Datatype::Float32)? + .fill_value(D_FILL_VALUE)? + .build(); + + let schema = SchemaBuilder::new(ctx, ArrayType::Dense, domain)? + .add_attribute(attr_a)? + .add_attribute(attr_b)? + .add_attribute(attr_c)? + .add_attribute(attr_d)? + .build()?; + + Array::create(ctx, ARRAY_URI, schema) +} + +/// Write the following data to the array: +/// +/// index | a | b | c | d +/// ------------------------------- +/// 0 | null | alice | 0 | 4.1 +/// 1 | 2 | bob | 0 | 3.4 +/// 2 | null | craig | 0 | 5.6 +/// 3 | 4 | dave | 0 | 3.7 +/// 4 | null | erin | 0 | 2.3 +/// 5 | 6 | frank | 0 | 1.7 +/// 6 | null | grace | 1 | 3.8 +/// 7 | 8 | heidi | 2 | 4.9 +/// 8 | null | ivan | 3 | 3.2 +/// 9 | 10 | judy | 4 | 3.1 +fn write_array(ctx: &Context) -> TileDBResult<()> { + let a_data = Arc::new(aa::Int32Array::from(vec![ + None, + Some(2), + None, + Some(4), + None, + Some(6), + None, + Some(8), + None, + Some(10), + ])); + let b_data = Arc::new(aa::LargeStringArray::from(vec![ + "alice", "bob", "craig", "daeve", "erin", "frank", "grace", "heidi", + "ivan", "judy", + ])); + let c_data = + Arc::new(aa::Int32Array::from(vec![0i32, 0, 0, 0, 0, 0, 1, 2, 3, 4])); + let d_data = Arc::new(aa::Float32Array::from(vec![ + 4.1f32, 3.4, 5.6, 3.7, 2.3, 1.7, 3.8, 4.9, 3.2, 3.1, + ])); + + let array = + tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Write)?; + + let mut query = QueryBuilder::new(array, QueryType::Write) + .start_fields() + .field_with_buffer("a", a_data) + .field_with_buffer("b", b_data) + .field_with_buffer("c", c_data) + .field_with_buffer("d", d_data) + .end_fields() + .start_subarray() + .add_range("index", &[0i32, NUM_ELEMS - 1]) + .end_subarray() + .build() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + query + .submit() + .and_then(|_| query.finalize()) + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + Ok(()) +} diff --git a/tiledb/api/examples/query_condition_sparse_arrow.rs b/tiledb/api/examples/query_condition_sparse_arrow.rs new file mode 100644 index 00000000..3fae9b6d --- /dev/null +++ b/tiledb/api/examples/query_condition_sparse_arrow.rs @@ -0,0 +1,233 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::array as aa; +use itertools::izip; + +use tiledb::array::{ + Array, ArrayType, AttributeBuilder, CellOrder, DimensionBuilder, + DomainBuilder, SchemaBuilder, +}; +use tiledb::error::Error as TileDBError; +use tiledb::query::conditions::QueryConditionExpr as QC; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; +use tiledb::{Context, Datatype, Result as TileDBResult}; + +const ARRAY_URI: &str = "query_condition_sparse"; +const NUM_ELEMS: i32 = 10; +const C_FILL_VALUE: i32 = -1; +const D_FILL_VALUE: f32 = 0.0; + +/// Demonstrate reading sparse arrays with query conditions. +fn main() -> TileDBResult<()> { + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + let _ = std::env::set_current_dir( + PathBuf::from(manifest_dir).join("examples").join("output"), + ); + } + + let ctx = Context::new()?; + if Array::exists(&ctx, ARRAY_URI)? { + Array::delete(&ctx, ARRAY_URI)?; + } + + create_array(&ctx)?; + write_array(&ctx)?; + + println!("Reading the entire array:"); + read_array(&ctx, None)?; + + println!("Reading: a is null"); + let qc = QC::field("a").is_null(); + read_array(&ctx, Some(qc))?; + + println!("Reading: b < \"eve\""); + let qc = QC::field("b").lt("eve"); + read_array(&ctx, Some(qc))?; + + println!("Reading: c >= 1"); + let qc = QC::field("c").ge(1i32); + read_array(&ctx, Some(qc))?; + + println!("Reading: 3.0 <= d <= 4.0"); + let qc = QC::field("d").ge(3.0f32) & QC::field("d").le(4.0f32); + read_array(&ctx, Some(qc))?; + + println!("Reading: (a is not null) && (b < \"eve\") && (3.0 <= d <= 4.0)"); + let qc = QC::field("a").not_null() + & QC::field("b").lt("eve") + & QC::field("d").ge(3.0f32) + & QC::field("d").le(4.0f32); + read_array(&ctx, Some(qc))?; + + Ok(()) +} + +/// Read the array with the optional query condition and print the results +/// to stdout. +fn read_array(ctx: &Context, qc: Option) -> TileDBResult<()> { + let array = tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Read)?; + let mut query = QueryBuilder::new(array, QueryType::Read) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field("index") + .field("a") + .field("b") + .field("c") + .field("d") + .end_fields() + .start_subarray() + .add_range("index", &[0i32, NUM_ELEMS - 1]) + .end_subarray(); + + query = if let Some(qc) = qc { + query.with_query_condition(qc) + } else { + query + }; + + let mut query = query.build()?; + let status = query.submit()?; + + if !status.is_complete() { + return Err(TileDBError::Other("Query did not complete.".to_string())); + } + + let buffers = query.buffers()?; + + let index = buffers.get::("index").unwrap(); + let a = buffers.get::("a").unwrap(); + let b = buffers.get::("b").unwrap(); + let c = buffers.get::("c").unwrap(); + let d = buffers.get::("d").unwrap(); + + for (index, a, b, c, d) in izip!(index, a, b, c, d) { + if a.is_some() { + println!( + "{}: '{}' '{}' '{}', '{}'", + index.unwrap(), + a.unwrap(), + b.unwrap(), + c.unwrap(), + d.unwrap() + ) + } else { + println!( + "{}: null '{}', '{}', '{}'", + index.unwrap(), + b.unwrap(), + c.unwrap(), + d.unwrap() + ) + } + } + + Ok(()) +} + +/// Function to create the TileDB array used in this example. +/// The array will be 1D with size 1 with dimension "index". +/// The bounds on the index will be 0 through 9, inclusive. +/// +/// The array has four attributes. The four attributes are +/// - "a" (type i32) +/// - "b" (type String) +/// - "c" (type i32) +/// - "d" (type f32) +fn create_array(ctx: &Context) -> TileDBResult<()> { + let domain = { + let index = DimensionBuilder::new( + ctx, + "index", + Datatype::Int32, + ([0, NUM_ELEMS - 1], 4), + )? + .build(); + + DomainBuilder::new(ctx)?.add_dimension(index)?.build() + }; + + let attr_a = AttributeBuilder::new(ctx, "a", Datatype::Int32)? + .nullability(true)? + .build(); + + let attr_b = AttributeBuilder::new(ctx, "b", Datatype::StringAscii)? + .var_sized()? + .build(); + + let attr_c = AttributeBuilder::new(ctx, "c", Datatype::Int32)? + .fill_value(C_FILL_VALUE)? + .build(); + + let attr_d = AttributeBuilder::new(ctx, "d", Datatype::Float32)? + .fill_value(D_FILL_VALUE)? + .build(); + + let schema = SchemaBuilder::new(ctx, ArrayType::Sparse, domain)? + .cell_order(CellOrder::RowMajor)? + .add_attribute(attr_a)? + .add_attribute(attr_b)? + .add_attribute(attr_c)? + .add_attribute(attr_d)? + .build()?; + + Array::create(ctx, ARRAY_URI, schema) +} + +/// Write the following data to the array: +/// +/// index | a | b | c | d +/// ------------------------------- +/// 0 | null | alice | 0 | 4.1 +/// 1 | 2 | bob | 0 | 3.4 +/// 2 | null | craig | 0 | 5.6 +/// 3 | 4 | dave | 0 | 3.7 +/// 4 | null | erin | 0 | 2.3 +/// 5 | 6 | frank | 0 | 1.7 +/// 6 | null | grace | 1 | 3.8 +/// 7 | 8 | heidi | 2 | 4.9 +/// 8 | null | ivan | 3 | 3.2 +/// 9 | 10 | judy | 4 | 3.1 +fn write_array(ctx: &Context) -> TileDBResult<()> { + let index_data = + Arc::new(aa::Int32Array::from(vec![0i32, 1, 2, 3, 4, 5, 6, 7, 8, 9])); + let a_data = Arc::new(aa::Int32Array::from(vec![ + None, + Some(2i32), + None, + Some(4), + None, + Some(6), + None, + Some(8), + None, + Some(10), + ])); + let b_data = Arc::new(aa::LargeStringArray::from(vec![ + "alice", "bob", "craig", "dave", "erin", "frank", "grace", "heidi", + "ivan", "judy", + ])); + let c_data = + Arc::new(aa::Int32Array::from(vec![0i32, 0, 0, 0, 0, 0, 1, 2, 3, 4])); + let d_data = Arc::new(aa::Float32Array::from(vec![ + 4.1f32, 3.4, 5.6, 3.7, 2.3, 1.7, 3.8, 4.9, 3.2, 3.1, + ])); + + let array = + tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Write)?; + + let mut query = QueryBuilder::new(array, QueryType::Write) + .with_layout(QueryLayout::Unordered) + .start_fields() + .field_with_buffer("index", index_data) + .field_with_buffer("a", a_data) + .field_with_buffer("b", b_data) + .field_with_buffer("c", c_data) + .field_with_buffer("d", d_data) + .end_fields() + .build()?; + + query.submit().and_then(|_| query.finalize())?; + + Ok(()) +} diff --git a/tiledb/api/examples/quickstart_dense_arrow.rs b/tiledb/api/examples/quickstart_dense_arrow.rs new file mode 100644 index 00000000..de9f0961 --- /dev/null +++ b/tiledb/api/examples/quickstart_dense_arrow.rs @@ -0,0 +1,176 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::array::{Array as ArrowArray, Int32Array}; +use itertools::izip; + +use tiledb::array::{ + Array, ArrayType, AttributeBuilder, Dimension, DimensionBuilder, + DomainBuilder, Mode as ArrayMode, SchemaBuilder, +}; +use tiledb::context::Context; +use tiledb::error::Error as TileDBError; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryStatus, QueryType}; +use tiledb::Datatype; +use tiledb::Result as TileDBResult; + +const QUICKSTART_DENSE_ARRAY_URI: &str = "quickstart_dense"; +const QUICKSTART_ATTRIBUTE_NAME: &str = "a"; + +/// Returns whether the example array already exists +fn array_exists() -> bool { + let tdb = match Context::new() { + Err(_) => return false, + Ok(tdb) => tdb, + }; + + Array::exists(&tdb, QUICKSTART_DENSE_ARRAY_URI) + .expect("Error checking array existence") +} + +/// Creates a dense array at URI `QUICKSTART_DENSE_ARRAY_URI()`. +/// The array has two i32 dimensions ["rows", "columns"] with a single int32 +/// attribute "a" stored in each cell. +/// Both "rows" and "columns" dimensions range from 1 to 4, and the tiles +/// span all 4 elements on each dimension. +/// Hence we have 16 cells of data and a single tile for the whole array. +fn create_array() -> TileDBResult<()> { + let tdb = Context::new()?; + + let domain = { + let rows: tiledb::array::Dimension = + tiledb::array::DimensionBuilder::new( + &tdb, + "rows", + Datatype::Int32, + ([1, 4], 4), + )? + .build(); + let cols: Dimension = DimensionBuilder::new( + &tdb, + "columns", + Datatype::Int32, + ([1, 4], 4), + )? + .build(); + + DomainBuilder::new(&tdb)? + .add_dimension(rows)? + .add_dimension(cols)? + .build() + }; + + let attribute_a = AttributeBuilder::new( + &tdb, + QUICKSTART_ATTRIBUTE_NAME, + Datatype::Int32, + )? + .build(); + + let schema = SchemaBuilder::new(&tdb, ArrayType::Dense, domain)? + .add_attribute(attribute_a)? + .build()?; + + Array::create(&tdb, QUICKSTART_DENSE_ARRAY_URI, schema) +} + +/// Writes data into the array in row-major order from a 1D-array buffer. +/// After the write, the contents of the array will be: +/// [[ 1, 2, 3, 4], +/// [ 5, 6, 7, 8], +/// [ 9, 10, 11, 12], +/// [13, 14, 15, 16]] +fn write_array() -> TileDBResult<()> { + let tdb = Context::new()?; + + let array = + Array::open(&tdb, QUICKSTART_DENSE_ARRAY_URI, ArrayMode::Write)?; + + let data: Arc = Arc::new(Int32Array::from(vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ])); + + let mut query = QueryBuilder::new(array, QueryType::Write) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field_with_buffer(QUICKSTART_ATTRIBUTE_NAME, data) + .end_fields() + .build() + // TODO: Make this not suck + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + let status = query + .submit() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + if matches!(status, QueryStatus::Completed) { + return Ok(()); + } else { + return Err(TileDBError::Other("Something better here.".to_string())); + } +} + +/// Query back a slice of our array and print the results to stdout. +/// The slice on "rows" is [1, 2] and on "columns" is [2, 4], +/// so the returned data should look like: +/// [[ _, 2, 3, 4], +/// [ _, 6, 7, 8], +/// [ _, _, _, _], +/// [ _, _, _, _]]] +/// Data is emitted in row-major order, so this will print "2 3 4 6 7 8". +fn read_array() -> TileDBResult<()> { + let tdb = Context::new()?; + + let array = Array::open(&tdb, QUICKSTART_DENSE_ARRAY_URI, ArrayMode::Read)?; + + let mut query = QueryBuilder::new(array, QueryType::Read) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field("rows") + .field("columns") + .field(QUICKSTART_ATTRIBUTE_NAME) + .end_fields() + .start_subarray() + .add_range("rows", &[1i32, 2]) + .add_range("columns", &[2i32, 4]) + .end_subarray() + .build() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + let status = query + .submit() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + if !matches!(status, QueryStatus::Completed) { + return Err(TileDBError::Other("Make this better.".to_string())); + } + + let buffers = query + .buffers() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + let rows = buffers.get::("rows").unwrap(); + let cols = buffers.get::("columns").unwrap(); + let attrs = buffers + .get::(QUICKSTART_ATTRIBUTE_NAME) + .unwrap(); + + for (row, col, attr) in izip!(rows.values(), cols.values(), attrs.values()) + { + println!("{} {} {}", row, col, attr); + } + + Ok(()) +} + +fn main() { + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + let _ = std::env::set_current_dir( + PathBuf::from(manifest_dir).join("examples").join("output"), + ); + } + + if !array_exists() { + create_array().expect("Failed to create array"); + } + write_array().expect("Failed to write array"); + read_array().expect("Failed to read array"); +} diff --git a/tiledb/api/examples/quickstart_sparse_string_arrow.rs b/tiledb/api/examples/quickstart_sparse_string_arrow.rs new file mode 100644 index 00000000..7ca552bf --- /dev/null +++ b/tiledb/api/examples/quickstart_sparse_string_arrow.rs @@ -0,0 +1,139 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::array as aa; +use itertools::izip; + +use tiledb::array::dimension::DimensionConstraints; +use tiledb::array::{ + Array, ArrayType, AttributeData, CellOrder, DimensionData, DomainData, + SchemaData, TileOrder, +}; +use tiledb::context::Context; +use tiledb::error::Error as TileDBError; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryStatus, QueryType}; +use tiledb::Result as TileDBResult; +use tiledb::{Datatype, Factory}; + +const ARRAY_URI: &str = "quickstart_sparse_string"; + +fn main() -> TileDBResult<()> { + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + let _ = std::env::set_current_dir( + PathBuf::from(manifest_dir).join("examples").join("output"), + ); + } + + let ctx = Context::new()?; + if Array::exists(&ctx, ARRAY_URI)? { + Array::delete(&ctx, ARRAY_URI)?; + } + + create_array(&ctx)?; + write_array(&ctx)?; + + let array = Array::open(&ctx, ARRAY_URI, tiledb::array::Mode::Read)?; + let mut query = QueryBuilder::new(array, QueryType::Read) + .with_layout(QueryLayout::RowMajor) + .start_fields() + .field("rows") + .field("cols") + .field("a") + .end_fields() + .start_subarray() + .add_range("rows", &["a", "c"]) + .add_range("cols", &[2, 4]) + .end_subarray() + .build() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + let status = query + .submit() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + if !matches!(status, QueryStatus::Completed) { + return Err(TileDBError::Other("Make this better.".to_string())); + } + + let buffers = query + .buffers() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + let rows = buffers.get::("rows").unwrap(); + let cols = buffers.get::("cols").unwrap(); + let attr = buffers.get::("a").unwrap(); + + for (row, col, attr) in izip!(rows, cols.values(), attr.values()) { + println!("{} {} {}", row.unwrap(), col, attr); + } + + Ok(()) +} + +fn create_array(ctx: &Context) -> TileDBResult<()> { + let schema = SchemaData { + array_type: ArrayType::Sparse, + domain: DomainData { + dimension: vec![ + DimensionData { + name: "rows".to_owned(), + datatype: Datatype::StringAscii, + constraints: DimensionConstraints::StringAscii, + filters: None, + }, + DimensionData { + name: "cols".to_owned(), + datatype: Datatype::Int32, + constraints: ([1i32, 4], 4i32).into(), + filters: None, + }, + ], + }, + attributes: vec![AttributeData { + name: "a".to_owned(), + datatype: Datatype::Int32, + ..Default::default() + }], + tile_order: Some(TileOrder::RowMajor), + cell_order: Some(CellOrder::RowMajor), + + ..Default::default() + }; + + let schema = schema.create(ctx)?; + Array::create(ctx, ARRAY_URI, schema)?; + Ok(()) +} + +fn write_array(ctx: &Context) -> TileDBResult<()> { + let row_data = Arc::new(aa::LargeStringArray::from(vec!["a", "bb", "c"])); + let col_data = Arc::new(aa::Int32Array::from(vec![1, 4, 3])); + let a_data = Arc::new(aa::Int32Array::from(vec![1, 2, 3])); + + let array = + tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Write)?; + + let mut query = QueryBuilder::new(array, QueryType::Write) + .with_layout(CellOrder::Unordered) + .start_fields() + .field_with_buffer("rows", row_data) + .field_with_buffer("cols", col_data) + .field_with_buffer("a", a_data) + .end_fields() + .build() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + let status = query + .submit() + .map_err(|e| TileDBError::Other(format!("{}", e)))?; + + if !matches!(status, QueryStatus::Completed) { + return Err(TileDBError::Other("Make this better.".to_string())); + } + + let (_, _) = query + .finalize() + .map_err(|e| TileDBError::Other(format!("{e}")))?; + + Ok(()) +} diff --git a/tiledb/api/examples/reading_incomplete_arrow.rs b/tiledb/api/examples/reading_incomplete_arrow.rs new file mode 100644 index 00000000..17064551 --- /dev/null +++ b/tiledb/api/examples/reading_incomplete_arrow.rs @@ -0,0 +1,179 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use arrow::array as aa; +use itertools::izip; + +use tiledb::array::{ + Array, ArrayType, AttributeData, CellOrder, CellValNum, DimensionData, + DomainData, SchemaData, TileOrder, +}; +use tiledb::context::Context; +use tiledb::query_arrow::fields::QueryFieldsBuilder; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; +use tiledb::Result as TileDBResult; +use tiledb::{Datatype, Factory}; + +const ARRAY_URI: &str = "reading_incomplete"; + +fn main() -> TileDBResult<()> { + if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") { + let _ = std::env::set_current_dir( + PathBuf::from(manifest_dir).join("examples").join("output"), + ); + } + + let ctx = Context::new()?; + if Array::exists(&ctx, ARRAY_URI)? { + Array::delete(&ctx, ARRAY_URI)?; + } + + create_array(&ctx)?; + write_array(&ctx)?; + read_array(&ctx)?; + Ok(()) +} + +/// Creates a dense array at URI `ARRAY_NAME`. +/// The array has two i32 dimensions ["rows", "columns"] with two +/// attributes in each cell - (a1 INT32, a2 CHAR). +/// Both "rows" and "columns" dimensions range from 1 to 4, and the tiles +/// span all 4 elements on each dimension. +/// Hence we have 16 cells of data and a single tile for the whole array. +fn create_array(ctx: &Context) -> TileDBResult<()> { + let schema = SchemaData { + array_type: ArrayType::Sparse, + domain: DomainData { + dimension: vec![ + DimensionData { + name: "rows".to_owned(), + datatype: Datatype::Int32, + constraints: ([1i32, 4], 4i32).into(), + filters: None, + }, + DimensionData { + name: "cols".to_owned(), + datatype: Datatype::Int32, + constraints: ([1i32, 4], 4i32).into(), + filters: None, + }, + ], + }, + attributes: vec![ + AttributeData { + name: "a1".to_string(), + datatype: Datatype::Int32, + ..Default::default() + }, + AttributeData { + name: "a2".to_string(), + datatype: Datatype::StringUtf8, + cell_val_num: Some(CellValNum::Var), + ..Default::default() + }, + ], + tile_order: Some(TileOrder::RowMajor), + cell_order: Some(CellOrder::RowMajor), + + ..Default::default() + }; + + let schema = schema.create(ctx)?; + Array::create(ctx, ARRAY_URI, schema)?; + Ok(()) +} + +/// Writes data into the array. +/// After the write, the contents of the array will be: +/// [[ (1, "a"), (2, "bb"), _, _], +/// [ _, (3, "ccc"), _, _], +/// [ _, _, _, _], +/// [ _, _, _, _]] +fn write_array(ctx: &Context) -> TileDBResult<()> { + let rows_data = Arc::new(aa::Int32Array::from(vec![1, 2, 2, 2])); + let cols_data = Arc::new(aa::Int32Array::from(vec![1, 1, 2, 3])); + let a1_data = Arc::new(aa::Int32Array::from(vec![1, 2, 3, 3])); + let a2_data = + Arc::new(aa::LargeStringArray::from(vec!["a", "bb", "ccc", "dddd"])); + + let array = + tiledb::Array::open(&ctx, ARRAY_URI, tiledb::array::Mode::Write)?; + + let mut query = QueryBuilder::new(array, QueryType::Write) + .with_layout(QueryLayout::Global) + .start_fields() + .field_with_buffer("rows", rows_data) + .field_with_buffer("cols", cols_data) + .field_with_buffer("a1", a1_data) + .field_with_buffer("a2", a2_data) + .end_fields() + .build()?; + + query.submit().and_then(|_| query.finalize())?; + Ok(()) +} + +fn read_array(ctx: &Context) -> TileDBResult<()> { + let mut curr_capacity = 1; + + let array = tiledb::Array::open(ctx, ARRAY_URI, tiledb::array::Mode::Read)?; + + let make_fields = |capacity| { + QueryFieldsBuilder::new() + .field_with_capacity("rows", capacity) + .field_with_capacity("cols", capacity) + .field_with_capacity("a1", capacity) + .field_with_capacity("a2", capacity) + .build() + }; + + let mut query = QueryBuilder::new(array, QueryType::Read) + .with_layout(QueryLayout::RowMajor) + .with_fields(make_fields(curr_capacity)) + .start_subarray() + .add_range("rows", &[1i32, 4]) + .add_range("cols", &[1i32, 4]) + .end_subarray() + .build()?; + + loop { + let status = query.submit()?; + + // Double our buffer sizes if we didn't manage to get any data out + // of the query. + if !status.has_data() { + println!( + "Doubling buffer capacity: {} to {}", + curr_capacity, + curr_capacity * 2 + ); + curr_capacity = curr_capacity * 2; + query.replace_buffers(make_fields(curr_capacity))?; + continue; + } + + // Print any results we did get. + let buffers = query.buffers()?; + let rows = buffers.get::("rows").unwrap(); + let cols = buffers.get::("cols").unwrap(); + let a1 = buffers.get::("a1").unwrap(); + let a2 = buffers.get::("a2").unwrap(); + + for (r, c, a1, a2) in izip!(rows, cols, a1, a2) { + println!( + "\tCell ({}, {}) a1: {}, a2: {}", + r.unwrap(), + c.unwrap(), + a1.unwrap(), + a2.unwrap() + ); + } + + // Break from the loop when completed. + if status.is_complete() { + break; + } + } + + Ok(()) +} diff --git a/tiledb/api/src/array/attribute/mod.rs b/tiledb/api/src/array/attribute/mod.rs index 4c77f441..7dace0e5 100644 --- a/tiledb/api/src/array/attribute/mod.rs +++ b/tiledb/api/src/array/attribute/mod.rs @@ -751,7 +751,6 @@ impl Factory for AttributeData { } } -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(any(test, feature = "proptest-strategies"))] diff --git a/tiledb/api/src/array/dimension/mod.rs b/tiledb/api/src/array/dimension/mod.rs index 2ab0c073..5548a6e7 100644 --- a/tiledb/api/src/array/dimension/mod.rs +++ b/tiledb/api/src/array/dimension/mod.rs @@ -720,7 +720,6 @@ impl Factory for DimensionData { } } -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(any(test, feature = "proptest-strategies"))] diff --git a/tiledb/api/src/array/mod.rs b/tiledb/api/src/array/mod.rs index 2303a62f..610850dd 100644 --- a/tiledb/api/src/array/mod.rs +++ b/tiledb/api/src/array/mod.rs @@ -98,10 +98,6 @@ impl Display for Mode { #[derive( Clone, Copy, Debug, Deserialize, Eq, OptionSubset, PartialEq, Serialize, )] -#[cfg_attr( - any(test, feature = "proptest-strategies"), - derive(proptest_derive::Arbitrary) -)] pub enum TileOrder { RowMajor, ColumnMajor, diff --git a/tiledb/api/src/array/schema/mod.rs b/tiledb/api/src/array/schema/mod.rs index bc21b454..d8b166c3 100644 --- a/tiledb/api/src/array/schema/mod.rs +++ b/tiledb/api/src/array/schema/mod.rs @@ -32,10 +32,6 @@ use crate::{Factory, Result as TileDBResult}; PartialEq, Serialize, )] -#[cfg_attr( - any(test, feature = "proptest-strategies"), - derive(proptest_derive::Arbitrary) -)] pub enum ArrayType { #[default] Dense, @@ -1031,7 +1027,6 @@ impl<'a> Iterator for FieldDataIter<'a> { impl std::iter::FusedIterator for FieldDataIter<'_> {} -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(any(test, feature = "proptest-strategies"))] diff --git a/tiledb/api/src/array/schema/strategy.rs b/tiledb/api/src/array/schema/strategy.rs index 11a5ccff..edd43a5d 100644 --- a/tiledb/api/src/array/schema/strategy.rs +++ b/tiledb/api/src/array/schema/strategy.rs @@ -25,6 +25,15 @@ use crate::filter::strategy::{ StrategyContext as FilterContext, }; +impl Arbitrary for ArrayType { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + prop_oneof![Just(ArrayType::Dense), Just(ArrayType::Sparse),].boxed() + } +} + #[derive(Clone)] pub struct Requirements { pub domain: Option>, diff --git a/tiledb/api/src/array/strategy.rs b/tiledb/api/src/array/strategy.rs index 66a19cca..78254b60 100644 --- a/tiledb/api/src/array/strategy.rs +++ b/tiledb/api/src/array/strategy.rs @@ -1,6 +1,19 @@ -#[cfg(test)] +use proptest::prelude::*; + +use super::TileOrder; + +impl Arbitrary for TileOrder { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + prop_oneof![Just(TileOrder::RowMajor), Just(TileOrder::ColumnMajor),] + .boxed() + } +} + mod tests { - use proptest::prelude::*; + use super::*; use util::assert_option_subset; use util::option::OptionSubset; diff --git a/tiledb/api/src/datatype/mod.rs b/tiledb/api/src/datatype/mod.rs index ca7f1be2..15b12986 100644 --- a/tiledb/api/src/datatype/mod.rs +++ b/tiledb/api/src/datatype/mod.rs @@ -795,7 +795,6 @@ macro_rules! physical_type_go { }}; } -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(any(test, feature = "proptest-strategies"))] diff --git a/tiledb/api/src/filter/mod.rs b/tiledb/api/src/filter/mod.rs index 73893159..0cfdb354 100644 --- a/tiledb/api/src/filter/mod.rs +++ b/tiledb/api/src/filter/mod.rs @@ -631,7 +631,6 @@ impl PartialEq for Filter { } } -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(any(test, feature = "proptest-strategies"))] diff --git a/tiledb/api/src/lib.rs b/tiledb/api/src/lib.rs index 518388f0..fd7aa900 100644 --- a/tiledb/api/src/lib.rs +++ b/tiledb/api/src/lib.rs @@ -51,13 +51,13 @@ pub mod group; pub mod key; pub mod metadata; pub mod query; +pub mod query_arrow; #[macro_use] pub mod range; pub mod stats; pub mod string; pub mod vfs; -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(test)] diff --git a/tiledb/api/src/query/buffer/mod.rs b/tiledb/api/src/query/buffer/mod.rs index d793ec18..9d32622e 100644 --- a/tiledb/api/src/query/buffer/mod.rs +++ b/tiledb/api/src/query/buffer/mod.rs @@ -4,7 +4,6 @@ use std::ops::{Deref, DerefMut}; use crate::array::CellValNum; -#[cfg(feature = "arrow")] pub mod arrow; #[derive(Debug)] diff --git a/tiledb/api/src/query/read/output/mod.rs b/tiledb/api/src/query/read/output/mod.rs index 5c8cddf2..b4b01899 100644 --- a/tiledb/api/src/query/read/output/mod.rs +++ b/tiledb/api/src/query/read/output/mod.rs @@ -12,7 +12,6 @@ use crate::query::buffer::*; use crate::Result as TileDBResult; use crate::{typed_query_buffers_go, Datatype}; -#[cfg(feature = "arrow")] pub mod arrow; #[cfg(any(test, feature = "proptest-strategies"))] pub mod strategy; diff --git a/tiledb/api/src/query/write/input/mod.rs b/tiledb/api/src/query/write/input/mod.rs index afa2e743..43e27bde 100644 --- a/tiledb/api/src/query/write/input/mod.rs +++ b/tiledb/api/src/query/write/input/mod.rs @@ -9,7 +9,6 @@ use crate::query::buffer::{ }; use crate::Result as TileDBResult; -#[cfg(feature = "arrow")] pub mod arrow; pub trait DataProvider { diff --git a/tiledb/api/src/query_arrow/arrow.rs b/tiledb/api/src/query_arrow/arrow.rs new file mode 100644 index 00000000..5c73ee10 --- /dev/null +++ b/tiledb/api/src/query_arrow/arrow.rs @@ -0,0 +1,520 @@ +use std::sync::Arc; + +use arrow::datatypes as adt; + +use thiserror::Error; + +use crate::array::schema::CellValNum; +use crate::datatype::Datatype; + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum Error { + #[error("Cell value size '{0}' is out of range.")] + CellValNumOutOfRange(u32), + + #[error("Internal type error: Unhandled Arrow type: {0}")] + InternalTypeError(adt::DataType), + + #[error("Invalid fixed sized length: {0}")] + InvalidFixedSize(i32), + + #[error("Invalid Arrow type for conversion: '{0}'")] + InvalidTargetType(adt::DataType), + + #[error("Failed to convert Arrow list element type: {0}")] + ListElementTypeConversionFailed(Box), + + #[error( + "The TileDB datatype '{0}' does not have a default Arrow DataType." + )] + NoDefaultArrowType(Datatype), + + #[error("Arrow type '{0} requires the TileDB field to be single valued.")] + RequiresSingleValued(adt::DataType), + + #[error("Arrow type {0} requires the TileDB field be var sized.")] + RequiresVarSized(adt::DataType), + + #[error("TileDB does not support timezones on timestamps")] + TimeZonesNotSupported, + + #[error( + "TileDB type '{0}' and Arrow type '{1}' have different physical sizes" + )] + PhysicalSizeMismatch(Datatype, adt::DataType), + + #[error("Unsupported Arrow DataType: {0}")] + UnsupportedArrowDataType(adt::DataType), + + #[error("TileDB does not support lists with element type: '{0}'")] + UnsupportedListElementType(adt::DataType), + + #[error("The Arrow DataType '{0}' is not supported.")] + ArrowTypeNotSupported(adt::DataType), + #[error("DataFusion does not support multi-value cells.")] + InvalidMultiCellValNum, + #[error("The TileDB Datatype '{0}' is not supported by DataFusion")] + UnsupportedTileDBDatatype(Datatype), + #[error("Variable-length datatypes as list type elements are not supported by TileDB")] + UnsupportedListVariableLengthElement, +} + +pub type Result = std::result::Result; + +/// ConversionMode dictates whether certain conversions are allowed +pub enum ConversionMode { + /// Only allow conversions that are semantically equivalent + Strict, + /// Allow conversions as long as the physical type is maintained. + Relaxed, +} + +pub struct ToArrowConverter { + mode: ConversionMode, +} + +impl ToArrowConverter { + pub fn strict() -> Self { + Self { + mode: ConversionMode::Strict, + } + } + + pub fn physical() -> Self { + Self { + mode: ConversionMode::Relaxed, + } + } + + pub fn convert_datatype( + &self, + dtype: &Datatype, + cvn: &CellValNum, + nullable: bool, + ) -> Result { + if let Some(arrow_type) = self.default_arrow_type(dtype) { + self.convert_datatype_to(dtype, cvn, nullable, arrow_type) + } else { + Err(Error::NoDefaultArrowType(*dtype)) + } + } + + pub fn convert_datatype_to( + &self, + dtype: &Datatype, + cvn: &CellValNum, + nullable: bool, + arrow_type: adt::DataType, + ) -> Result { + if matches!(arrow_type, adt::DataType::Null) { + return Err(Error::InvalidTargetType(arrow_type)); + } + + if arrow_type.is_primitive() { + let width = arrow_type.primitive_width().unwrap(); + if width != dtype.size() as usize { + return Err(Error::PhysicalSizeMismatch(*dtype, arrow_type)); + } + + if cvn.is_single_valued() { + return Ok(arrow_type); + } else if cvn.is_var_sized() { + let field = + Arc::new(adt::Field::new("item", arrow_type, nullable)); + return Ok(adt::DataType::LargeList(field)); + } else { + // SAFETY: Due to the logic above we can guarantee that this + // is a fixed length cvn. + let cvn = cvn.fixed().unwrap().get(); + if cvn > i32::MAX as u32 { + return Err(Error::CellValNumOutOfRange(cvn)); + } + let field = + Arc::new(adt::Field::new("item", arrow_type, nullable)); + return Ok(adt::DataType::FixedSizeList(field, cvn as i32)); + } + } else if matches!(arrow_type, adt::DataType::Boolean) { + if !cvn.is_single_valued() { + return Err(Error::RequiresSingleValued(arrow_type)); + } + return Ok(arrow_type); + } else if matches!( + arrow_type, + adt::DataType::LargeBinary | adt::DataType::LargeUtf8 + ) { + if !cvn.is_var_sized() { + return Err(Error::RequiresVarSized(arrow_type)); + } + return Ok(arrow_type); + } else { + return Err(Error::InternalTypeError(arrow_type)); + } + } + + fn default_arrow_type(&self, dtype: &Datatype) -> Option { + use crate::datatype::Datatype as tiledb; + use arrow::datatypes::DataType as arrow; + let arrow_type = match dtype { + // Any <-> Null, both indicate lack of a type + tiledb::Any => Some(arrow::Null), + + // Boolean, n.b., this requires a byte array to bit array converesion + tiledb::Boolean => Some(arrow::Boolean), + + // Char -> Int8 + tiledb::Char => Some(arrow::Int8), + + // Standard primitive types + tiledb::Int8 => Some(arrow::Int8), + tiledb::Int16 => Some(arrow::Int16), + tiledb::Int32 => Some(arrow::Int32), + tiledb::Int64 => Some(arrow::Int64), + tiledb::UInt8 => Some(arrow::UInt8), + tiledb::UInt16 => Some(arrow::UInt16), + tiledb::UInt32 => Some(arrow::UInt32), + tiledb::UInt64 => Some(arrow::UInt64), + tiledb::Float32 => Some(arrow::Float32), + tiledb::Float64 => Some(arrow::Float64), + + // Supportable datetime types + tiledb::DateTimeSecond => { + Some(arrow::Timestamp(adt::TimeUnit::Second, None)) + } + tiledb::DateTimeMillisecond => { + Some(arrow::Timestamp(adt::TimeUnit::Millisecond, None)) + } + tiledb::DateTimeMicrosecond => { + Some(arrow::Timestamp(adt::TimeUnit::Microsecond, None)) + } + tiledb::DateTimeNanosecond => { + Some(arrow::Timestamp(adt::TimeUnit::Nanosecond, None)) + } + + // Supportable time types + tiledb::TimeSecond => Some(arrow::Time64(adt::TimeUnit::Second)), + tiledb::TimeMillisecond => { + Some(arrow::Time64(adt::TimeUnit::Millisecond)) + } + tiledb::TimeMicrosecond => { + Some(arrow::Time64(adt::TimeUnit::Microsecond)) + } + tiledb::TimeNanosecond => { + Some(arrow::Time64(adt::TimeUnit::Nanosecond)) + } + + // Supported string types + tiledb::StringAscii => Some(arrow::LargeUtf8), + tiledb::StringUtf8 => Some(arrow::LargeUtf8), + + // Blob <-> Binary + tiledb::Blob => Some(arrow::LargeBinary), + + tiledb::StringUtf16 + | tiledb::StringUtf32 + | tiledb::StringUcs2 + | tiledb::StringUcs4 + | tiledb::DateTimeYear + | tiledb::DateTimeMonth + | tiledb::DateTimeWeek + | tiledb::DateTimeDay + | tiledb::DateTimeHour + | tiledb::DateTimeMinute + | tiledb::DateTimePicosecond + | tiledb::DateTimeFemtosecond + | tiledb::DateTimeAttosecond + | tiledb::TimeHour + | tiledb::TimeMinute + | tiledb::TimePicosecond + | tiledb::TimeFemtosecond + | tiledb::TimeAttosecond + | tiledb::GeometryWkb + | tiledb::GeometryWkt => None, + }; + + if arrow_type.is_some() { + return arrow_type; + } + + // If we're doing a strict semantic conversion we don't attempt to find + // a matching physical type. + if matches!(self.mode, ConversionMode::Strict) { + return None; + } + + // Assert in case we add more conversion modes in the future. + assert!(matches!(self.mode, ConversionMode::Relaxed)); + + // Physical conversions means we'll allow dropping the TileDB semantic + // information to allow for raw data access. + match dtype { + // Uncommon string types + tiledb::StringUtf16 => Some(arrow::UInt16), + tiledb::StringUtf32 => Some(arrow::UInt32), + tiledb::StringUcs2 => Some(arrow::UInt16), + tiledb::StringUcs4 => Some(arrow::UInt32), + + // Time types that could lose data if converted to Arrow's + // time resolution. + tiledb::DateTimeYear => Some(arrow::Int64), + tiledb::DateTimeMonth => Some(arrow::Int64), + tiledb::DateTimeWeek => Some(arrow::Int64), + tiledb::DateTimeDay => Some(arrow::Int64), + tiledb::DateTimeHour => Some(arrow::Int64), + tiledb::DateTimeMinute => Some(arrow::Int64), + tiledb::DateTimePicosecond => Some(arrow::Int64), + tiledb::DateTimeFemtosecond => Some(arrow::Int64), + tiledb::DateTimeAttosecond => Some(arrow::Int64), + tiledb::TimeHour => Some(arrow::Int64), + tiledb::TimeMinute => Some(arrow::Int64), + tiledb::TimePicosecond => Some(arrow::Int64), + tiledb::TimeFemtosecond => Some(arrow::Int64), + tiledb::TimeAttosecond => Some(arrow::Int64), + + // Geometry types + tiledb::GeometryWkb => Some(arrow::LargeBinary), + tiledb::GeometryWkt => Some(arrow::LargeUtf8), + + // These are all of the types that have strict equivalents and + // should have already been handled above. + tiledb::Any + | tiledb::Boolean + | tiledb::Char + | tiledb::Int8 + | tiledb::Int16 + | tiledb::Int32 + | tiledb::Int64 + | tiledb::UInt8 + | tiledb::UInt16 + | tiledb::UInt32 + | tiledb::UInt64 + | tiledb::Float32 + | tiledb::Float64 + | tiledb::DateTimeSecond + | tiledb::DateTimeMillisecond + | tiledb::DateTimeMicrosecond + | tiledb::DateTimeNanosecond + | tiledb::TimeSecond + | tiledb::TimeMillisecond + | tiledb::TimeMicrosecond + | tiledb::TimeNanosecond + | tiledb::StringAscii + | tiledb::StringUtf8 + | tiledb::Blob => unreachable!("Strict conversion failed"), + } + } +} + +pub struct FromArrowConverter { + mode: ConversionMode, +} + +impl FromArrowConverter { + pub fn strict() -> Self { + Self { + mode: ConversionMode::Strict, + } + } + + pub fn relaxed() -> Self { + Self { + mode: ConversionMode::Relaxed, + } + } + + pub fn convert_datatype( + &self, + arrow_type: adt::DataType, + ) -> Result<(Datatype, CellValNum, Option)> { + use adt::DataType as arrow; + use Datatype as tiledb; + + let single = CellValNum::single(); + let var = CellValNum::Var; + + match arrow_type { + arrow::Null => Ok((tiledb::Any, single, None)), + arrow::Boolean => Ok((tiledb::Boolean, single, None)), + + arrow::Int8 => Ok((tiledb::Int8, single, None)), + arrow::Int16 => Ok((tiledb::Int16, single, None)), + arrow::Int32 => Ok((tiledb::Int32, single, None)), + arrow::Int64 => Ok((tiledb::Int64, single, None)), + arrow::UInt8 => Ok((tiledb::UInt8, single, None)), + arrow::UInt16 => Ok((tiledb::UInt16, single, None)), + arrow::UInt32 => Ok((tiledb::UInt32, single, None)), + arrow::UInt64 => Ok((tiledb::UInt64, single, None)), + arrow::Float32 => Ok((tiledb::Float32, single, None)), + arrow::Float64 => Ok((tiledb::Float64, single, None)), + + arrow::Timestamp(adt::TimeUnit::Second, None) => { + Ok((tiledb::DateTimeSecond, single, None)) + } + arrow::Timestamp(adt::TimeUnit::Millisecond, None) => { + Ok((tiledb::DateTimeMillisecond, single, None)) + } + arrow::Timestamp(adt::TimeUnit::Microsecond, None) => { + Ok((tiledb::DateTimeMicrosecond, single, None)) + } + arrow::Timestamp(adt::TimeUnit::Nanosecond, None) => { + Ok((tiledb::DateTimeNanosecond, single, None)) + } + arrow::Timestamp(_, Some(_)) => { + return Err(Error::TimeZonesNotSupported); + } + + arrow::Time64(adt::TimeUnit::Second) => { + Ok((tiledb::TimeSecond, single, None)) + } + arrow::Time64(adt::TimeUnit::Millisecond) => { + Ok((tiledb::TimeMillisecond, single, None)) + } + arrow::Time64(adt::TimeUnit::Microsecond) => { + Ok((tiledb::TimeMicrosecond, single, None)) + } + arrow::Time64(adt::TimeUnit::Nanosecond) => { + Ok((tiledb::TimeNanosecond, single, None)) + } + + arrow::Utf8 => Ok((tiledb::StringUtf8, var, None)), + arrow::LargeUtf8 => Ok((tiledb::StringUtf8, var, None)), + arrow::Binary => Ok((tiledb::Blob, var, None)), + arrow::FixedSizeBinary(cvn) => { + if cvn < 1 { + return Err(Error::InvalidFixedSize(cvn)); + } + let cvn = if cvn == 1 { + CellValNum::single() + } else { + CellValNum::try_from(cvn as u32).unwrap() + }; + Ok((tiledb::Blob, cvn, None)) + } + arrow::LargeBinary => Ok((tiledb::Blob, var, None)), + + arrow::List(field) | arrow::LargeList(field) => { + let dtype = field.data_type(); + if !dtype.is_primitive() { + return Err(Error::UnsupportedListElementType( + dtype.clone(), + )); + } + + let (tdb_type, _, _) = + self.convert_datatype(dtype.clone()).map_err(|e| { + Error::ListElementTypeConversionFailed(Box::new(e)) + })?; + + Ok((tdb_type, var, Some(field.is_nullable()))) + } + + arrow::FixedSizeList(field, cvn) => { + let dtype = field.data_type(); + if !dtype.is_primitive() { + return Err(Error::UnsupportedListElementType( + dtype.clone(), + )); + } + + let (tdb_type, _, _) = + self.convert_datatype(dtype.clone()).map_err(|e| { + Error::ListElementTypeConversionFailed(Box::new(e)) + })?; + + Ok(( + tdb_type, + CellValNum::try_from(cvn as u32).unwrap(), + Some(field.is_nullable()), + )) + } + + // A few relaxed conversions for accepting Arrow types that don't + // line up directly with TileDB. + arrow::Date32 if self.is_relaxed() => { + Ok((tiledb::Int32, single, None)) + } + + arrow::Date64 if self.is_relaxed() => { + Ok((tiledb::Int64, single, None)) + } + + arrow::Time32(_) if self.is_relaxed() => { + Ok((tiledb::Int32, single, None)) + } + + // Notes on other possible relaxed conversions: + // + // Duration and some intervals are likely supportable, but + // leaving them off for now as the docs aren't clear. + // + // Views are also likely supportable, but will likely require + // separate buffer allocations since individual values are not + // contiguous. + // + // Struct and Union are never supportable (given current core) + // + // Dictionary is, but they should be handled higher up the stack + // to ensure that things line up with enumerations. + // + // Decimal128 and Decimal256 might be supportable using Float64 + // and 2 or 4 fixed length cell val num. Though it'd be fairly + // hacky. + // + // Map isn't supported in TileDB (given current core) + // + // RunEndEncoded is probably supportable, but like views will + // require separate buffer allocations so leaving for now. + arrow::Float16 + | arrow::Date32 + | arrow::Date64 + | arrow::Time32(_) + | arrow::Duration(_) + | arrow::Interval(_) + | arrow::BinaryView + | arrow::Utf8View + | arrow::ListView(_) + | arrow::LargeListView(_) + | arrow::Struct(_) + | arrow::Union(_, _) + | arrow::Dictionary(_, _) + | arrow::Decimal128(_, _) + | arrow::Decimal256(_, _) + | arrow::Map(_, _) + | arrow::RunEndEncoded(_, _) => { + return Err(Error::UnsupportedArrowDataType(arrow_type)); + } + } + } + + fn is_relaxed(&self) -> bool { + matches!(self.mode, ConversionMode::Relaxed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Datatype; + + /// Test that a datatype is supported as a scalar type + /// if and only if it is also supported as a list element type + #[test] + fn list_unsupported_element() { + let conv = ToArrowConverter::strict(); + for dt in Datatype::iter() { + let single_to_arrow = + conv.convert_datatype(&dt, &CellValNum::single(), false); + let var_to_arrow = + conv.convert_datatype(&dt, &CellValNum::Var, false); + + if let Err(Error::RequiresVarSized(_)) = single_to_arrow { + assert!(var_to_arrow.is_ok()); + } else if single_to_arrow.is_err() { + assert_eq!(single_to_arrow, var_to_arrow); + } + + if var_to_arrow.is_err() { + assert_eq!(var_to_arrow, single_to_arrow); + } + } + } +} diff --git a/tiledb/api/src/query_arrow/buffers.rs b/tiledb/api/src/query_arrow/buffers.rs new file mode 100644 index 00000000..b2195d6d --- /dev/null +++ b/tiledb/api/src/query_arrow/buffers.rs @@ -0,0 +1,1645 @@ +use std::any::Any; +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; + +use arrow::array as aa; +use arrow::buffer::{ + self as abuf, Buffer as ArrowBuffer, MutableBuffer as ArrowBufferMut, +}; +use arrow::datatypes as adt; +use arrow::error::ArrowError; +use thiserror::Error; + +use super::arrow::ToArrowConverter; +use super::fields::{QueryField, QueryFields}; +use crate::array::schema::{CellValNum, Field, Schema}; +use crate::error::Error as TileDBError; + +const AVERAGE_STRING_LENGTH: usize = 64; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Provided Arrow Array is externally referenced.")] + ArrayInUse, + #[error("Error converting to Arrow for field '{0}': {1}")] + ArrowConversionError(String, super::arrow::Error), + #[error("Failed to convert Arrow Array for field '{0}': {1}")] + FailedConversionFromArrow(String, Box), + #[error("Failed to allocate Arrow array: {0}")] + ArrayCreationFailed(ArrowError), + #[error("Capacity {0} is to small to hold {1} bytes per cell.")] + CapacityTooSmall(usize, usize), + #[error("Failed to convert owned buffers into a list array: {0}")] + FailedListArrayCreation(ArrowError), + #[error("Buffer is immutable")] + ImmutableBuffer, + #[error("Internal binary type mismatch error")] + InternalBinaryType, + #[error("Invalid buffer, no offsets present.")] + InvalidBufferNoOffsets, + #[error("Invalid buffer, no validity present.")] + InvalidBufferNoValidity, + #[error("Invalid data type for bytes array: {0}")] + InvalidBytesType(adt::DataType), + #[error("Invalid fixed sized list length {0} is less than 2")] + InvalidFixedSizeListLength(i32), + #[error("TileDB does not support nullable list elements")] + InvalidNullableListElements, + #[error("Invalid data type for primitive data: {0}")] + InvalidPrimitiveType(adt::DataType), + #[error("Internal error: Converted primitive array is not scalar")] + InternalListTypeMismatch, + #[error("Internal string type mismatch error")] + InternalStringType, + #[error("Error converting var sized buffers to arrow: {0}")] + InvalidVarBuffers(ArrowError), + #[error("Only the large variant is supported: {0}")] + LargeVariantOnly(adt::DataType), + #[error("Failed to convert internal list array: {0}")] + ListSubarrayConversion(Box), + #[error("Provided array had external references to its offsets buffer.")] + OffsetsInUse, + #[error("Internal TileDB Error: {0}")] + TileDB(#[from] TileDBError), + #[error("Unexpected var sized arrow type: {0}")] + UnexpectedArrowVarType(adt::DataType), + #[error("Mutable buffers are not shareable.")] + UnshareableMutableBuffer, + #[error("Unsupported arrow array type: {0}")] + UnsupportedArrowType(adt::DataType), + #[error( + "TileDB only supports fixed size lists of primtiive types, not {0}" + )] + UnsupportedFixedSizeListType(adt::DataType), + #[error("TileDB does not support timezones")] + UnsupportedTimeZones, +} + +type Result = std::result::Result; + +// The Arrow downcast_array function doesn't take an Arc which leaves us +// with an outstanding reference when we attempt the Buffer::into_mutable +// call. This function exists to consume the Arc after the cast. +fn downcast_consume(array: Arc) -> T +where + T: From, +{ + aa::downcast_array(&array) +} + +#[derive(Clone)] +struct SizeInfo { + data: Pin>, + offsets: Pin>, + validity: Pin>, +} + +impl Default for SizeInfo { + fn default() -> Self { + Self { + data: Box::pin(0), + offsets: Box::pin(0), + validity: Box::pin(0), + } + } +} + +/// The return type for the NewBufferTraitThing's into_arrow method. This +/// allows for fallible conversion without dropping the underlying buffers. +type IntoArrowResult = std::result::Result< + Arc, + (Box, Error), +>; + +/// The error type to use on TryFrom> implementations +type FromArrowError = (Arc, Error); + +/// The return type to use when implementing TryFrom +type FromArrowResult = std::result::Result; + +/// An interface to our mutable buffer implementations. +trait NewBufferTraitThing { + /// Return this trait object as any for downcasting. + fn as_any(&self) -> &dyn Any; + + /// The length of the buffer, in cells + fn len(&self) -> usize; + + /// The data buffer + fn data(&mut self) -> &mut ArrowBufferMut; + + /// The offsets buffer, for variants that have one + fn offsets(&mut self) -> Option<&mut ArrowBufferMut>; + + /// The validity buffer, when present + fn validity(&mut self) -> Option<&mut ArrowBufferMut>; + + /// The SizeInfo struct + fn sizes(&mut self) -> &mut SizeInfo; + + /// Check if another buffer is compatible with this buffer + fn is_compatible(&self, other: &Box) -> bool; + + /// Consume self and return an Arc + fn into_arrow(self: Box) -> IntoArrowResult; + + /// Reset all buffer lengths to match capacity. + /// + /// This should happen before a read query so that we're sure to be using + /// the entire available buffer rather than just whatever fit in the + /// previous iteration. + fn reset_len(&mut self) { + let data = self.data(); + data.resize(data.capacity(), 0); + self.offsets().map(|o| { + // Arrow requires an extra offset that TileDB elides so we need + // to leave room for it later. + assert!(o.capacity() >= std::mem::size_of::()); + o.resize(o.capacity() - std::mem::size_of::(), 0); + }); + self.validity().map(|v| v.resize(v.capacity(), 0)); + + let data_size = self.data().len(); + let offsets_size = self.offsets().map_or(0, |o| o.len()); + let validity_size = self.validity().map_or(0, |v| v.len()); + let sizes = self.sizes(); + *sizes.data = data_size as u64; + *sizes.offsets = offsets_size as u64; + *sizes.validity = validity_size as u64; + } + + /// Shrink len to data + /// + /// After a read query, this method is used to update the lenght of all + /// buffers to match the number of bytes written by TileDB. + fn shrink_len(&mut self) { + let sizes = self.sizes().clone(); + assert!((*sizes.data as usize) <= self.data().capacity()); + self.data().resize(*sizes.data as usize, 0); + + self.offsets().map(|o| { + assert!( + (*sizes.offsets as usize) + <= o.capacity() - std::mem::size_of::() + ); + o.resize(*sizes.offsets as usize, 0); + }); + + self.validity().map(|v| { + assert!((*sizes.validity as usize) <= v.capacity()); + v.resize(*sizes.validity as usize, 0); + }); + } + + /// Returns a mutable pointer to the data buffer + fn data_ptr(&mut self) -> *mut std::ffi::c_void { + self.data().as_mut_ptr() as *mut std::ffi::c_void + } + + /// Returns a mutable pointer to the data size + fn data_size_ptr(&mut self) -> *mut u64 { + self.sizes().data.as_mut().get_mut() + } + + /// Returns a mutable poiniter to the offsets buffer. + /// + /// For variants that don't have offsets, it returns a null pointer. + fn offsets_ptr(&mut self) -> *mut u64 { + let Some(offsets) = self.offsets() else { + return std::ptr::null_mut(); + }; + + offsets.as_mut_ptr() as *mut u64 + } + + /// Returns a mutable pointer to the offsets size. + /// + /// For variants that don't have offsets, it returns a null pointer. + fn offsets_size_ptr(&mut self) -> *mut u64 { + let Some(_) = self.offsets() else { + return std::ptr::null_mut(); + }; + + self.sizes().offsets.as_mut().get_mut() + } + + /// Returns a mutable pointer to the validity buffer, when present + /// + /// When validity is not present, it returns a null pointer. + fn validity_ptr(&mut self) -> *mut u8 { + let Some(validity) = self.validity() else { + return std::ptr::null_mut(); + }; + + validity.as_mut_ptr() + } + + /// Returns a mutable pointer to the validity size, when present + /// + /// When validity is not present, it returns a null pointer. + fn validity_size_ptr(&mut self) -> *mut u64 { + let Some(_) = self.validity() else { + return std::ptr::null_mut(); + }; + + self.sizes().validity.as_mut().get_mut() + } +} + +struct BooleanBuffers { + data: ArrowBufferMut, + validity: Option, + sizes: SizeInfo, +} + +impl TryFrom> for BooleanBuffers { + type Error = FromArrowError; + fn try_from(array: Arc) -> FromArrowResult { + let array: aa::BooleanArray = downcast_consume(array); + let (data, validity) = array.into_parts(); + let data = data + .iter() + .map(|v| if v { 1u8 } else { 0 }) + .collect::>(); + let validity = to_tdb_validity(validity); + let mut sizes = SizeInfo::default(); + *sizes.data = data.len() as u64; + *sizes.validity = validity.as_ref().map_or(0, |v| v.len()) as u64; + Ok(BooleanBuffers { + data: abuf::MutableBuffer::from(data), + validity: validity.map(abuf::MutableBuffer::from), + sizes, + }) + } +} + +impl NewBufferTraitThing for BooleanBuffers { + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.data.len() + } + + fn data(&mut self) -> &mut ArrowBufferMut { + &mut self.data + } + + fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + None + } + + fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + self.validity.as_mut() + } + + fn sizes(&mut self) -> &mut SizeInfo { + &mut self.sizes + } + + fn is_compatible(&self, other: &Box) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + self.validity.is_some() == other.validity.is_some() + } + + fn into_arrow(self: Box) -> IntoArrowResult { + let data = + abuf::BooleanBuffer::from_iter(self.data.iter().map(|b| *b != 0)); + Ok(Arc::new(aa::BooleanArray::new( + data, + from_tdb_validity(self.validity), + ))) + } +} + +struct ByteBuffers { + dtype: adt::DataType, + data: ArrowBufferMut, + offsets: ArrowBufferMut, + validity: Option, + sizes: SizeInfo, +} + +macro_rules! to_byte_buffers { + ($ARRAY:expr, $ARROW_TYPE:expr, $ARROW_DT:ty) => {{ + let array: $ARROW_DT = downcast_consume($ARRAY); + let (offsets, data, nulls) = array.into_parts(); + + let data = data.into_mutable(); + let offsets = offsets.into_inner().into_inner().into_mutable(); + + if data.is_ok() && offsets.is_ok() { + let data = data.ok().unwrap(); + let offsets = offsets.ok().unwrap(); + let validity = to_tdb_validity(nulls); + let mut sizes = SizeInfo::default(); + *sizes.data = data.len() as u64; + *sizes.offsets = + (offsets.len() - std::mem::size_of::()) as u64; + *sizes.validity = validity.as_ref().map_or(0, |v| v.len()) as u64; + return Ok(ByteBuffers { + dtype: $ARROW_TYPE, + data, + offsets, + validity, + sizes, + }); + } + + let data = if data.is_ok() { + ArrowBuffer::from(data.ok().unwrap()) + } else { + data.err().unwrap() + }; + + let offsets = if offsets.is_ok() { + offsets + .map(abuf::ScalarBuffer::::from) + .map(abuf::OffsetBuffer::new) + .ok() + .unwrap() + } else { + offsets + .map_err(abuf::ScalarBuffer::::from) + .map_err(abuf::OffsetBuffer::new) + .err() + .unwrap() + }; + + let array: Arc = Arc::new( + aa::LargeBinaryArray::try_new(offsets, data.into(), nulls).unwrap(), + ); + Err((array, Error::OffsetsInUse)) + }}; +} + +impl TryFrom> for ByteBuffers { + type Error = FromArrowError; + + fn try_from(array: Arc) -> FromArrowResult { + let dtype = array.data_type().clone(); + match dtype { + adt::DataType::LargeBinary => { + to_byte_buffers!(array, dtype.clone(), aa::LargeBinaryArray) + } + adt::DataType::LargeUtf8 => { + to_byte_buffers!(array, dtype.clone(), aa::LargeStringArray) + } + t => Err((array, Error::InvalidBytesType(t))), + } + } +} + +impl NewBufferTraitThing for ByteBuffers { + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.offsets.len() / std::mem::size_of::() + } + + fn data(&mut self) -> &mut ArrowBufferMut { + &mut self.data + } + + fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + Some(&mut self.offsets) + } + + fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + self.validity.as_mut() + } + + fn sizes(&mut self) -> &mut SizeInfo { + &mut self.sizes + } + + fn is_compatible(&self, other: &Box) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + if self.dtype != other.dtype { + return false; + } + + self.validity.is_some() == other.validity.is_some() + } + + fn into_arrow(self: Box) -> IntoArrowResult { + // Make sure we add back the extra offsets element that arrow expects + // which TileDB doesn't use. + assert!( + self.offsets.len() + <= self.offsets.capacity() - std::mem::size_of::() + ); + + let mut offsets = self.offsets; + offsets.extend_from_slice(&[self.data.len() as i64]); + + let dtype = self.dtype; + let data = ArrowBuffer::from(self.data); + let offsets = ArrowBuffer::from(offsets); + let validity = from_tdb_validity(self.validity); + let sizes = self.sizes; + // N.B., the calls to cloning the data/offsets/validity are as cheap + // as an Arc::clone plus pointer and usize copy. They are *not* cloning + // the underlying allocated data. + aa::ArrayData::try_new( + dtype.clone(), + *sizes.offsets as usize / std::mem::size_of::(), + validity.clone().map(|v| v.into_inner().into_inner()), + 0, + vec![offsets.clone(), data.clone()], + vec![], + ) + .map(|data| aa::make_array(data)) + .map_err(|e| { + // SAFETY: These unwraps are fine because the only other reference + // was consumed in the failed try_new call. Unless of course + // the ArrowError `e` ever ends up carrying a reference. + let boxed: Box = Box::new(ByteBuffers { + dtype, + data: data.into_mutable().unwrap(), + offsets: offsets.into_mutable().unwrap(), + validity: to_tdb_validity(validity), + sizes, + }); + + (boxed, Error::ArrayCreationFailed(e)) + }) + } +} + +struct FixedListBuffers { + field: Arc, + cell_val_num: CellValNum, + data: ArrowBufferMut, + validity: Option, + sizes: SizeInfo, +} + +impl TryFrom> for FixedListBuffers { + type Error = FromArrowError; + + fn try_from(array: Arc) -> FromArrowResult { + assert!(matches!( + array.data_type(), + adt::DataType::FixedSizeList(_, _) + )); + + let array: aa::FixedSizeListArray = downcast_consume(array); + let (field, cvn, array, nulls) = array.into_parts(); + + if field.is_nullable() { + return Err((array, Error::InvalidNullableListElements)); + } + + if cvn < 2 { + return Err((array, Error::InvalidFixedSizeListLength(cvn))); + } + + // SAFETY: We just showed cvn >= 2 && cvn is i32 whicih means + // it can't be u32::MAX + let cvn = CellValNum::try_from(cvn as u32) + .expect("Internal cell val num error"); + + let dtype = field.data_type().clone(); + if !dtype.is_primitive() { + return Err((array, Error::UnsupportedFixedSizeListType(dtype))); + } + + PrimitiveBuffers::try_from(array) + .map(|mut buffers| { + assert_eq!(buffers.dtype, dtype); + let validity = to_tdb_validity(nulls.clone()); + *buffers.sizes.validity = + validity.as_ref().map_or(0, |v| v.len()) as u64; + FixedListBuffers { + field: Arc::clone(&field), + cell_val_num: cvn, + data: buffers.data, + validity, + sizes: buffers.sizes, + } + }) + .map_err(|(array, e)| { + let array: Arc = Arc::new( + aa::FixedSizeListArray::try_new( + field, + u32::from(cvn) as i32, + array, + nulls, + ) + .unwrap(), + ); + (array, e) + }) + } +} + +impl NewBufferTraitThing for FixedListBuffers { + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.data.len() / u32::from(self.cell_val_num) as usize + } + + fn data(&mut self) -> &mut ArrowBufferMut { + &mut self.data + } + + fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + None + } + + fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + self.validity.as_mut() + } + + fn sizes(&mut self) -> &mut SizeInfo { + &mut self.sizes + } + + fn is_compatible(&self, other: &Box) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + if self.field != other.field { + return false; + } + + if self.cell_val_num != other.cell_val_num { + return false; + } + + self.validity.is_some() == other.validity.is_some() + } + + fn into_arrow(self: Box) -> IntoArrowResult { + let field = self.field; + let cell_val_num = self.cell_val_num; + let data = ArrowBuffer::from(self.data); + let validity = from_tdb_validity(self.validity); + let sizes = self.sizes; + + assert!(field.data_type().is_primitive()); + let num_values = + *sizes.data as usize / field.data_type().primitive_width().unwrap(); + let cvn = u32::from(cell_val_num) as i32; + let len = num_values / cvn as usize; + + // N.B., data/validity clones are cheap. They are not cloning the + // underlying data buffers. We have to clone so that we can put ourself + // back together if the array conversion failes. + aa::ArrayData::try_new( + field.data_type().clone(), + len, + validity.clone().map(|v| v.into_inner().into_inner()), + 0, + vec![data.clone().into()], + vec![], + ) + .map(|data| { + let array: Arc = + Arc::new(aa::FixedSizeListArray::new( + Arc::clone(&field), + u32::from(cell_val_num) as i32, + aa::make_array(data), + validity.clone(), + )); + array + }) + .map_err(|e| { + let boxed: Box = + Box::new(FixedListBuffers { + field, + cell_val_num, + data: data.into_mutable().unwrap(), + validity: to_tdb_validity(validity), + sizes, + }); + + (boxed, Error::ArrayCreationFailed(e)) + }) + } +} + +struct ListBuffers { + field: Arc, + data: ArrowBufferMut, + offsets: ArrowBufferMut, + validity: Option, + sizes: SizeInfo, +} + +impl TryFrom> for ListBuffers { + type Error = FromArrowError; + + fn try_from(array: Arc) -> FromArrowResult { + assert!(matches!(array.data_type(), adt::DataType::LargeList(_))); + + let array: aa::LargeListArray = downcast_consume(array); + let (field, offsets, array, nulls) = array.into_parts(); + + if field.is_nullable() { + return Err((array, Error::InvalidNullableListElements)); + } + + let dtype = field.data_type().clone(); + if !dtype.is_primitive() { + return Err((array, Error::UnsupportedFixedSizeListType(dtype))); + } + + // N.B., I really, really tried to make this a fancy map/map_err + // cascade like all of the others. But it turns out that keeping the + // proper refcounts on either array or offsets turns into a bit of + // an issue when passing things through multiple closures. + let result = PrimitiveBuffers::try_from(array); + if result.is_err() { + let (array, err) = result.err().unwrap(); + let array: Arc = Arc::new( + aa::LargeListArray::try_new( + Arc::clone(&field), + offsets, + array, + nulls.clone(), + ) + .unwrap(), + ); + return Err((array, err)); + } + + let mut data = result.ok().unwrap(); + + let result = offsets.into_inner().into_inner().into_mutable(); + if result.is_err() { + let offsets_buffer = result.err().unwrap(); + let offsets = abuf::OffsetBuffer::new( + abuf::ScalarBuffer::::from(offsets_buffer), + ); + let array: Arc = Arc::new( + aa::LargeListArray::try_new( + Arc::clone(&field), + offsets, + // Safety: We just turned this into a mutable buffer, so + // the inversion should never fail. + Box::new(data).into_arrow().ok().unwrap(), + nulls.clone(), + ) + .unwrap(), + ); + return Err((array, Error::ArrayInUse)); + } + + // Safety: We already yeeted an error on non-primitive types. + let width = dtype.primitive_width().unwrap() as i64; + let mut offsets = result.ok().unwrap(); + + // TileDB works in offsets of bytes, so we re-map all of our arrow + // offsets here. + for elem in offsets.typed_data_mut::() { + *elem = *elem * width; + } + + let validity = to_tdb_validity(nulls); + *data.sizes.offsets = + (offsets.len() - std::mem::size_of::()) as u64; + *data.sizes.validity = validity.as_ref().map_or(0, |v| v.len()) as u64; + + Ok(ListBuffers { + field, + data: data.data, + offsets, + validity, + sizes: data.sizes, + }) + } +} + +impl NewBufferTraitThing for ListBuffers { + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.offsets.len() / std::mem::size_of::() + } + + fn data(&mut self) -> &mut ArrowBufferMut { + &mut self.data + } + + fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + Some(&mut self.offsets) + } + + fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + self.validity.as_mut() + } + + fn sizes(&mut self) -> &mut SizeInfo { + &mut self.sizes + } + + fn is_compatible(&self, other: &Box) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + if self.field != other.field { + return false; + } + + self.validity.is_some() == other.validity.is_some() + } + + fn into_arrow(self: Box) -> IntoArrowResult { + let field = self.field; + + assert!(field.data_type().is_primitive()); + let width = field.data_type().primitive_width().unwrap() as i64; + + // Update the last offset to be the total number of bytes written + let mut offsets = self.offsets; + assert!( + offsets.len() <= offsets.capacity() - std::mem::size_of::() + ); + offsets.extend_from_slice(&[self.data.len() as i64]); + + // Arrow does offsets in terms of elements, not bytes. Here we rewrite + // that by dividing all offsets by the width of the primitive type in + // the data buffer. + let offset_slice = offsets.typed_data_mut::(); + for elem in offset_slice { + assert!(*elem % width == 0); + *elem = *elem / width; + } + + let data = ArrowBuffer::from(self.data); + let offsets = from_tdb_offsets(offsets); + let validity = from_tdb_validity(self.validity); + let sizes = self.sizes; + + // N.B., the calls to cloning the data/offsets/validity are as cheap + // as an Arc::clone plus pointer and usize copy. They are *not* cloning + // the underlying allocated data. + aa::ArrayData::try_new( + field.data_type().clone(), + *sizes.offsets as usize / std::mem::size_of::(), + None, + 0, + vec![data.clone().into()], + vec![], + ) + .and_then(|data| { + aa::LargeListArray::try_new( + field.clone(), + offsets.clone(), + aa::make_array(data), + validity.clone(), + ) + }) + .map(|array| { + let array: Arc = Arc::new(array); + array + }) + .map_err(|e| { + let boxed: Box = Box::new(ListBuffers { + field, + data: data.into_mutable().unwrap(), + offsets: to_tdb_offsets(offsets).unwrap(), + validity: to_tdb_validity(validity), + sizes, + }); + + (boxed, Error::ArrayCreationFailed(e)) + }) + } +} + +struct PrimitiveBuffers { + dtype: adt::DataType, + data: ArrowBufferMut, + validity: Option, + sizes: SizeInfo, +} + +macro_rules! to_primitive { + ($ARRAY:expr, $ARROW_DT:ty) => {{ + let array: $ARROW_DT = downcast_consume($ARRAY); + let len = array.len(); + let (dtype, buffer, nulls) = array.into_parts(); + + buffer + .into_inner() + .into_mutable() + .map(|data| { + let validity = to_tdb_validity(nulls.clone()); + let mut sizes = SizeInfo::default(); + *sizes.data = data.len() as u64; + *sizes.validity = + validity.as_ref().map_or(0, |v| v.len()) as u64; + PrimitiveBuffers { + dtype: dtype.clone(), + data, + validity, + sizes, + } + }) + .map_err(|buffer| { + // Safety: We just broke an array open to get these so + // unless someone did something unsafe they should go + // right back together again. Sorry, Humpty. + let data = aa::ArrayData::try_new( + dtype, + len, + nulls.map(|n| n.into_inner().into_inner()), + 0, + vec![buffer], + vec![], + ) + .unwrap(); + (aa::make_array(data), Error::ArrayInUse) + }) + }}; +} + +impl TryFrom> for PrimitiveBuffers { + type Error = FromArrowError; + fn try_from(array: Arc) -> FromArrowResult { + assert!(array.data_type().is_primitive()); + + match array.data_type().clone() { + adt::DataType::Int8 => to_primitive!(array, aa::Int8Array), + adt::DataType::Int16 => to_primitive!(array, aa::Int16Array), + adt::DataType::Int32 => to_primitive!(array, aa::Int32Array), + adt::DataType::Int64 => to_primitive!(array, aa::Int64Array), + adt::DataType::UInt8 => to_primitive!(array, aa::UInt8Array), + adt::DataType::UInt16 => to_primitive!(array, aa::UInt16Array), + adt::DataType::UInt32 => to_primitive!(array, aa::UInt32Array), + adt::DataType::UInt64 => to_primitive!(array, aa::UInt64Array), + adt::DataType::Float32 => to_primitive!(array, aa::Float32Array), + adt::DataType::Float64 => to_primitive!(array, aa::Float64Array), + t => Err((array, Error::InvalidPrimitiveType(t))), + } + } +} + +impl NewBufferTraitThing for PrimitiveBuffers { + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + assert!(self.dtype.is_primitive()); + self.data.len() / self.dtype.primitive_width().unwrap() + } + + fn data(&mut self) -> &mut ArrowBufferMut { + &mut self.data + } + + fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + None + } + + fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + self.validity.as_mut() + } + + fn sizes(&mut self) -> &mut SizeInfo { + &mut self.sizes + } + + fn is_compatible(&self, other: &Box) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + if self.dtype != other.dtype { + return false; + } + + self.validity.is_some() == other.validity.is_some() + } + + fn into_arrow(self: Box) -> IntoArrowResult { + let dtype = self.dtype; + let data = ArrowBuffer::from(self.data); + let validity = from_tdb_validity(self.validity); + let sizes = self.sizes; + + assert!(dtype.is_primitive()); + + // N.B., data/validity clones are cheap. They are not cloning the + // underlying data buffers. We have to clone so that we can put ourself + // back together if the array conversion failes. + aa::ArrayData::try_new( + dtype.clone(), + *sizes.data as usize / dtype.primitive_width().unwrap(), + validity.clone().map(|v| v.into_inner().into_inner()), + 0, + vec![data.clone()], + vec![], + ) + .map(aa::make_array) + .map_err(|e| { + let boxed: Box = + Box::new(PrimitiveBuffers { + dtype, + data: data.into_mutable().unwrap(), + validity: to_tdb_validity(validity), + sizes, + }); + + (boxed, Error::ArrayCreationFailed(e)) + }) + } +} + +impl TryFrom> for Box { + type Error = FromArrowError; + + fn try_from(array: Arc) -> FromArrowResult { + let dtype = array.data_type().clone(); + match dtype { + adt::DataType::Boolean => { + BooleanBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::LargeBinary | adt::DataType::LargeUtf8 => { + ByteBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::FixedSizeList(_, _) => { + FixedListBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::LargeList(_) => { + ListBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::Int8 + | adt::DataType::Int16 + | adt::DataType::Int32 + | adt::DataType::Int64 + | adt::DataType::UInt8 + | adt::DataType::UInt16 + | adt::DataType::UInt32 + | adt::DataType::UInt64 + | adt::DataType::Float32 + | adt::DataType::Float64 + | adt::DataType::Timestamp(_, None) + | adt::DataType::Time64(_) => PrimitiveBuffers::try_from(array) + .map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }), + + adt::DataType::Timestamp(_, Some(_)) => { + Err((array, Error::UnsupportedTimeZones)) + } + + adt::DataType::Binary + | adt::DataType::List(_) + | adt::DataType::Utf8 => { + Err((array, Error::LargeVariantOnly(dtype))) + } + + adt::DataType::FixedSizeBinary(_) => { + todo!("This can probably be supported.") + } + + adt::DataType::Null + | adt::DataType::Float16 + | adt::DataType::Date32 + | adt::DataType::Date64 + | adt::DataType::Time32(_) + | adt::DataType::Duration(_) + | adt::DataType::Interval(_) + | adt::DataType::BinaryView + | adt::DataType::Utf8View + | adt::DataType::ListView(_) + | adt::DataType::LargeListView(_) + | adt::DataType::Struct(_) + | adt::DataType::Union(_, _) + | adt::DataType::Dictionary(_, _) + | adt::DataType::Decimal128(_, _) + | adt::DataType::Decimal256(_, _) + | adt::DataType::Map(_, _) + | adt::DataType::RunEndEncoded(_, _) => { + return Err((array, Error::UnsupportedArrowType(dtype))); + } + } + } +} + +/// A utility that requires it contains exactly one of two variants +/// +/// Unfortuantely, mutating enum variants through a mutable reference isn't +/// a thing that can be done safely, so we have a specialized utility struct +/// that does the same idea, at the cost of an extra None in the struct. +struct MutableOrShared { + mutable: Option>, + shared: Option>, +} + +impl MutableOrShared { + pub fn new(value: Arc) -> Self { + Self { + mutable: None, + shared: Some(value), + } + } + + pub fn mutable(&mut self) -> Option<&mut Box> { + self.mutable.as_mut() + } + + pub fn shared(&self) -> Option> { + self.shared.as_ref().map(Arc::clone) + } + + pub fn to_mutable(&mut self) -> Result<()> { + self.validate(); + + if self.mutable.is_some() { + return Ok(()); + } + + let shared = self.shared.take().unwrap(); + let mutable: FromArrowResult> = + shared.try_into(); + + let ret = if mutable.is_ok() { + self.mutable = mutable.ok(); + Ok(()) + } else { + let (array, err) = mutable.err().unwrap(); + self.shared = Some(array); + Err(err) + }; + + self.validate(); + ret + } + + pub fn to_shared(&mut self) -> Result<()> { + self.validate(); + + if self.shared.is_some() { + return Ok(()); + } + + let mutable = self.mutable.take().unwrap(); + let shared = mutable.into_arrow(); + + let ret = if shared.is_ok() { + self.shared = shared.ok(); + Ok(()) + } else { + let (mutable, err) = shared.err().unwrap(); + self.mutable = Some(mutable); + Err(err) + }; + + self.validate(); + ret + } + + fn validate(&self) { + assert!( + (self.shared.is_some() && self.mutable.is_none()) + || (self.shared.is_none() && self.mutable.is_some()) + ) + } +} + +pub struct BufferEntry { + entry: MutableOrShared, +} + +impl BufferEntry { + pub fn as_shared(&self) -> Result> { + let Some(ref array) = self.entry.shared() else { + return Err(Error::UnshareableMutableBuffer); + }; + + return Ok(Arc::clone(array)); + } + + pub fn len(&self) -> usize { + if self.entry.shared.is_some() { + return self.entry.shared.as_ref().unwrap().len(); + } else { + self.entry.mutable.as_ref().unwrap().len() + } + } + + pub fn is_compatible(&self, other: &BufferEntry) -> bool { + if self.entry.mutable.is_some() && other.entry.mutable.is_some() { + return self + .entry + .mutable + .as_ref() + .unwrap() + .is_compatible(other.entry.mutable.as_ref().unwrap()); + } + + false + } + + pub fn reset_len(&mut self) -> Result<()> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + mutable.reset_len(); + Ok(()) + } + + pub fn shrink_len(&mut self) -> Result<()> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + mutable.shrink_len(); + Ok(()) + } + + pub fn data_ptr(&mut self) -> Result<*mut std::ffi::c_void> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.data_ptr()) + } + + pub fn data_size_ptr(&mut self) -> Result<*mut u64> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.data_size_ptr()) + } + + pub fn has_offsets(&mut self) -> Result { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.offsets_ptr() != std::ptr::null_mut()) + } + + pub fn offsets_ptr(&mut self) -> Result<*mut u64> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.offsets_ptr()) + } + + pub fn offsets_size_ptr(&mut self) -> Result<*mut u64> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.offsets_size_ptr()) + } + + pub fn has_validity(&mut self) -> Result { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.validity_ptr() != std::ptr::null_mut()) + } + + pub fn validity_ptr(&mut self) -> Result<*mut u8> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.validity_ptr()) + } + + pub fn validity_size_ptr(&mut self) -> Result<*mut u64> { + let Some(mutable) = self.entry.mutable() else { + return Err(Error::ImmutableBuffer); + }; + + Ok(mutable.validity_size_ptr()) + } + + fn to_mutable(&mut self) -> Result<()> { + self.entry.to_mutable() + } + + fn to_shared(&mut self) -> Result<()> { + self.entry.to_shared() + } +} + +impl From> for BufferEntry { + fn from(array: Arc) -> Self { + Self { + entry: MutableOrShared::new(array), + } + } +} + +pub struct QueryBuffers { + buffers: HashMap, +} + +impl QueryBuffers { + pub fn new(buffers: HashMap>) -> Self { + let mut new_buffers = HashMap::with_capacity(buffers.len()); + for (field, array) in buffers.into_iter() { + new_buffers.insert(field, BufferEntry::from(array)); + } + Self { + buffers: new_buffers, + } + } + + /// Reset all mutable buffers' len to match its total capacity. + pub fn reset_lengths(&mut self) -> Result<()> { + for array in self.buffers.values_mut() { + array.reset_len()?; + } + Ok(()) + } + + /// Shrink all mutable buffers' len to match what the TileDB size read. + pub fn shrink_lengths(&mut self) -> Result<()> { + for array in self.buffers.values_mut() { + array.shrink_len()?; + } + Ok(()) + } + + pub fn len(&self) -> usize { + self.buffers.len() + } + + pub fn fields(&self) -> Vec { + self.buffers.keys().cloned().collect::>() + } + + pub fn get(&self, key: &String) -> Option<&BufferEntry> { + self.buffers.get(key) + } + + pub fn get_mut(&mut self, key: &String) -> Option<&mut BufferEntry> { + self.buffers.get_mut(key) + } + + pub fn from_fields(schema: Schema, fields: QueryFields) -> Result { + let conv = ToArrowConverter::strict(); + let mut ret = HashMap::with_capacity(fields.fields.len()); + for (name, field) in fields.fields.into_iter() { + let tdb_field = schema.field(name.clone())?; + + if let QueryField::Buffer(array) = field { + Self::validate_buffer(&tdb_field, &array)?; + ret.insert(name.clone(), array); + continue; + } + + // ToDo: Clean these error conversions up so they clearly indicate + // a failed buffer creation. + let tdb_dtype = tdb_field.datatype()?; + let tdb_cvn = tdb_field.cell_val_num()?; + let tdb_nullable = tdb_field.nullability()?; + let arrow_type = if let Some(dtype) = field.target_type() { + conv.convert_datatype_to( + &tdb_dtype, + &tdb_cvn, + tdb_nullable, + dtype, + ) + } else { + conv.convert_datatype(&tdb_dtype, &tdb_cvn, tdb_nullable) + } + .map_err(|e| Error::ArrowConversionError(name.clone(), e))?; + + let array = alloc_array( + arrow_type, + tdb_nullable, + field.capacity().unwrap(), + )?; + ret.insert(name.clone(), array); + } + + Ok(Self::new(ret)) + } + + pub fn is_compatible(&self, other: &Self) -> bool { + let mut my_keys = self.buffers.keys().collect::>(); + let mut their_keys = other.buffers.keys().collect::>(); + + my_keys.sort(); + their_keys.sort(); + if my_keys != their_keys { + return false; + } + + for key in my_keys { + let mine = self.buffers.get(key); + let theirs = other.buffers.get(key); + + if mine.is_none() || theirs.is_none() { + return false; + } + + if !mine.unwrap().is_compatible(theirs.unwrap()) { + return false; + } + } + + return true; + } + + pub fn iter(&self) -> impl Iterator { + self.buffers.iter() + } + + pub fn iter_mut( + &mut self, + ) -> impl Iterator { + self.buffers.iter_mut() + } + + pub fn to_mutable(&mut self) -> Result<()> { + for value in self.buffers.values_mut() { + value.to_mutable()? + } + Ok(()) + } + + pub fn to_shared(&mut self) -> Result<()> { + for value in self.buffers.values_mut() { + value.to_shared()? + } + Ok(()) + } + + /// When I get to it, this needs to ensure that the provided array matches + /// the field's TileDB datatype. + fn validate_buffer( + _field: &Field, + _buffer: &Arc, + ) -> Result<()> { + Ok(()) + } +} + +/// A small helper for users writing code directly against the TileDB API +/// +/// This struct is freely convertible to and from a HashMap of Arrow arrays. +pub struct SharedBuffers { + buffers: HashMap>, +} + +impl SharedBuffers { + pub fn get(&self, key: &str) -> Option<&T> + where + T: Any, + { + self.buffers.get(key)?.as_any().downcast_ref::() + } +} + +impl From>> for SharedBuffers { + fn from(buffers: HashMap>) -> Self { + Self { buffers } + } +} + +impl From for HashMap> { + fn from(buffers: SharedBuffers) -> Self { + buffers.buffers + } +} + +fn alloc_array( + dtype: adt::DataType, + nullable: bool, + capacity: usize, +) -> Result> { + let num_cells = calculate_num_cells(dtype.clone(), nullable, capacity)?; + + match dtype { + adt::DataType::Boolean => { + Ok(Arc::new(aa::BooleanArray::new_null(num_cells))) + } + adt::DataType::LargeList(field) => { + let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); + let value_capacity = + capacity - (num_cells * std::mem::size_of::()); + let values = + alloc_array(field.data_type().clone(), false, value_capacity)?; + let nulls = if nullable { + Some(abuf::NullBuffer::new_null(num_cells)) + } else { + None + }; + Ok(Arc::new( + aa::LargeListArray::try_new(field, offsets, values, nulls) + .map_err(Error::ArrayCreationFailed)?, + )) + } + adt::DataType::FixedSizeList(field, cvn) => { + let nulls = if nullable { + Some(abuf::NullBuffer::new_null(num_cells)) + } else { + None + }; + let values = + alloc_array(field.data_type().clone(), false, capacity)?; + Ok(Arc::new( + aa::FixedSizeListArray::try_new(field, cvn, values, nulls) + .map_err(Error::ArrayCreationFailed)?, + )) + } + adt::DataType::LargeUtf8 => { + let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); + let values = ArrowBufferMut::from_len_zeroed( + capacity - (num_cells * std::mem::size_of::()), + ); + let nulls = if nullable { + Some(abuf::NullBuffer::new_null(num_cells)) + } else { + None + }; + Ok(Arc::new( + aa::LargeStringArray::try_new(offsets, values.into(), nulls) + .map_err(Error::ArrayCreationFailed)?, + )) + } + adt::DataType::LargeBinary => { + let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); + let values = ArrowBufferMut::from_len_zeroed( + capacity - (num_cells * std::mem::size_of::()), + ); + let nulls = if nullable { + Some(abuf::NullBuffer::new_null(num_cells)) + } else { + None + }; + Ok(Arc::new( + aa::LargeBinaryArray::try_new(offsets, values.into(), nulls) + .map_err(Error::ArrayCreationFailed)?, + )) + } + _ if dtype.is_primitive() => { + let data = ArrowBufferMut::from_len_zeroed( + num_cells * dtype.primitive_width().unwrap(), + ); + + let nulls = if nullable { + Some(ArrowBufferMut::from_len_zeroed(num_cells).into()) + } else { + None + }; + + let data = aa::ArrayData::try_new( + dtype, + num_cells, + nulls, + 0, + vec![data.into()], + vec![], + ) + .map_err(|e| Error::ArrayCreationFailed(e))?; + + Ok(aa::make_array(data)) + } + _ => todo!(), + } +} + +fn calculate_num_cells( + dtype: adt::DataType, + nullable: bool, + capacity: usize, +) -> Result { + match dtype { + adt::DataType::Boolean => { + if nullable { + Ok(capacity * 8 / 2) + } else { + Ok(capacity * 8) + } + } + adt::DataType::LargeList(ref field) => { + if !field.data_type().is_primitive() { + return Err(Error::UnsupportedArrowType(dtype.clone())); + } + + // Todo: Figure out a better way to approximate values to offsets ratios + // based on whatever Python does or some such. + // + // For now, I'll pull a guess at of the ether and assume on average a + // var sized primitive array averages two values per cell. Becuase why + // not? + let width = field.data_type().primitive_width().unwrap(); + let bytes_per_cell = (width * 2) + + std::mem::size_of::() + + if nullable { 1 } else { 0 }; + Ok(capacity / bytes_per_cell) + } + adt::DataType::FixedSizeList(ref field, cvn) => { + if !field.data_type().is_primitive() { + return Err(Error::UnsupportedArrowType(dtype)); + } + + if cvn < 2 { + return Err(Error::InvalidFixedSizeListLength(cvn)); + } + + let cvn = cvn as usize; + let width = field.data_type().primitive_width().unwrap(); + let bytes_per_cell = capacity / (width * cvn); + let bytes_per_cell = if nullable { + bytes_per_cell + 1 + } else { + bytes_per_cell + }; + Ok(capacity / bytes_per_cell) + } + adt::DataType::LargeUtf8 | adt::DataType::LargeBinary => { + let bytes_per_cell = + AVERAGE_STRING_LENGTH + std::mem::size_of::(); + let bytes_per_cell = if nullable { + bytes_per_cell + 1 + } else { + bytes_per_cell + }; + Ok(capacity / bytes_per_cell) + } + _ if dtype.is_primitive() => { + let width = dtype.primitive_width().unwrap(); + let bytes_per_cell = width + if nullable { 1 } else { 0 }; + Ok(capacity / bytes_per_cell) + } + _ => Err(Error::UnsupportedArrowType(dtype.clone())), + } +} + +// Private utility functions + +fn to_tdb_offsets(offsets: abuf::OffsetBuffer) -> Result { + offsets + .into_inner() + .into_inner() + .into_mutable() + .map_err(|_| Error::ArrayInUse) +} + +fn to_tdb_validity(nulls: Option) -> Option { + nulls.map(|nulls| { + ArrowBufferMut::from( + nulls + .iter() + .map(|v| if v { 1u8 } else { 0 }) + .collect::>(), + ) + }) +} + +fn from_tdb_offsets(offsets: ArrowBufferMut) -> abuf::OffsetBuffer { + let buffer = abuf::ScalarBuffer::::from(offsets); + abuf::OffsetBuffer::new(buffer) +} + +fn from_tdb_validity( + validity: Option, +) -> Option { + validity.map(|v| { + abuf::NullBuffer::from( + v.into_iter() + .map(|f| if *f != 0 { true } else { false }) + .collect::>(), + ) + }) +} diff --git a/tiledb/api/src/query_arrow/fields.rs b/tiledb/api/src/query_arrow/fields.rs new file mode 100644 index 00000000..770d4f0b --- /dev/null +++ b/tiledb/api/src/query_arrow/fields.rs @@ -0,0 +1,171 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array as aa; +use arrow::datatypes as adt; + +use super::QueryBuilder; + +/// Default field capacity is 10MiB +const DEFAULT_CAPACITY: usize = 1024 * 1024 * 10; + +#[derive(Debug, Default)] +pub enum QueryField { + #[default] + Default, + WithCapacity(usize), + WithCapacityAndType(usize, adt::DataType), + WithType(adt::DataType), + Buffer(Arc), +} + +impl QueryField { + pub fn capacity(&self) -> Option { + match self { + Self::Default => Some(DEFAULT_CAPACITY), + Self::WithCapacity(capacity) => Some(*capacity), + Self::WithCapacityAndType(capacity, _) => Some(*capacity), + Self::WithType(_) => Some(DEFAULT_CAPACITY), + Self::Buffer(_) => None, + } + } + + pub fn target_type(&self) -> Option { + match self { + Self::Default => None, + Self::WithCapacity(_) => None, + Self::WithCapacityAndType(_, dtype) => Some(dtype.clone()), + Self::WithType(dtype) => Some(dtype.clone()), + Self::Buffer(array) => Some(array.data_type().clone()), + } + } +} + +#[derive(Debug, Default)] +pub struct QueryFields { + pub fields: HashMap, +} + +impl QueryFields { + pub fn insert>(&mut self, name: S, field: QueryField) { + let name: String = name.into(); + self.fields.insert(name.clone(), field); + } +} + +pub struct QueryFieldsBuilder { + fields: QueryFields, +} + +impl QueryFieldsBuilder { + pub fn new() -> Self { + Self { + fields: Default::default(), + } + } + + pub fn build(self) -> QueryFields { + self.fields + } + + pub fn field(mut self, name: &str) -> Self { + self.fields.insert(name, Default::default()); + self + } + + pub fn field_with_buffer( + mut self, + name: &str, + buffer: Arc, + ) -> Self { + self.fields.insert(name, QueryField::Buffer(buffer)); + self + } + + pub fn field_with_capacity(mut self, name: &str, capacity: usize) -> Self { + self.fields.insert(name, QueryField::WithCapacity(capacity)); + self + } + + pub fn field_with_capacity_and_type( + mut self, + name: &str, + capacity: usize, + dtype: adt::DataType, + ) -> Self { + self.fields + .insert(name, QueryField::WithCapacityAndType(capacity, dtype)); + self + } + + pub fn field_with_type(mut self, name: &str, dtype: adt::DataType) -> Self { + self.fields.insert(name, QueryField::WithType(dtype)); + self + } +} + +pub struct QueryFieldsBuilderForQuery { + query_builder: QueryBuilder, + fields_builder: QueryFieldsBuilder, +} + +impl QueryFieldsBuilderForQuery { + pub(crate) fn new(query_builder: QueryBuilder) -> Self { + Self { + query_builder, + fields_builder: QueryFieldsBuilder::new(), + } + } + + pub fn end_fields(self) -> QueryBuilder { + self.query_builder.with_fields(self.fields_builder.build()) + } + + pub fn field(self, name: &str) -> Self { + Self { + fields_builder: self.fields_builder.field(name), + ..self + } + } + + pub fn field_with_buffer( + self, + name: &str, + buffer: Arc, + ) -> Self { + Self { + fields_builder: self.fields_builder.field_with_buffer(name, buffer), + ..self + } + } + + pub fn field_with_capacity(self, name: &str, capacity: usize) -> Self { + Self { + fields_builder: self + .fields_builder + .field_with_capacity(name, capacity), + ..self + } + } + + pub fn field_with_capacity_and_type( + self, + name: &str, + capacity: usize, + dtype: adt::DataType, + ) -> Self { + Self { + fields_builder: self + .fields_builder + .field_with_capacity_and_type(name, capacity, dtype), + ..self + } + } + + pub fn field_with_type(self, name: &str, dtype: adt::DataType) -> Self { + Self { + fields_builder: self.fields_builder.field_with_type(name, dtype), + ..self + } + } +} diff --git a/tiledb/api/src/query_arrow/mod.rs b/tiledb/api/src/query_arrow/mod.rs new file mode 100644 index 00000000..984cedcb --- /dev/null +++ b/tiledb/api/src/query_arrow/mod.rs @@ -0,0 +1,578 @@ +///! The TileDB Query interface and supporting utilities +use std::collections::HashMap; +use std::ops::Deref; + +use thiserror::Error; + +use crate::array::Array; +use crate::config::Config; +use crate::context::{CApiInterface, Context, ContextBound}; +use crate::error::Error as TileDBError; +use crate::key::LookupKey; +use crate::query::conditions::QueryConditionExpr; +use crate::range::{Range, SingleValueRange, VarValueRange}; + +use buffers::{Error as QueryBuffersError, QueryBuffers, SharedBuffers}; +use fields::{QueryFields, QueryFieldsBuilderForQuery}; +use subarray::{SubarrayBuilderForQuery, SubarrayData}; + +pub mod arrow; +pub mod buffers; +pub mod fields; +pub mod subarray; + +pub type QueryType = crate::array::Mode; +pub type QueryLayout = crate::array::CellOrder; + +/// Errors related to query creation and execution +#[derive(Debug, Error)] +pub enum Error { + #[error("Incompatible buffer specification when replacing buffers.")] + IncompatibleReplacementBuffers, + #[error("Internal TileDB Error: {0}")] + InternalError(String), + #[error("Invalid string for C API calls: {0}")] + NulError(#[from] std::ffi::NulError), + #[error("Error building query buffers: {0}")] + QueryBuffersError(#[from] QueryBuffersError), + #[error("Encountered internal libtiledb error: {0}")] + TileDBError(#[from] TileDBError), +} + +impl From for TileDBError { + fn from(err: Error) -> TileDBError { + TileDBError::Other(format!("{err}")) + } +} + +type Result = std::result::Result; + +/// The status of a query submission +/// +/// Note that BuffersTooSmall is a Rust invention. But given that we never +/// attempt to translate this status object back into a capi value its fine. +pub enum QueryStatus { + Uninitialized, + Initialized, + InProgress, + Incomplete, + BuffersTooSmall, + Completed, + Failed, +} + +impl QueryStatus { + pub fn is_complete(&self) -> bool { + matches!(self, QueryStatus::Completed) + } + + pub fn has_data(&self) -> bool { + !matches!(self, QueryStatus::BuffersTooSmall) + } +} + +impl TryFrom for QueryStatus { + type Error = Error; + fn try_from(status: ffi::tiledb_query_status_t) -> Result { + match status { + ffi::tiledb_query_status_t_TILEDB_UNINITIALIZED => { + Ok(QueryStatus::Uninitialized) + } + ffi::tiledb_query_status_t_TILEDB_INITIALIZED => { + Ok(QueryStatus::Initialized) + } + ffi::tiledb_query_status_t_TILEDB_INPROGRESS => { + Ok(QueryStatus::InProgress) + } + ffi::tiledb_query_status_t_TILEDB_INCOMPLETE => { + Ok(QueryStatus::Incomplete) + } + ffi::tiledb_query_status_t_TILEDB_COMPLETED => { + Ok(QueryStatus::Completed) + } + ffi::tiledb_query_status_t_TILEDB_FAILED => Ok(QueryStatus::Failed), + invalid => Err(Error::InternalError(format!( + "Invaldi query status: {}", + invalid + ))), + } + } +} + +pub(crate) enum RawQuery { + Owned(*mut ffi::tiledb_query_t), +} + +impl Deref for RawQuery { + type Target = *mut ffi::tiledb_query_t; + fn deref(&self) -> &Self::Target { + let RawQuery::Owned(ref ffi) = self; + ffi + } +} + +impl Drop for RawQuery { + fn drop(&mut self) { + let RawQuery::Owned(ref mut ffi) = *self; + unsafe { ffi::tiledb_query_free(ffi) } + } +} + +pub(crate) enum RawSubarray { + Owned(*mut ffi::tiledb_subarray_t), +} + +impl Deref for RawSubarray { + type Target = *mut ffi::tiledb_subarray_t; + fn deref(&self) -> &Self::Target { + match *self { + RawSubarray::Owned(ref ffi) => ffi, + } + } +} + +impl Drop for RawSubarray { + fn drop(&mut self) { + let RawSubarray::Owned(ref mut ffi) = *self; + unsafe { ffi::tiledb_subarray_free(ffi) }; + } +} + +/// The main Query interface +/// +/// This struct is responsible for executing queries against TileDB arrays. +pub struct Query { + context: Context, + raw: RawQuery, + query_type: QueryType, + array: Array, + buffers: QueryBuffers, +} + +impl ContextBound for Query { + fn context(&self) -> Context { + self.array.context() + } +} + +impl Query { + pub(crate) fn capi(&mut self) -> *mut ffi::tiledb_query_t { + *self.raw + } + + pub fn submit(&mut self) -> Result { + self.buffers.to_mutable()?; + if matches!(self.query_type, QueryType::Read) { + self.buffers.reset_lengths()?; + } + self.set_buffers()?; + + let c_query = self.capi(); + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_submit(ctx, c_query) + })?; + + if matches!(self.query_type, QueryType::Read) { + self.buffers.shrink_lengths()?; + } + + match self.curr_status()? { + QueryStatus::Uninitialized + | QueryStatus::Initialized + | QueryStatus::InProgress => { + return Err(Error::InternalError( + "Invalid query status after submit".to_string(), + )) + } + QueryStatus::Failed => { + return Err(self.context.expect_last_error().into()); + } + QueryStatus::Incomplete => { + if self.buffers.iter().any(|(_, b)| b.len() > 0) { + Ok(QueryStatus::Incomplete) + } else { + Ok(QueryStatus::BuffersTooSmall) + } + } + QueryStatus::BuffersTooSmall => { + panic!("TileDB does not generate this variant.") + } + QueryStatus::Completed => Ok(QueryStatus::Completed), + } + } + + pub fn finalize(mut self) -> Result<(Array, SharedBuffers)> { + let c_query = self.capi(); + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_finalize(ctx, c_query) + })?; + + self.buffers.to_shared()?; + let mut ret = HashMap::with_capacity(self.buffers.len()); + for (field, buffer) in self.buffers.iter() { + ret.insert(field.clone(), buffer.as_shared()?); + } + + Ok((self.array, ret.into())) + } + + pub fn buffers(&mut self) -> Result { + self.buffers.to_shared()?; + let mut ret = HashMap::with_capacity(self.buffers.len()); + for (field, buffer) in self.buffers.iter() { + ret.insert(field.clone(), buffer.as_shared()?); + } + + Ok(ret.into()) + } + + /// Replace this queries buffers with a new set specified by fields + /// + /// This can be used to reallocate buffers with a larger capacity. + pub fn replace_buffers( + &mut self, + fields: QueryFields, + ) -> Result { + let mut tmp_buffers = + QueryBuffers::from_fields(self.array.schema()?, fields)?; + tmp_buffers.to_mutable()?; + if self.buffers.is_compatible(&tmp_buffers) { + std::mem::swap(&mut self.buffers, &mut tmp_buffers); + Ok(tmp_buffers) + } else { + Err(Error::IncompatibleReplacementBuffers) + } + } + + fn set_buffers(&mut self) -> Result<()> { + let c_query = self.capi(); + for (field, buffer) in self.buffers.iter_mut() { + let c_name = std::ffi::CString::new(field.as_bytes())?; + + let c_data_ptr = buffer.data_ptr()?; + let c_data_size_ptr = buffer.data_size_ptr()?; + + self.context.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_data_buffer( + ctx, + c_query, + c_name.as_ptr(), + c_data_ptr, + c_data_size_ptr, + ) + })?; + + if buffer.has_offsets()? { + let c_offsets_ptr = buffer.offsets_ptr()?; + let c_offsets_size_ptr = buffer.offsets_size_ptr()?; + self.context.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_offsets_buffer( + ctx, + c_query, + c_name.as_ptr(), + c_offsets_ptr, + c_offsets_size_ptr, + ) + })?; + } + + if buffer.has_validity()? { + let c_validity_ptr = buffer.validity_ptr()?; + let c_validity_size_ptr = buffer.validity_size_ptr()?; + self.context.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_validity_buffer( + ctx, + c_query, + c_name.as_ptr(), + c_validity_ptr, + c_validity_size_ptr, + ) + })?; + } + } + Ok(()) + } + + fn curr_status(&mut self) -> Result { + let c_query = self.capi(); + let mut c_status: ffi::tiledb_query_status_t = out_ptr!(); + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_get_status(ctx, c_query, &mut c_status) + })?; + + QueryStatus::try_from(c_status) + } +} + +/// The main interface to creating Query instances +pub struct QueryBuilder { + context: Context, + array: Array, + query_type: QueryType, + config: Option, + layout: Option, + subarray: Option, + query_condition: Option, + fields: QueryFields, +} + +impl ContextBound for QueryBuilder { + fn context(&self) -> Context { + self.context.clone() + } +} + +impl QueryBuilder { + pub fn new(array: Array, query_type: QueryType) -> Self { + Self { + context: array.context(), + array, + query_type, + config: None, + layout: None, + subarray: None, + query_condition: None, + fields: Default::default(), + } + } + + pub fn read(array: Array) -> Self { + Self::new(array, QueryType::Read) + } + + pub fn write(array: Array) -> Self { + Self::new(array, QueryType::Write) + } + + pub fn build(mut self) -> Result { + let raw = self.alloc_query()?; + + let schema = self.array.schema()?; + self.set_config(&raw)?; + self.set_layout(&raw)?; + self.set_subarray(&raw)?; + self.set_query_condition(&raw)?; + + Ok(Query { + context: self.array.context(), + raw, + query_type: self.query_type, + array: self.array, + buffers: QueryBuffers::from_fields(schema, self.fields)?, + }) + } + + pub fn with_config(mut self, config: Config) -> Self { + self.config = Some(config); + self + } + + pub fn with_layout(mut self, layout: QueryLayout) -> Self { + self.layout = Some(layout); + self + } + + pub fn with_query_condition( + mut self, + query_condition: QueryConditionExpr, + ) -> Self { + self.query_condition = Some(query_condition); + self + } + + pub fn with_subarray_data(mut self, subarray: SubarrayData) -> Self { + self.subarray = Some(subarray); + self + } + + pub fn start_subarray(self) -> SubarrayBuilderForQuery { + SubarrayBuilderForQuery::new(self) + } + + pub fn with_fields(mut self, fields: QueryFields) -> Self { + self.fields = fields; + self + } + + pub fn start_fields(self) -> QueryFieldsBuilderForQuery { + QueryFieldsBuilderForQuery::new(self) + } + + // Internal builder methods below + + fn alloc_query(&self) -> Result { + let c_array = **self.array.capi(); + let c_query_type = self.query_type.capi_enum(); + let mut c_query: *mut ffi::tiledb_query_t = out_ptr!(); + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_alloc(ctx, c_array, c_query_type, &mut c_query) + })?; + + Ok(RawQuery::Owned(c_query)) + } + + fn alloc_subarray(&self) -> Result { + let c_array = **self.array.capi(); + let mut c_subarray: *mut ffi::tiledb_subarray_t = out_ptr!(); + + self.capi_call(|ctx| unsafe { + ffi::tiledb_subarray_alloc(ctx, c_array, &mut c_subarray) + })?; + + Ok(RawSubarray::Owned(c_subarray)) + } + + fn set_config(&mut self, raw: &RawQuery) -> Result<()> { + if self.config.is_none() { + return Ok(()); + } + + // TODO: Reject configurations that will break out buffer management + // logic. Specifically, the various sm.var_offsets.* keys. + let c_query = **raw; + let c_cfg = self.config.as_mut().unwrap().capi(); + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_config(ctx, c_query, c_cfg) + })?; + + Ok(()) + } + + fn set_layout(&mut self, raw: &RawQuery) -> Result<()> { + let Some(layout) = self.layout.as_ref() else { + return Ok(()); + }; + + let c_query = **raw; + let c_layout = layout.capi_enum(); + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_layout(ctx, c_query, c_layout) + })?; + + Ok(()) + } + + fn set_subarray(&self, raw: &RawQuery) -> Result<()> { + let Some(subarray_data) = self.subarray.as_ref() else { + return Ok(()); + }; + + let raw_subarray = self.alloc_subarray()?; + for (key, ranges) in subarray_data.iter() { + for range in ranges { + self.set_subarray_range(*raw_subarray, &key.into(), range)?; + } + } + + let c_query = **raw; + let c_subarray = *raw_subarray; + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_subarray_t(ctx, c_query, c_subarray) + })?; + + Ok(()) + } + + fn set_subarray_range( + &self, + c_subarray: *mut ffi::tiledb_subarray_t, + key: &LookupKey, + range: &Range, + ) -> Result<()> { + let schema = self.array.schema()?; + let dim = schema.domain()?.dimension(key.clone())?; + + range + .check_dimension_compatibility(dim.datatype()?, dim.cell_val_num()?) + .map_err(Error::from)?; + + match range { + Range::Single(range) => { + crate::single_value_range_go!(range, _DT, start, end, { + let start = start.to_le_bytes(); + let end = end.to_le_bytes(); + match key { + LookupKey::Index(idx) => { + self.capi_call(|ctx| unsafe { + ffi::tiledb_subarray_add_range( + ctx, + c_subarray, + *idx as u32, + start.as_ptr() as *const std::ffi::c_void, + end.as_ptr() as *const std::ffi::c_void, + std::ptr::null(), + ) + })?; + } + LookupKey::Name(name) => { + let c_name = std::ffi::CString::new(name.clone())?; + self.capi_call(|ctx| unsafe { + ffi::tiledb_subarray_add_range_by_name( + ctx, + c_subarray, + c_name.as_ptr(), + start.as_ptr() as *const std::ffi::c_void, + end.as_ptr() as *const std::ffi::c_void, + std::ptr::null(), + ) + })?; + } + } + }) + } + Range::Multi(_) => unreachable!( + "This is rejected by range.check_dimension_compatibility" + ), + Range::Var(range) => { + crate::var_value_range_go!(range, _DT, start, end, { + match key { + LookupKey::Index(idx) => { + self.capi_call(|ctx| unsafe { + ffi::tiledb_subarray_add_range_var( + ctx, + c_subarray, + *idx as u32, + start.as_ptr() as *const std::ffi::c_void, + start.len() as u64, + end.as_ptr() as *const std::ffi::c_void, + end.len() as u64, + ) + })?; + } + LookupKey::Name(name) => { + let c_name = std::ffi::CString::new(name.clone())?; + self.capi_call(|ctx| unsafe { + ffi::tiledb_subarray_add_range_var_by_name( + ctx, + c_subarray, + c_name.as_ptr(), + start.as_ptr() as *const std::ffi::c_void, + start.len() as u64, + end.as_ptr() as *const std::ffi::c_void, + end.len() as u64, + ) + })?; + } + } + }) + } + } + + Ok(()) + } + + fn set_query_condition(&self, raw: &RawQuery) -> Result<()> { + let Some(query_condition) = self.query_condition.as_ref() else { + return Ok(()); + }; + + let cq_raw = query_condition.build(&self.context)?; + let c_query = **raw; + let c_cond = *cq_raw; + self.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_condition(ctx, c_query, c_cond) + })?; + + Ok(()) + } +} diff --git a/tiledb/api/src/query_arrow/subarray.rs b/tiledb/api/src/query_arrow/subarray.rs new file mode 100644 index 00000000..038dc7ae --- /dev/null +++ b/tiledb/api/src/query_arrow/subarray.rs @@ -0,0 +1,64 @@ +use std::collections::HashMap; + +use crate::range::Range; + +use super::QueryBuilder; + +pub type SubarrayData = HashMap>; + +pub struct SubarrayBuilder { + subarray: SubarrayData, +} + +impl SubarrayBuilder { + pub fn new() -> Self { + Self { + subarray: Default::default(), + } + } + + pub fn add_range>( + mut self, + dimension: &str, + range: IntoRange, + ) -> Self { + self.subarray + .entry(dimension.to_string()) + .or_default() + .push(range.into()); + self + } + + pub fn build(self) -> SubarrayData { + self.subarray + } +} + +pub struct SubarrayBuilderForQuery { + query_builder: QueryBuilder, + subarray_builder: SubarrayBuilder, +} + +impl SubarrayBuilderForQuery { + pub(crate) fn new(query_builder: QueryBuilder) -> Self { + Self { + query_builder, + subarray_builder: SubarrayBuilder::new(), + } + } + + pub fn end_subarray(self) -> QueryBuilder { + self.query_builder + .with_subarray_data(self.subarray_builder.build()) + } + + pub fn add_range>( + mut self, + dimension: &str, + range: IntoRange, + ) -> Self { + self.subarray_builder = + self.subarray_builder.add_range(dimension, range); + self + } +} diff --git a/tiledb/api/src/range.rs b/tiledb/api/src/range.rs index c5dca513..a35c18aa 100644 --- a/tiledb/api/src/range.rs +++ b/tiledb/api/src/range.rs @@ -224,6 +224,12 @@ impl Hash for SingleValueRange { macro_rules! single_value_range_from { ($($V:ident : $U:ty),+) => { $( + impl From<[$U; 2]> for SingleValueRange { + fn from(value: [$U; 2]) -> SingleValueRange { + SingleValueRange::$V(value[0], value[1]) + } + } + impl From<&[$U; 2]> for SingleValueRange { fn from(value: &[$U; 2]) -> SingleValueRange { SingleValueRange::$V(value[0], value[1]) From 905e5c5a1452f2fbe67090acab9f9f98f71acb4b Mon Sep 17 00:00:00 2001 From: "Paul J. Davis" Date: Fri, 4 Oct 2024 10:59:51 -0500 Subject: [PATCH 02/42] Cleanup error handling --- .../examples/multi_range_subarray_arrow.rs | 27 ++++---------- .../examples/query_condition_dense_arrow.rs | 27 ++++---------- .../examples/query_condition_sparse_arrow.rs | 7 +--- tiledb/api/examples/quickstart_dense_arrow.rs | 36 ++++++------------- .../quickstart_sparse_string_arrow.rs | 35 ++++-------------- 5 files changed, 31 insertions(+), 101 deletions(-) diff --git a/tiledb/api/examples/multi_range_subarray_arrow.rs b/tiledb/api/examples/multi_range_subarray_arrow.rs index 1019a5d1..5d8c6a05 100644 --- a/tiledb/api/examples/multi_range_subarray_arrow.rs +++ b/tiledb/api/examples/multi_range_subarray_arrow.rs @@ -9,8 +9,7 @@ use tiledb::array::{ SchemaData, TileOrder, }; use tiledb::context::Context; -use tiledb::error::Error as TileDBError; -use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryStatus, QueryType}; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb::Result as TileDBResult; use tiledb::{Datatype, Factory}; @@ -68,20 +67,12 @@ fn main() -> TileDBResult<()> { .add_range("rows", &[4, 4]) .add_range("cols", &[1, 4]) .end_subarray() - .build() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + .build()?; - let status = query - .submit() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + let status = query.submit()?; + assert!(status.is_complete()); - if !matches!(status, QueryStatus::Completed) { - return Err(TileDBError::Other("Make this better.".to_string())); - } - - let buffers = query - .buffers() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + let buffers = query.buffers()?; let rows = buffers.get::("rows").unwrap(); let cols = buffers.get::("cols").unwrap(); @@ -142,13 +133,9 @@ fn write_array(ctx: &Context) -> TileDBResult<()> { .start_fields() .field_with_buffer("a", data) .end_fields() - .build() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + .build()?; - let (_, _) = query - .submit() - .and_then(|_| query.finalize()) - .map_err(|e| TileDBError::Other(format!("{e}")))?; + let (_, _) = query.submit().and_then(|_| query.finalize())?; Ok(()) } diff --git a/tiledb/api/examples/query_condition_dense_arrow.rs b/tiledb/api/examples/query_condition_dense_arrow.rs index 5cef4450..aa91960a 100644 --- a/tiledb/api/examples/query_condition_dense_arrow.rs +++ b/tiledb/api/examples/query_condition_dense_arrow.rs @@ -8,7 +8,6 @@ use tiledb::array::{ Array, ArrayType, AttributeBuilder, DimensionBuilder, DomainBuilder, SchemaBuilder, }; -use tiledb::error::Error as TileDBError; use tiledb::query::conditions::QueryConditionExpr as QC; use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb::{Context, Datatype, Result as TileDBResult}; @@ -86,22 +85,12 @@ fn read_array(ctx: &Context, qc: Option) -> TileDBResult<()> { query }; - let mut query = query - .build() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + let mut query = query.build()?; - let status = query - .submit() - .map_err(|e| TileDBError::Other(format!("{e}")))?; - - if !status.is_complete() { - return Err(TileDBError::Other("Query incomplete.".to_string())); - } - - let buffers = query.buffers().map_err(|e| { - TileDBError::Other(format!("Error getting buffers: {e}")) - })?; + let status = query.submit()?; + assert!(status.is_complete()); + let buffers = query.buffers()?; let index = buffers.get::("index").unwrap(); let a = buffers.get::("a").unwrap(); let b = buffers.get::("b").unwrap(); @@ -230,13 +219,9 @@ fn write_array(ctx: &Context) -> TileDBResult<()> { .start_subarray() .add_range("index", &[0i32, NUM_ELEMS - 1]) .end_subarray() - .build() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + .build()?; - query - .submit() - .and_then(|_| query.finalize()) - .map_err(|e| TileDBError::Other(format!("{e}")))?; + query.submit().and_then(|_| query.finalize())?; Ok(()) } diff --git a/tiledb/api/examples/query_condition_sparse_arrow.rs b/tiledb/api/examples/query_condition_sparse_arrow.rs index 3fae9b6d..6e6e4969 100644 --- a/tiledb/api/examples/query_condition_sparse_arrow.rs +++ b/tiledb/api/examples/query_condition_sparse_arrow.rs @@ -8,7 +8,6 @@ use tiledb::array::{ Array, ArrayType, AttributeBuilder, CellOrder, DimensionBuilder, DomainBuilder, SchemaBuilder, }; -use tiledb::error::Error as TileDBError; use tiledb::query::conditions::QueryConditionExpr as QC; use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb::{Context, Datatype, Result as TileDBResult}; @@ -88,13 +87,9 @@ fn read_array(ctx: &Context, qc: Option) -> TileDBResult<()> { let mut query = query.build()?; let status = query.submit()?; - - if !status.is_complete() { - return Err(TileDBError::Other("Query did not complete.".to_string())); - } + assert!(status.is_complete()); let buffers = query.buffers()?; - let index = buffers.get::("index").unwrap(); let a = buffers.get::("a").unwrap(); let b = buffers.get::("b").unwrap(); diff --git a/tiledb/api/examples/quickstart_dense_arrow.rs b/tiledb/api/examples/quickstart_dense_arrow.rs index de9f0961..e58f582c 100644 --- a/tiledb/api/examples/quickstart_dense_arrow.rs +++ b/tiledb/api/examples/quickstart_dense_arrow.rs @@ -9,8 +9,7 @@ use tiledb::array::{ DomainBuilder, Mode as ArrayMode, SchemaBuilder, }; use tiledb::context::Context; -use tiledb::error::Error as TileDBError; -use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryStatus, QueryType}; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb::Datatype; use tiledb::Result as TileDBResult; @@ -95,18 +94,11 @@ fn write_array() -> TileDBResult<()> { .start_fields() .field_with_buffer(QUICKSTART_ATTRIBUTE_NAME, data) .end_fields() - .build() - // TODO: Make this not suck - .map_err(|e| TileDBError::Other(format!("{}", e)))?; - - let status = query - .submit() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; - if matches!(status, QueryStatus::Completed) { - return Ok(()); - } else { - return Err(TileDBError::Other("Something better here.".to_string())); - } + .build()?; + + query.submit().and_then(|_| query.finalize())?; + + Ok(()) } /// Query back a slice of our array and print the results to stdout. @@ -133,20 +125,12 @@ fn read_array() -> TileDBResult<()> { .add_range("rows", &[1i32, 2]) .add_range("columns", &[2i32, 4]) .end_subarray() - .build() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; - - let status = query - .submit() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; + .build()?; - if !matches!(status, QueryStatus::Completed) { - return Err(TileDBError::Other("Make this better.".to_string())); - } + let status = query.submit()?; + assert!(status.is_complete()); - let buffers = query - .buffers() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; + let buffers = query.buffers()?; let rows = buffers.get::("rows").unwrap(); let cols = buffers.get::("columns").unwrap(); let attrs = buffers diff --git a/tiledb/api/examples/quickstart_sparse_string_arrow.rs b/tiledb/api/examples/quickstart_sparse_string_arrow.rs index 7ca552bf..3ff1d389 100644 --- a/tiledb/api/examples/quickstart_sparse_string_arrow.rs +++ b/tiledb/api/examples/quickstart_sparse_string_arrow.rs @@ -10,8 +10,7 @@ use tiledb::array::{ SchemaData, TileOrder, }; use tiledb::context::Context; -use tiledb::error::Error as TileDBError; -use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryStatus, QueryType}; +use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb::Result as TileDBResult; use tiledb::{Datatype, Factory}; @@ -44,21 +43,12 @@ fn main() -> TileDBResult<()> { .add_range("rows", &["a", "c"]) .add_range("cols", &[2, 4]) .end_subarray() - .build() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; + .build()?; - let status = query - .submit() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; - - if !matches!(status, QueryStatus::Completed) { - return Err(TileDBError::Other("Make this better.".to_string())); - } - - let buffers = query - .buffers() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; + let status = query.submit()?; + assert!(status.is_complete()); + let buffers = query.buffers()?; let rows = buffers.get::("rows").unwrap(); let cols = buffers.get::("cols").unwrap(); let attr = buffers.get::("a").unwrap(); @@ -120,20 +110,9 @@ fn write_array(ctx: &Context) -> TileDBResult<()> { .field_with_buffer("cols", col_data) .field_with_buffer("a", a_data) .end_fields() - .build() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; - - let status = query - .submit() - .map_err(|e| TileDBError::Other(format!("{}", e)))?; - - if !matches!(status, QueryStatus::Completed) { - return Err(TileDBError::Other("Make this better.".to_string())); - } + .build()?; - let (_, _) = query - .finalize() - .map_err(|e| TileDBError::Other(format!("{e}")))?; + query.submit().and_then(|_| query.finalize())?; Ok(()) } From a0df7085d9e07e9b161a24bdf8918e35eca4ce16 Mon Sep 17 00:00:00 2001 From: "Paul J. Davis" Date: Wed, 9 Oct 2024 16:02:28 -0500 Subject: [PATCH 03/42] Show external reference error --- .../api/examples/reading_incomplete_arrow.rs | 31 +++++++++++++++++-- tiledb/api/src/query_arrow/buffers.rs | 1 + tiledb/api/src/query_arrow/mod.rs | 4 ++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tiledb/api/examples/reading_incomplete_arrow.rs b/tiledb/api/examples/reading_incomplete_arrow.rs index 17064551..68dec4aa 100644 --- a/tiledb/api/examples/reading_incomplete_arrow.rs +++ b/tiledb/api/examples/reading_incomplete_arrow.rs @@ -9,8 +9,11 @@ use tiledb::array::{ DomainData, SchemaData, TileOrder, }; use tiledb::context::Context; +use tiledb::query_arrow::buffers::Error as BuffersError; use tiledb::query_arrow::fields::QueryFieldsBuilder; -use tiledb::query_arrow::{QueryBuilder, QueryLayout, QueryType}; +use tiledb::query_arrow::{ + Error as QueryError, QueryBuilder, QueryLayout, QueryType, SharedBuffers, +}; use tiledb::Result as TileDBResult; use tiledb::{Datatype, Factory}; @@ -136,8 +139,27 @@ fn read_array(ctx: &Context) -> TileDBResult<()> { .end_subarray() .build()?; + let mut external_ref: Option = None; + loop { - let status = query.submit()?; + let result = query.submit(); + + if result.is_err() { + let err = result.err().unwrap(); + println!("ERROR: {:?}", err); + + if matches!( + err, + QueryError::QueryBuffersError(BuffersError::ArrayInUse) + ) { + drop(external_ref.take()); + continue; + } + + return Err(err.into()); + } + + let status = result.ok().unwrap(); // Double our buffer sizes if we didn't manage to get any data out // of the query. @@ -154,6 +176,11 @@ fn read_array(ctx: &Context) -> TileDBResult<()> { // Print any results we did get. let buffers = query.buffers()?; + + // Simulate what happens if the client doesn't let go of their + // SharedBuffers reference. + external_ref = Some(buffers.clone()); + let rows = buffers.get::("rows").unwrap(); let cols = buffers.get::("cols").unwrap(); let a1 = buffers.get::("a1").unwrap(); diff --git a/tiledb/api/src/query_arrow/buffers.rs b/tiledb/api/src/query_arrow/buffers.rs index b2195d6d..6b5d4f13 100644 --- a/tiledb/api/src/query_arrow/buffers.rs +++ b/tiledb/api/src/query_arrow/buffers.rs @@ -1416,6 +1416,7 @@ impl QueryBuffers { /// A small helper for users writing code directly against the TileDB API /// /// This struct is freely convertible to and from a HashMap of Arrow arrays. +#[derive(Clone)] pub struct SharedBuffers { buffers: HashMap>, } diff --git a/tiledb/api/src/query_arrow/mod.rs b/tiledb/api/src/query_arrow/mod.rs index 984cedcb..1392efbe 100644 --- a/tiledb/api/src/query_arrow/mod.rs +++ b/tiledb/api/src/query_arrow/mod.rs @@ -12,10 +12,12 @@ use crate::key::LookupKey; use crate::query::conditions::QueryConditionExpr; use crate::range::{Range, SingleValueRange, VarValueRange}; -use buffers::{Error as QueryBuffersError, QueryBuffers, SharedBuffers}; +use buffers::{Error as QueryBuffersError, QueryBuffers}; use fields::{QueryFields, QueryFieldsBuilderForQuery}; use subarray::{SubarrayBuilderForQuery, SubarrayData}; +pub use buffers::SharedBuffers; + pub mod arrow; pub mod buffers; pub mod fields; From 6ce17e701a2fd620bdccabb1e93abc4c4eae0a6b Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 09:04:07 -0500 Subject: [PATCH 04/42] Split arrow queries out to new tiledb-query-core crate --- Cargo.lock | 13 ++++++++++ Cargo.toml | 4 ++- tiledb/api/Cargo.toml | 13 +--------- tiledb/api/src/array/mod.rs | 6 +++++ tiledb/api/src/lib.rs | 1 - tiledb/api/src/query/conditions.rs | 12 ++++++++- tiledb/query-core/Cargo.toml | 17 ++++++++++++ .../examples/multi_range_subarray_arrow.rs | 2 +- .../examples/query_condition_dense_arrow.rs | 2 +- .../examples/query_condition_sparse_arrow.rs | 2 +- .../examples/quickstart_dense_arrow.rs | 2 +- .../quickstart_sparse_string_arrow.rs | 2 +- .../examples/reading_incomplete_arrow.rs | 10 +++---- .../query_arrow => query-core/src}/arrow.rs | 7 +++-- .../query_arrow => query-core/src}/buffers.rs | 5 ++-- .../query_arrow => query-core/src}/fields.rs | 0 .../mod.rs => query-core/src/lib.rs} | 26 ++++++++++++------- .../src}/subarray.rs | 2 +- 18 files changed, 85 insertions(+), 41 deletions(-) create mode 100644 tiledb/query-core/Cargo.toml rename tiledb/{api => query-core}/examples/multi_range_subarray_arrow.rs (98%) rename tiledb/{api => query-core}/examples/query_condition_dense_arrow.rs (98%) rename tiledb/{api => query-core}/examples/query_condition_sparse_arrow.rs (98%) rename tiledb/{api => query-core}/examples/quickstart_dense_arrow.rs (98%) rename tiledb/{api => query-core}/examples/quickstart_sparse_string_arrow.rs (98%) rename tiledb/{api => query-core}/examples/reading_incomplete_arrow.rs (97%) rename tiledb/{api/src/query_arrow => query-core/src}/arrow.rs (99%) rename tiledb/{api/src/query_arrow => query-core/src}/buffers.rs (99%) rename tiledb/{api/src/query_arrow => query-core/src}/fields.rs (100%) rename tiledb/{api/src/query_arrow/mod.rs => query-core/src/lib.rs} (96%) rename tiledb/{api/src/query_arrow => query-core/src}/subarray.rs (97%) diff --git a/Cargo.lock b/Cargo.lock index 392c7e4e..6a186303 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1878,6 +1878,19 @@ dependencies = [ "tiledb-sys-cfg", ] +[[package]] +name = "tiledb-query-core" +version = "0.1.0" +dependencies = [ + "arrow", + "itertools 0.12.1", + "thiserror", + "tiledb-api", + "tiledb-common", + "tiledb-pod", + "tiledb-sys", +] + [[package]] name = "tiledb-sys" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 9da06854..40b89b9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "tiledb/pod", "tiledb/proc-macro", "tiledb/queries", + "tiledb/query-core", "tiledb/sys", "tiledb/sys-cfg", "tiledb/sys-defs", @@ -15,7 +16,7 @@ members = [ "test-utils/signal", "test-utils/strategy-ext", "test-utils/uri", - "tools/api-coverage" + "tools/api-coverage", ] default-members = [ "tiledb/api", @@ -23,6 +24,7 @@ default-members = [ "tiledb/pod", "tiledb/proc-macro", "tiledb/queries", + "tiledb/query-core", "tiledb/utils", ] diff --git a/tiledb/api/Cargo.toml b/tiledb/api/Cargo.toml index 591352e5..735d71cb 100644 --- a/tiledb/api/Cargo.toml +++ b/tiledb/api/Cargo.toml @@ -39,24 +39,13 @@ default = [] arrow = ["dep:serde", "dep:serde_json", "tiledb-common/arrow", "tiledb-common/serde", "tiledb-pod/serde"] pod = ["dep:tiledb-pod"] proptest-strategies = ["dep:cells", "dep:proptest"] +raw = [] serde = ["dep:serde", "dep:serde_json", "dep:tiledb-pod"] [[example]] name = "fragment_info" required-features = ["serde"] -[[example]] -name = "multi_range_subarray_arrow" -required-features = ["pod"] - -[[example]] -name = "quickstart_sparse_string_arrow" -required-features = ["pod"] - -[[example]] -name = "reading_incomplete_arrow" -required-features = ["pod"] - [[example]] name = "using_tiledb_stats" required-features = ["serde"] diff --git a/tiledb/api/src/array/mod.rs b/tiledb/api/src/array/mod.rs index e35829d9..c64f31bb 100644 --- a/tiledb/api/src/array/mod.rs +++ b/tiledb/api/src/array/mod.rs @@ -167,6 +167,12 @@ fn unwrap_config_to_ptr(context: Option<&Config>) -> *mut tiledb_config_t { } impl Array { + #[cfg(feature = "raw")] + pub fn capi(&self) -> &RawArray { + &self.raw + } + + #[cfg(not(feature = "raw"))] pub(crate) fn capi(&self) -> &RawArray { &self.raw } diff --git a/tiledb/api/src/lib.rs b/tiledb/api/src/lib.rs index 86fd6d9f..f2fee1a4 100644 --- a/tiledb/api/src/lib.rs +++ b/tiledb/api/src/lib.rs @@ -52,7 +52,6 @@ pub mod filter; pub mod group; pub mod metadata; pub mod query; -pub mod query_arrow; pub mod stats; pub mod string; pub mod vfs; diff --git a/tiledb/api/src/query/conditions.rs b/tiledb/api/src/query/conditions.rs index 9b66252c..b848f7e5 100644 --- a/tiledb/api/src/query/conditions.rs +++ b/tiledb/api/src/query/conditions.rs @@ -795,10 +795,20 @@ impl QueryConditionExpr { } } + #[cfg(feature = "raw")] + pub fn build(&self, ctx: &Context) -> TileDBResult { + self.build_impl(ctx) + } + + #[cfg(not(feature = "raw"))] pub(crate) fn build( &self, ctx: &Context, ) -> TileDBResult { + self.build_impl(ctx) + } + + fn build_impl(&self, ctx: &Context) -> TileDBResult { match self { Self::Cond(cond) => cond.build(ctx), Self::Comb { lhs, rhs, op } => { @@ -885,7 +895,7 @@ impl Display for QueryConditionExpr { } } -pub(crate) enum RawQueryCondition { +pub enum RawQueryCondition { Owned(*mut ffi::tiledb_query_condition_t), } diff --git a/tiledb/query-core/Cargo.toml b/tiledb/query-core/Cargo.toml new file mode 100644 index 00000000..a186ecd9 --- /dev/null +++ b/tiledb/query-core/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "tiledb-query-core" +edition.workspace = true +rust-version.workspace = true +version.workspace = true + +[dependencies] +arrow = { workspace = true } +thiserror = { workspace = true } +tiledb-api = { workspace = true, features = ["raw"] } +tiledb-common = { workspace = true } +tiledb-sys = { workspace = true } + +[dev-dependencies] +itertools = { workspace = true } +tiledb-api = { workspace = true, features = ["pod", "raw"] } +tiledb-pod = { workspace = true } diff --git a/tiledb/api/examples/multi_range_subarray_arrow.rs b/tiledb/query-core/examples/multi_range_subarray_arrow.rs similarity index 98% rename from tiledb/api/examples/multi_range_subarray_arrow.rs rename to tiledb/query-core/examples/multi_range_subarray_arrow.rs index ea3fec8f..cb017f77 100644 --- a/tiledb/api/examples/multi_range_subarray_arrow.rs +++ b/tiledb/query-core/examples/multi_range_subarray_arrow.rs @@ -6,11 +6,11 @@ use itertools::izip; use tiledb_api::array::Array; use tiledb_api::context::Context; -use tiledb_api::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb_api::{Factory, Result as TileDBResult}; use tiledb_common::array::{ArrayType, CellOrder, Mode, TileOrder}; use tiledb_common::Datatype; use tiledb_pod::array::{AttributeData, DimensionData, DomainData, SchemaData}; +use tiledb_query_core::{QueryBuilder, QueryLayout, QueryType}; const ARRAY_URI: &str = "multi_range_slicing"; diff --git a/tiledb/api/examples/query_condition_dense_arrow.rs b/tiledb/query-core/examples/query_condition_dense_arrow.rs similarity index 98% rename from tiledb/api/examples/query_condition_dense_arrow.rs rename to tiledb/query-core/examples/query_condition_dense_arrow.rs index 7ff4e9f2..74bfbed4 100644 --- a/tiledb/api/examples/query_condition_dense_arrow.rs +++ b/tiledb/query-core/examples/query_condition_dense_arrow.rs @@ -9,10 +9,10 @@ use tiledb_api::array::{ }; use tiledb_api::context::Context; use tiledb_api::query::conditions::QueryConditionExpr as QC; -use tiledb_api::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb_api::Result as TileDBResult; use tiledb_common::array::{ArrayType, Mode}; use tiledb_common::Datatype; +use tiledb_query_core::{QueryBuilder, QueryLayout, QueryType}; const ARRAY_URI: &str = "query_condition_dense"; const NUM_ELEMS: i32 = 10; diff --git a/tiledb/api/examples/query_condition_sparse_arrow.rs b/tiledb/query-core/examples/query_condition_sparse_arrow.rs similarity index 98% rename from tiledb/api/examples/query_condition_sparse_arrow.rs rename to tiledb/query-core/examples/query_condition_sparse_arrow.rs index cc65ec70..7798e07f 100644 --- a/tiledb/api/examples/query_condition_sparse_arrow.rs +++ b/tiledb/query-core/examples/query_condition_sparse_arrow.rs @@ -9,10 +9,10 @@ use tiledb_api::array::{ }; use tiledb_api::context::Context; use tiledb_api::query::conditions::QueryConditionExpr as QC; -use tiledb_api::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb_api::Result as TileDBResult; use tiledb_common::array::{ArrayType, CellOrder, Mode}; use tiledb_common::Datatype; +use tiledb_query_core::{QueryBuilder, QueryLayout, QueryType}; const ARRAY_URI: &str = "query_condition_sparse"; const NUM_ELEMS: i32 = 10; diff --git a/tiledb/api/examples/quickstart_dense_arrow.rs b/tiledb/query-core/examples/quickstart_dense_arrow.rs similarity index 98% rename from tiledb/api/examples/quickstart_dense_arrow.rs rename to tiledb/query-core/examples/quickstart_dense_arrow.rs index 8ffdf5d1..7081b704 100644 --- a/tiledb/api/examples/quickstart_dense_arrow.rs +++ b/tiledb/query-core/examples/quickstart_dense_arrow.rs @@ -9,10 +9,10 @@ use tiledb_api::array::{ SchemaBuilder, }; use tiledb_api::context::Context; -use tiledb_api::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb_api::Result as TileDBResult; use tiledb_common::array::{ArrayType, Mode}; use tiledb_common::Datatype; +use tiledb_query_core::{QueryBuilder, QueryLayout, QueryType}; const QUICKSTART_DENSE_ARRAY_URI: &str = "quickstart_dense"; const QUICKSTART_ATTRIBUTE_NAME: &str = "a"; diff --git a/tiledb/api/examples/quickstart_sparse_string_arrow.rs b/tiledb/query-core/examples/quickstart_sparse_string_arrow.rs similarity index 98% rename from tiledb/api/examples/quickstart_sparse_string_arrow.rs rename to tiledb/query-core/examples/quickstart_sparse_string_arrow.rs index 3f439f52..2b2d8971 100644 --- a/tiledb/api/examples/quickstart_sparse_string_arrow.rs +++ b/tiledb/query-core/examples/quickstart_sparse_string_arrow.rs @@ -6,12 +6,12 @@ use itertools::izip; use tiledb_api::array::Array; use tiledb_api::context::Context; -use tiledb_api::query_arrow::{QueryBuilder, QueryLayout, QueryType}; use tiledb_api::{Factory, Result as TileDBResult}; use tiledb_common::array::dimension::DimensionConstraints; use tiledb_common::array::{ArrayType, CellOrder, Mode, TileOrder}; use tiledb_common::Datatype; use tiledb_pod::array::{AttributeData, DimensionData, DomainData, SchemaData}; +use tiledb_query_core::{QueryBuilder, QueryLayout, QueryType}; const ARRAY_URI: &str = "quickstart_sparse_string"; diff --git a/tiledb/api/examples/reading_incomplete_arrow.rs b/tiledb/query-core/examples/reading_incomplete_arrow.rs similarity index 97% rename from tiledb/api/examples/reading_incomplete_arrow.rs rename to tiledb/query-core/examples/reading_incomplete_arrow.rs index 587d263e..218e3d58 100644 --- a/tiledb/api/examples/reading_incomplete_arrow.rs +++ b/tiledb/query-core/examples/reading_incomplete_arrow.rs @@ -6,15 +6,15 @@ use itertools::izip; use tiledb_api::array::Array; use tiledb_api::context::Context; -use tiledb_api::query_arrow::buffers::Error as BuffersError; -use tiledb_api::query_arrow::fields::QueryFieldsBuilder; -use tiledb_api::query_arrow::{ - Error as QueryError, QueryBuilder, QueryLayout, QueryType, SharedBuffers, -}; use tiledb_api::{Factory, Result as TileDBResult}; use tiledb_common::array::{ArrayType, CellOrder, CellValNum, Mode, TileOrder}; use tiledb_common::Datatype; use tiledb_pod::array::{AttributeData, DimensionData, DomainData, SchemaData}; +use tiledb_query_core::buffers::Error as BuffersError; +use tiledb_query_core::fields::QueryFieldsBuilder; +use tiledb_query_core::{ + Error as QueryError, QueryBuilder, QueryLayout, QueryType, SharedBuffers, +}; const ARRAY_URI: &str = "reading_incomplete"; diff --git a/tiledb/api/src/query_arrow/arrow.rs b/tiledb/query-core/src/arrow.rs similarity index 99% rename from tiledb/api/src/query_arrow/arrow.rs rename to tiledb/query-core/src/arrow.rs index 5c73ee10..00f303de 100644 --- a/tiledb/api/src/query_arrow/arrow.rs +++ b/tiledb/query-core/src/arrow.rs @@ -4,8 +4,8 @@ use arrow::datatypes as adt; use thiserror::Error; -use crate::array::schema::CellValNum; -use crate::datatype::Datatype; +use tiledb_common::array::CellValNum; +use tiledb_common::Datatype; #[derive(Error, Debug, PartialEq, Eq)] pub enum Error { @@ -152,8 +152,8 @@ impl ToArrowConverter { } fn default_arrow_type(&self, dtype: &Datatype) -> Option { - use crate::datatype::Datatype as tiledb; use arrow::datatypes::DataType as arrow; + use tiledb_common::Datatype as tiledb; let arrow_type = match dtype { // Any <-> Null, both indicate lack of a type tiledb::Any => Some(arrow::Null), @@ -493,7 +493,6 @@ impl FromArrowConverter { #[cfg(test)] mod tests { use super::*; - use crate::Datatype; /// Test that a datatype is supported as a scalar type /// if and only if it is also supported as a list element type diff --git a/tiledb/api/src/query_arrow/buffers.rs b/tiledb/query-core/src/buffers.rs similarity index 99% rename from tiledb/api/src/query_arrow/buffers.rs rename to tiledb/query-core/src/buffers.rs index 6b5d4f13..161f1d14 100644 --- a/tiledb/api/src/query_arrow/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -10,11 +10,12 @@ use arrow::buffer::{ use arrow::datatypes as adt; use arrow::error::ArrowError; use thiserror::Error; +use tiledb_api::array::schema::{Field, Schema}; +use tiledb_api::error::Error as TileDBError; +use tiledb_common::array::CellValNum; use super::arrow::ToArrowConverter; use super::fields::{QueryField, QueryFields}; -use crate::array::schema::{CellValNum, Field, Schema}; -use crate::error::Error as TileDBError; const AVERAGE_STRING_LENGTH: usize = 64; diff --git a/tiledb/api/src/query_arrow/fields.rs b/tiledb/query-core/src/fields.rs similarity index 100% rename from tiledb/api/src/query_arrow/fields.rs rename to tiledb/query-core/src/fields.rs diff --git a/tiledb/api/src/query_arrow/mod.rs b/tiledb/query-core/src/lib.rs similarity index 96% rename from tiledb/api/src/query_arrow/mod.rs rename to tiledb/query-core/src/lib.rs index afee05d2..6786f3e0 100644 --- a/tiledb/api/src/query_arrow/mod.rs +++ b/tiledb/query-core/src/lib.rs @@ -1,17 +1,19 @@ ///! The TileDB Query interface and supporting utilities +extern crate tiledb_sys as ffi; + use std::collections::HashMap; use std::ops::Deref; use thiserror::Error; use tiledb_common::{single_value_range_go, var_value_range_go}; -use crate::array::Array; -use crate::config::Config; -use crate::context::{CApiInterface, Context, ContextBound}; -use crate::error::Error as TileDBError; -use crate::key::LookupKey; -use crate::query::conditions::QueryConditionExpr; -use crate::range::{Range, SingleValueRange, VarValueRange}; +use tiledb_api::array::Array; +use tiledb_api::config::Config; +use tiledb_api::context::{CApiInterface, Context, ContextBound}; +use tiledb_api::error::Error as TileDBError; +use tiledb_api::key::LookupKey; +use tiledb_api::query::conditions::QueryConditionExpr; +use tiledb_common::range::{Range, SingleValueRange, VarValueRange}; use buffers::{Error as QueryBuffersError, QueryBuffers}; use fields::{QueryFields, QueryFieldsBuilderForQuery}; @@ -24,8 +26,14 @@ pub mod buffers; pub mod fields; pub mod subarray; -pub type QueryType = crate::array::Mode; -pub type QueryLayout = crate::array::CellOrder; +pub type QueryType = tiledb_common::array::Mode; +pub type QueryLayout = tiledb_common::array::CellOrder; + +macro_rules! out_ptr { + () => { + unsafe { std::mem::MaybeUninit::zeroed().assume_init() } + }; +} /// Errors related to query creation and execution #[derive(Debug, Error)] diff --git a/tiledb/api/src/query_arrow/subarray.rs b/tiledb/query-core/src/subarray.rs similarity index 97% rename from tiledb/api/src/query_arrow/subarray.rs rename to tiledb/query-core/src/subarray.rs index 038dc7ae..3bd6e881 100644 --- a/tiledb/api/src/query_arrow/subarray.rs +++ b/tiledb/query-core/src/subarray.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::range::Range; +use tiledb_common::range::Range; use super::QueryBuilder; From 274446ca72542132945c14ab3240e94c90114c42 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 13:25:18 -0500 Subject: [PATCH 05/42] Remove SizeInfo, introduce struct QueryBuffer --- tiledb/query-core/src/buffers.rs | 521 ++++++++++++++----------------- 1 file changed, 230 insertions(+), 291 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index 161f1d14..cca101b0 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -89,23 +89,6 @@ where aa::downcast_array(&array) } -#[derive(Clone)] -struct SizeInfo { - data: Pin>, - offsets: Pin>, - validity: Pin>, -} - -impl Default for SizeInfo { - fn default() -> Self { - Self { - data: Box::pin(0), - offsets: Box::pin(0), - validity: Box::pin(0), - } - } -} - /// The return type for the NewBufferTraitThing's into_arrow method. This /// allows for fallible conversion without dropping the underlying buffers. type IntoArrowResult = std::result::Result< @@ -128,16 +111,13 @@ trait NewBufferTraitThing { fn len(&self) -> usize; /// The data buffer - fn data(&mut self) -> &mut ArrowBufferMut; + fn data(&mut self) -> &mut QueryBuffer; /// The offsets buffer, for variants that have one - fn offsets(&mut self) -> Option<&mut ArrowBufferMut>; + fn offsets(&mut self) -> Option<&mut QueryBuffer>; /// The validity buffer, when present - fn validity(&mut self) -> Option<&mut ArrowBufferMut>; - - /// The SizeInfo struct - fn sizes(&mut self) -> &mut SizeInfo; + fn validity(&mut self) -> Option<&mut QueryBuffer>; /// Check if another buffer is compatible with this buffer fn is_compatible(&self, other: &Box) -> bool; @@ -151,56 +131,37 @@ trait NewBufferTraitThing { /// the entire available buffer rather than just whatever fit in the /// previous iteration. fn reset_len(&mut self) { - let data = self.data(); - data.resize(data.capacity(), 0); - self.offsets().map(|o| { - // Arrow requires an extra offset that TileDB elides so we need - // to leave room for it later. - assert!(o.capacity() >= std::mem::size_of::()); - o.resize(o.capacity() - std::mem::size_of::(), 0); - }); - self.validity().map(|v| v.resize(v.capacity(), 0)); - - let data_size = self.data().len(); - let offsets_size = self.offsets().map_or(0, |o| o.len()); - let validity_size = self.validity().map_or(0, |v| v.len()); - let sizes = self.sizes(); - *sizes.data = data_size as u64; - *sizes.offsets = offsets_size as u64; - *sizes.validity = validity_size as u64; + self.data().reset(); + if let Some(offsets) = self.offsets() { + offsets.reset(); + } + if let Some(validity) = self.validity() { + validity.reset(); + } } /// Shrink len to data /// - /// After a read query, this method is used to update the lenght of all + /// After a read query, this method is used to update the length of all /// buffers to match the number of bytes written by TileDB. fn shrink_len(&mut self) { - let sizes = self.sizes().clone(); - assert!((*sizes.data as usize) <= self.data().capacity()); - self.data().resize(*sizes.data as usize, 0); - - self.offsets().map(|o| { - assert!( - (*sizes.offsets as usize) - <= o.capacity() - std::mem::size_of::() - ); - o.resize(*sizes.offsets as usize, 0); - }); - - self.validity().map(|v| { - assert!((*sizes.validity as usize) <= v.capacity()); - v.resize(*sizes.validity as usize, 0); - }); + self.data().resize(); + if let Some(offsets) = self.offsets() { + offsets.resize(); + } + if let Some(validity) = self.validity() { + validity.resize(); + } } /// Returns a mutable pointer to the data buffer fn data_ptr(&mut self) -> *mut std::ffi::c_void { - self.data().as_mut_ptr() as *mut std::ffi::c_void + self.data().buffer.as_mut_ptr() as *mut std::ffi::c_void } /// Returns a mutable pointer to the data size fn data_size_ptr(&mut self) -> *mut u64 { - self.sizes().data.as_mut().get_mut() + self.data().size.as_mut().get_mut() } /// Returns a mutable poiniter to the offsets buffer. @@ -211,18 +172,20 @@ trait NewBufferTraitThing { return std::ptr::null_mut(); }; - offsets.as_mut_ptr() as *mut u64 + offsets.buffer.as_mut_ptr() as *mut u64 } /// Returns a mutable pointer to the offsets size. /// /// For variants that don't have offsets, it returns a null pointer. + /// + /// FIXME: why does this not return `Option` fn offsets_size_ptr(&mut self) -> *mut u64 { - let Some(_) = self.offsets() else { + let Some(offsets) = self.offsets() else { return std::ptr::null_mut(); }; - self.sizes().offsets.as_mut().get_mut() + offsets.size.as_mut().get_mut() } /// Returns a mutable pointer to the validity buffer, when present @@ -233,25 +196,24 @@ trait NewBufferTraitThing { return std::ptr::null_mut(); }; - validity.as_mut_ptr() + validity.buffer.as_mut_ptr() } /// Returns a mutable pointer to the validity size, when present /// /// When validity is not present, it returns a null pointer. fn validity_size_ptr(&mut self) -> *mut u64 { - let Some(_) = self.validity() else { + let Some(validity) = self.validity() else { return std::ptr::null_mut(); }; - self.sizes().validity.as_mut().get_mut() + validity.size.as_mut().get_mut() } } struct BooleanBuffers { - data: ArrowBufferMut, - validity: Option, - sizes: SizeInfo, + data: QueryBuffer, + validity: Option, } impl TryFrom> for BooleanBuffers { @@ -264,13 +226,11 @@ impl TryFrom> for BooleanBuffers { .map(|v| if v { 1u8 } else { 0 }) .collect::>(); let validity = to_tdb_validity(validity); - let mut sizes = SizeInfo::default(); - *sizes.data = data.len() as u64; - *sizes.validity = validity.as_ref().map_or(0, |v| v.len()) as u64; + Ok(BooleanBuffers { - data: abuf::MutableBuffer::from(data), - validity: validity.map(abuf::MutableBuffer::from), - sizes, + data: QueryBuffer::new(abuf::MutableBuffer::from(data)), + validity: validity + .map(|v| QueryBuffer::new(abuf::MutableBuffer::from(v))), }) } } @@ -281,25 +241,21 @@ impl NewBufferTraitThing for BooleanBuffers { } fn len(&self) -> usize { - self.data.len() + self.data.buffer.len() } - fn data(&mut self) -> &mut ArrowBufferMut { + fn data(&mut self) -> &mut QueryBuffer { &mut self.data } - fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + fn offsets(&mut self) -> Option<&mut QueryBuffer> { None } - fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + fn validity(&mut self) -> Option<&mut QueryBuffer> { self.validity.as_mut() } - fn sizes(&mut self) -> &mut SizeInfo { - &mut self.sizes - } - fn is_compatible(&self, other: &Box) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; @@ -309,21 +265,21 @@ impl NewBufferTraitThing for BooleanBuffers { } fn into_arrow(self: Box) -> IntoArrowResult { - let data = - abuf::BooleanBuffer::from_iter(self.data.iter().map(|b| *b != 0)); + let data = abuf::BooleanBuffer::from_iter( + self.data.buffer.iter().map(|b| *b != 0), + ); Ok(Arc::new(aa::BooleanArray::new( data, - from_tdb_validity(self.validity), + from_tdb_validity(&self.validity), ))) } } struct ByteBuffers { dtype: adt::DataType, - data: ArrowBufferMut, - offsets: ArrowBufferMut, - validity: Option, - sizes: SizeInfo, + data: QueryBuffer, + offsets: QueryBuffer, + validity: Option, } macro_rules! to_byte_buffers { @@ -335,20 +291,14 @@ macro_rules! to_byte_buffers { let offsets = offsets.into_inner().into_inner().into_mutable(); if data.is_ok() && offsets.is_ok() { - let data = data.ok().unwrap(); - let offsets = offsets.ok().unwrap(); - let validity = to_tdb_validity(nulls); - let mut sizes = SizeInfo::default(); - *sizes.data = data.len() as u64; - *sizes.offsets = - (offsets.len() - std::mem::size_of::()) as u64; - *sizes.validity = validity.as_ref().map_or(0, |v| v.len()) as u64; + let data = QueryBuffer::new(data.ok().unwrap()); + let offsets = QueryBuffer::new(offsets.ok().unwrap()); + let validity = to_tdb_validity(nulls).map(QueryBuffer::new); return Ok(ByteBuffers { dtype: $ARROW_TYPE, data, offsets, validity, - sizes, }); } @@ -402,25 +352,21 @@ impl NewBufferTraitThing for ByteBuffers { } fn len(&self) -> usize { - self.offsets.len() / std::mem::size_of::() + (self.offsets.buffer.len() / std::mem::size_of::()) - 1 } - fn data(&mut self) -> &mut ArrowBufferMut { + fn data(&mut self) -> &mut QueryBuffer { &mut self.data } - fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + fn offsets(&mut self) -> Option<&mut QueryBuffer> { Some(&mut self.offsets) } - fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + fn validity(&mut self) -> Option<&mut QueryBuffer> { self.validity.as_mut() } - fn sizes(&mut self) -> &mut SizeInfo { - &mut self.sizes - } - fn is_compatible(&self, other: &Box) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; @@ -434,56 +380,57 @@ impl NewBufferTraitThing for ByteBuffers { } fn into_arrow(self: Box) -> IntoArrowResult { - // Make sure we add back the extra offsets element that arrow expects - // which TileDB doesn't use. - assert!( - self.offsets.len() - <= self.offsets.capacity() - std::mem::size_of::() - ); - - let mut offsets = self.offsets; - offsets.extend_from_slice(&[self.data.len() as i64]); + // NB: by default the offsets are not arrow-shaped. + // However we use the configuration options to make them so. let dtype = self.dtype; - let data = ArrowBuffer::from(self.data); - let offsets = ArrowBuffer::from(offsets); - let validity = from_tdb_validity(self.validity); - let sizes = self.sizes; + let data = ArrowBuffer::from(self.data.buffer); + let offsets = ArrowBuffer::from(self.offsets.buffer); + let validity = from_tdb_validity(&self.validity); + // N.B., the calls to cloning the data/offsets/validity are as cheap // as an Arc::clone plus pointer and usize copy. They are *not* cloning // the underlying allocated data. - aa::ArrayData::try_new( + match aa::ArrayData::try_new( dtype.clone(), - *sizes.offsets as usize / std::mem::size_of::(), + (*self.offsets.size as usize / std::mem::size_of::()) - 1, validity.clone().map(|v| v.into_inner().into_inner()), 0, vec![offsets.clone(), data.clone()], vec![], ) .map(|data| aa::make_array(data)) - .map_err(|e| { - // SAFETY: These unwraps are fine because the only other reference - // was consumed in the failed try_new call. Unless of course - // the ArrowError `e` ever ends up carrying a reference. - let boxed: Box = Box::new(ByteBuffers { - dtype, - data: data.into_mutable().unwrap(), - offsets: offsets.into_mutable().unwrap(), - validity: to_tdb_validity(validity), - sizes, - }); - - (boxed, Error::ArrayCreationFailed(e)) - }) + { + Ok(arrow) => Ok(arrow), + Err(e) => { + // SAFETY: These unwraps are fine because the only other reference + // was consumed in the failed try_new call. Unless of course + // the ArrowError `e` ever ends up carrying a reference. + let boxed: Box = + Box::new(ByteBuffers { + dtype, + data: QueryBuffer { + buffer: data.into_mutable().unwrap(), + size: self.data.size, + }, + offsets: QueryBuffer { + buffer: offsets.into_mutable().unwrap(), + size: self.offsets.size, + }, + validity: self.validity, + }); + + Err((boxed, Error::ArrayCreationFailed(e))) + } + } } } struct FixedListBuffers { field: Arc, cell_val_num: CellValNum, - data: ArrowBufferMut, - validity: Option, - sizes: SizeInfo, + data: QueryBuffer, + validity: Option, } impl TryFrom> for FixedListBuffers { @@ -517,17 +464,15 @@ impl TryFrom> for FixedListBuffers { } PrimitiveBuffers::try_from(array) - .map(|mut buffers| { + .map(|buffers| { assert_eq!(buffers.dtype, dtype); - let validity = to_tdb_validity(nulls.clone()); - *buffers.sizes.validity = - validity.as_ref().map_or(0, |v| v.len()) as u64; + let validity = + to_tdb_validity(nulls.clone()).map(QueryBuffer::new); FixedListBuffers { field: Arc::clone(&field), cell_val_num: cvn, data: buffers.data, validity, - sizes: buffers.sizes, } }) .map_err(|(array, e)| { @@ -551,25 +496,21 @@ impl NewBufferTraitThing for FixedListBuffers { } fn len(&self) -> usize { - self.data.len() / u32::from(self.cell_val_num) as usize + self.data.buffer.len() / u32::from(self.cell_val_num) as usize } - fn data(&mut self) -> &mut ArrowBufferMut { + fn data(&mut self) -> &mut QueryBuffer { &mut self.data } - fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + fn offsets(&mut self) -> Option<&mut QueryBuffer> { None } - fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + fn validity(&mut self) -> Option<&mut QueryBuffer> { self.validity.as_mut() } - fn sizes(&mut self) -> &mut SizeInfo { - &mut self.sizes - } - fn is_compatible(&self, other: &Box) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; @@ -589,20 +530,19 @@ impl NewBufferTraitThing for FixedListBuffers { fn into_arrow(self: Box) -> IntoArrowResult { let field = self.field; let cell_val_num = self.cell_val_num; - let data = ArrowBuffer::from(self.data); - let validity = from_tdb_validity(self.validity); - let sizes = self.sizes; + let data = ArrowBuffer::from(self.data.buffer); + let validity = from_tdb_validity(&self.validity); assert!(field.data_type().is_primitive()); - let num_values = - *sizes.data as usize / field.data_type().primitive_width().unwrap(); + let num_values = *self.data.size as usize + / field.data_type().primitive_width().unwrap(); let cvn = u32::from(cell_val_num) as i32; let len = num_values / cvn as usize; // N.B., data/validity clones are cheap. They are not cloning the // underlying data buffers. We have to clone so that we can put ourself // back together if the array conversion failes. - aa::ArrayData::try_new( + match aa::ArrayData::try_new( field.data_type().clone(), len, validity.clone().map(|v| v.into_inner().into_inner()), @@ -619,28 +559,52 @@ impl NewBufferTraitThing for FixedListBuffers { validity.clone(), )); array - }) - .map_err(|e| { - let boxed: Box = - Box::new(FixedListBuffers { - field, - cell_val_num, - data: data.into_mutable().unwrap(), - validity: to_tdb_validity(validity), - sizes, - }); - - (boxed, Error::ArrayCreationFailed(e)) - }) + }) { + Ok(arrow) => Ok(arrow), + Err(e) => { + let boxed: Box = + Box::new(FixedListBuffers { + field, + cell_val_num, + data: QueryBuffer { + buffer: data.into_mutable().unwrap(), + size: self.data.size, + }, + validity: self.validity, + }); + + Err((boxed, Error::ArrayCreationFailed(e))) + } + } + } +} + +struct QueryBuffer { + buffer: ArrowBufferMut, + size: Pin>, +} + +impl QueryBuffer { + pub fn new(buffer: ArrowBufferMut) -> Self { + let size = Box::pin(buffer.len() as u64); + Self { buffer, size } + } + + pub fn reset(&mut self) { + *self.size = self.buffer.capacity() as u64; + self.resize() + } + + pub fn resize(&mut self) { + self.buffer.resize(*self.size as usize, 0); } } struct ListBuffers { field: Arc, - data: ArrowBufferMut, - offsets: ArrowBufferMut, - validity: Option, - sizes: SizeInfo, + data: QueryBuffer, + offsets: QueryBuffer, + validity: Option, } impl TryFrom> for ListBuffers { @@ -680,49 +644,40 @@ impl TryFrom> for ListBuffers { return Err((array, err)); } - let mut data = result.ok().unwrap(); + let data = result.ok().unwrap(); - let result = offsets.into_inner().into_inner().into_mutable(); - if result.is_err() { - let offsets_buffer = result.err().unwrap(); - let offsets = abuf::OffsetBuffer::new( - abuf::ScalarBuffer::::from(offsets_buffer), - ); - let array: Arc = Arc::new( - aa::LargeListArray::try_new( - Arc::clone(&field), - offsets, - // Safety: We just turned this into a mutable buffer, so - // the inversion should never fail. - Box::new(data).into_arrow().ok().unwrap(), - nulls.clone(), - ) - .unwrap(), - ); - return Err((array, Error::ArrayInUse)); - } - - // Safety: We already yeeted an error on non-primitive types. - let width = dtype.primitive_width().unwrap() as i64; - let mut offsets = result.ok().unwrap(); + let offsets = match offsets.into_inner().into_inner().into_mutable() { + Ok(offsets) => offsets, + Err(e) => { + let offsets_buffer = e; + let offsets = abuf::OffsetBuffer::new( + abuf::ScalarBuffer::::from(offsets_buffer), + ); + let array: Arc = Arc::new( + aa::LargeListArray::try_new( + Arc::clone(&field), + offsets, + // Safety: We just turned this into a mutable buffer, so + // the inversion should never fail. + Box::new(data).into_arrow().ok().unwrap(), + nulls.clone(), + ) + .unwrap(), + ); + return Err((array, Error::ArrayInUse)); + } + }; - // TileDB works in offsets of bytes, so we re-map all of our arrow - // offsets here. - for elem in offsets.typed_data_mut::() { - *elem = *elem * width; - } + // NB: by default the offsets are not arrow-shaped. + // However we use the configuration options to make them so. - let validity = to_tdb_validity(nulls); - *data.sizes.offsets = - (offsets.len() - std::mem::size_of::()) as u64; - *data.sizes.validity = validity.as_ref().map_or(0, |v| v.len()) as u64; + let validity = to_tdb_validity(nulls).map(QueryBuffer::new); Ok(ListBuffers { field, data: data.data, - offsets, + offsets: QueryBuffer::new(offsets), validity, - sizes: data.sizes, }) } } @@ -733,25 +688,21 @@ impl NewBufferTraitThing for ListBuffers { } fn len(&self) -> usize { - self.offsets.len() / std::mem::size_of::() + (self.offsets.buffer.len() / std::mem::size_of::()) - 1 } - fn data(&mut self) -> &mut ArrowBufferMut { + fn data(&mut self) -> &mut QueryBuffer { &mut self.data } - fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + fn offsets(&mut self) -> Option<&mut QueryBuffer> { Some(&mut self.offsets) } - fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + fn validity(&mut self) -> Option<&mut QueryBuffer> { self.validity.as_mut() } - fn sizes(&mut self) -> &mut SizeInfo { - &mut self.sizes - } - fn is_compatible(&self, other: &Box) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; @@ -768,35 +719,20 @@ impl NewBufferTraitThing for ListBuffers { let field = self.field; assert!(field.data_type().is_primitive()); - let width = field.data_type().primitive_width().unwrap() as i64; - // Update the last offset to be the total number of bytes written - let mut offsets = self.offsets; - assert!( - offsets.len() <= offsets.capacity() - std::mem::size_of::() - ); - offsets.extend_from_slice(&[self.data.len() as i64]); - - // Arrow does offsets in terms of elements, not bytes. Here we rewrite - // that by dividing all offsets by the width of the primitive type in - // the data buffer. - let offset_slice = offsets.typed_data_mut::(); - for elem in offset_slice { - assert!(*elem % width == 0); - *elem = *elem / width; - } + // NB: by default the offsets are not arrow-shaped. + // However we use the configuration options to make them so. - let data = ArrowBuffer::from(self.data); - let offsets = from_tdb_offsets(offsets); - let validity = from_tdb_validity(self.validity); - let sizes = self.sizes; + let data = ArrowBuffer::from(self.data.buffer); + let offsets = from_tdb_offsets(self.offsets.buffer); + let validity = from_tdb_validity(&self.validity); // N.B., the calls to cloning the data/offsets/validity are as cheap // as an Arc::clone plus pointer and usize copy. They are *not* cloning // the underlying allocated data. - aa::ArrayData::try_new( + match aa::ArrayData::try_new( field.data_type().clone(), - *sizes.offsets as usize / std::mem::size_of::(), + (*self.offsets.size as usize / std::mem::size_of::()) - 1, None, 0, vec![data.clone().into()], @@ -813,26 +749,33 @@ impl NewBufferTraitThing for ListBuffers { .map(|array| { let array: Arc = Arc::new(array); array - }) - .map_err(|e| { - let boxed: Box = Box::new(ListBuffers { - field, - data: data.into_mutable().unwrap(), - offsets: to_tdb_offsets(offsets).unwrap(), - validity: to_tdb_validity(validity), - sizes, - }); - - (boxed, Error::ArrayCreationFailed(e)) - }) + }) { + Ok(arrow) => Ok(arrow), + Err(e) => { + let boxed: Box = + Box::new(ListBuffers { + field, + data: QueryBuffer { + buffer: data.into_mutable().unwrap(), + size: self.data.size, + }, + offsets: QueryBuffer { + buffer: to_tdb_offsets(offsets).unwrap(), + size: self.offsets.size, + }, + validity: self.validity, + }); + + Err((boxed, Error::ArrayCreationFailed(e))) + } + } } } struct PrimitiveBuffers { dtype: adt::DataType, - data: ArrowBufferMut, - validity: Option, - sizes: SizeInfo, + data: QueryBuffer, + validity: Option, } macro_rules! to_primitive { @@ -845,16 +788,12 @@ macro_rules! to_primitive { .into_inner() .into_mutable() .map(|data| { - let validity = to_tdb_validity(nulls.clone()); - let mut sizes = SizeInfo::default(); - *sizes.data = data.len() as u64; - *sizes.validity = - validity.as_ref().map_or(0, |v| v.len()) as u64; + let validity = + to_tdb_validity(nulls.clone()).map(QueryBuffer::new); PrimitiveBuffers { dtype: dtype.clone(), - data, + data: QueryBuffer::new(data), validity, - sizes, } }) .map_err(|buffer| { @@ -903,25 +842,21 @@ impl NewBufferTraitThing for PrimitiveBuffers { fn len(&self) -> usize { assert!(self.dtype.is_primitive()); - self.data.len() / self.dtype.primitive_width().unwrap() + self.data.buffer.len() / self.dtype.primitive_width().unwrap() } - fn data(&mut self) -> &mut ArrowBufferMut { + fn data(&mut self) -> &mut QueryBuffer { &mut self.data } - fn offsets(&mut self) -> Option<&mut ArrowBufferMut> { + fn offsets(&mut self) -> Option<&mut QueryBuffer> { None } - fn validity(&mut self) -> Option<&mut ArrowBufferMut> { + fn validity(&mut self) -> Option<&mut QueryBuffer> { self.validity.as_mut() } - fn sizes(&mut self) -> &mut SizeInfo { - &mut self.sizes - } - fn is_compatible(&self, other: &Box) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; @@ -935,36 +870,39 @@ impl NewBufferTraitThing for PrimitiveBuffers { } fn into_arrow(self: Box) -> IntoArrowResult { - let dtype = self.dtype; - let data = ArrowBuffer::from(self.data); - let validity = from_tdb_validity(self.validity); - let sizes = self.sizes; + let data = ArrowBuffer::from(self.data.buffer); + let validity = from_tdb_validity(&self.validity); - assert!(dtype.is_primitive()); + assert!(self.dtype.is_primitive()); // N.B., data/validity clones are cheap. They are not cloning the // underlying data buffers. We have to clone so that we can put ourself // back together if the array conversion failes. - aa::ArrayData::try_new( - dtype.clone(), - *sizes.data as usize / dtype.primitive_width().unwrap(), + match aa::ArrayData::try_new( + self.dtype.clone(), + *self.data.size as usize / self.dtype.primitive_width().unwrap(), validity.clone().map(|v| v.into_inner().into_inner()), 0, vec![data.clone()], vec![], ) .map(aa::make_array) - .map_err(|e| { - let boxed: Box = - Box::new(PrimitiveBuffers { - dtype, - data: data.into_mutable().unwrap(), - validity: to_tdb_validity(validity), - sizes, - }); - - (boxed, Error::ArrayCreationFailed(e)) - }) + { + Ok(arrow) => Ok(arrow), + Err(e) => { + let boxed: Box = + Box::new(PrimitiveBuffers { + dtype: self.dtype, + data: QueryBuffer { + buffer: data.into_mutable().unwrap(), + size: self.data.size, + }, + validity: self.validity, + }); + + Err((boxed, Error::ArrayCreationFailed(e))) + } + } } } @@ -1635,11 +1573,12 @@ fn from_tdb_offsets(offsets: ArrowBufferMut) -> abuf::OffsetBuffer { } fn from_tdb_validity( - validity: Option, + validity: &Option, ) -> Option { - validity.map(|v| { + validity.as_ref().map(|v| { abuf::NullBuffer::from( - v.into_iter() + v.buffer + .iter() .map(|f| if *f != 0 { true } else { false }) .collect::>(), ) From 48bc1fc1872df83191c1ff6c78b60adca4635727 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 13:35:05 -0500 Subject: [PATCH 06/42] build.rs --- Cargo.lock | 1 + tiledb/query-core/Cargo.toml | 3 +++ tiledb/query-core/build.rs | 3 +++ 3 files changed, 7 insertions(+) create mode 100644 tiledb/query-core/build.rs diff --git a/Cargo.lock b/Cargo.lock index 6a186303..9264d592 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1889,6 +1889,7 @@ dependencies = [ "tiledb-common", "tiledb-pod", "tiledb-sys", + "tiledb-sys-cfg", ] [[package]] diff --git a/tiledb/query-core/Cargo.toml b/tiledb/query-core/Cargo.toml index a186ecd9..d09fac6c 100644 --- a/tiledb/query-core/Cargo.toml +++ b/tiledb/query-core/Cargo.toml @@ -15,3 +15,6 @@ tiledb-sys = { workspace = true } itertools = { workspace = true } tiledb-api = { workspace = true, features = ["pod", "raw"] } tiledb-pod = { workspace = true } + +[build-dependencies] +tiledb-sys-cfg = { workspace = true } diff --git a/tiledb/query-core/build.rs b/tiledb/query-core/build.rs new file mode 100644 index 00000000..82ad8ae6 --- /dev/null +++ b/tiledb/query-core/build.rs @@ -0,0 +1,3 @@ +fn main() { + tiledb_sys_cfg::rpath(); +} From 5fa66b0b970ab3f45c303d6e015392c21347b36e Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 13:41:22 -0500 Subject: [PATCH 07/42] Config::capi behind raw feature --- tiledb/api/src/config.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tiledb/api/src/config.rs b/tiledb/api/src/config.rs index 41351810..b579e672 100644 --- a/tiledb/api/src/config.rs +++ b/tiledb/api/src/config.rs @@ -55,10 +55,16 @@ pub struct ConfigIterator<'cfg> { } impl Config { + #[cfg(feature = "raw")] pub fn capi(&self) -> *mut ffi::tiledb_config_t { *self.raw } + #[cfg(not(feature = "raw"))] + pub(crate) fn capi(&self) -> *mut ffi::tiledb_config_t { + *self.raw + } + pub fn new() -> TileDBResult { let mut c_cfg: *mut ffi::tiledb_config_t = out_ptr!(); let mut c_err: *mut ffi::tiledb_error_t = std::ptr::null_mut(); From a2a79514e16a31f12fea9b878c10daff961a5ec1 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 13:41:50 -0500 Subject: [PATCH 08/42] Configure query-core to do arrow-shaped offsets --- tiledb/query-core/src/lib.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 6786f3e0..6d303a56 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -423,7 +423,19 @@ impl QueryBuilder { ffi::tiledb_query_alloc(ctx, c_array, c_query_type, &mut c_query) })?; - Ok(RawQuery::Owned(c_query)) + let raw = RawQuery::Owned(c_query); + + // configure the query to use arrow-shaped offsets + let mut qconf = Config::new()?; + qconf.set("sm.var_offsets.bitsize", "64")?; + qconf.set("sm.var_offsets.mode", "elements")?; + qconf.set("sm.var_offsets.extra_element", "true")?; + + self.capi_call(|c_context| unsafe { + ffi::tiledb_query_set_config(c_context, c_query, qconf.capi()) + })?; + + Ok(raw) } fn alloc_subarray(&self) -> Result { From 1a2bae88316b8680fbd9b60452ff0a42a532cf88 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 13:42:00 -0500 Subject: [PATCH 09/42] Fix empty offsets --- tiledb/query-core/src/buffers.rs | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index cca101b0..cfeabc8f 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -352,7 +352,7 @@ impl NewBufferTraitThing for ByteBuffers { } fn len(&self) -> usize { - (self.offsets.buffer.len() / std::mem::size_of::()) - 1 + self.offsets.capacity_var_cells() } fn data(&mut self) -> &mut QueryBuffer { @@ -382,6 +382,7 @@ impl NewBufferTraitThing for ByteBuffers { fn into_arrow(self: Box) -> IntoArrowResult { // NB: by default the offsets are not arrow-shaped. // However we use the configuration options to make them so. + let num_cells = self.offsets.num_var_cells(); let dtype = self.dtype; let data = ArrowBuffer::from(self.data.buffer); @@ -393,7 +394,7 @@ impl NewBufferTraitThing for ByteBuffers { // the underlying allocated data. match aa::ArrayData::try_new( dtype.clone(), - (*self.offsets.size as usize / std::mem::size_of::()) - 1, + num_cells, validity.clone().map(|v| v.into_inner().into_inner()), 0, vec![offsets.clone(), data.clone()], @@ -598,6 +599,26 @@ impl QueryBuffer { pub fn resize(&mut self) { self.buffer.resize(*self.size as usize, 0); } + + /// Returns the number of variable-length cells which this buffer + /// has room to hold offsets for + pub fn capacity_var_cells(&self) -> usize { + if self.buffer.len() == 0 { + 0 + } else { + (self.buffer.len() / std::mem::size_of::()) - 1 + } + } + + /// Returns the number of variable-length cells which the offsets + /// in this buffer describe + pub fn num_var_cells(&self) -> usize { + if *self.size == 0 { + 0 + } else { + (*self.size as usize / std::mem::size_of::()) - 1 + } + } } struct ListBuffers { @@ -688,7 +709,7 @@ impl NewBufferTraitThing for ListBuffers { } fn len(&self) -> usize { - (self.offsets.buffer.len() / std::mem::size_of::()) - 1 + self.offsets.num_var_cells() } fn data(&mut self) -> &mut QueryBuffer { @@ -720,6 +741,8 @@ impl NewBufferTraitThing for ListBuffers { assert!(field.data_type().is_primitive()); + let num_cells = self.offsets.num_var_cells(); + // NB: by default the offsets are not arrow-shaped. // However we use the configuration options to make them so. @@ -732,7 +755,7 @@ impl NewBufferTraitThing for ListBuffers { // the underlying allocated data. match aa::ArrayData::try_new( field.data_type().clone(), - (*self.offsets.size as usize / std::mem::size_of::()) - 1, + num_cells, None, 0, vec![data.clone().into()], From 3ffbbcb3854a0cca21eee91694561a046954140b Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 13:45:01 -0500 Subject: [PATCH 10/42] Example typo --- tiledb/query-core/examples/query_condition_dense_arrow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiledb/query-core/examples/query_condition_dense_arrow.rs b/tiledb/query-core/examples/query_condition_dense_arrow.rs index 74bfbed4..7b1f6966 100644 --- a/tiledb/query-core/examples/query_condition_dense_arrow.rs +++ b/tiledb/query-core/examples/query_condition_dense_arrow.rs @@ -199,7 +199,7 @@ fn write_array(ctx: &Context) -> TileDBResult<()> { Some(10), ])); let b_data = Arc::new(aa::LargeStringArray::from(vec![ - "alice", "bob", "craig", "daeve", "erin", "frank", "grace", "heidi", + "alice", "bob", "craig", "dave", "erin", "frank", "grace", "heidi", "ivan", "judy", ])); let c_data = From 92ec59c87007e0e3c6dbe106f9106b8a62c11402 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 14:03:38 -0500 Subject: [PATCH 11/42] Move raw ptr methods into QueryBuffer --- tiledb/query-core/src/buffers.rs | 126 ++++++------------------------- tiledb/query-core/src/lib.rs | 51 +++++++------ 2 files changed, 52 insertions(+), 125 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index cfeabc8f..43e5e8aa 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -153,62 +153,6 @@ trait NewBufferTraitThing { validity.resize(); } } - - /// Returns a mutable pointer to the data buffer - fn data_ptr(&mut self) -> *mut std::ffi::c_void { - self.data().buffer.as_mut_ptr() as *mut std::ffi::c_void - } - - /// Returns a mutable pointer to the data size - fn data_size_ptr(&mut self) -> *mut u64 { - self.data().size.as_mut().get_mut() - } - - /// Returns a mutable poiniter to the offsets buffer. - /// - /// For variants that don't have offsets, it returns a null pointer. - fn offsets_ptr(&mut self) -> *mut u64 { - let Some(offsets) = self.offsets() else { - return std::ptr::null_mut(); - }; - - offsets.buffer.as_mut_ptr() as *mut u64 - } - - /// Returns a mutable pointer to the offsets size. - /// - /// For variants that don't have offsets, it returns a null pointer. - /// - /// FIXME: why does this not return `Option` - fn offsets_size_ptr(&mut self) -> *mut u64 { - let Some(offsets) = self.offsets() else { - return std::ptr::null_mut(); - }; - - offsets.size.as_mut().get_mut() - } - - /// Returns a mutable pointer to the validity buffer, when present - /// - /// When validity is not present, it returns a null pointer. - fn validity_ptr(&mut self) -> *mut u8 { - let Some(validity) = self.validity() else { - return std::ptr::null_mut(); - }; - - validity.buffer.as_mut_ptr() - } - - /// Returns a mutable pointer to the validity size, when present - /// - /// When validity is not present, it returns a null pointer. - fn validity_size_ptr(&mut self) -> *mut u64 { - let Some(validity) = self.validity() else { - return std::ptr::null_mut(); - }; - - validity.size.as_mut().get_mut() - } } struct BooleanBuffers { @@ -580,7 +524,7 @@ impl NewBufferTraitThing for FixedListBuffers { } } -struct QueryBuffer { +pub struct QueryBuffer { buffer: ArrowBufferMut, size: Pin>, } @@ -591,6 +535,22 @@ impl QueryBuffer { Self { buffer, size } } + pub fn data_ptr(&mut self) -> *mut std::ffi::c_void { + self.buffer.as_mut_ptr() as *mut std::ffi::c_void + } + + pub fn offsets_ptr(&mut self) -> *mut u64 { + self.buffer.as_mut_ptr() as *mut u64 + } + + pub fn validity_ptr(&mut self) -> *mut u8 { + self.buffer.as_mut_ptr() as *mut u8 + } + + pub fn size_ptr(&mut self) -> *mut u64 { + self.size.as_mut().get_mut() + } + pub fn reset(&mut self) { *self.size = self.buffer.capacity() as u64; self.resize() @@ -1147,68 +1107,28 @@ impl BufferEntry { Ok(()) } - pub fn data_ptr(&mut self) -> Result<*mut std::ffi::c_void> { - let Some(mutable) = self.entry.mutable() else { - return Err(Error::ImmutableBuffer); - }; - - Ok(mutable.data_ptr()) - } - - pub fn data_size_ptr(&mut self) -> Result<*mut u64> { - let Some(mutable) = self.entry.mutable() else { - return Err(Error::ImmutableBuffer); - }; - - Ok(mutable.data_size_ptr()) - } - - pub fn has_offsets(&mut self) -> Result { - let Some(mutable) = self.entry.mutable() else { - return Err(Error::ImmutableBuffer); - }; - - Ok(mutable.offsets_ptr() != std::ptr::null_mut()) - } - - pub fn offsets_ptr(&mut self) -> Result<*mut u64> { - let Some(mutable) = self.entry.mutable() else { - return Err(Error::ImmutableBuffer); - }; - - Ok(mutable.offsets_ptr()) - } - - pub fn offsets_size_ptr(&mut self) -> Result<*mut u64> { - let Some(mutable) = self.entry.mutable() else { - return Err(Error::ImmutableBuffer); - }; - - Ok(mutable.offsets_size_ptr()) - } - - pub fn has_validity(&mut self) -> Result { + pub fn data_mut(&mut self) -> Result<&mut QueryBuffer> { let Some(mutable) = self.entry.mutable() else { return Err(Error::ImmutableBuffer); }; - Ok(mutable.validity_ptr() != std::ptr::null_mut()) + Ok(mutable.data()) } - pub fn validity_ptr(&mut self) -> Result<*mut u8> { + pub fn offsets_mut(&mut self) -> Result> { let Some(mutable) = self.entry.mutable() else { return Err(Error::ImmutableBuffer); }; - Ok(mutable.validity_ptr()) + Ok(mutable.offsets()) } - pub fn validity_size_ptr(&mut self) -> Result<*mut u64> { + pub fn validity_mut(&mut self) -> Result> { let Some(mutable) = self.entry.mutable() else { return Err(Error::ImmutableBuffer); }; - Ok(mutable.validity_size_ptr()) + Ok(mutable.validity()) } fn to_mutable(&mut self) -> Result<()> { diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 6d303a56..7edf08f9 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -264,36 +264,43 @@ impl Query { for (field, buffer) in self.buffers.iter_mut() { let c_name = std::ffi::CString::new(field.as_bytes())?; - let c_data_ptr = buffer.data_ptr()?; - let c_data_size_ptr = buffer.data_size_ptr()?; - - self.context.capi_call(|ctx| unsafe { - ffi::tiledb_query_set_data_buffer( - ctx, - c_query, - c_name.as_ptr(), - c_data_ptr, - c_data_size_ptr, - ) - })?; - - if buffer.has_offsets()? { - let c_offsets_ptr = buffer.offsets_ptr()?; - let c_offsets_size_ptr = buffer.offsets_size_ptr()?; + { + let data = buffer.data_mut()?; + let c_data_ptr = data.data_ptr(); + let c_data_size_ptr = data.size_ptr(); + self.context.capi_call(|ctx| unsafe { - ffi::tiledb_query_set_offsets_buffer( + ffi::tiledb_query_set_data_buffer( ctx, c_query, c_name.as_ptr(), - c_offsets_ptr, - c_offsets_size_ptr, + c_data_ptr, + c_data_size_ptr, ) })?; } - if buffer.has_validity()? { - let c_validity_ptr = buffer.validity_ptr()?; - let c_validity_size_ptr = buffer.validity_size_ptr()?; + { + // NB: `if let` binding is longer than it looks + if let Some(offsets) = buffer.offsets_mut()? { + let c_offsets_ptr = offsets.offsets_ptr(); + let c_offsets_size_ptr = offsets.size_ptr(); + + self.context.capi_call(|ctx| unsafe { + ffi::tiledb_query_set_offsets_buffer( + ctx, + c_query, + c_name.as_ptr(), + c_offsets_ptr, + c_offsets_size_ptr, + ) + })?; + } + } + + if let Some(validity) = buffer.validity_mut()? { + let c_validity_ptr = validity.validity_ptr(); + let c_validity_size_ptr = validity.size_ptr(); self.context.capi_call(|ctx| unsafe { ffi::tiledb_query_set_validity_buffer( ctx, From 4f842845df9838ddfbf9a1c0295a9b36bd9ef46e Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 14:18:30 -0500 Subject: [PATCH 12/42] clippy except for fn len --- tiledb/query-core/src/arrow.rs | 26 ++++++------- tiledb/query-core/src/buffers.rs | 62 +++++++++++++++---------------- tiledb/query-core/src/fields.rs | 5 +-- tiledb/query-core/src/lib.rs | 22 +++++------ tiledb/query-core/src/subarray.rs | 5 +-- 5 files changed, 55 insertions(+), 65 deletions(-) diff --git a/tiledb/query-core/src/arrow.rs b/tiledb/query-core/src/arrow.rs index 00f303de..b7119cab 100644 --- a/tiledb/query-core/src/arrow.rs +++ b/tiledb/query-core/src/arrow.rs @@ -112,16 +112,16 @@ impl ToArrowConverter { if arrow_type.is_primitive() { let width = arrow_type.primitive_width().unwrap(); - if width != dtype.size() as usize { + if width != dtype.size() { return Err(Error::PhysicalSizeMismatch(*dtype, arrow_type)); } if cvn.is_single_valued() { - return Ok(arrow_type); + Ok(arrow_type) } else if cvn.is_var_sized() { let field = Arc::new(adt::Field::new("item", arrow_type, nullable)); - return Ok(adt::DataType::LargeList(field)); + Ok(adt::DataType::LargeList(field)) } else { // SAFETY: Due to the logic above we can guarantee that this // is a fixed length cvn. @@ -131,23 +131,25 @@ impl ToArrowConverter { } let field = Arc::new(adt::Field::new("item", arrow_type, nullable)); - return Ok(adt::DataType::FixedSizeList(field, cvn as i32)); + Ok(adt::DataType::FixedSizeList(field, cvn as i32)) } } else if matches!(arrow_type, adt::DataType::Boolean) { if !cvn.is_single_valued() { - return Err(Error::RequiresSingleValued(arrow_type)); + Err(Error::RequiresSingleValued(arrow_type)) + } else { + Ok(arrow_type) } - return Ok(arrow_type); } else if matches!( arrow_type, adt::DataType::LargeBinary | adt::DataType::LargeUtf8 ) { if !cvn.is_var_sized() { - return Err(Error::RequiresVarSized(arrow_type)); + Err(Error::RequiresVarSized(arrow_type)) + } else { + Ok(arrow_type) } - return Ok(arrow_type); } else { - return Err(Error::InternalTypeError(arrow_type)); + Err(Error::InternalTypeError(arrow_type)) } } @@ -358,9 +360,7 @@ impl FromArrowConverter { arrow::Timestamp(adt::TimeUnit::Nanosecond, None) => { Ok((tiledb::DateTimeNanosecond, single, None)) } - arrow::Timestamp(_, Some(_)) => { - return Err(Error::TimeZonesNotSupported); - } + arrow::Timestamp(_, Some(_)) => Err(Error::TimeZonesNotSupported), arrow::Time64(adt::TimeUnit::Second) => { Ok((tiledb::TimeSecond, single, None)) @@ -480,7 +480,7 @@ impl FromArrowConverter { | arrow::Decimal256(_, _) | arrow::Map(_, _) | arrow::RunEndEncoded(_, _) => { - return Err(Error::UnsupportedArrowDataType(arrow_type)); + Err(Error::UnsupportedArrowDataType(arrow_type)) } } } diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index 43e5e8aa..14b7fb4b 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -120,7 +120,7 @@ trait NewBufferTraitThing { fn validity(&mut self) -> Option<&mut QueryBuffer>; /// Check if another buffer is compatible with this buffer - fn is_compatible(&self, other: &Box) -> bool; + fn is_compatible(&self, other: &dyn NewBufferTraitThing) -> bool; /// Consume self and return an Arc fn into_arrow(self: Box) -> IntoArrowResult; @@ -173,8 +173,7 @@ impl TryFrom> for BooleanBuffers { Ok(BooleanBuffers { data: QueryBuffer::new(abuf::MutableBuffer::from(data)), - validity: validity - .map(|v| QueryBuffer::new(abuf::MutableBuffer::from(v))), + validity: validity.map(QueryBuffer::new), }) } } @@ -200,7 +199,7 @@ impl NewBufferTraitThing for BooleanBuffers { self.validity.as_mut() } - fn is_compatible(&self, other: &Box) -> bool { + fn is_compatible(&self, other: &dyn NewBufferTraitThing) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; }; @@ -311,7 +310,7 @@ impl NewBufferTraitThing for ByteBuffers { self.validity.as_mut() } - fn is_compatible(&self, other: &Box) -> bool { + fn is_compatible(&self, other: &dyn NewBufferTraitThing) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; }; @@ -344,7 +343,7 @@ impl NewBufferTraitThing for ByteBuffers { vec![offsets.clone(), data.clone()], vec![], ) - .map(|data| aa::make_array(data)) + .map(aa::make_array) { Ok(arrow) => Ok(arrow), Err(e) => { @@ -456,7 +455,7 @@ impl NewBufferTraitThing for FixedListBuffers { self.validity.as_mut() } - fn is_compatible(&self, other: &Box) -> bool { + fn is_compatible(&self, other: &dyn NewBufferTraitThing) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; }; @@ -492,7 +491,7 @@ impl NewBufferTraitThing for FixedListBuffers { len, validity.clone().map(|v| v.into_inner().into_inner()), 0, - vec![data.clone().into()], + vec![data.clone()], vec![], ) .map(|data| { @@ -544,7 +543,7 @@ impl QueryBuffer { } pub fn validity_ptr(&mut self) -> *mut u8 { - self.buffer.as_mut_ptr() as *mut u8 + self.buffer.as_mut_ptr() } pub fn size_ptr(&mut self) -> *mut u64 { @@ -563,7 +562,7 @@ impl QueryBuffer { /// Returns the number of variable-length cells which this buffer /// has room to hold offsets for pub fn capacity_var_cells(&self) -> usize { - if self.buffer.len() == 0 { + if self.buffer.is_empty() { 0 } else { (self.buffer.len() / std::mem::size_of::()) - 1 @@ -684,7 +683,7 @@ impl NewBufferTraitThing for ListBuffers { self.validity.as_mut() } - fn is_compatible(&self, other: &Box) -> bool { + fn is_compatible(&self, other: &dyn NewBufferTraitThing) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; }; @@ -718,7 +717,7 @@ impl NewBufferTraitThing for ListBuffers { num_cells, None, 0, - vec![data.clone().into()], + vec![data.clone()], vec![], ) .and_then(|data| { @@ -840,7 +839,7 @@ impl NewBufferTraitThing for PrimitiveBuffers { self.validity.as_mut() } - fn is_compatible(&self, other: &Box) -> bool { + fn is_compatible(&self, other: &dyn NewBufferTraitThing) -> bool { let Some(other) = other.as_any().downcast_ref::() else { return false; }; @@ -968,7 +967,7 @@ impl TryFrom> for Box { | adt::DataType::Decimal256(_, _) | adt::DataType::Map(_, _) | adt::DataType::RunEndEncoded(_, _) => { - return Err((array, Error::UnsupportedArrowType(dtype))); + Err((array, Error::UnsupportedArrowType(dtype))) } } } @@ -1000,7 +999,7 @@ impl MutableOrShared { self.shared.as_ref().map(Arc::clone) } - pub fn to_mutable(&mut self) -> Result<()> { + pub fn make_mut(&mut self) -> Result<()> { self.validate(); if self.mutable.is_some() { @@ -1024,7 +1023,7 @@ impl MutableOrShared { ret } - pub fn to_shared(&mut self) -> Result<()> { + pub fn make_shared(&mut self) -> Result<()> { self.validate(); if self.shared.is_some() { @@ -1065,7 +1064,7 @@ impl BufferEntry { return Err(Error::UnshareableMutableBuffer); }; - return Ok(Arc::clone(array)); + Ok(Arc::clone(array)) } pub fn len(&self) -> usize { @@ -1083,7 +1082,7 @@ impl BufferEntry { .mutable .as_ref() .unwrap() - .is_compatible(other.entry.mutable.as_ref().unwrap()); + .is_compatible(other.entry.mutable.as_ref().unwrap().as_ref()); } false @@ -1131,12 +1130,12 @@ impl BufferEntry { Ok(mutable.validity()) } - fn to_mutable(&mut self) -> Result<()> { - self.entry.to_mutable() + fn make_mut(&mut self) -> Result<()> { + self.entry.make_mut() } - fn to_shared(&mut self) -> Result<()> { - self.entry.to_shared() + fn make_shared(&mut self) -> Result<()> { + self.entry.make_shared() } } @@ -1258,7 +1257,7 @@ impl QueryBuffers { } } - return true; + true } pub fn iter(&self) -> impl Iterator { @@ -1271,16 +1270,16 @@ impl QueryBuffers { self.buffers.iter_mut() } - pub fn to_mutable(&mut self) -> Result<()> { + pub fn make_mut(&mut self) -> Result<()> { for value in self.buffers.values_mut() { - value.to_mutable()? + value.make_mut()? } Ok(()) } - pub fn to_shared(&mut self) -> Result<()> { + pub fn make_shared(&mut self) -> Result<()> { for value in self.buffers.values_mut() { - value.to_shared()? + value.make_shared()? } Ok(()) } @@ -1304,7 +1303,7 @@ pub struct SharedBuffers { } impl SharedBuffers { - pub fn get(&self, key: &str) -> Option<&T> + pub fn get(&self, key: &str) -> Option<&T> where T: Any, { @@ -1413,7 +1412,7 @@ fn alloc_array( vec![data.into()], vec![], ) - .map_err(|e| Error::ArrayCreationFailed(e))?; + .map_err(Error::ArrayCreationFailed)?; Ok(aa::make_array(data)) } @@ -1520,10 +1519,7 @@ fn from_tdb_validity( ) -> Option { validity.as_ref().map(|v| { abuf::NullBuffer::from( - v.buffer - .iter() - .map(|f| if *f != 0 { true } else { false }) - .collect::>(), + v.buffer.iter().map(|f| *f != 0).collect::>(), ) }) } diff --git a/tiledb/query-core/src/fields.rs b/tiledb/query-core/src/fields.rs index 770d4f0b..2e03fac0 100644 --- a/tiledb/query-core/src/fields.rs +++ b/tiledb/query-core/src/fields.rs @@ -53,15 +53,14 @@ impl QueryFields { } } +#[derive(Default)] pub struct QueryFieldsBuilder { fields: QueryFields, } impl QueryFieldsBuilder { pub fn new() -> Self { - Self { - fields: Default::default(), - } + Self::default() } pub fn build(self) -> QueryFields { diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 7edf08f9..35613a7e 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -1,4 +1,4 @@ -///! The TileDB Query interface and supporting utilities +//! The TileDB Query interface and supporting utilities extern crate tiledb_sys as ffi; use std::collections::HashMap; @@ -176,7 +176,7 @@ impl Query { } pub fn submit(&mut self) -> Result { - self.buffers.to_mutable()?; + self.buffers.make_mut()?; if matches!(self.query_type, QueryType::Read) { self.buffers.reset_lengths()?; } @@ -194,14 +194,10 @@ impl Query { match self.curr_status()? { QueryStatus::Uninitialized | QueryStatus::Initialized - | QueryStatus::InProgress => { - return Err(Error::InternalError( - "Invalid query status after submit".to_string(), - )) - } - QueryStatus::Failed => { - return Err(self.context.expect_last_error().into()); - } + | QueryStatus::InProgress => Err(Error::InternalError( + "Invalid query status after submit".to_string(), + )), + QueryStatus::Failed => Err(self.context.expect_last_error().into()), QueryStatus::Incomplete => { if self.buffers.iter().any(|(_, b)| b.len() > 0) { Ok(QueryStatus::Incomplete) @@ -222,7 +218,7 @@ impl Query { ffi::tiledb_query_finalize(ctx, c_query) })?; - self.buffers.to_shared()?; + self.buffers.make_shared()?; let mut ret = HashMap::with_capacity(self.buffers.len()); for (field, buffer) in self.buffers.iter() { ret.insert(field.clone(), buffer.as_shared()?); @@ -232,7 +228,7 @@ impl Query { } pub fn buffers(&mut self) -> Result { - self.buffers.to_shared()?; + self.buffers.make_shared()?; let mut ret = HashMap::with_capacity(self.buffers.len()); for (field, buffer) in self.buffers.iter() { ret.insert(field.clone(), buffer.as_shared()?); @@ -250,7 +246,7 @@ impl Query { ) -> Result { let mut tmp_buffers = QueryBuffers::from_fields(self.array.schema()?, fields)?; - tmp_buffers.to_mutable()?; + tmp_buffers.make_mut()?; if self.buffers.is_compatible(&tmp_buffers) { std::mem::swap(&mut self.buffers, &mut tmp_buffers); Ok(tmp_buffers) diff --git a/tiledb/query-core/src/subarray.rs b/tiledb/query-core/src/subarray.rs index 3bd6e881..1b5fff7d 100644 --- a/tiledb/query-core/src/subarray.rs +++ b/tiledb/query-core/src/subarray.rs @@ -6,15 +6,14 @@ use super::QueryBuilder; pub type SubarrayData = HashMap>; +#[derive(Default)] pub struct SubarrayBuilder { subarray: SubarrayData, } impl SubarrayBuilder { pub fn new() -> Self { - Self { - subarray: Default::default(), - } + Self::default() } pub fn add_range>( From 47b8edd38863b2b8bc68188227d99d7b4d1382a9 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 15 Nov 2024 14:28:37 -0500 Subject: [PATCH 13/42] is_empty for clippy --- tiledb/query-core/src/buffers.rs | 8 ++++++++ tiledb/query-core/src/lib.rs | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index 14b7fb4b..aed4281a 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -1067,6 +1067,10 @@ impl BufferEntry { Ok(Arc::clone(array)) } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn len(&self) -> usize { if self.entry.shared.is_some() { return self.entry.shared.as_ref().unwrap().len(); @@ -1178,6 +1182,10 @@ impl QueryBuffers { Ok(()) } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn len(&self) -> usize { self.buffers.len() } diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 35613a7e..c5a6db87 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -199,7 +199,7 @@ impl Query { )), QueryStatus::Failed => Err(self.context.expect_last_error().into()), QueryStatus::Incomplete => { - if self.buffers.iter().any(|(_, b)| b.len() > 0) { + if self.buffers.iter().any(|(_, b)| !b.is_empty()) { Ok(QueryStatus::Incomplete) } else { Ok(QueryStatus::BuffersTooSmall) From 29fef07f112733cec3c8def40c4eca2659f6e7b9 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 19 Nov 2024 10:08:54 -0500 Subject: [PATCH 14/42] Add tiledb_query_field_t to sys --- tiledb/sys-defs/src/lib.rs | 8 +++++++ tiledb/sys/src/field.rs | 43 ++++++++++++++++++++++++++++++++++++++ tiledb/sys/src/lib.rs | 2 ++ tiledb/sys/src/types.rs | 6 ++++++ 4 files changed, 59 insertions(+) create mode 100644 tiledb/sys/src/field.rs diff --git a/tiledb/sys-defs/src/lib.rs b/tiledb/sys-defs/src/lib.rs index a79d02a8..8cf8eb75 100644 --- a/tiledb/sys-defs/src/lib.rs +++ b/tiledb/sys-defs/src/lib.rs @@ -185,6 +185,14 @@ pub const tiledb_query_type_t_TILEDB_UPDATE: tiledb_query_type_t = 3; pub const tiledb_query_type_t_TILEDB_MODIFY_EXCLUSIVE: tiledb_query_type_t = 4; pub type tiledb_query_type_t = ::std::os::raw::c_uint; +pub const tiledb_field_origin_t_TILEDB_ATTRIBUTE_FIELD: tiledb_field_origin_t = + 0; +pub const tiledb_field_origin_t_TILEDB_DIMENSION_FIELD: tiledb_field_origin_t = + 1; +pub const tiledb_field_origin_t_TILEDB_AGGREGATE_FIELD: tiledb_field_origin_t = + 2; +pub type tiledb_field_origin_t = ::std::os::raw::c_uint; + pub const tiledb_vfs_mode_t_TILEDB_VFS_READ: tiledb_vfs_mode_t = 0; pub const tiledb_vfs_mode_t_TILEDB_VFS_WRITE: tiledb_vfs_mode_t = 1; pub const tiledb_vfs_mode_t_TILEDB_VFS_APPEND: tiledb_vfs_mode_t = 2; diff --git a/tiledb/sys/src/field.rs b/tiledb/sys/src/field.rs new file mode 100644 index 00000000..a5c793d9 --- /dev/null +++ b/tiledb/sys/src/field.rs @@ -0,0 +1,43 @@ +use crate::capi_enum::{tiledb_datatype_t, tiledb_field_origin_t}; +use crate::types::{ + capi_return_t, tiledb_ctx_t, tiledb_query_channel_t, tiledb_query_field_t, + tiledb_query_t, +}; + +extern "C" { + pub fn tiledb_query_get_field( + ctx: *mut tiledb_ctx_t, + query: *mut tiledb_query_t, + field_name: *const ::std::os::raw::c_char, + field: *mut *mut tiledb_query_field_t, + ) -> capi_return_t; + + pub fn tiledb_query_field_free( + ctx: *mut tiledb_ctx_t, + field: *mut *mut tiledb_query_field_t, + ) -> capi_return_t; + + pub fn tiledb_field_datatype( + ctx: *mut tiledb_ctx_t, + field: *mut tiledb_query_field_t, + type_: *mut tiledb_datatype_t, + ) -> capi_return_t; + + pub fn tiledb_field_cell_val_num( + ctx: *mut tiledb_ctx_t, + field: *mut tiledb_query_field_t, + cell_val_num: *mut u32, + ) -> capi_return_t; + + pub fn tiledb_field_origin( + ctx: *mut tiledb_ctx_t, + field: *mut tiledb_query_field_t, + origin: *mut tiledb_field_origin_t, + ) -> capi_return_t; + + pub fn tiledb_field_channel( + ctx: *mut tiledb_ctx_t, + field: *mut tiledb_query_field_t, + channel: *mut *mut tiledb_query_channel_t, + ) -> capi_return_t; +} diff --git a/tiledb/sys/src/lib.rs b/tiledb/sys/src/lib.rs index da4e2ffe..147d8288 100644 --- a/tiledb/sys/src/lib.rs +++ b/tiledb/sys/src/lib.rs @@ -14,6 +14,7 @@ mod domain; mod encryption; mod enumeration; mod error; +mod field; mod filesystem; mod filter; mod filter_list; @@ -47,6 +48,7 @@ pub use domain::*; pub use encryption::*; pub use enumeration::*; pub use error::*; +pub use field::*; pub use filesystem::*; pub use filter::*; pub use filter_list::*; diff --git a/tiledb/sys/src/types.rs b/tiledb/sys/src/types.rs index 2b95e7f6..f5b58c2a 100644 --- a/tiledb/sys/src/types.rs +++ b/tiledb/sys/src/types.rs @@ -93,6 +93,12 @@ pub struct tiledb_query_condition_t { _unused: [u8; 0], } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct tiledb_query_field_t { + _unused: [u8; 0], +} + #[repr(C)] #[derive(Clone, Copy, Debug)] pub struct tiledb_string_t { From 25b0a86722831161e8d198d3355e8270435bbb7b Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 20 Nov 2024 13:55:58 -0500 Subject: [PATCH 15/42] sys tiledb_field_get_nullable --- tiledb/sys/src/field.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tiledb/sys/src/field.rs b/tiledb/sys/src/field.rs index a5c793d9..1f98fee5 100644 --- a/tiledb/sys/src/field.rs +++ b/tiledb/sys/src/field.rs @@ -29,6 +29,12 @@ extern "C" { cell_val_num: *mut u32, ) -> capi_return_t; + pub fn tiledb_field_get_nullable( + ctx: *mut tiledb_ctx_t, + field: *mut tiledb_query_field_t, + nullable: *mut u8, + ) -> capi_return_t; + pub fn tiledb_field_origin( ctx: *mut tiledb_ctx_t, field: *mut tiledb_query_field_t, From a6d45cfab8a69855d72ef19d2e228de4c8f1b630 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 20 Nov 2024 13:56:28 -0500 Subject: [PATCH 16/42] AggregateFunctionHandle --- tiledb/api/src/query/read/aggregate/mod.rs | 214 ++++++++++++--------- 1 file changed, 123 insertions(+), 91 deletions(-) diff --git a/tiledb/api/src/query/read/aggregate/mod.rs b/tiledb/api/src/query/read/aggregate/mod.rs index 64bfe71d..ffaec0d6 100644 --- a/tiledb/api/src/query/read/aggregate/mod.rs +++ b/tiledb/api/src/query/read/aggregate/mod.rs @@ -144,12 +144,22 @@ impl Display for AggregateFunction { } /// Encapsulates data needed to run an aggregate function in the C API. +#[cfg(feature = "raw")] +#[derive(Debug)] +pub struct AggregateFunctionHandle { + function: AggregateFunction, + // NB: C API uses this memory location to store the attribute name if any + agg_name: CString, + field_name: Option, +} + +#[cfg(not(feature = "raw"))] #[derive(Debug)] struct AggregateFunctionHandle { - pub function: AggregateFunction, + function: AggregateFunction, // NB: C API uses this memory location to store the attribute name if any - pub agg_name: CString, - pub field_name: Option, + agg_name: CString, + field_name: Option, } impl AggregateFunctionHandle { @@ -167,6 +177,115 @@ impl AggregateFunctionHandle { field_name, }) } + + pub fn aggregate(&self) -> &AggregateFunction { + &self.function + } + + pub fn aggregate_name(&self) -> &std::ffi::CStr { + &self.agg_name + } + + pub fn field_name(&self) -> Option<&std::ffi::CStr> { + self.field_name.as_ref().map(|c| c.deref()) + } +} + +impl AggregateFunctionHandle { + pub fn apply_to_raw_query( + &self, + context: &Context, + c_query: *mut ffi::tiledb_query_t, + ) -> TileDBResult<()> { + let mut c_channel: *mut tiledb_query_channel_t = out_ptr!(); + context.capi_call(|ctx| unsafe { + ffi::tiledb_query_get_default_channel(ctx, c_query, &mut c_channel) + })?; + + // C API functionality + let mut c_agg_operator: *const tiledb_channel_operator_t = out_ptr!(); + let mut c_agg_operation: *mut tiledb_channel_operation_t = out_ptr!(); + let c_agg_name = self.agg_name.as_c_str().as_ptr(); + + // The if statement and match statement are in different arms because of the agg_operation + // variable takes in different types in the respective functions. + if self.function == AggregateFunction::Count { + context.capi_call(|ctx| unsafe { + ffi::tiledb_aggregate_count_get( + ctx, + core::ptr::addr_of_mut!(c_agg_operation) + as *mut *const tiledb_channel_operation_t, + ) + })?; + } else { + let c_field_name = + self.field_name.as_ref().unwrap().as_c_str().as_ptr(); + match self.function { + AggregateFunction::Count => unreachable!( + "AggregateFunction::Count handled in above case, found {:?}", + self.function + ), + AggregateFunction::NullCount(_) => { + context.capi_call(|ctx| unsafe { + ffi::tiledb_channel_operator_null_count_get( + ctx, + &mut c_agg_operator, + ) + })?; + } + AggregateFunction::Sum(_) => { + context.capi_call(|ctx| unsafe { + ffi::tiledb_channel_operator_sum_get( + ctx, + &mut c_agg_operator, + ) + })?; + } + AggregateFunction::Max(_) => { + context.capi_call(|ctx| unsafe { + ffi::tiledb_channel_operator_max_get( + ctx, + &mut c_agg_operator, + ) + })?; + } + AggregateFunction::Min(_) => { + context.capi_call(|ctx| unsafe { + ffi::tiledb_channel_operator_min_get( + ctx, + &mut c_agg_operator, + ) + })?; + } + AggregateFunction::Mean(_) => { + context.capi_call(|ctx| unsafe { + ffi::tiledb_channel_operator_mean_get( + ctx, + &mut c_agg_operator, + ) + })?; + } + }; + context.capi_call(|ctx| unsafe { + ffi::tiledb_create_unary_aggregate( + ctx, + c_query, + c_agg_operator, + c_field_name, + &mut c_agg_operation, + ) + })?; + } + + context.capi_call(|ctx| unsafe { + ffi::tiledb_channel_apply_aggregate( + ctx, + c_channel, + c_agg_name, + c_agg_operation, + ) + }) + } } /// Query builder adapter for constructing queries with aggregate functions. @@ -328,94 +447,7 @@ pub trait AggregateQueryBuilder: QueryBuilder { let context = self.base().context(); let c_query = **self.base().cquery(); - let mut c_channel: *mut tiledb_query_channel_t = out_ptr!(); - context.capi_call(|ctx| unsafe { - ffi::tiledb_query_get_default_channel(ctx, c_query, &mut c_channel) - })?; - - // C API functionality - let mut c_agg_operator: *const tiledb_channel_operator_t = out_ptr!(); - let mut c_agg_operation: *mut tiledb_channel_operation_t = out_ptr!(); - let c_agg_name = handle.agg_name.as_c_str().as_ptr(); - - // The if statement and match statement are in different arms because of the agg_operation - // variable takes in different types in the respective functions. - if handle.function == AggregateFunction::Count { - context.capi_call(|ctx| unsafe { - ffi::tiledb_aggregate_count_get( - ctx, - core::ptr::addr_of_mut!(c_agg_operation) - as *mut *const tiledb_channel_operation_t, - ) - })?; - } else { - let c_field_name = - handle.field_name.as_ref().unwrap().as_c_str().as_ptr(); - match handle.function { - AggregateFunction::Count => unreachable!( - "AggregateFunction::Count handled in above case, found {:?}", - handle.function - ), - AggregateFunction::NullCount(_) => { - context.capi_call(|ctx| unsafe { - ffi::tiledb_channel_operator_null_count_get( - ctx, - &mut c_agg_operator, - ) - })?; - } - AggregateFunction::Sum(_) => { - context.capi_call(|ctx| unsafe { - ffi::tiledb_channel_operator_sum_get( - ctx, - &mut c_agg_operator, - ) - })?; - } - AggregateFunction::Max(_) => { - context.capi_call(|ctx| unsafe { - ffi::tiledb_channel_operator_max_get( - ctx, - &mut c_agg_operator, - ) - })?; - } - AggregateFunction::Min(_) => { - context.capi_call(|ctx| unsafe { - ffi::tiledb_channel_operator_min_get( - ctx, - &mut c_agg_operator, - ) - })?; - } - AggregateFunction::Mean(_) => { - context.capi_call(|ctx| unsafe { - ffi::tiledb_channel_operator_mean_get( - ctx, - &mut c_agg_operator, - ) - })?; - } - }; - context.capi_call(|ctx| unsafe { - ffi::tiledb_create_unary_aggregate( - ctx, - c_query, - c_agg_operator, - c_field_name, - &mut c_agg_operation, - ) - })?; - } - - context.capi_call(|ctx| unsafe { - ffi::tiledb_channel_apply_aggregate( - ctx, - c_channel, - c_agg_name, - c_agg_operation, - ) - })?; + handle.apply_to_raw_query(&context, c_query)?; Ok(AggregateBuilder:: { base: self, From 76cfe85284c11e46c79f5a1a955d760494b85f7b Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 20 Nov 2024 13:58:37 -0500 Subject: [PATCH 17/42] Add aggregate to query-core and a bunch of Error refactoring --- .../examples/reading_incomplete_arrow.rs | 11 +- tiledb/query-core/src/buffers.rs | 314 +++++++++++------- tiledb/query-core/src/field.rs | 188 +++++++++++ tiledb/query-core/src/fields.rs | 27 ++ tiledb/query-core/src/lib.rs | 21 +- 5 files changed, 437 insertions(+), 124 deletions(-) create mode 100644 tiledb/query-core/src/field.rs diff --git a/tiledb/query-core/examples/reading_incomplete_arrow.rs b/tiledb/query-core/examples/reading_incomplete_arrow.rs index 218e3d58..402e7470 100644 --- a/tiledb/query-core/examples/reading_incomplete_arrow.rs +++ b/tiledb/query-core/examples/reading_incomplete_arrow.rs @@ -10,7 +10,9 @@ use tiledb_api::{Factory, Result as TileDBResult}; use tiledb_common::array::{ArrayType, CellOrder, CellValNum, Mode, TileOrder}; use tiledb_common::Datatype; use tiledb_pod::array::{AttributeData, DimensionData, DomainData, SchemaData}; -use tiledb_query_core::buffers::Error as BuffersError; +use tiledb_query_core::buffers::{ + Error as BuffersError, FieldError, UnsupportedArrowArrayError, +}; use tiledb_query_core::fields::QueryFieldsBuilder; use tiledb_query_core::{ Error as QueryError, QueryBuilder, QueryLayout, QueryType, SharedBuffers, @@ -148,7 +150,12 @@ fn read_array(ctx: &Context) -> TileDBResult<()> { if matches!( err, - QueryError::QueryBuffersError(BuffersError::ArrayInUse) + QueryError::QueryBuffersError(BuffersError::Field( + _, + FieldError::UnsupportedArrowArray( + UnsupportedArrowArrayError::InUse(_) + ) + )) ) { drop(external_ref.take()); continue; diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index aed4281a..8271d215 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -10,25 +10,27 @@ use arrow::buffer::{ use arrow::datatypes as adt; use arrow::error::ArrowError; use thiserror::Error; -use tiledb_api::array::schema::{Field, Schema}; +use tiledb_api::array::schema::Schema; use tiledb_api::error::Error as TileDBError; +use tiledb_api::query::read::aggregate::AggregateFunctionHandle; +use tiledb_api::ContextBound; use tiledb_common::array::CellValNum; use super::arrow::ToArrowConverter; -use super::fields::{QueryField, QueryFields}; +use super::field::QueryField; +use super::fields::{QueryField as RequestField, QueryFields}; +use super::RawQuery; const AVERAGE_STRING_LENGTH: usize = 64; #[derive(Debug, Error)] pub enum Error { - #[error("Provided Arrow Array is externally referenced.")] - ArrayInUse, #[error("Error converting to Arrow for field '{0}': {1}")] ArrowConversionError(String, super::arrow::Error), #[error("Failed to convert Arrow Array for field '{0}': {1}")] FailedConversionFromArrow(String, Box), - #[error("Failed to allocate Arrow array: {0}")] - ArrayCreationFailed(ArrowError), + #[error("Failed to add field '{0}' to query: {1}")] + Field(String, #[source] FieldError), #[error("Capacity {0} is to small to hold {1} bytes per cell.")] CapacityTooSmall(usize, usize), #[error("Failed to convert owned buffers into a list array: {0}")] @@ -41,43 +43,75 @@ pub enum Error { InvalidBufferNoOffsets, #[error("Invalid buffer, no validity present.")] InvalidBufferNoValidity, - #[error("Invalid data type for bytes array: {0}")] - InvalidBytesType(adt::DataType), - #[error("Invalid fixed sized list length {0} is less than 2")] - InvalidFixedSizeListLength(i32), - #[error("TileDB does not support nullable list elements")] - InvalidNullableListElements, - #[error("Invalid data type for primitive data: {0}")] - InvalidPrimitiveType(adt::DataType), - #[error("Internal error: Converted primitive array is not scalar")] - InternalListTypeMismatch, #[error("Internal string type mismatch error")] InternalStringType, #[error("Error converting var sized buffers to arrow: {0}")] InvalidVarBuffers(ArrowError), - #[error("Only the large variant is supported: {0}")] - LargeVariantOnly(adt::DataType), #[error("Failed to convert internal list array: {0}")] ListSubarrayConversion(Box), - #[error("Provided array had external references to its offsets buffer.")] - OffsetsInUse, #[error("Internal TileDB Error: {0}")] TileDB(#[from] TileDBError), #[error("Unexpected var sized arrow type: {0}")] UnexpectedArrowVarType(adt::DataType), #[error("Mutable buffers are not shareable.")] UnshareableMutableBuffer, +} + +type Result = std::result::Result; + +/// An error that occurred when adding a field to the query. +#[derive(Debug, Error)] +pub enum FieldError { + #[error("Error reading query field: {0}")] + QueryField(#[from] crate::field::Error), + #[error("Type mismatch for requested field: {0}")] + TypeMismatch(crate::arrow::Error), + #[error("Failed to allocate buffer: {0}")] + BufferAllocation(ArrowError), + #[error("Unsupported arrow array: {0}")] + UnsupportedArrowArray(#[from] UnsupportedArrowArrayError), +} + +type FieldResult = std::result::Result; + +#[derive(Debug, Error)] +pub enum UnsupportedArrowArrayError { #[error("Unsupported arrow array type: {0}")] UnsupportedArrowType(adt::DataType), + #[error("TileDB does not support nullable list elements")] + InvalidNullableListElements, + #[error("Invalid fixed sized list length {0} is less than 2")] + InvalidFixedSizeListLength(i32), #[error( - "TileDB only supports fixed size lists of primtiive types, not {0}" + "TileDB only supports fixed size lists of primitive types, not {0}" )] UnsupportedFixedSizeListType(adt::DataType), + #[error("Invalid data type for bytes array: {0}")] + InvalidBytesType(adt::DataType), + #[error("Invalid data type for primitive data: {0}")] + InvalidPrimitiveType(adt::DataType), #[error("TileDB does not support timezones")] UnsupportedTimeZones, + #[error("Only the large variant is supported: {0}")] + LargeVariantOnly(adt::DataType), + #[error("Failed to create arrow array: {0}")] + ArrayCreationFailed(ArrowError), + #[error("Array is in use: {0}")] + InUse(#[from] ArrayInUseError), } -type Result = std::result::Result; +type UnsupportedArrowArrayResult = + std::result::Result; + +#[derive(Debug, Error)] +pub enum ArrayInUseError { + #[error("External references to offsets buffer")] + Offsets, + #[error("External references to array")] + Array, +} + +type ArrayInUseResult = std::result::Result; // The Arrow downcast_array function doesn't take an Arc which leaves us // with an outstanding reference when we attempt the Buffer::into_mutable @@ -93,11 +127,11 @@ where /// allows for fallible conversion without dropping the underlying buffers. type IntoArrowResult = std::result::Result< Arc, - (Box, Error), + (Box, UnsupportedArrowArrayError), >; /// The error type to use on TryFrom> implementations -type FromArrowError = (Arc, Error); +type FromArrowError = (Arc, UnsupportedArrowArrayError); /// The return type to use when implementing TryFrom type FromArrowResult = std::result::Result; @@ -268,7 +302,7 @@ macro_rules! to_byte_buffers { let array: Arc = Arc::new( aa::LargeBinaryArray::try_new(offsets, data.into(), nulls).unwrap(), ); - Err((array, Error::OffsetsInUse)) + Err((array, ArrayInUseError::Offsets.into())) }}; } @@ -284,7 +318,7 @@ impl TryFrom> for ByteBuffers { adt::DataType::LargeUtf8 => { to_byte_buffers!(array, dtype.clone(), aa::LargeStringArray) } - t => Err((array, Error::InvalidBytesType(t))), + t => Err((array, UnsupportedArrowArrayError::InvalidBytesType(t))), } } } @@ -364,7 +398,7 @@ impl NewBufferTraitThing for ByteBuffers { validity: self.validity, }); - Err((boxed, Error::ArrayCreationFailed(e))) + Err((boxed, UnsupportedArrowArrayError::ArrayCreationFailed(e))) } } } @@ -390,11 +424,17 @@ impl TryFrom> for FixedListBuffers { let (field, cvn, array, nulls) = array.into_parts(); if field.is_nullable() { - return Err((array, Error::InvalidNullableListElements)); + return Err(( + array, + UnsupportedArrowArrayError::InvalidNullableListElements, + )); } if cvn < 2 { - return Err((array, Error::InvalidFixedSizeListLength(cvn))); + return Err(( + array, + UnsupportedArrowArrayError::InvalidFixedSizeListLength(cvn), + )); } // SAFETY: We just showed cvn >= 2 && cvn is i32 whicih means @@ -404,7 +444,10 @@ impl TryFrom> for FixedListBuffers { let dtype = field.data_type().clone(); if !dtype.is_primitive() { - return Err((array, Error::UnsupportedFixedSizeListType(dtype))); + return Err(( + array, + UnsupportedArrowArrayError::UnsupportedFixedSizeListType(dtype), + )); } PrimitiveBuffers::try_from(array) @@ -517,7 +560,7 @@ impl NewBufferTraitThing for FixedListBuffers { validity: self.validity, }); - Err((boxed, Error::ArrayCreationFailed(e))) + Err((boxed, UnsupportedArrowArrayError::ArrayCreationFailed(e))) } } } @@ -597,12 +640,18 @@ impl TryFrom> for ListBuffers { let (field, offsets, array, nulls) = array.into_parts(); if field.is_nullable() { - return Err((array, Error::InvalidNullableListElements)); + return Err(( + array, + UnsupportedArrowArrayError::InvalidNullableListElements, + )); } let dtype = field.data_type().clone(); if !dtype.is_primitive() { - return Err((array, Error::UnsupportedFixedSizeListType(dtype))); + return Err(( + array, + UnsupportedArrowArrayError::UnsupportedFixedSizeListType(dtype), + )); } // N.B., I really, really tried to make this a fancy map/map_err @@ -644,7 +693,7 @@ impl TryFrom> for ListBuffers { ) .unwrap(), ); - return Err((array, Error::ArrayInUse)); + return Err((array, ArrayInUseError::Array.into())); } }; @@ -748,7 +797,7 @@ impl NewBufferTraitThing for ListBuffers { validity: self.validity, }); - Err((boxed, Error::ArrayCreationFailed(e))) + Err((boxed, UnsupportedArrowArrayError::ArrayCreationFailed(e))) } } } @@ -791,7 +840,7 @@ macro_rules! to_primitive { vec![], ) .unwrap(); - (aa::make_array(data), Error::ArrayInUse) + (aa::make_array(data), ArrayInUseError::Array.into()) }) }}; } @@ -812,7 +861,10 @@ impl TryFrom> for PrimitiveBuffers { adt::DataType::UInt64 => to_primitive!(array, aa::UInt64Array), adt::DataType::Float32 => to_primitive!(array, aa::Float32Array), adt::DataType::Float64 => to_primitive!(array, aa::Float64Array), - t => Err((array, Error::InvalidPrimitiveType(t))), + t => Err(( + array, + UnsupportedArrowArrayError::InvalidPrimitiveType(t), + )), } } } @@ -882,7 +934,7 @@ impl NewBufferTraitThing for PrimitiveBuffers { validity: self.validity, }); - Err((boxed, Error::ArrayCreationFailed(e))) + Err((boxed, UnsupportedArrowArrayError::ArrayCreationFailed(e))) } } } @@ -936,14 +988,15 @@ impl TryFrom> for Box { }), adt::DataType::Timestamp(_, Some(_)) => { - Err((array, Error::UnsupportedTimeZones)) + Err((array, UnsupportedArrowArrayError::UnsupportedTimeZones)) } adt::DataType::Binary | adt::DataType::List(_) - | adt::DataType::Utf8 => { - Err((array, Error::LargeVariantOnly(dtype))) - } + | adt::DataType::Utf8 => Err(( + array, + UnsupportedArrowArrayError::LargeVariantOnly(dtype), + )), adt::DataType::FixedSizeBinary(_) => { todo!("This can probably be supported.") @@ -966,9 +1019,10 @@ impl TryFrom> for Box { | adt::DataType::Decimal128(_, _) | adt::DataType::Decimal256(_, _) | adt::DataType::Map(_, _) - | adt::DataType::RunEndEncoded(_, _) => { - Err((array, Error::UnsupportedArrowType(dtype))) - } + | adt::DataType::RunEndEncoded(_, _) => Err(( + array, + UnsupportedArrowArrayError::UnsupportedArrowType(dtype), + )), } } } @@ -999,22 +1053,21 @@ impl MutableOrShared { self.shared.as_ref().map(Arc::clone) } - pub fn make_mut(&mut self) -> Result<()> { + pub fn make_mut(&mut self) -> UnsupportedArrowArrayResult<()> { self.validate(); if self.mutable.is_some() { return Ok(()); } - let shared = self.shared.take().unwrap(); - let mutable: FromArrowResult> = - shared.try_into(); + let shared: Arc = self.shared.take().unwrap(); + let maybe_mutable = Box::::try_from(shared); - let ret = if mutable.is_ok() { - self.mutable = mutable.ok(); + let ret = if maybe_mutable.is_ok() { + self.mutable = maybe_mutable.ok(); Ok(()) } else { - let (array, err) = mutable.err().unwrap(); + let (array, err) = maybe_mutable.err().unwrap(); self.shared = Some(array); Err(err) }; @@ -1023,7 +1076,7 @@ impl MutableOrShared { ret } - pub fn make_shared(&mut self) -> Result<()> { + pub fn make_shared(&mut self) -> UnsupportedArrowArrayResult<()> { self.validate(); if self.shared.is_some() { @@ -1056,6 +1109,7 @@ impl MutableOrShared { pub struct BufferEntry { entry: MutableOrShared, + aggregate: Option, } impl BufferEntry { @@ -1134,19 +1188,24 @@ impl BufferEntry { Ok(mutable.validity()) } - fn make_mut(&mut self) -> Result<()> { + fn make_mut(&mut self) -> UnsupportedArrowArrayResult<()> { self.entry.make_mut() } - fn make_shared(&mut self) -> Result<()> { + fn make_shared(&mut self) -> UnsupportedArrowArrayResult<()> { self.entry.make_shared() } + + pub fn aggregate(&self) -> Option<&AggregateFunctionHandle> { + self.aggregate.as_ref() + } } impl From> for BufferEntry { fn from(array: Arc) -> Self { Self { entry: MutableOrShared::new(array), + aggregate: None, } } } @@ -1202,44 +1261,39 @@ impl QueryBuffers { self.buffers.get_mut(key) } - pub fn from_fields(schema: Schema, fields: QueryFields) -> Result { - let conv = ToArrowConverter::strict(); - let mut ret = HashMap::with_capacity(fields.fields.len()); - for (name, field) in fields.fields.into_iter() { - let tdb_field = schema.field(name.clone())?; + pub(crate) fn from_fields( + schema: Schema, + raw: &RawQuery, + fields: QueryFields, + ) -> Result { + let mut buffers = HashMap::with_capacity(fields.fields.len()); + for (name, request_field) in fields.fields.into_iter() { + let query_field = QueryField::get(&schema.context(), raw, &name) + .map_err(|e| Error::Field(name.clone(), e.into()))?; + let array = to_array(request_field, query_field) + .map_err(|e| Error::Field(name.clone(), e))?; - if let QueryField::Buffer(array) = field { - Self::validate_buffer(&tdb_field, &array)?; - ret.insert(name.clone(), array); - continue; - } + buffers.insert(name.clone(), BufferEntry::from(array)); + } - // ToDo: Clean these error conversions up so they clearly indicate - // a failed buffer creation. - let tdb_dtype = tdb_field.datatype()?; - let tdb_cvn = tdb_field.cell_val_num()?; - let tdb_nullable = tdb_field.nullability()?; - let arrow_type = if let Some(dtype) = field.target_type() { - conv.convert_datatype_to( - &tdb_dtype, - &tdb_cvn, - tdb_nullable, - dtype, - ) - } else { - conv.convert_datatype(&tdb_dtype, &tdb_cvn, tdb_nullable) - } - .map_err(|e| Error::ArrowConversionError(name.clone(), e))?; - - let array = alloc_array( - arrow_type, - tdb_nullable, - field.capacity().unwrap(), - )?; - ret.insert(name.clone(), array); + for (name, (function, request_field)) in fields.aggregates.into_iter() { + let handle = AggregateFunctionHandle::new(function)?; + + let query_field = QueryField::get(&schema.context(), raw, &name) + .map_err(|e| Error::Field(name.clone(), e.into()))?; + let array = to_array(request_field, query_field) + .map_err(|e| Error::Field(name.clone(), e))?; + + buffers.insert( + name.to_owned(), + BufferEntry { + entry: MutableOrShared::new(array), + aggregate: Some(handle), + }, + ); } - Ok(Self::new(ret)) + Ok(Self { buffers }) } pub fn is_compatible(&self, other: &Self) -> bool { @@ -1279,27 +1333,47 @@ impl QueryBuffers { } pub fn make_mut(&mut self) -> Result<()> { - for value in self.buffers.values_mut() { - value.make_mut()? + for (name, value) in self.buffers.iter_mut() { + value + .make_mut() + .map_err(|e| Error::Field(name.clone(), e.into()))?; } Ok(()) } pub fn make_shared(&mut self) -> Result<()> { - for value in self.buffers.values_mut() { - value.make_shared()? + for (name, value) in self.buffers.iter_mut() { + value + .make_shared() + .map_err(|e| Error::Field(name.clone(), e.into()))?; } Ok(()) } +} - /// When I get to it, this needs to ensure that the provided array matches - /// the field's TileDB datatype. - fn validate_buffer( - _field: &Field, - _buffer: &Arc, - ) -> Result<()> { - Ok(()) +fn to_array( + field: RequestField, + tiledb_field: QueryField, +) -> FieldResult> { + let conv = ToArrowConverter::strict(); + + let tdb_dtype = tiledb_field.datatype()?; + let tdb_cvn = tiledb_field.cell_val_num()?; + let tdb_nullable = tiledb_field.nullable()?; + + if let RequestField::Buffer(array) = field { + // FIXME: validate data type and nullability + return Ok(array); + } + + let arrow_type = if let Some(dtype) = field.target_type() { + conv.convert_datatype_to(&tdb_dtype, &tdb_cvn, tdb_nullable, dtype) + } else { + conv.convert_datatype(&tdb_dtype, &tdb_cvn, tdb_nullable) } + .map_err(FieldError::TypeMismatch)?; + + alloc_array(arrow_type, tdb_nullable, field.capacity().unwrap()) } /// A small helper for users writing code directly against the TileDB API @@ -1335,7 +1409,7 @@ fn alloc_array( dtype: adt::DataType, nullable: bool, capacity: usize, -) -> Result> { +) -> FieldResult> { let num_cells = calculate_num_cells(dtype.clone(), nullable, capacity)?; match dtype { @@ -1355,7 +1429,7 @@ fn alloc_array( }; Ok(Arc::new( aa::LargeListArray::try_new(field, offsets, values, nulls) - .map_err(Error::ArrayCreationFailed)?, + .map_err(FieldError::BufferAllocation)?, )) } adt::DataType::FixedSizeList(field, cvn) => { @@ -1368,7 +1442,7 @@ fn alloc_array( alloc_array(field.data_type().clone(), false, capacity)?; Ok(Arc::new( aa::FixedSizeListArray::try_new(field, cvn, values, nulls) - .map_err(Error::ArrayCreationFailed)?, + .map_err(FieldError::BufferAllocation)?, )) } adt::DataType::LargeUtf8 => { @@ -1383,7 +1457,7 @@ fn alloc_array( }; Ok(Arc::new( aa::LargeStringArray::try_new(offsets, values.into(), nulls) - .map_err(Error::ArrayCreationFailed)?, + .map_err(FieldError::BufferAllocation)?, )) } adt::DataType::LargeBinary => { @@ -1398,7 +1472,7 @@ fn alloc_array( }; Ok(Arc::new( aa::LargeBinaryArray::try_new(offsets, values.into(), nulls) - .map_err(Error::ArrayCreationFailed)?, + .map_err(FieldError::BufferAllocation)?, )) } _ if dtype.is_primitive() => { @@ -1420,7 +1494,7 @@ fn alloc_array( vec![data.into()], vec![], ) - .map_err(Error::ArrayCreationFailed)?; + .map_err(FieldError::BufferAllocation)?; Ok(aa::make_array(data)) } @@ -1432,7 +1506,7 @@ fn calculate_num_cells( dtype: adt::DataType, nullable: bool, capacity: usize, -) -> Result { +) -> FieldResult { match dtype { adt::DataType::Boolean => { if nullable { @@ -1443,7 +1517,10 @@ fn calculate_num_cells( } adt::DataType::LargeList(ref field) => { if !field.data_type().is_primitive() { - return Err(Error::UnsupportedArrowType(dtype.clone())); + return Err(UnsupportedArrowArrayError::UnsupportedArrowType( + dtype.clone(), + ) + .into()); } // Todo: Figure out a better way to approximate values to offsets ratios @@ -1460,11 +1537,17 @@ fn calculate_num_cells( } adt::DataType::FixedSizeList(ref field, cvn) => { if !field.data_type().is_primitive() { - return Err(Error::UnsupportedArrowType(dtype)); + return Err(UnsupportedArrowArrayError::UnsupportedArrowType( + dtype, + ) + .into()); } if cvn < 2 { - return Err(Error::InvalidFixedSizeListLength(cvn)); + return Err( + UnsupportedArrowArrayError::InvalidFixedSizeListLength(cvn) + .into(), + ); } let cvn = cvn as usize; @@ -1492,18 +1575,23 @@ fn calculate_num_cells( let bytes_per_cell = width + if nullable { 1 } else { 0 }; Ok(capacity / bytes_per_cell) } - _ => Err(Error::UnsupportedArrowType(dtype.clone())), + _ => Err(UnsupportedArrowArrayError::UnsupportedArrowType( + dtype.clone(), + ) + .into()), } } // Private utility functions -fn to_tdb_offsets(offsets: abuf::OffsetBuffer) -> Result { +fn to_tdb_offsets( + offsets: abuf::OffsetBuffer, +) -> ArrayInUseResult { offsets .into_inner() .into_inner() .into_mutable() - .map_err(|_| Error::ArrayInUse) + .map_err(|_| ArrayInUseError::Offsets) } fn to_tdb_validity(nulls: Option) -> Option { diff --git a/tiledb/query-core/src/field.rs b/tiledb/query-core/src/field.rs new file mode 100644 index 00000000..d5e8e64f --- /dev/null +++ b/tiledb/query-core/src/field.rs @@ -0,0 +1,188 @@ +use std::ffi::CString; +use std::ops::Deref; + +use thiserror::Error; +use tiledb_api::context::{Context, ContextBound}; +use tiledb_common::array::{CellValNum, CellValNumError}; +use tiledb_common::datatype::{Datatype, TryFromFFIError as DatatypeError}; + +use crate::RawQuery; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Internal cell val num error: {0}")] + CellValNum(#[from] CellValNumError), + #[error("Internal datatype error: {0}")] + Datatype(#[from] DatatypeError), + #[error("Field name '{0}' error: {1}")] + NameError(String, #[source] std::ffi::NulError), + #[error("libtiledb error: {0}")] + LibTileDB(#[from] tiledb_api::error::Error), +} + +type Result = std::result::Result; + +pub enum RawQueryField { + Owned(Context, *mut ffi::tiledb_query_field_t), +} + +impl ContextBound for RawQueryField { + fn context(&self) -> Context { + let Self::Owned(ref ctx, _) = self; + ctx.clone() + } +} + +impl Deref for RawQueryField { + type Target = *mut ffi::tiledb_query_field_t; + + fn deref(&self) -> &Self::Target { + let Self::Owned(_, ref ffi) = self; + ffi + } +} + +impl Drop for RawQueryField { + fn drop(&mut self) { + let Self::Owned(ref mut ctx, ref mut ffi) = self; + ctx.capi_call(|ctx| unsafe { ffi::tiledb_query_field_free(ctx, ffi) }) + .expect("Internal error dropping `RawQueryField`"); + } +} + +pub struct QueryField { + raw: RawQueryField, +} + +impl QueryField { + pub(crate) fn get( + context: &Context, + query: &RawQuery, + name: &str, + ) -> Result { + let c_query = **query; + let c_name = CString::new(name) + .map_err(|e| Error::NameError(name.to_owned(), e))?; + let mut c_field = out_ptr!(); + context.capi_call(|ctx| unsafe { + ffi::tiledb_query_get_field( + ctx, + c_query, + c_name.as_c_str().as_ptr(), + &mut c_field, + ) + })?; + + let raw = RawQueryField::Owned(context.clone(), c_field); + Ok(Self { raw }) + } + + pub fn datatype(&self) -> Result { + let c_field = *self.raw; + let mut c_datatype = out_ptr!(); + self.context().capi_call(|ctx| unsafe { + ffi::tiledb_field_datatype(ctx, c_field, &mut c_datatype) + })?; + Ok(Datatype::try_from(c_datatype)?) + } + + pub fn cell_val_num(&self) -> Result { + let c_field = *self.raw; + let mut c_cvn = out_ptr!(); + self.context().capi_call(|ctx| unsafe { + ffi::tiledb_field_cell_val_num(ctx, c_field, &mut c_cvn) + })?; + Ok(CellValNum::try_from(c_cvn)?) + } + + pub fn nullable(&self) -> Result { + let c_field = *self.raw; + let mut c_nullable = out_ptr!(); + self.context().capi_call(|ctx| unsafe { + ffi::tiledb_field_get_nullable(ctx, c_field, &mut c_nullable) + })?; + Ok(c_nullable != 0) + } + + pub fn origin(&self) -> Result { + todo!() + } +} + +impl ContextBound for QueryField { + fn context(&self) -> Context { + self.raw.context() + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum QueryFieldOrigin { + Attribute, + Dimension, + Aggregate, +} + +#[derive(Debug, Error)] +pub enum QueryFieldOriginError { + #[error("Invalid discriminant for QueryFieldOriginError: {0}")] + InvalidDiscriminant(u64), +} + +impl From for ffi::tiledb_field_origin_t { + fn from(value: QueryFieldOrigin) -> Self { + match value { + QueryFieldOrigin::Attribute => { + ffi::tiledb_field_origin_t_TILEDB_ATTRIBUTE_FIELD + } + QueryFieldOrigin::Dimension => { + ffi::tiledb_field_origin_t_TILEDB_DIMENSION_FIELD + } + QueryFieldOrigin::Aggregate => { + ffi::tiledb_field_origin_t_TILEDB_AGGREGATE_FIELD + } + } + } +} + +impl TryFrom for QueryFieldOrigin { + type Error = QueryFieldOriginError; + + fn try_from( + value: ffi::tiledb_field_origin_t, + ) -> Result { + match value { + ffi::tiledb_field_origin_t_TILEDB_ATTRIBUTE_FIELD => { + Ok(Self::Attribute) + } + ffi::tiledb_field_origin_t_TILEDB_DIMENSION_FIELD => { + Ok(Self::Dimension) + } + ffi::tiledb_field_origin_t_TILEDB_AGGREGATE_FIELD => { + Ok(Self::Aggregate) + } + _ => Err(QueryFieldOriginError::InvalidDiscriminant(value as u64)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ffi_query_origin() { + for q in [ + QueryFieldOrigin::Attribute, + QueryFieldOrigin::Dimension, + QueryFieldOrigin::Aggregate, + ] + .into_iter() + { + assert_eq!( + q, + QueryFieldOrigin::try_from(ffi::tiledb_field_origin_t::from(q)) + .unwrap() + ); + } + } +} diff --git a/tiledb/query-core/src/fields.rs b/tiledb/query-core/src/fields.rs index 2e03fac0..994a69ec 100644 --- a/tiledb/query-core/src/fields.rs +++ b/tiledb/query-core/src/fields.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use arrow::array as aa; use arrow::datatypes as adt; +use tiledb_api::query::read::aggregate::AggregateFunction; use super::QueryBuilder; @@ -44,6 +45,7 @@ impl QueryField { #[derive(Debug, Default)] pub struct QueryFields { pub fields: HashMap, + pub aggregates: HashMap, } impl QueryFields { @@ -101,6 +103,17 @@ impl QueryFieldsBuilder { self.fields.insert(name, QueryField::WithType(dtype)); self } + + pub fn aggregate( + mut self, + function: AggregateFunction, + name: Option, + buffering: QueryField, + ) -> Self { + let name = name.unwrap_or(function.aggregate_name()); + self.fields.aggregates.insert(name, (function, buffering)); + self + } } pub struct QueryFieldsBuilderForQuery { @@ -167,4 +180,18 @@ impl QueryFieldsBuilderForQuery { ..self } } + + pub fn aggregate( + self, + function: AggregateFunction, + name: Option, + buffering: QueryField, + ) -> Self { + Self { + fields_builder: self + .fields_builder + .aggregate(function, name, buffering), + ..self + } + } } diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index c5a6db87..bd7d849d 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -21,20 +21,21 @@ use subarray::{SubarrayBuilderForQuery, SubarrayData}; pub use buffers::SharedBuffers; +macro_rules! out_ptr { + () => { + unsafe { std::mem::MaybeUninit::zeroed().assume_init() } + }; +} + pub mod arrow; pub mod buffers; +pub mod field; pub mod fields; pub mod subarray; pub type QueryType = tiledb_common::array::Mode; pub type QueryLayout = tiledb_common::array::CellOrder; -macro_rules! out_ptr { - () => { - unsafe { std::mem::MaybeUninit::zeroed().assume_init() } - }; -} - /// Errors related to query creation and execution #[derive(Debug, Error)] pub enum Error { @@ -176,7 +177,7 @@ impl Query { } pub fn submit(&mut self) -> Result { - self.buffers.make_mut()?; + self.buffers.make_mut().map_err(QueryBuffersError::from)?; if matches!(self.query_type, QueryType::Read) { self.buffers.reset_lengths()?; } @@ -245,7 +246,7 @@ impl Query { fields: QueryFields, ) -> Result { let mut tmp_buffers = - QueryBuffers::from_fields(self.array.schema()?, fields)?; + QueryBuffers::from_fields(self.array.schema()?, &self.raw, fields)?; tmp_buffers.make_mut()?; if self.buffers.is_compatible(&tmp_buffers) { std::mem::swap(&mut self.buffers, &mut tmp_buffers); @@ -371,12 +372,14 @@ impl QueryBuilder { self.set_subarray(&raw)?; self.set_query_condition(&raw)?; + let buffers = QueryBuffers::from_fields(schema, &raw, self.fields)?; + Ok(Query { context: self.array.context(), raw, query_type: self.query_type, array: self.array, - buffers: QueryBuffers::from_fields(schema, self.fields)?, + buffers, }) } From 3443d7df05d1709bd8101f4c2e274372df5c9592 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 20 Nov 2024 22:45:49 -0500 Subject: [PATCH 18/42] Add Cells -> RecordBatch conversion --- Cargo.lock | 3 + Cargo.toml | 2 + test-utils/cells/Cargo.toml | 4 ++ test-utils/cells/src/arrow.rs | 116 ++++++++++++++++++++++++++++++++++ test-utils/cells/src/lib.rs | 3 + 5 files changed, 128 insertions(+) create mode 100644 test-utils/cells/src/arrow.rs diff --git a/Cargo.lock b/Cargo.lock index 9264d592..fb11f5e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -528,6 +528,9 @@ dependencies = [ name = "cells" version = "0.1.0" dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", "paste", "proptest", "strategy-ext", diff --git a/Cargo.toml b/Cargo.toml index 40b89b9d..ff60e5cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ version = "0.1.0" anyhow = "1.0" armerge = "2" arrow = { version = "52.0.0", features = ["prettyprint"] } +arrow-array = { version = "52.0.0" } +arrow-buffer = { version = "52.0.0" } arrow-schema = { version = "52.0.0" } bindgen = "0.70" cells = { path = "test-utils/cells", version = "0.1.0" } diff --git a/test-utils/cells/Cargo.toml b/test-utils/cells/Cargo.toml index 4d29ba39..ecfba2fc 100644 --- a/test-utils/cells/Cargo.toml +++ b/test-utils/cells/Cargo.toml @@ -5,6 +5,9 @@ rust-version.workspace = true version.workspace = true [dependencies] +arrow-array = { workspace = true, optional = true } +arrow-buffer = { workspace = true, optional = true } +arrow-schema = { workspace = true, optional = true } paste = { workspace = true } proptest = { workspace = true } strategy-ext = { workspace = true } @@ -18,4 +21,5 @@ tiledb-pod = { workspace = true, features = ["proptest-strategies"] } [features] default = [] +arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"] proptest-strategies = ["dep:tiledb-proptest-config", "tiledb-common/proptest-strategies", "tiledb-pod/proptest-strategies"] diff --git a/test-utils/cells/src/arrow.rs b/test-utils/cells/src/arrow.rs new file mode 100644 index 00000000..5823f90b --- /dev/null +++ b/test-utils/cells/src/arrow.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; + +use arrow_array::types::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::{ + Array, ArrowPrimitiveType, LargeListArray, PrimitiveArray, RecordBatch, +}; +use arrow_buffer::{ArrowNativeType, OffsetBuffer}; +use arrow_schema::{DataType, Field, Fields, Schema}; + +use crate::{Cells, FieldData}; + +pub fn to_record_batch(cells: &Cells) -> RecordBatch { + let (fnames, columns) = cells + .fields() + .iter() + .map(|(fname, fdata)| (fname, to_column(fdata))) + .collect::<(Vec<_>, Vec<_>)>(); + + let fields = fnames + .into_iter() + .zip(columns.iter()) + .map(|(fname, column)| { + Field::new( + fname.to_owned(), + column.data_type().clone(), + column.null_count() > 0, + ) + }) + .collect::(); + + let schema = Schema { + fields, + metadata: Default::default(), + }; + + RecordBatch::try_new(schema.into(), columns).unwrap() +} + +fn to_column(fdata: &FieldData) -> Arc { + match fdata { + FieldData::Int8(cells) => to_column_primitive::(cells), + FieldData::Int16(cells) => to_column_primitive::(cells), + FieldData::Int32(cells) => to_column_primitive::(cells), + FieldData::Int64(cells) => to_column_primitive::(cells), + FieldData::UInt8(cells) => to_column_primitive::(cells), + FieldData::UInt16(cells) => { + to_column_primitive::(cells) + } + FieldData::UInt32(cells) => { + to_column_primitive::(cells) + } + FieldData::UInt64(cells) => { + to_column_primitive::(cells) + } + FieldData::Float32(cells) => { + to_column_primitive::(cells) + } + FieldData::Float64(cells) => { + to_column_primitive::(cells) + } + FieldData::VecUInt8(cells) => to_column_list::(cells), + FieldData::VecUInt16(cells) => to_column_list::(cells), + FieldData::VecUInt32(cells) => to_column_list::(cells), + FieldData::VecUInt64(cells) => to_column_list::(cells), + FieldData::VecInt8(cells) => to_column_list::(cells), + FieldData::VecInt16(cells) => to_column_list::(cells), + FieldData::VecInt32(cells) => to_column_list::(cells), + FieldData::VecInt64(cells) => to_column_list::(cells), + FieldData::VecFloat32(cells) => { + to_column_list::(cells) + } + FieldData::VecFloat64(cells) => { + to_column_list::(cells) + } + } +} + +fn to_column_primitive(cells: &[T]) -> Arc +where + T: ArrowNativeType, + A: ArrowPrimitiveType, + PrimitiveArray: From>, +{ + Arc::new(PrimitiveArray::::from(cells.to_vec())) +} + +fn to_column_list(cells: &[Vec]) -> Arc +where + T: ArrowNativeType, + A: ArrowPrimitiveType, + PrimitiveArray: From>, +{ + let offsets = + OffsetBuffer::::from_lengths(cells.iter().map(|c| c.len())); + let values = PrimitiveArray::::from( + cells.iter().cloned().flatten().collect::>(), + ); + + Arc::new( + LargeListArray::try_new( + Field::new( + "unused", + DataType::new_large_list(values.data_type().clone(), false), + false, + ) + .into(), + offsets, + Arc::new(values), + None, + ) + .unwrap(), + ) +} diff --git a/test-utils/cells/src/lib.rs b/test-utils/cells/src/lib.rs index 295c6953..be25d9e1 100644 --- a/test-utils/cells/src/lib.rs +++ b/test-utils/cells/src/lib.rs @@ -1,6 +1,9 @@ pub mod field; pub mod write; +#[cfg(feature = "arrow")] +pub mod arrow; + #[cfg(any(test, feature = "proptest-strategies"))] pub mod strategy; From 170cb84ca0457d263bed25cf5a154f50f802e78b Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 20 Nov 2024 22:51:38 -0500 Subject: [PATCH 19/42] Fix cargo check --tests for cells crate --- test-utils/cells/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test-utils/cells/Cargo.toml b/test-utils/cells/Cargo.toml index ecfba2fc..1bfd53fc 100644 --- a/test-utils/cells/Cargo.toml +++ b/test-utils/cells/Cargo.toml @@ -18,6 +18,7 @@ tiledb-proptest-config = { workspace = true, optional = true } [dev-dependencies] tiledb-common = { workspace = true, features = ["proptest-strategies"] } tiledb-pod = { workspace = true, features = ["proptest-strategies"] } +tiledb-proptest-config = { workspace = true } [features] default = [] From 566c2dde90b7823fcb1ce780df0bc2d7e6e348c3 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 12:53:41 -0500 Subject: [PATCH 20/42] arrow.rs => datatype/mod.rs --- tiledb/query-core/src/buffers.rs | 6 +++--- tiledb/query-core/src/{arrow.rs => datatype/mod.rs} | 0 tiledb/query-core/src/lib.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename tiledb/query-core/src/{arrow.rs => datatype/mod.rs} (100%) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index 8271d215..f2a96513 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -16,7 +16,7 @@ use tiledb_api::query::read::aggregate::AggregateFunctionHandle; use tiledb_api::ContextBound; use tiledb_common::array::CellValNum; -use super::arrow::ToArrowConverter; +use super::datatype::ToArrowConverter; use super::field::QueryField; use super::fields::{QueryField as RequestField, QueryFields}; use super::RawQuery; @@ -26,7 +26,7 @@ const AVERAGE_STRING_LENGTH: usize = 64; #[derive(Debug, Error)] pub enum Error { #[error("Error converting to Arrow for field '{0}': {1}")] - ArrowConversionError(String, super::arrow::Error), + ArrowConversionError(String, crate::datatype::Error), #[error("Failed to convert Arrow Array for field '{0}': {1}")] FailedConversionFromArrow(String, Box), #[error("Failed to add field '{0}' to query: {1}")] @@ -65,7 +65,7 @@ pub enum FieldError { #[error("Error reading query field: {0}")] QueryField(#[from] crate::field::Error), #[error("Type mismatch for requested field: {0}")] - TypeMismatch(crate::arrow::Error), + TypeMismatch(crate::datatype::Error), #[error("Failed to allocate buffer: {0}")] BufferAllocation(ArrowError), #[error("Unsupported arrow array: {0}")] diff --git a/tiledb/query-core/src/arrow.rs b/tiledb/query-core/src/datatype/mod.rs similarity index 100% rename from tiledb/query-core/src/arrow.rs rename to tiledb/query-core/src/datatype/mod.rs diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index bd7d849d..fd38dd45 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -27,8 +27,8 @@ macro_rules! out_ptr { }; } -pub mod arrow; pub mod buffers; +pub mod datatype; pub mod field; pub mod fields; pub mod subarray; From 8cb083f2f241afed8945400bfc18a28d718cfe29 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 12:54:22 -0500 Subject: [PATCH 21/42] Use DimensionKey for SubarrayBuilder --- tiledb/query-core/src/lib.rs | 2 +- tiledb/query-core/src/subarray.rs | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index fd38dd45..25f53b95 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -493,7 +493,7 @@ impl QueryBuilder { let raw_subarray = self.alloc_subarray()?; for (key, ranges) in subarray_data.iter() { for range in ranges { - self.set_subarray_range(*raw_subarray, &key.into(), range)?; + self.set_subarray_range(*raw_subarray, &key, range)?; } } diff --git a/tiledb/query-core/src/subarray.rs b/tiledb/query-core/src/subarray.rs index 1b5fff7d..e2fd4ac7 100644 --- a/tiledb/query-core/src/subarray.rs +++ b/tiledb/query-core/src/subarray.rs @@ -1,10 +1,11 @@ use std::collections::HashMap; +use tiledb_common::array::dimension::DimensionKey; use tiledb_common::range::Range; use super::QueryBuilder; -pub type SubarrayData = HashMap>; +pub type SubarrayData = HashMap>; #[derive(Default)] pub struct SubarrayBuilder { @@ -16,13 +17,13 @@ impl SubarrayBuilder { Self::default() } - pub fn add_range>( + pub fn add_range, IntoRange: Into>( mut self, - dimension: &str, + dimension: IntoKey, range: IntoRange, ) -> Self { self.subarray - .entry(dimension.to_string()) + .entry(dimension.into()) .or_default() .push(range.into()); self @@ -51,9 +52,9 @@ impl SubarrayBuilderForQuery { .with_subarray_data(self.subarray_builder.build()) } - pub fn add_range>( + pub fn add_range, IntoRange: Into>( mut self, - dimension: &str, + dimension: IntoKey, range: IntoRange, ) -> Self { self.subarray_builder = From 5228a43709bff56a17c14d7522fada1820069329 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 12:54:58 -0500 Subject: [PATCH 22/42] SharedBuffers is type alias instead of wrapper --- tiledb/query-core/src/buffers.rs | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index f2a96513..8118a4bb 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -1376,34 +1376,7 @@ fn to_array( alloc_array(arrow_type, tdb_nullable, field.capacity().unwrap()) } -/// A small helper for users writing code directly against the TileDB API -/// -/// This struct is freely convertible to and from a HashMap of Arrow arrays. -#[derive(Clone)] -pub struct SharedBuffers { - buffers: HashMap>, -} - -impl SharedBuffers { - pub fn get(&self, key: &str) -> Option<&T> - where - T: Any, - { - self.buffers.get(key)?.as_any().downcast_ref::() - } -} - -impl From>> for SharedBuffers { - fn from(buffers: HashMap>) -> Self { - Self { buffers } - } -} - -impl From for HashMap> { - fn from(buffers: SharedBuffers) -> Self { - buffers.buffers - } -} +pub type SharedBuffers = HashMap>; fn alloc_array( dtype: adt::DataType, From 891a1bfe3939dac9bb8bb49baf0d4b3395765d83 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 12:55:27 -0500 Subject: [PATCH 23/42] Query::subarray --- tiledb/query-core/src/lib.rs | 49 +++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 25f53b95..159892b7 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -13,6 +13,8 @@ use tiledb_api::context::{CApiInterface, Context, ContextBound}; use tiledb_api::error::Error as TileDBError; use tiledb_api::key::LookupKey; use tiledb_api::query::conditions::QueryConditionExpr; +use tiledb_api::query::subarray::{RawSubarray, Subarray}; +use tiledb_api::Result as TileDBResult; use tiledb_common::range::{Range, SingleValueRange, VarValueRange}; use buffers::{Error as QueryBuffersError, QueryBuffers}; @@ -134,26 +136,6 @@ impl Drop for RawQuery { } } -pub(crate) enum RawSubarray { - Owned(*mut ffi::tiledb_subarray_t), -} - -impl Deref for RawSubarray { - type Target = *mut ffi::tiledb_subarray_t; - fn deref(&self) -> &Self::Target { - match *self { - RawSubarray::Owned(ref ffi) => ffi, - } - } -} - -impl Drop for RawSubarray { - fn drop(&mut self) { - let RawSubarray::Owned(ref mut ffi) = *self; - unsafe { ffi::tiledb_subarray_free(ffi) }; - } -} - /// The main Query interface /// /// This struct is responsible for executing queries against TileDB arrays. @@ -176,6 +158,33 @@ impl Query { *self.raw } + /// Get the subarray for this query. + /// + /// The returned [Subarray] is tied to the lifetime of the Query. + /// + /// ```compile_fail,E0505 + /// # use tiledb_api::query::{Query, QueryBase, Subarray}; + /// fn invalid_use(query: QueryBase) { + /// let subarray = query.subarray().unwrap(); + /// drop(query); + /// /// The subarray should not be usable after the query is dropped. + /// let _ = subarray.ranges(); + /// } + /// ``` + pub fn subarray(&self) -> TileDBResult { + let ctx = self.context(); + let c_query = *self.raw; + let mut c_subarray: *mut ffi::tiledb_subarray_t = out_ptr!(); + ctx.capi_call(|ctx| unsafe { + ffi::tiledb_query_get_subarray_t(ctx, c_query, &mut c_subarray) + })?; + + Ok(Subarray::new( + self.array.schema()?, + RawSubarray::Owned(c_subarray), + )) + } + pub fn submit(&mut self) -> Result { self.buffers.make_mut().map_err(QueryBuffersError::from)?; if matches!(self.query_type, QueryType::Read) { From 089c59aedea59f585072f34edcd8adaee3cf8181 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 12:56:12 -0500 Subject: [PATCH 24/42] Should be part of DimensionKey commit --- tiledb/common/src/array/dimension.rs | 2 ++ tiledb/common/src/key.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tiledb/common/src/array/dimension.rs b/tiledb/common/src/array/dimension.rs index 82f391fc..5d618ad9 100644 --- a/tiledb/common/src/array/dimension.rs +++ b/tiledb/common/src/array/dimension.rs @@ -10,6 +10,8 @@ use crate::array::CellValNum; use crate::datatype::{Datatype, Error as DatatypeError}; use crate::range::SingleValueRange; +pub use crate::key::LookupKey as DimensionKey; + #[derive(Clone, Debug, Error)] pub enum Error { #[error("Invalid datatype: {0}")] diff --git a/tiledb/common/src/key.rs b/tiledb/common/src/key.rs index 9ad4d09c..3a097dab 100644 --- a/tiledb/common/src/key.rs +++ b/tiledb/common/src/key.rs @@ -1,4 +1,4 @@ -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum LookupKey { Index(usize), Name(String), From c1caf7d093e74809cce0a12f462c80b4a447ba50 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 12:58:32 -0500 Subject: [PATCH 25/42] Add query_roundtrip tests, not passing yet --- Cargo.lock | 5 + test-utils/cells/Cargo.toml | 3 +- test-utils/cells/src/arrow.rs | 156 ++++++++- test-utils/cells/src/field.rs | 7 + tiledb/api/src/query/subarray.rs | 4 +- tiledb/query-core/Cargo.toml | 8 +- tiledb/query-core/src/lib.rs | 33 ++ tiledb/query-core/src/tests.rs | 559 +++++++++++++++++++++++++++++++ 8 files changed, 757 insertions(+), 18 deletions(-) create mode 100644 tiledb/query-core/src/tests.rs diff --git a/Cargo.lock b/Cargo.lock index fb11f5e0..99de8292 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -531,6 +531,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-schema", + "itertools 0.12.1", "paste", "proptest", "strategy-ext", @@ -1885,14 +1886,18 @@ dependencies = [ name = "tiledb-query-core" version = "0.1.0" dependencies = [ + "anyhow", "arrow", + "cells", "itertools 0.12.1", + "proptest", "thiserror", "tiledb-api", "tiledb-common", "tiledb-pod", "tiledb-sys", "tiledb-sys-cfg", + "uri", ] [[package]] diff --git a/test-utils/cells/Cargo.toml b/test-utils/cells/Cargo.toml index 1bfd53fc..af945502 100644 --- a/test-utils/cells/Cargo.toml +++ b/test-utils/cells/Cargo.toml @@ -8,6 +8,7 @@ version.workspace = true arrow-array = { workspace = true, optional = true } arrow-buffer = { workspace = true, optional = true } arrow-schema = { workspace = true, optional = true } +itertools = { workspace = true, optional = true } paste = { workspace = true } proptest = { workspace = true } strategy-ext = { workspace = true } @@ -22,5 +23,5 @@ tiledb-proptest-config = { workspace = true } [features] default = [] -arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"] +arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema", "dep:itertools"] proptest-strategies = ["dep:tiledb-proptest-config", "tiledb-common/proptest-strategies", "tiledb-pod/proptest-strategies"] diff --git a/test-utils/cells/src/arrow.rs b/test-utils/cells/src/arrow.rs index 5823f90b..e2734b95 100644 --- a/test-utils/cells/src/arrow.rs +++ b/test-utils/cells/src/arrow.rs @@ -1,16 +1,17 @@ +use itertools::Itertools; +use std::collections::HashMap; use std::sync::Arc; -use arrow_array::types::{ - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; +use arrow_array::cast::downcast_array; +use arrow_array::types::*; use arrow_array::{ - Array, ArrowPrimitiveType, LargeListArray, PrimitiveArray, RecordBatch, + downcast_primitive_array, Array, ArrowPrimitiveType, LargeBinaryArray, + LargeListArray, LargeStringArray, PrimitiveArray, RecordBatch, }; use arrow_buffer::{ArrowNativeType, OffsetBuffer}; use arrow_schema::{DataType, Field, Fields, Schema}; -use crate::{Cells, FieldData}; +use crate::{typed_field_data_go, Cells, FieldData}; pub fn to_record_batch(cells: &Cells) -> RecordBatch { let (fnames, columns) = cells @@ -39,7 +40,18 @@ pub fn to_record_batch(cells: &Cells) -> RecordBatch { RecordBatch::try_new(schema.into(), columns).unwrap() } -fn to_column(fdata: &FieldData) -> Arc { +pub fn from_record_batch(batch: &RecordBatch) -> Option { + batch + .schema() + .fields + .iter() + .zip_eq(batch.columns().iter()) + .map(|(f, c)| from_column(c).map(|fdata| (f.name().to_owned(), fdata))) + .collect::>>() + .map(Cells::new) +} + +pub fn to_column(fdata: &FieldData) -> Arc { match fdata { FieldData::Int8(cells) => to_column_primitive::(cells), FieldData::Int16(cells) => to_column_primitive::(cells), @@ -101,12 +113,7 @@ where Arc::new( LargeListArray::try_new( - Field::new( - "unused", - DataType::new_large_list(values.data_type().clone(), false), - false, - ) - .into(), + Field::new("unused", values.data_type().clone(), false).into(), offsets, Arc::new(values), None, @@ -114,3 +121,126 @@ where .unwrap(), ) } + +pub fn from_column(column: &dyn Array) -> Option { + downcast_primitive_array!( + column => { + column.maybe_to_field_data() + }, + DataType::LargeUtf8 => { + let column = downcast_array::(column); + column.maybe_to_field_data() + } + DataType::LargeBinary => { + let column = downcast_array::(column); + column.maybe_to_field_data() + } + DataType::LargeList(_) => { + let column = downcast_array::(column); + column.maybe_to_field_data() + }, + _ => None + ) +} + +trait MaybeToFieldData { + fn maybe_to_field_data(&self) -> Option; +} + +macro_rules! to_field_data { + ($($primitive:ty),+) => { + $( + impl MaybeToFieldData for PrimitiveArray<$primitive> { + fn maybe_to_field_data(&self) -> Option { + Some(self.values().to_vec().into()) + } + } + )+ + }; +} + +macro_rules! not_to_field_data { + ($($array:ty),+) => { + $( + impl MaybeToFieldData for PrimitiveArray<$array> { + fn maybe_to_field_data(&self) -> Option { + None + } + } + )+ + }; +} + +to_field_data!( + Int8Type, + Int16Type, + Int32Type, + Int64Type, + UInt8Type, + UInt16Type, + UInt32Type, + UInt64Type, + Float32Type, + Float64Type, + TimestampSecondType, + TimestampMillisecondType, + TimestampMicrosecondType, + TimestampNanosecondType, + Time64MicrosecondType, + Time64NanosecondType, + DurationSecondType, + DurationMillisecondType, + DurationMicrosecondType, + DurationNanosecondType, + Date64Type +); + +not_to_field_data!( + Float16Type, + Time32SecondType, + Time32MillisecondType, + Date32Type, + Decimal128Type, + Decimal256Type, + IntervalYearMonthType, + IntervalDayTimeType, + IntervalMonthDayNanoType +); + +impl MaybeToFieldData for LargeStringArray { + fn maybe_to_field_data(&self) -> Option { + self.iter() + .map(|s| s.map(|s| s.bytes().collect::>())) + .collect::>>() + .map(|v| v.into()) + } +} + +impl MaybeToFieldData for LargeBinaryArray { + fn maybe_to_field_data(&self) -> Option { + self.iter() + .map(|s| s.map(|s| s.to_vec())) + .collect::>>() + .map(|v| v.into()) + } +} + +impl MaybeToFieldData for LargeListArray { + fn maybe_to_field_data(&self) -> Option { + typed_field_data_go!( + from_column(self.values())?, + _DT, + _values, + { + Some( + self.offsets() + .windows(2) + .map(|w| _values[w[0] as usize..w[1] as usize].to_vec()) + .collect::>() + .into(), + ) + }, + None + ) + } +} diff --git a/test-utils/cells/src/field.rs b/test-utils/cells/src/field.rs index dba48075..41233409 100644 --- a/test-utils/cells/src/field.rs +++ b/test-utils/cells/src/field.rs @@ -388,6 +388,13 @@ impl FieldData { } } +#[cfg(feature = "arrow")] +impl FieldData { + pub fn to_arrow(&self) -> std::sync::Arc { + crate::arrow::to_column(self) + } +} + impl BitsEq for FieldData { fn bits_eq(&self, other: &Self) -> bool { typed_field_data_cmp!( diff --git a/tiledb/api/src/query/subarray.rs b/tiledb/api/src/query/subarray.rs index 84a97589..b55642fd 100644 --- a/tiledb/api/src/query/subarray.rs +++ b/tiledb/api/src/query/subarray.rs @@ -16,7 +16,7 @@ use tiledb_common::{ physical_type_go, single_value_range_go, var_value_range_go, }; -pub(crate) enum RawSubarray { +pub enum RawSubarray { Owned(*mut ffi::tiledb_subarray_t), } @@ -53,7 +53,7 @@ impl Subarray<'_> { *self.raw } - pub(crate) fn new(schema: Schema, raw: RawSubarray) -> Self { + pub fn new(schema: Schema, raw: RawSubarray) -> Self { Subarray { schema, raw, diff --git a/tiledb/query-core/Cargo.toml b/tiledb/query-core/Cargo.toml index d09fac6c..5a04cf2c 100644 --- a/tiledb/query-core/Cargo.toml +++ b/tiledb/query-core/Cargo.toml @@ -12,9 +12,13 @@ tiledb-common = { workspace = true } tiledb-sys = { workspace = true } [dev-dependencies] +anyhow = { workspace = true } +cells = { workspace = true, features = ["arrow", "proptest-strategies"] } itertools = { workspace = true } -tiledb-api = { workspace = true, features = ["pod", "raw"] } -tiledb-pod = { workspace = true } +proptest = { workspace = true } +tiledb-api = { workspace = true, features = ["pod", "proptest-strategies", "raw"] } +tiledb-pod = { workspace = true, features = ["proptest-strategies"] } +uri = { workspace = true } [build-dependencies] tiledb-sys-cfg = { workspace = true } diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 159892b7..67e4f03b 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -3,7 +3,11 @@ extern crate tiledb_sys as ffi; use std::collections::HashMap; use std::ops::Deref; +use std::sync::Arc; +use arrow::datatypes::{Field as ArrowField, Schema as ArrowSchema}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; use thiserror::Error; use tiledb_common::{single_value_range_go, var_value_range_go}; @@ -55,6 +59,8 @@ pub enum Error { ), #[error("Encountered internal libtiledb error: {0}")] TileDBError(#[from] TileDBError), + #[error("Internal error constructing RecordBatch: {0}")] + RecordBatch(ArrowError), } impl From for TileDBError { @@ -69,6 +75,7 @@ type Result = std::result::Result; /// /// Note that BuffersTooSmall is a Rust invention. But given that we never /// attempt to translate this status object back into a capi value its fine. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum QueryStatus { Uninitialized, Initialized, @@ -247,6 +254,29 @@ impl Query { Ok(ret.into()) } + pub fn records(&mut self) -> Result { + let (fields, columns) = self + .buffers()? + .into_iter() + .map(|(fname, fdata)| { + let field = ArrowField::new( + fname, + fdata.data_type().clone(), + fdata.is_nullable(), + ); + (field, fdata) + }) + .collect::<(Vec<_>, Vec<_>)>(); + + let schema = ArrowSchema { + fields: fields.into(), + metadata: Default::default(), + }; + + RecordBatch::try_new(Arc::new(schema), columns) + .map_err(Error::RecordBatch) + } + /// Replace this queries buffers with a new set specified by fields /// /// This can be used to reallocate buffers with a larger capacity. @@ -619,3 +649,6 @@ impl QueryBuilder { Ok(()) } } + +#[cfg(test)] +mod tests; diff --git a/tiledb/query-core/src/tests.rs b/tiledb/query-core/src/tests.rs new file mode 100644 index 00000000..0ce0527f --- /dev/null +++ b/tiledb/query-core/src/tests.rs @@ -0,0 +1,559 @@ +use std::rc::Rc; + +use cells::write::strategy::{WriteParameters, WriteSequenceParameters}; +use cells::write::{ + DenseWriteInput, SparseWriteInput, WriteInput, WriteSequence, +}; +use cells::{self, Cells}; +use proptest::prelude::*; +use tiledb_api::array::{Array, ArrayOpener}; +use tiledb_api::query::strategy::query_write_schema_requirements; +use tiledb_api::Factory; +use tiledb_common::array::{ArrayType, Mode}; +use tiledb_common::range::NonEmptyDomain; +use tiledb_pod::array::SchemaData; +use uri::TestArrayUri; + +use super::*; + +#[test] +fn query_roundtrip() -> anyhow::Result<()> { + let ctx = Context::new()?; + + let schema_req = query_write_schema_requirements(None); + + let strategy = + any_with::(Rc::new(schema_req)).prop_flat_map(|schema| { + let schema = Rc::new(schema); + ( + Just(Rc::clone(&schema)), + any_with::(WriteParameters::default_for(schema)) + .prop_map(WriteSequence::from), + ) + }); + + proptest!(|((schema_spec, write_sequence) in strategy)| { + do_query_roundtrip(&ctx, schema_spec, write_sequence) + .expect("Error in query round trip") + }); + + Ok(()) +} + +#[test] +fn query_roundtrip_accumulated() -> anyhow::Result<()> { + let ctx = Context::new()?; + + let schema_req = query_write_schema_requirements(None); + + let strategy = + any_with::(Rc::new(schema_req)).prop_flat_map(|schema| { + let schema = Rc::new(schema); + ( + Just(Rc::clone(&schema)), + any_with::( + WriteSequenceParameters::default_for(schema), + ), + ) + }); + + proptest!(|((schema_spec, write_sequence) in strategy)| { + do_query_roundtrip(&ctx, schema_spec, write_sequence) + .expect("Error in query round trip"); + }); + + Ok(()) +} + +struct DenseCellsAccumulator { + // TODO: implement accepting more than one write for dense write sequence + write: Option, +} + +impl DenseCellsAccumulator { + pub fn new(_: &SchemaData) -> Self { + DenseCellsAccumulator { write: None } + } + + pub fn cells(&self) -> &Cells { + // will not be called until first cells are written + &self.write.as_ref().unwrap().data + } + + pub fn accumulate(&mut self, write: DenseWriteInput) { + if self.write.is_some() { + unimplemented!() + } + self.write = Some(write) + } +} + +struct SparseCellsAccumulator { + cells: Option, + dedup_keys: Option>, +} + +impl SparseCellsAccumulator { + pub fn new(schema: &SchemaData) -> Self { + let dedup_keys = if schema.allow_duplicates.unwrap_or(false) { + None + } else { + Some( + schema + .domain + .dimension + .iter() + .map(|d| d.name.clone()) + .collect::>(), + ) + }; + SparseCellsAccumulator { + cells: None, + dedup_keys, + } + } + + pub fn cells(&self) -> &Cells { + // will not be called until first cells arrive + self.cells.as_ref().unwrap() + } + + /// Update state representing what we expect to see in the array. + /// For a sparse array this means adding this write's coordinates, + /// overwriting the old coordinates if they overlap. + pub fn accumulate(&mut self, mut write: SparseWriteInput) { + if let Some(cells) = self.cells.take() { + write.data.extend(cells); + if let Some(dedup_keys) = self.dedup_keys.as_ref() { + self.cells = Some(write.data.dedup(dedup_keys)); + } else { + self.cells = Some(write.data); + } + } else { + self.cells = Some(write.data); + } + } +} + +enum CellsAccumulator { + Dense(DenseCellsAccumulator), + Sparse(SparseCellsAccumulator), +} + +impl CellsAccumulator { + pub fn new(schema: &SchemaData) -> Self { + match schema.array_type { + ArrayType::Dense => Self::Dense(DenseCellsAccumulator::new(schema)), + ArrayType::Sparse => { + Self::Sparse(SparseCellsAccumulator::new(schema)) + } + } + } + + pub fn cells(&self) -> &Cells { + match self { + Self::Dense(ref d) => d.cells(), + Self::Sparse(ref s) => s.cells(), + } + } + + pub fn accumulate(&mut self, write: WriteInput) { + match write { + WriteInput::Sparse(w) => { + let Self::Sparse(ref mut sparse) = self else { + unreachable!() + }; + sparse.accumulate(w) + } + WriteInput::Dense(w) => { + let Self::Dense(ref mut dense) = self else { + unreachable!() + }; + dense.accumulate(w) + } + } + } +} + +trait BuildReadQuery { + fn read_same_fields(&self, array: Array) -> anyhow::Result; +} + +impl BuildReadQuery for Cells { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + Ok(self + .fields() + .keys() + .fold( + QueryBuilder::new(array, QueryType::Read).start_fields(), + |b, k| b.field(k), + ) + .end_fields()) + } +} + +impl BuildReadQuery for DenseCellsAccumulator { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + self.write.as_ref().unwrap().read_same_fields(array) + } +} + +impl BuildReadQuery for SparseCellsAccumulator { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + self.cells.as_ref().unwrap().read_same_fields(array) + } +} + +impl BuildReadQuery for CellsAccumulator { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + match self { + Self::Dense(ref d) => d.read_same_fields(array), + Self::Sparse(ref s) => s.read_same_fields(array), + } + } +} + +impl BuildReadQuery for DenseWriteInput { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + Ok(self + .subarray + .iter() + .enumerate() + .fold( + self.data.read_same_fields(array)?.start_subarray(), + |b, (d, r)| b.add_range(d, r.clone()), + ) + .end_subarray()) + } +} + +impl BuildReadQuery for SparseWriteInput { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + self.data.read_same_fields(array) + } +} + +impl BuildReadQuery for WriteInput { + fn read_same_fields(&self, array: Array) -> anyhow::Result { + match self { + Self::Dense(ref d) => d.read_same_fields(array), + Self::Sparse(ref s) => s.read_same_fields(array), + } + } +} + +trait BuildWriteQuery { + fn write_query_builder(&self, array: Array) + -> anyhow::Result; +} + +impl BuildWriteQuery for Cells { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + Ok(self + .fields() + .iter() + .fold( + QueryBuilder::new(array, QueryType::Write).start_fields(), + |b, (k, v)| b.field_with_buffer(k, v.to_arrow()), + ) + .end_fields()) + } +} + +/* +impl BuildWriteQuery for DenseCellsAccumulator { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + self.write.as_ref().map(|w| w.write_query_ + todo!() + } +} + +impl BuildWriteQuery for SparseCellsAccumulator { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + todo!() + } +} + +impl BuildWriteQuery for CellsAccumulator { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + match self { + Self::Dense(ref d) => d.write_query_builder(array), + Self::Sparse(ref s) => s.write_query_builder(array), + } + } +} +*/ + +impl BuildWriteQuery for DenseWriteInput { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + Ok(self + .subarray + .iter() + .enumerate() + .fold( + self.data.write_query_builder(array)?.start_subarray(), + |b, (d, r)| b.add_range(d, r.clone()), + ) + .end_subarray()) + } +} + +impl BuildWriteQuery for SparseWriteInput { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + self.data.write_query_builder(array) + } +} + +impl BuildWriteQuery for WriteInput { + fn write_query_builder( + &self, + array: Array, + ) -> anyhow::Result { + match self { + Self::Dense(ref d) => d.write_query_builder(array), + Self::Sparse(ref s) => s.write_query_builder(array), + } + } +} + +fn do_query_roundtrip( + ctx: &Context, + schema_spec: Rc, + write_sequence: WriteSequence, +) -> anyhow::Result<()> { + let test_uri = uri::get_uri_generator()?; + let uri = test_uri.with_path("array")?; + + let schema_in = schema_spec + .create(ctx) + .expect("Error constructing arbitrary schema"); + Array::create(ctx, &uri, schema_in).expect("Error creating array"); + + let mut accumulated = AccumulatedArray::new(&schema_spec); + + /* + * Results do not come back in a defined order, so we must sort and + * compare. Writes currently have to write all fields. + */ + let sort_keys = match write_sequence { + WriteSequence::Dense(_) => schema_spec + .attributes + .iter() + .map(|f| f.name.clone()) + .collect::>(), + WriteSequence::Sparse(_) => schema_spec + .fields() + .map(|f| f.name().to_owned()) + .collect::>(), + }; + + for write in write_sequence { + apply_write(ctx, &uri, write, &mut accumulated, &sort_keys)?; + } + Ok(()) +} + +struct AccumulatedArray { + domain: Option, + cells: CellsAccumulator, +} + +impl AccumulatedArray { + pub fn new(schema: &SchemaData) -> Self { + Self { + domain: None, + cells: CellsAccumulator::new(schema), + } + } +} + +fn apply_write( + ctx: &Context, + uri: &str, + write: WriteInput, + accumulated_array: &mut AccumulatedArray, + cmp_sort_keys: &[String], +) -> anyhow::Result<()> { + /* write data and preserve ranges for sanity check */ + let write_ranges = { + let array = + Array::open(ctx, &uri, Mode::Write).expect("Error opening array"); + + let mut write_query = write + .write_query_builder(array) + .expect("Error building write query") + .build() + .unwrap(); + write_query.submit().expect("Error running write query"); + + let write_ranges = if let Some(ranges) = write.subarray() { + let generic_ranges = ranges + .iter() + .cloned() + .map(|r| vec![r]) + .collect::>>(); + assert_eq!( + generic_ranges, + write_query.subarray().unwrap().ranges().unwrap() + ); + Some(generic_ranges) + } else { + None + }; + + let _ = write_query + .finalize() + .expect("Error finalizing write query"); + + write_ranges + }; + + if write.cells().is_empty() { + // in this case, writing and finalizing does not create a new fragment + // TODO + return Ok(()); + } + + /* NB: results are not read back in a defined order, so we must sort and compare */ + + let mut array = ArrayOpener::new(ctx, &uri, Mode::Read) + .unwrap() + .open() + .unwrap(); + + /* + * First check fragment - its domain should match what we just wrote, and we need the + * timestamp so we can read back only this fragment + */ + let [timestamp_min, timestamp_max] = { + let fi = array.fragment_info().unwrap(); + let nf = fi.num_fragments().unwrap(); + assert!(nf > 0); + + let this_fragment = fi.get_fragment(nf - 1).unwrap(); + + if let Some(write_domain) = write.domain() { + let nonempty_domain = + this_fragment.non_empty_domain().unwrap().untyped(); + assert_eq!(write_domain, nonempty_domain); + } else { + // most recent fragment should be empty, + // what does that look like if no data was written? + } + + this_fragment.timestamp_range().unwrap() + }; + + let safety_write_start = std::time::Instant::now(); + + /* + * Then re-open the array to read back what we just wrote + * into the most recent fragment only + */ + { + array = array + .reopen() + .start_timestamp(timestamp_min) + .unwrap() + .end_timestamp(timestamp_max) + .unwrap() + .open() + .unwrap(); + + let mut read = write.read_same_fields(array).unwrap().build().unwrap(); + + if let Some(write_ranges) = write_ranges { + let read_ranges = read.subarray().unwrap().ranges().unwrap(); + assert_eq!(write_ranges, read_ranges); + } + + let mut cells = { + let status = read.submit().unwrap(); + assert_eq!(status, QueryStatus::Completed); + + let record_batch = read.records().unwrap(); + cells::arrow::from_record_batch(&record_batch).unwrap() + }; + + /* `cells` should match the write */ + { + let write_sorted = write.cells().sorted(&cmp_sort_keys); + cells.sort(&cmp_sort_keys); + assert_eq!(write_sorted, cells); + } + + (array, _) = read.finalize().unwrap(); + } + + /* finally, check that everything written up until now is correct */ + array = array.reopen().start_timestamp(0).unwrap().open().unwrap(); + + /* check array non-empty domain */ + if let Some(accumulated_domain) = accumulated_array.domain.as_mut() { + let Some(write_domain) = write.domain() else { + unreachable!() + }; + *accumulated_domain = accumulated_domain.union(&write_domain); + } else { + accumulated_array.domain = write.domain(); + } + { + let Some(acc) = accumulated_array.domain.as_ref() else { + unreachable!() + }; + let nonempty = array.nonempty_domain().unwrap().unwrap().untyped(); + assert_eq!(*acc, nonempty); + } + + /* update accumulated expected array data */ + accumulated_array.cells.accumulate(write); + { + let acc = accumulated_array.cells.cells().sorted(&cmp_sort_keys); + + let cells = { + let mut read = accumulated_array + .cells + .read_same_fields(array) + .unwrap() + .build() + .unwrap(); + + let mut cells = { + let status = read.submit().unwrap(); + assert_eq!(status, QueryStatus::Completed); + + let record_batch = read.records().unwrap(); + cells::arrow::from_record_batch(&record_batch).unwrap() + }; + cells.sort(&cmp_sort_keys); + cells + }; + + assert_eq!(acc, cells); + } + + // safety valve to ensure we don't write two fragments in the same millisecond + if safety_write_start.elapsed() < std::time::Duration::from_millis(1) { + std::thread::sleep(std::time::Duration::from_millis(1)); + } + + Ok(()) +} From a0d433d21981c9b47f01e352b7bea75780f95556 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 13:11:50 -0500 Subject: [PATCH 26/42] Make target query field available to MutableOrShared --- tiledb/query-core/src/buffers.rs | 212 +++++++++++++++---------------- 1 file changed, 102 insertions(+), 110 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index 8118a4bb..abfc4cef 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -940,90 +940,88 @@ impl NewBufferTraitThing for PrimitiveBuffers { } } -impl TryFrom> for Box { - type Error = FromArrowError; - - fn try_from(array: Arc) -> FromArrowResult { - let dtype = array.data_type().clone(); - match dtype { - adt::DataType::Boolean => { - BooleanBuffers::try_from(array).map(|buffers| { - let boxed: Box = Box::new(buffers); - boxed - }) - } - adt::DataType::LargeBinary | adt::DataType::LargeUtf8 => { - ByteBuffers::try_from(array).map(|buffers| { - let boxed: Box = Box::new(buffers); - boxed - }) - } - adt::DataType::FixedSizeList(_, _) => { - FixedListBuffers::try_from(array).map(|buffers| { - let boxed: Box = Box::new(buffers); - boxed - }) - } - adt::DataType::LargeList(_) => { - ListBuffers::try_from(array).map(|buffers| { - let boxed: Box = Box::new(buffers); - boxed - }) - } - adt::DataType::Int8 - | adt::DataType::Int16 - | adt::DataType::Int32 - | adt::DataType::Int64 - | adt::DataType::UInt8 - | adt::DataType::UInt16 - | adt::DataType::UInt32 - | adt::DataType::UInt64 - | adt::DataType::Float32 - | adt::DataType::Float64 - | adt::DataType::Timestamp(_, None) - | adt::DataType::Time64(_) => PrimitiveBuffers::try_from(array) - .map(|buffers| { - let boxed: Box = Box::new(buffers); - boxed - }), - - adt::DataType::Timestamp(_, Some(_)) => { - Err((array, UnsupportedArrowArrayError::UnsupportedTimeZones)) - } +fn to_target_buffers( + _target: &QueryField, + array: Arc, +) -> FromArrowResult> { + let dtype = array.data_type().clone(); + match dtype { + adt::DataType::Boolean => { + BooleanBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::LargeBinary | adt::DataType::LargeUtf8 => { + ByteBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::FixedSizeList(_, _) => FixedListBuffers::try_from(array) + .map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }), + adt::DataType::LargeList(_) => { + ListBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } + adt::DataType::Int8 + | adt::DataType::Int16 + | adt::DataType::Int32 + | adt::DataType::Int64 + | adt::DataType::UInt8 + | adt::DataType::UInt16 + | adt::DataType::UInt32 + | adt::DataType::UInt64 + | adt::DataType::Float32 + | adt::DataType::Float64 + | adt::DataType::Timestamp(_, None) + | adt::DataType::Time64(_) => { + PrimitiveBuffers::try_from(array).map(|buffers| { + let boxed: Box = Box::new(buffers); + boxed + }) + } - adt::DataType::Binary - | adt::DataType::List(_) - | adt::DataType::Utf8 => Err(( - array, - UnsupportedArrowArrayError::LargeVariantOnly(dtype), - )), + adt::DataType::Timestamp(_, Some(_)) => { + Err((array, UnsupportedArrowArrayError::UnsupportedTimeZones)) + } - adt::DataType::FixedSizeBinary(_) => { - todo!("This can probably be supported.") - } + adt::DataType::Binary + | adt::DataType::List(_) + | adt::DataType::Utf8 => { + Err((array, UnsupportedArrowArrayError::LargeVariantOnly(dtype))) + } - adt::DataType::Null - | adt::DataType::Float16 - | adt::DataType::Date32 - | adt::DataType::Date64 - | adt::DataType::Time32(_) - | adt::DataType::Duration(_) - | adt::DataType::Interval(_) - | adt::DataType::BinaryView - | adt::DataType::Utf8View - | adt::DataType::ListView(_) - | adt::DataType::LargeListView(_) - | adt::DataType::Struct(_) - | adt::DataType::Union(_, _) - | adt::DataType::Dictionary(_, _) - | adt::DataType::Decimal128(_, _) - | adt::DataType::Decimal256(_, _) - | adt::DataType::Map(_, _) - | adt::DataType::RunEndEncoded(_, _) => Err(( - array, - UnsupportedArrowArrayError::UnsupportedArrowType(dtype), - )), + adt::DataType::FixedSizeBinary(_) => { + todo!("This can probably be supported.") } + + adt::DataType::Null + | adt::DataType::Float16 + | adt::DataType::Date32 + | adt::DataType::Date64 + | adt::DataType::Time32(_) + | adt::DataType::Duration(_) + | adt::DataType::Interval(_) + | adt::DataType::BinaryView + | adt::DataType::Utf8View + | adt::DataType::ListView(_) + | adt::DataType::LargeListView(_) + | adt::DataType::Struct(_) + | adt::DataType::Union(_, _) + | adt::DataType::Dictionary(_, _) + | adt::DataType::Decimal128(_, _) + | adt::DataType::Decimal256(_, _) + | adt::DataType::Map(_, _) + | adt::DataType::RunEndEncoded(_, _) => Err(( + array, + UnsupportedArrowArrayError::UnsupportedArrowType(dtype), + )), } } @@ -1033,13 +1031,15 @@ impl TryFrom> for Box { /// a thing that can be done safely, so we have a specialized utility struct /// that does the same idea, at the cost of an extra None in the struct. struct MutableOrShared { + target: QueryField, mutable: Option>, shared: Option>, } impl MutableOrShared { - pub fn new(value: Arc) -> Self { + pub fn new(target: QueryField, value: Arc) -> Self { Self { + target, mutable: None, shared: Some(value), } @@ -1061,7 +1061,7 @@ impl MutableOrShared { } let shared: Arc = self.shared.take().unwrap(); - let maybe_mutable = Box::::try_from(shared); + let maybe_mutable = to_target_buffers(&self.target, shared); let ret = if maybe_mutable.is_ok() { self.mutable = maybe_mutable.ok(); @@ -1113,6 +1113,20 @@ pub struct BufferEntry { } impl BufferEntry { + pub fn new(target: QueryField, buffer: Arc) -> Self { + Self { + entry: MutableOrShared::new(target, buffer), + aggregate: None, + } + } + + pub fn with_aggregate(self, aggregate: AggregateFunctionHandle) -> Self { + Self { + aggregate: Some(aggregate), + ..self + } + } + pub fn as_shared(&self) -> Result> { let Some(ref array) = self.entry.shared() else { return Err(Error::UnshareableMutableBuffer); @@ -1201,30 +1215,11 @@ impl BufferEntry { } } -impl From> for BufferEntry { - fn from(array: Arc) -> Self { - Self { - entry: MutableOrShared::new(array), - aggregate: None, - } - } -} - pub struct QueryBuffers { buffers: HashMap, } impl QueryBuffers { - pub fn new(buffers: HashMap>) -> Self { - let mut new_buffers = HashMap::with_capacity(buffers.len()); - for (field, array) in buffers.into_iter() { - new_buffers.insert(field, BufferEntry::from(array)); - } - Self { - buffers: new_buffers, - } - } - /// Reset all mutable buffers' len to match its total capacity. pub fn reset_lengths(&mut self) -> Result<()> { for array in self.buffers.values_mut() { @@ -1270,10 +1265,10 @@ impl QueryBuffers { for (name, request_field) in fields.fields.into_iter() { let query_field = QueryField::get(&schema.context(), raw, &name) .map_err(|e| Error::Field(name.clone(), e.into()))?; - let array = to_array(request_field, query_field) + let array = request_to_buffers(request_field, &query_field) .map_err(|e| Error::Field(name.clone(), e))?; - buffers.insert(name.clone(), BufferEntry::from(array)); + buffers.insert(name.clone(), BufferEntry::new(query_field, array)); } for (name, (function, request_field)) in fields.aggregates.into_iter() { @@ -1281,15 +1276,12 @@ impl QueryBuffers { let query_field = QueryField::get(&schema.context(), raw, &name) .map_err(|e| Error::Field(name.clone(), e.into()))?; - let array = to_array(request_field, query_field) + let array = request_to_buffers(request_field, &query_field) .map_err(|e| Error::Field(name.clone(), e))?; buffers.insert( name.to_owned(), - BufferEntry { - entry: MutableOrShared::new(array), - aggregate: Some(handle), - }, + BufferEntry::new(query_field, array).with_aggregate(handle), ); } @@ -1351,9 +1343,9 @@ impl QueryBuffers { } } -fn to_array( +fn request_to_buffers( field: RequestField, - tiledb_field: QueryField, + tiledb_field: &QueryField, ) -> FieldResult> { let conv = ToArrowConverter::strict(); From be274008cacdeda645ed5d39efdd8a740b4589db Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Tue, 26 Nov 2024 14:02:17 -0500 Subject: [PATCH 27/42] Alloc validity buffer for writes when the source is not nullable and the target is --- tiledb/query-core/src/buffers.rs | 252 +++++++++++++++++++++---------- tiledb/query-core/src/lib.rs | 15 +- 2 files changed, 188 insertions(+), 79 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index abfc4cef..42871cc7 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -13,12 +13,13 @@ use thiserror::Error; use tiledb_api::array::schema::Schema; use tiledb_api::error::Error as TileDBError; use tiledb_api::query::read::aggregate::AggregateFunctionHandle; -use tiledb_api::ContextBound; +use tiledb_api::{Context, ContextBound}; use tiledb_common::array::CellValNum; use super::datatype::ToArrowConverter; use super::field::QueryField; use super::fields::{QueryField as RequestField, QueryFields}; +use super::QueryType; use super::RawQuery; const AVERAGE_STRING_LENGTH: usize = 64; @@ -189,21 +190,28 @@ trait NewBufferTraitThing { } } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum BufferTarget { + Read, + Write(bool), +} + struct BooleanBuffers { data: QueryBuffer, validity: Option, } -impl TryFrom> for BooleanBuffers { - type Error = FromArrowError; - fn try_from(array: Arc) -> FromArrowResult { - let array: aa::BooleanArray = downcast_consume(array); +impl BooleanBuffers { + pub fn try_new( + target: &BufferTarget, + array: aa::BooleanArray, + ) -> FromArrowResult { let (data, validity) = array.into_parts(); let data = data .iter() .map(|v| if v { 1u8 } else { 0 }) .collect::>(); - let validity = to_tdb_validity(validity); + let validity = to_tdb_validity(target, data.len(), validity); Ok(BooleanBuffers { data: QueryBuffer::new(abuf::MutableBuffer::from(data)), @@ -260,7 +268,8 @@ struct ByteBuffers { } macro_rules! to_byte_buffers { - ($ARRAY:expr, $ARROW_TYPE:expr, $ARROW_DT:ty) => {{ + ($TARGET:expr, $ARRAY:expr, $ARROW_TYPE:expr, $ARROW_DT:ty) => {{ + let target = $TARGET; let array: $ARROW_DT = downcast_consume($ARRAY); let (offsets, data, nulls) = array.into_parts(); @@ -268,9 +277,12 @@ macro_rules! to_byte_buffers { let offsets = offsets.into_inner().into_inner().into_mutable(); if data.is_ok() && offsets.is_ok() { + let offsets = offsets.unwrap(); + let num_cells = offsets.len() - 1; let data = QueryBuffer::new(data.ok().unwrap()); - let offsets = QueryBuffer::new(offsets.ok().unwrap()); - let validity = to_tdb_validity(nulls).map(QueryBuffer::new); + let offsets = QueryBuffer::new(offsets); + let validity = + to_tdb_validity(target, num_cells, nulls).map(QueryBuffer::new); return Ok(ByteBuffers { dtype: $ARROW_TYPE, data, @@ -306,17 +318,28 @@ macro_rules! to_byte_buffers { }}; } -impl TryFrom> for ByteBuffers { - type Error = FromArrowError; - - fn try_from(array: Arc) -> FromArrowResult { +impl ByteBuffers { + pub fn try_new( + target: &BufferTarget, + array: Arc, + ) -> FromArrowResult { let dtype = array.data_type().clone(); match dtype { adt::DataType::LargeBinary => { - to_byte_buffers!(array, dtype.clone(), aa::LargeBinaryArray) + to_byte_buffers!( + target, + array, + dtype.clone(), + aa::LargeBinaryArray + ) } adt::DataType::LargeUtf8 => { - to_byte_buffers!(array, dtype.clone(), aa::LargeStringArray) + to_byte_buffers!( + target, + array, + dtype.clone(), + aa::LargeStringArray + ) } t => Err((array, UnsupportedArrowArrayError::InvalidBytesType(t))), } @@ -411,16 +434,19 @@ struct FixedListBuffers { validity: Option, } -impl TryFrom> for FixedListBuffers { - type Error = FromArrowError; - - fn try_from(array: Arc) -> FromArrowResult { +impl FixedListBuffers { + pub fn try_new( + target: &BufferTarget, + array: Arc, + ) -> FromArrowResult { assert!(matches!( array.data_type(), adt::DataType::FixedSizeList(_, _) )); let array: aa::FixedSizeListArray = downcast_consume(array); + let num_cells = aa::Array::len(&array); + let (field, cvn, array, nulls) = array.into_parts(); if field.is_nullable() { @@ -450,11 +476,12 @@ impl TryFrom> for FixedListBuffers { )); } - PrimitiveBuffers::try_from(array) + PrimitiveBuffers::try_new(target, array) .map(|buffers| { assert_eq!(buffers.dtype, dtype); let validity = - to_tdb_validity(nulls.clone()).map(QueryBuffer::new); + to_tdb_validity(target, num_cells, nulls.clone()) + .map(QueryBuffer::new); FixedListBuffers { field: Arc::clone(&field), cell_val_num: cvn, @@ -630,10 +657,11 @@ struct ListBuffers { validity: Option, } -impl TryFrom> for ListBuffers { - type Error = FromArrowError; - - fn try_from(array: Arc) -> FromArrowResult { +impl ListBuffers { + pub fn try_new( + target: &BufferTarget, + array: Arc, + ) -> FromArrowResult { assert!(matches!(array.data_type(), adt::DataType::LargeList(_))); let array: aa::LargeListArray = downcast_consume(array); @@ -654,11 +682,13 @@ impl TryFrom> for ListBuffers { )); } + let num_cells = array.len(); + // N.B., I really, really tried to make this a fancy map/map_err // cascade like all of the others. But it turns out that keeping the // proper refcounts on either array or offsets turns into a bit of // an issue when passing things through multiple closures. - let result = PrimitiveBuffers::try_from(array); + let result = PrimitiveBuffers::try_new(target, array); if result.is_err() { let (array, err) = result.err().unwrap(); let array: Arc = Arc::new( @@ -700,7 +730,8 @@ impl TryFrom> for ListBuffers { // NB: by default the offsets are not arrow-shaped. // However we use the configuration options to make them so. - let validity = to_tdb_validity(nulls).map(QueryBuffer::new); + let validity = + to_tdb_validity(target, num_cells, nulls).map(QueryBuffer::new); Ok(ListBuffers { field, @@ -810,9 +841,10 @@ struct PrimitiveBuffers { } macro_rules! to_primitive { - ($ARRAY:expr, $ARROW_DT:ty) => {{ + ($TARGET:expr, $ARRAY:expr, $ARROW_DT:ty) => {{ + let target = $TARGET; let array: $ARROW_DT = downcast_consume($ARRAY); - let len = array.len(); + let num_cells = array.len(); let (dtype, buffer, nulls) = array.into_parts(); buffer @@ -820,7 +852,8 @@ macro_rules! to_primitive { .into_mutable() .map(|data| { let validity = - to_tdb_validity(nulls.clone()).map(QueryBuffer::new); + to_tdb_validity(target, num_cells, nulls.clone()) + .map(QueryBuffer::new); PrimitiveBuffers { dtype: dtype.clone(), data: QueryBuffer::new(data), @@ -833,7 +866,7 @@ macro_rules! to_primitive { // right back together again. Sorry, Humpty. let data = aa::ArrayData::try_new( dtype, - len, + num_cells, nulls.map(|n| n.into_inner().into_inner()), 0, vec![buffer], @@ -845,22 +878,42 @@ macro_rules! to_primitive { }}; } -impl TryFrom> for PrimitiveBuffers { - type Error = FromArrowError; - fn try_from(array: Arc) -> FromArrowResult { +impl PrimitiveBuffers { + pub fn try_new( + target: &BufferTarget, + array: Arc, + ) -> FromArrowResult { assert!(array.data_type().is_primitive()); match array.data_type().clone() { - adt::DataType::Int8 => to_primitive!(array, aa::Int8Array), - adt::DataType::Int16 => to_primitive!(array, aa::Int16Array), - adt::DataType::Int32 => to_primitive!(array, aa::Int32Array), - adt::DataType::Int64 => to_primitive!(array, aa::Int64Array), - adt::DataType::UInt8 => to_primitive!(array, aa::UInt8Array), - adt::DataType::UInt16 => to_primitive!(array, aa::UInt16Array), - adt::DataType::UInt32 => to_primitive!(array, aa::UInt32Array), - adt::DataType::UInt64 => to_primitive!(array, aa::UInt64Array), - adt::DataType::Float32 => to_primitive!(array, aa::Float32Array), - adt::DataType::Float64 => to_primitive!(array, aa::Float64Array), + adt::DataType::Int8 => to_primitive!(target, array, aa::Int8Array), + adt::DataType::Int16 => { + to_primitive!(target, array, aa::Int16Array) + } + adt::DataType::Int32 => { + to_primitive!(target, array, aa::Int32Array) + } + adt::DataType::Int64 => { + to_primitive!(target, array, aa::Int64Array) + } + adt::DataType::UInt8 => { + to_primitive!(target, array, aa::UInt8Array) + } + adt::DataType::UInt16 => { + to_primitive!(target, array, aa::UInt16Array) + } + adt::DataType::UInt32 => { + to_primitive!(target, array, aa::UInt32Array) + } + adt::DataType::UInt64 => { + to_primitive!(target, array, aa::UInt64Array) + } + adt::DataType::Float32 => { + to_primitive!(target, array, aa::Float32Array) + } + adt::DataType::Float64 => { + to_primitive!(target, array, aa::Float64Array) + } t => Err(( array, UnsupportedArrowArrayError::InvalidPrimitiveType(t), @@ -941,30 +994,32 @@ impl NewBufferTraitThing for PrimitiveBuffers { } fn to_target_buffers( - _target: &QueryField, + target: &BufferTarget, array: Arc, ) -> FromArrowResult> { let dtype = array.data_type().clone(); match dtype { adt::DataType::Boolean => { - BooleanBuffers::try_from(array).map(|buffers| { + let array = downcast_consume::(array); + BooleanBuffers::try_new(target, array).map(|buffers| { let boxed: Box = Box::new(buffers); boxed }) } adt::DataType::LargeBinary | adt::DataType::LargeUtf8 => { - ByteBuffers::try_from(array).map(|buffers| { + ByteBuffers::try_new(target, array).map(|buffers| { let boxed: Box = Box::new(buffers); boxed }) } - adt::DataType::FixedSizeList(_, _) => FixedListBuffers::try_from(array) - .map(|buffers| { + adt::DataType::FixedSizeList(_, _) => { + FixedListBuffers::try_new(target, array).map(|buffers| { let boxed: Box = Box::new(buffers); boxed - }), + }) + } adt::DataType::LargeList(_) => { - ListBuffers::try_from(array).map(|buffers| { + ListBuffers::try_new(target, array).map(|buffers| { let boxed: Box = Box::new(buffers); boxed }) @@ -980,12 +1035,11 @@ fn to_target_buffers( | adt::DataType::Float32 | adt::DataType::Float64 | adt::DataType::Timestamp(_, None) - | adt::DataType::Time64(_) => { - PrimitiveBuffers::try_from(array).map(|buffers| { + | adt::DataType::Time64(_) => PrimitiveBuffers::try_new(target, array) + .map(|buffers| { let boxed: Box = Box::new(buffers); boxed - }) - } + }), adt::DataType::Timestamp(_, Some(_)) => { Err((array, UnsupportedArrowArrayError::UnsupportedTimeZones)) @@ -1031,13 +1085,13 @@ fn to_target_buffers( /// a thing that can be done safely, so we have a specialized utility struct /// that does the same idea, at the cost of an extra None in the struct. struct MutableOrShared { - target: QueryField, + target: BufferTarget, mutable: Option>, shared: Option>, } impl MutableOrShared { - pub fn new(target: QueryField, value: Arc) -> Self { + pub fn new(target: BufferTarget, value: Arc) -> Self { Self { target, mutable: None, @@ -1113,7 +1167,7 @@ pub struct BufferEntry { } impl BufferEntry { - pub fn new(target: QueryField, buffer: Arc) -> Self { + fn new(target: BufferTarget, buffer: Arc) -> Self { Self { entry: MutableOrShared::new(target, buffer), aggregate: None, @@ -1258,30 +1312,38 @@ impl QueryBuffers { pub(crate) fn from_fields( schema: Schema, + query_type: QueryType, raw: &RawQuery, fields: QueryFields, ) -> Result { let mut buffers = HashMap::with_capacity(fields.fields.len()); for (name, request_field) in fields.fields.into_iter() { - let query_field = QueryField::get(&schema.context(), raw, &name) - .map_err(|e| Error::Field(name.clone(), e.into()))?; - let array = request_to_buffers(request_field, &query_field) - .map_err(|e| Error::Field(name.clone(), e))?; - - buffers.insert(name.clone(), BufferEntry::new(query_field, array)); + buffers.insert( + name.clone(), + make_buffer_entry( + &schema.context(), + query_type, + raw, + &name, + request_field, + None, + ) + .map_err(|e| Error::Field(name.clone(), e))?, + ); } for (name, (function, request_field)) in fields.aggregates.into_iter() { - let handle = AggregateFunctionHandle::new(function)?; - - let query_field = QueryField::get(&schema.context(), raw, &name) - .map_err(|e| Error::Field(name.clone(), e.into()))?; - let array = request_to_buffers(request_field, &query_field) - .map_err(|e| Error::Field(name.clone(), e))?; - buffers.insert( - name.to_owned(), - BufferEntry::new(query_field, array).with_aggregate(handle), + name.clone(), + make_buffer_entry( + &schema.context(), + query_type, + raw, + &name, + request_field, + Some(AggregateFunctionHandle::new(function)?), + ) + .map_err(|e| Error::Field(name.clone(), e))?, ); } @@ -1343,6 +1405,33 @@ impl QueryBuffers { } } +fn make_buffer_entry( + context: &Context, + query_type: QueryType, + raw: &RawQuery, + name: &str, + request: RequestField, + aggregate: Option, +) -> FieldResult { + let query_field = QueryField::get(context, raw, &name)?; + let array = request_to_buffers(request, &query_field)?; + + let target = match query_type { + QueryType::Read => BufferTarget::Read, + QueryType::Write => BufferTarget::Write(query_field.nullable()?), + QueryType::Delete => unimplemented!(), + QueryType::Update => unimplemented!(), + QueryType::ModifyExclusive => unimplemented!(), + }; + + let entry = BufferEntry::new(target, array); + Ok(if let Some(aggregate) = aggregate { + entry.with_aggregate(aggregate) + } else { + entry + }) +} + fn request_to_buffers( field: RequestField, tiledb_field: &QueryField, @@ -1559,15 +1648,26 @@ fn to_tdb_offsets( .map_err(|_| ArrayInUseError::Offsets) } -fn to_tdb_validity(nulls: Option) -> Option { - nulls.map(|nulls| { +fn to_tdb_validity( + target: &BufferTarget, + num_cells: usize, + nulls: Option, +) -> Option { + let arrow_validity = nulls.map(|nulls| { ArrowBufferMut::from( nulls .iter() .map(|v| if v { 1u8 } else { 0 }) .collect::>(), ) - }) + }); + if arrow_validity.is_some() { + arrow_validity + } else if matches!(target, BufferTarget::Write(true)) { + Some(ArrowBufferMut::from(vec![1u8; num_cells])) + } else { + None + } } fn from_tdb_offsets(offsets: ArrowBufferMut) -> abuf::OffsetBuffer { diff --git a/tiledb/query-core/src/lib.rs b/tiledb/query-core/src/lib.rs index 67e4f03b..d3f8807a 100644 --- a/tiledb/query-core/src/lib.rs +++ b/tiledb/query-core/src/lib.rs @@ -284,8 +284,12 @@ impl Query { &mut self, fields: QueryFields, ) -> Result { - let mut tmp_buffers = - QueryBuffers::from_fields(self.array.schema()?, &self.raw, fields)?; + let mut tmp_buffers = QueryBuffers::from_fields( + self.array.schema()?, + self.query_type, + &self.raw, + fields, + )?; tmp_buffers.make_mut()?; if self.buffers.is_compatible(&tmp_buffers) { std::mem::swap(&mut self.buffers, &mut tmp_buffers); @@ -411,7 +415,12 @@ impl QueryBuilder { self.set_subarray(&raw)?; self.set_query_condition(&raw)?; - let buffers = QueryBuffers::from_fields(schema, &raw, self.fields)?; + let buffers = QueryBuffers::from_fields( + schema, + self.query_type, + &raw, + self.fields, + )?; Ok(Query { context: self.array.context(), From 74a75d4ae28f241cd7212cc3c77d3d2af7f6307a Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 08:52:19 -0500 Subject: [PATCH 28/42] datatype mod cleanly separates default type mapping vs. is compatible --- tiledb/query-core/src/buffers.rs | 28 +- .../query-core/src/datatype/compatibility.rs | 251 ++++++++ tiledb/query-core/src/datatype/default_to.rs | 169 ++++++ tiledb/query-core/src/datatype/mod.rs | 548 ++---------------- 4 files changed, 482 insertions(+), 514 deletions(-) create mode 100644 tiledb/query-core/src/datatype/compatibility.rs create mode 100644 tiledb/query-core/src/datatype/default_to.rs diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index 42871cc7..a3fc0635 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -16,7 +16,6 @@ use tiledb_api::query::read::aggregate::AggregateFunctionHandle; use tiledb_api::{Context, ContextBound}; use tiledb_common::array::CellValNum; -use super::datatype::ToArrowConverter; use super::field::QueryField; use super::fields::{QueryField as RequestField, QueryFields}; use super::QueryType; @@ -26,8 +25,6 @@ const AVERAGE_STRING_LENGTH: usize = 64; #[derive(Debug, Error)] pub enum Error { - #[error("Error converting to Arrow for field '{0}': {1}")] - ArrowConversionError(String, crate::datatype::Error), #[error("Failed to convert Arrow Array for field '{0}': {1}")] FailedConversionFromArrow(String, Box), #[error("Failed to add field '{0}' to query: {1}")] @@ -65,8 +62,10 @@ type Result = std::result::Result; pub enum FieldError { #[error("Error reading query field: {0}")] QueryField(#[from] crate::field::Error), - #[error("Type mismatch for requested field: {0}")] - TypeMismatch(crate::datatype::Error), + #[error("Type mismatch: arrow type '{0}' is not compatible with tiledb type '({1}, {2})'")] + TypeMismatch(adt::DataType, tiledb_common::Datatype, CellValNum), + #[error("No default arrow type: {0}")] + TargetTypeRequired(crate::datatype::DefaultArrowTypeError), #[error("Failed to allocate buffer: {0}")] BufferAllocation(ArrowError), #[error("Unsupported arrow array: {0}")] @@ -1436,8 +1435,6 @@ fn request_to_buffers( field: RequestField, tiledb_field: &QueryField, ) -> FieldResult> { - let conv = ToArrowConverter::strict(); - let tdb_dtype = tiledb_field.datatype()?; let tdb_cvn = tiledb_field.cell_val_num()?; let tdb_nullable = tiledb_field.nullable()?; @@ -1447,13 +1444,20 @@ fn request_to_buffers( return Ok(array); } - let arrow_type = if let Some(dtype) = field.target_type() { - conv.convert_datatype_to(&tdb_dtype, &tdb_cvn, tdb_nullable, dtype) + let arrow_type = if let Some(atype) = field.target_type() { + if !crate::datatype::is_physically_compatible( + &atype, tdb_dtype, tdb_cvn, + ) { + return Err(FieldError::TypeMismatch(atype, tdb_dtype, tdb_cvn)); + } + atype } else { - conv.convert_datatype(&tdb_dtype, &tdb_cvn, tdb_nullable) - } - .map_err(FieldError::TypeMismatch)?; + crate::datatype::default_arrow_type(tdb_dtype, tdb_cvn) + .map_err(FieldError::TargetTypeRequired)? + .into_inner() + }; + // SAFETY: I dunno this unwrap looks sketchy alloc_array(arrow_type, tdb_nullable, field.capacity().unwrap()) } diff --git a/tiledb/query-core/src/datatype/compatibility.rs b/tiledb/query-core/src/datatype/compatibility.rs new file mode 100644 index 00000000..3bccc169 --- /dev/null +++ b/tiledb/query-core/src/datatype/compatibility.rs @@ -0,0 +1,251 @@ +use arrow::datatypes::{DataType as ArrowLogicalType, TimeUnit}; +use tiledb_common::array::CellValNum; +use tiledb_common::Datatype; + +/// Returns whether an arrow [DataType] can be used +/// to query fields with a particular [Datatype] and [CellValNum]. +/// +/// Both arrow's [DataType] and tiledb's [Datatype] are "logical" +/// types, i.e. they prescribe both the physical shape of the data +/// as well as how it should be interpreted. The variants of each +/// do not perfectly overlap, meaning that there are some logical +/// types which have a natural representation in arrow but not in +/// tiledb, and vice verse. +/// +/// To enable using arrow as a means of querying tiledb data, +/// the tiledb logical types which do not have a corresponding +/// arrow type may be queried using the arrow [DataType] which +/// matches the _physical_ type of the tiledb logical type. +/// ``` +/// use arrow::datatypes::DataType as ArrowLogicalType; +/// use tiledb_common::Datatype as TileDBLogicalType; +/// +/// // `DateTimeFemtosecond` has no corresponding arrow type +/// let tiledb = TileDBLogicalType::DateTimeFemtosecond; +/// +/// assert!(is_physically_compatible(&ArrowLogicalType::Int64, +/// tiledb, CellValNum::single())); +/// ``` +/// +/// If an application uses tiledb as the storage engine for data +/// which is described by an arrow schema, then it may need +/// to query an arrow logical type which does not have a corresponding +/// tiledb logical type. As above, the tiledb [Datatype] which +/// matches the corresponding [ArrowNativeType] can be used. +/// ``` +/// use arrow::datatypes::DataType as ArrowLogicalType; +/// use tiledb_common::Datatype as TileDBLogicalType; +/// +/// // `Date32` has no corresponding tiledb type +/// let arrow = ArrowLogicalType::Date32; +/// +/// assert!(is_physically_compatible(&arrow, +/// TileDBLogicalType::Int32, CellValNum::single())); +/// ``` +pub fn is_physically_compatible( + arrow_datatype: &ArrowLogicalType, + tiledb_datatype: Datatype, + tiledb_cell_val_num: CellValNum, +) -> bool { + let is_single = tiledb_cell_val_num == CellValNum::single(); + let is_var = tiledb_cell_val_num == CellValNum::Var; + + match arrow_datatype { + ArrowLogicalType::Null => false, + ArrowLogicalType::Boolean => { + matches!(tiledb_datatype, Datatype::Boolean) && is_single + } + + ArrowLogicalType::Int8 => { + matches!(tiledb_datatype, Datatype::Int8 | Datatype::Char) + && is_single + } + ArrowLogicalType::Int16 => { + matches!(tiledb_datatype, Datatype::Int16) && is_single + } + ArrowLogicalType::Int32 => { + matches!(tiledb_datatype, Datatype::Int32) && is_single + } + ArrowLogicalType::Int64 => { + matches!( + tiledb_datatype, + Datatype::Int64 + | Datatype::DateTimeYear + | Datatype::DateTimeMonth + | Datatype::DateTimeWeek + | Datatype::DateTimeDay + | Datatype::DateTimeHour + | Datatype::DateTimeMinute + | Datatype::DateTimeSecond + | Datatype::DateTimeMillisecond + | Datatype::DateTimeMicrosecond + | Datatype::DateTimeNanosecond + | Datatype::DateTimePicosecond + | Datatype::DateTimeFemtosecond + | Datatype::DateTimeAttosecond + | Datatype::TimeHour + | Datatype::TimeMinute + | Datatype::TimeSecond + | Datatype::TimeMillisecond + | Datatype::TimeMicrosecond + | Datatype::TimeNanosecond + | Datatype::TimePicosecond + | Datatype::TimeFemtosecond + | Datatype::TimeAttosecond + ) && is_single + } + ArrowLogicalType::UInt8 => { + matches!( + tiledb_datatype, + Datatype::UInt8 + | Datatype::StringAscii + | Datatype::StringUtf8 + | Datatype::Any + | Datatype::Blob + | Datatype::Boolean + | Datatype::GeometryWkb + | Datatype::GeometryWkt + ) && is_single + } + ArrowLogicalType::UInt16 => { + matches!( + tiledb_datatype, + Datatype::UInt16 | Datatype::StringUtf16 | Datatype::StringUcs2 + ) && is_single + } + ArrowLogicalType::UInt32 => { + matches!( + tiledb_datatype, + Datatype::UInt32 | Datatype::StringUtf32 | Datatype::StringUcs4 + ) && is_single + } + ArrowLogicalType::UInt64 => { + matches!(tiledb_datatype, Datatype::UInt64) && is_single + } + ArrowLogicalType::Float32 => { + matches!(tiledb_datatype, Datatype::Float32) && is_single + } + ArrowLogicalType::Float64 => { + matches!(tiledb_datatype, Datatype::Float64) && is_single + } + ArrowLogicalType::Timestamp(TimeUnit::Second, None) => { + matches!(tiledb_datatype, Datatype::DateTimeSecond) && is_single + } + ArrowLogicalType::Timestamp(TimeUnit::Millisecond, None) => { + matches!(tiledb_datatype, Datatype::DateTimeMillisecond) + && is_single + } + ArrowLogicalType::Timestamp(TimeUnit::Microsecond, None) => { + matches!(tiledb_datatype, Datatype::DateTimeMicrosecond) + && is_single + } + ArrowLogicalType::Timestamp(TimeUnit::Nanosecond, None) => { + matches!(tiledb_datatype, Datatype::DateTimeNanosecond) && is_single + } + ArrowLogicalType::Timestamp(_, Some(_)) => false, + ArrowLogicalType::Time64(TimeUnit::Second) => { + matches!(tiledb_datatype, Datatype::TimeSecond) && is_single + } + ArrowLogicalType::Time64(TimeUnit::Millisecond) => { + matches!(tiledb_datatype, Datatype::TimeMillisecond) && is_single + } + ArrowLogicalType::Time64(TimeUnit::Microsecond) => { + matches!(tiledb_datatype, Datatype::TimeMicrosecond) && is_single + } + ArrowLogicalType::Time64(TimeUnit::Nanosecond) => { + matches!(tiledb_datatype, Datatype::TimeNanosecond) && is_single + } + + ArrowLogicalType::Utf8 | ArrowLogicalType::LargeUtf8 => { + matches!( + tiledb_datatype, + Datatype::StringAscii | Datatype::StringUtf8 + ) && is_var + } + ArrowLogicalType::Binary | ArrowLogicalType::LargeBinary => { + is_physically_compatible( + &ArrowLogicalType::UInt8, + tiledb_datatype, + CellValNum::single(), + ) && is_var + } + ArrowLogicalType::FixedSizeBinary(cvn) => { + if tiledb_cell_val_num != *cvn as u32 { + false + } else { + is_physically_compatible( + &ArrowLogicalType::UInt8, + tiledb_datatype, + CellValNum::single(), + ) + } + } + + ArrowLogicalType::List(field) | ArrowLogicalType::LargeList(field) => { + // NB: any cell val num is allowed + is_physically_compatible( + field.data_type(), + tiledb_datatype, + CellValNum::single(), + ) + } + + ArrowLogicalType::FixedSizeList(field, cvn) => { + tiledb_cell_val_num == *cvn as u32 + && is_physically_compatible( + field.data_type(), + tiledb_datatype, + CellValNum::single(), + ) + } + ArrowLogicalType::Date32 => { + matches!(tiledb_datatype, Datatype::Int32) && is_single + } + ArrowLogicalType::Date64 => { + matches!(tiledb_datatype, Datatype::Int64) && is_single + } + ArrowLogicalType::Time32(_) => { + matches!(tiledb_datatype, Datatype::Int32) && is_single + } + + // TODO: Duration and Interval can be represented + // Decimal128 and Decimal256 could be blobs... + + // Notes on other possible relaxed conversions: + // + // Duration and some intervals are likely supportable, but + // leaving them off for now as the docs aren't clear. + // + // Views are also likely supportable, but will likely require + // separate buffer allocations since individual values are not + // contiguous. + // + // Struct and Union are never supportable (given current core) + // + // Dictionary is, but they should be handled higher up the stack + // to ensure that things line up with enumerations. + // + // Decimal128 and Decimal256 might be supportable using Float64 + // and 2 or 4 fixed length cell val num. Though it'd be fairly + // hacky. + // + // Map isn't supported in TileDB (given current core) + // + // RunEndEncoded is probably supportable, but like views will + // require separate buffer allocations so leaving for now. + ArrowLogicalType::Float16 + | ArrowLogicalType::Duration(_) + | ArrowLogicalType::Interval(_) + | ArrowLogicalType::BinaryView + | ArrowLogicalType::Utf8View + | ArrowLogicalType::ListView(_) + | ArrowLogicalType::LargeListView(_) + | ArrowLogicalType::Struct(_) + | ArrowLogicalType::Union(_, _) + | ArrowLogicalType::Dictionary(_, _) + | ArrowLogicalType::Decimal128(_, _) + | ArrowLogicalType::Decimal256(_, _) + | ArrowLogicalType::Map(_, _) + | ArrowLogicalType::RunEndEncoded(_, _) => false, + } +} diff --git a/tiledb/query-core/src/datatype/default_to.rs b/tiledb/query-core/src/datatype/default_to.rs new file mode 100644 index 00000000..14bfd5d5 --- /dev/null +++ b/tiledb/query-core/src/datatype/default_to.rs @@ -0,0 +1,169 @@ +use arrow::datatypes::{DataType as ArrowDataType, TimeUnit}; +use tiledb_common::array::CellValNum; +use tiledb_common::Datatype; + +use super::TypeConversion::{self, LogicalMatch, PhysicalMatch}; + +#[derive(Debug, thiserror::Error)] +pub enum NoMatchDetail { + #[error("Invalid fixed size cell val num: {0}")] + InvalidFixedSize(u32), +} + +pub fn default_arrow_type( + dtype: Datatype, + cell_val_num: CellValNum, +) -> Result, NoMatchDetail> { + Ok(match (dtype, cell_val_num) { + (Datatype::Blob, CellValNum::Fixed(nz)) if nz.get() != 1 => { + if let Ok(fl) = i32::try_from(nz.get()) { + LogicalMatch(ArrowDataType::FixedSizeBinary(fl)) + } else { + return Err(NoMatchDetail::InvalidFixedSize(nz.get())); + } + } + ( + Datatype::GeometryWkb | Datatype::GeometryWkt, + CellValNum::Fixed(nz), + ) if nz.get() != 1 => { + if let Ok(fl) = i32::try_from(nz.get()) { + PhysicalMatch(ArrowDataType::FixedSizeBinary(fl)) + } else { + return Err(NoMatchDetail::InvalidFixedSize(nz.get())); + } + } + (Datatype::StringAscii, CellValNum::Var) => { + PhysicalMatch(ArrowDataType::LargeUtf8) + } + (Datatype::StringUtf8, CellValNum::Var) => { + LogicalMatch(ArrowDataType::LargeUtf8) + } + (Datatype::Blob, CellValNum::Var) => { + LogicalMatch(ArrowDataType::LargeBinary) + } + + // then the general cases + (_, CellValNum::Fixed(nz)) if nz.get() == 1 => { + single_valued_type(dtype) + } + (_, CellValNum::Fixed(nz)) => { + if let Ok(fl) = i32::try_from(nz.get()) { + match single_valued_type(dtype) { + PhysicalMatch(adt) => PhysicalMatch( + ArrowDataType::new_fixed_size_list(adt, fl, false), + ), + LogicalMatch(adt) => LogicalMatch( + ArrowDataType::new_fixed_size_list(adt, fl, false), + ), + } + } else { + return Err(NoMatchDetail::InvalidFixedSize(nz.get())); + } + } + (_, CellValNum::Var) => match single_valued_type(dtype) { + PhysicalMatch(adt) => { + PhysicalMatch(ArrowDataType::new_large_list(adt, false)) + } + LogicalMatch(adt) => { + LogicalMatch(ArrowDataType::new_large_list(adt, false)) + } + }, + }) +} + +fn single_valued_type(tiledb: Datatype) -> TypeConversion { + use arrow::datatypes::DataType as arrow; + use tiledb_common::Datatype as tiledb; + + match tiledb { + // Any is basically blob + tiledb::Any => PhysicalMatch(arrow::UInt8), + + // Boolean + // NB: this requires a byte array to bit array conversion, + // it is a weird case of being a logical match but not a physical + // match, we'll just handle it specially + tiledb::Boolean => LogicalMatch(arrow::Boolean), + + // Char -> Int8 + tiledb::Char => PhysicalMatch(arrow::Int8), + + // Standard primitive types + tiledb::Int8 => LogicalMatch(arrow::Int8), + tiledb::Int16 => LogicalMatch(arrow::Int16), + tiledb::Int32 => LogicalMatch(arrow::Int32), + tiledb::Int64 => LogicalMatch(arrow::Int64), + tiledb::UInt8 => LogicalMatch(arrow::UInt8), + tiledb::UInt16 => LogicalMatch(arrow::UInt16), + tiledb::UInt32 => LogicalMatch(arrow::UInt32), + tiledb::UInt64 => LogicalMatch(arrow::UInt64), + tiledb::Float32 => LogicalMatch(arrow::Float32), + tiledb::Float64 => LogicalMatch(arrow::Float64), + + // string types + // NB: with `CellValNum::Var` these map to `LargeUtf8` + tiledb::StringAscii => PhysicalMatch(arrow::UInt8), + tiledb::StringUtf8 => PhysicalMatch(arrow::UInt8), + + // string types with no exact match + tiledb::StringUtf16 | tiledb::StringUcs2 => { + PhysicalMatch(arrow::UInt16) + } + tiledb::StringUtf32 | tiledb::StringUcs4 => { + PhysicalMatch(arrow::UInt32) + } + + // datetime types with logical matches + tiledb::DateTimeSecond => { + LogicalMatch(arrow::Timestamp(TimeUnit::Second, None)) + } + tiledb::DateTimeMillisecond => { + LogicalMatch(arrow::Timestamp(TimeUnit::Millisecond, None)) + } + tiledb::DateTimeMicrosecond => { + LogicalMatch(arrow::Timestamp(TimeUnit::Microsecond, None)) + } + tiledb::DateTimeNanosecond => { + LogicalMatch(arrow::Timestamp(TimeUnit::Nanosecond, None)) + } + tiledb::TimeSecond => LogicalMatch(arrow::Time64(TimeUnit::Second)), + tiledb::TimeMillisecond => { + LogicalMatch(arrow::Time64(TimeUnit::Millisecond)) + } + tiledb::TimeMicrosecond => { + LogicalMatch(arrow::Time64(TimeUnit::Microsecond)) + } + tiledb::TimeNanosecond => { + LogicalMatch(arrow::Time64(TimeUnit::Nanosecond)) + } + + // datetime types with no logical matches + // NB: these can lose data if converted to an Arrow logical date/time type resolution + tiledb::DateTimeYear + | tiledb::DateTimeMonth + | tiledb::DateTimeWeek + | tiledb::DateTimeDay + | tiledb::DateTimeHour + | tiledb::DateTimeMinute + | tiledb::DateTimePicosecond + | tiledb::DateTimeFemtosecond + | tiledb::DateTimeAttosecond + | tiledb::TimeHour + | tiledb::TimeMinute + | tiledb::TimePicosecond + | tiledb::TimeFemtosecond + | tiledb::TimeAttosecond => PhysicalMatch(arrow::Int64), + + // Supported string types + + // Blob + // NB: with other cell val nums this maps to `FixedSizeBinary` or `LargeBinary` + tiledb::Blob => PhysicalMatch(arrow::UInt8), + + // Geometry + // NB: with other cell val nums this maps to `FixedSizeBinary` or `LargeBinary` + tiledb::GeometryWkb | tiledb::GeometryWkt => { + PhysicalMatch(arrow::UInt8) + } + } +} diff --git a/tiledb/query-core/src/datatype/mod.rs b/tiledb/query-core/src/datatype/mod.rs index b7119cab..2a854454 100644 --- a/tiledb/query-core/src/datatype/mod.rs +++ b/tiledb/query-core/src/datatype/mod.rs @@ -1,519 +1,63 @@ -use std::sync::Arc; +mod compatibility; +mod default_to; -use arrow::datatypes as adt; - -use thiserror::Error; - -use tiledb_common::array::CellValNum; -use tiledb_common::Datatype; - -#[derive(Error, Debug, PartialEq, Eq)] -pub enum Error { - #[error("Cell value size '{0}' is out of range.")] - CellValNumOutOfRange(u32), - - #[error("Internal type error: Unhandled Arrow type: {0}")] - InternalTypeError(adt::DataType), - - #[error("Invalid fixed sized length: {0}")] - InvalidFixedSize(i32), - - #[error("Invalid Arrow type for conversion: '{0}'")] - InvalidTargetType(adt::DataType), - - #[error("Failed to convert Arrow list element type: {0}")] - ListElementTypeConversionFailed(Box), - - #[error( - "The TileDB datatype '{0}' does not have a default Arrow DataType." - )] - NoDefaultArrowType(Datatype), - - #[error("Arrow type '{0} requires the TileDB field to be single valued.")] - RequiresSingleValued(adt::DataType), - - #[error("Arrow type {0} requires the TileDB field be var sized.")] - RequiresVarSized(adt::DataType), - - #[error("TileDB does not support timezones on timestamps")] - TimeZonesNotSupported, - - #[error( - "TileDB type '{0}' and Arrow type '{1}' have different physical sizes" - )] - PhysicalSizeMismatch(Datatype, adt::DataType), - - #[error("Unsupported Arrow DataType: {0}")] - UnsupportedArrowDataType(adt::DataType), - - #[error("TileDB does not support lists with element type: '{0}'")] - UnsupportedListElementType(adt::DataType), - - #[error("The Arrow DataType '{0}' is not supported.")] - ArrowTypeNotSupported(adt::DataType), - #[error("DataFusion does not support multi-value cells.")] - InvalidMultiCellValNum, - #[error("The TileDB Datatype '{0}' is not supported by DataFusion")] - UnsupportedTileDBDatatype(Datatype), - #[error("Variable-length datatypes as list type elements are not supported by TileDB")] - UnsupportedListVariableLengthElement, -} - -pub type Result = std::result::Result; - -/// ConversionMode dictates whether certain conversions are allowed -pub enum ConversionMode { - /// Only allow conversions that are semantically equivalent - Strict, - /// Allow conversions as long as the physical type is maintained. - Relaxed, +pub enum TypeConversion { + PhysicalMatch(Output), + LogicalMatch(Output), } -pub struct ToArrowConverter { - mode: ConversionMode, -} - -impl ToArrowConverter { - pub fn strict() -> Self { - Self { - mode: ConversionMode::Strict, - } - } - - pub fn physical() -> Self { - Self { - mode: ConversionMode::Relaxed, - } - } - - pub fn convert_datatype( - &self, - dtype: &Datatype, - cvn: &CellValNum, - nullable: bool, - ) -> Result { - if let Some(arrow_type) = self.default_arrow_type(dtype) { - self.convert_datatype_to(dtype, cvn, nullable, arrow_type) - } else { - Err(Error::NoDefaultArrowType(*dtype)) - } - } - - pub fn convert_datatype_to( - &self, - dtype: &Datatype, - cvn: &CellValNum, - nullable: bool, - arrow_type: adt::DataType, - ) -> Result { - if matches!(arrow_type, adt::DataType::Null) { - return Err(Error::InvalidTargetType(arrow_type)); - } - - if arrow_type.is_primitive() { - let width = arrow_type.primitive_width().unwrap(); - if width != dtype.size() { - return Err(Error::PhysicalSizeMismatch(*dtype, arrow_type)); - } - - if cvn.is_single_valued() { - Ok(arrow_type) - } else if cvn.is_var_sized() { - let field = - Arc::new(adt::Field::new("item", arrow_type, nullable)); - Ok(adt::DataType::LargeList(field)) - } else { - // SAFETY: Due to the logic above we can guarantee that this - // is a fixed length cvn. - let cvn = cvn.fixed().unwrap().get(); - if cvn > i32::MAX as u32 { - return Err(Error::CellValNumOutOfRange(cvn)); - } - let field = - Arc::new(adt::Field::new("item", arrow_type, nullable)); - Ok(adt::DataType::FixedSizeList(field, cvn as i32)) - } - } else if matches!(arrow_type, adt::DataType::Boolean) { - if !cvn.is_single_valued() { - Err(Error::RequiresSingleValued(arrow_type)) - } else { - Ok(arrow_type) - } - } else if matches!( - arrow_type, - adt::DataType::LargeBinary | adt::DataType::LargeUtf8 - ) { - if !cvn.is_var_sized() { - Err(Error::RequiresVarSized(arrow_type)) - } else { - Ok(arrow_type) - } - } else { - Err(Error::InternalTypeError(arrow_type)) - } - } - - fn default_arrow_type(&self, dtype: &Datatype) -> Option { - use arrow::datatypes::DataType as arrow; - use tiledb_common::Datatype as tiledb; - let arrow_type = match dtype { - // Any <-> Null, both indicate lack of a type - tiledb::Any => Some(arrow::Null), - - // Boolean, n.b., this requires a byte array to bit array converesion - tiledb::Boolean => Some(arrow::Boolean), - - // Char -> Int8 - tiledb::Char => Some(arrow::Int8), - - // Standard primitive types - tiledb::Int8 => Some(arrow::Int8), - tiledb::Int16 => Some(arrow::Int16), - tiledb::Int32 => Some(arrow::Int32), - tiledb::Int64 => Some(arrow::Int64), - tiledb::UInt8 => Some(arrow::UInt8), - tiledb::UInt16 => Some(arrow::UInt16), - tiledb::UInt32 => Some(arrow::UInt32), - tiledb::UInt64 => Some(arrow::UInt64), - tiledb::Float32 => Some(arrow::Float32), - tiledb::Float64 => Some(arrow::Float64), - - // Supportable datetime types - tiledb::DateTimeSecond => { - Some(arrow::Timestamp(adt::TimeUnit::Second, None)) - } - tiledb::DateTimeMillisecond => { - Some(arrow::Timestamp(adt::TimeUnit::Millisecond, None)) - } - tiledb::DateTimeMicrosecond => { - Some(arrow::Timestamp(adt::TimeUnit::Microsecond, None)) - } - tiledb::DateTimeNanosecond => { - Some(arrow::Timestamp(adt::TimeUnit::Nanosecond, None)) - } - - // Supportable time types - tiledb::TimeSecond => Some(arrow::Time64(adt::TimeUnit::Second)), - tiledb::TimeMillisecond => { - Some(arrow::Time64(adt::TimeUnit::Millisecond)) - } - tiledb::TimeMicrosecond => { - Some(arrow::Time64(adt::TimeUnit::Microsecond)) - } - tiledb::TimeNanosecond => { - Some(arrow::Time64(adt::TimeUnit::Nanosecond)) - } - - // Supported string types - tiledb::StringAscii => Some(arrow::LargeUtf8), - tiledb::StringUtf8 => Some(arrow::LargeUtf8), - - // Blob <-> Binary - tiledb::Blob => Some(arrow::LargeBinary), - - tiledb::StringUtf16 - | tiledb::StringUtf32 - | tiledb::StringUcs2 - | tiledb::StringUcs4 - | tiledb::DateTimeYear - | tiledb::DateTimeMonth - | tiledb::DateTimeWeek - | tiledb::DateTimeDay - | tiledb::DateTimeHour - | tiledb::DateTimeMinute - | tiledb::DateTimePicosecond - | tiledb::DateTimeFemtosecond - | tiledb::DateTimeAttosecond - | tiledb::TimeHour - | tiledb::TimeMinute - | tiledb::TimePicosecond - | tiledb::TimeFemtosecond - | tiledb::TimeAttosecond - | tiledb::GeometryWkb - | tiledb::GeometryWkt => None, - }; - - if arrow_type.is_some() { - return arrow_type; - } - - // If we're doing a strict semantic conversion we don't attempt to find - // a matching physical type. - if matches!(self.mode, ConversionMode::Strict) { - return None; - } - - // Assert in case we add more conversion modes in the future. - assert!(matches!(self.mode, ConversionMode::Relaxed)); - - // Physical conversions means we'll allow dropping the TileDB semantic - // information to allow for raw data access. - match dtype { - // Uncommon string types - tiledb::StringUtf16 => Some(arrow::UInt16), - tiledb::StringUtf32 => Some(arrow::UInt32), - tiledb::StringUcs2 => Some(arrow::UInt16), - tiledb::StringUcs4 => Some(arrow::UInt32), - - // Time types that could lose data if converted to Arrow's - // time resolution. - tiledb::DateTimeYear => Some(arrow::Int64), - tiledb::DateTimeMonth => Some(arrow::Int64), - tiledb::DateTimeWeek => Some(arrow::Int64), - tiledb::DateTimeDay => Some(arrow::Int64), - tiledb::DateTimeHour => Some(arrow::Int64), - tiledb::DateTimeMinute => Some(arrow::Int64), - tiledb::DateTimePicosecond => Some(arrow::Int64), - tiledb::DateTimeFemtosecond => Some(arrow::Int64), - tiledb::DateTimeAttosecond => Some(arrow::Int64), - tiledb::TimeHour => Some(arrow::Int64), - tiledb::TimeMinute => Some(arrow::Int64), - tiledb::TimePicosecond => Some(arrow::Int64), - tiledb::TimeFemtosecond => Some(arrow::Int64), - tiledb::TimeAttosecond => Some(arrow::Int64), - - // Geometry types - tiledb::GeometryWkb => Some(arrow::LargeBinary), - tiledb::GeometryWkt => Some(arrow::LargeUtf8), - - // These are all of the types that have strict equivalents and - // should have already been handled above. - tiledb::Any - | tiledb::Boolean - | tiledb::Char - | tiledb::Int8 - | tiledb::Int16 - | tiledb::Int32 - | tiledb::Int64 - | tiledb::UInt8 - | tiledb::UInt16 - | tiledb::UInt32 - | tiledb::UInt64 - | tiledb::Float32 - | tiledb::Float64 - | tiledb::DateTimeSecond - | tiledb::DateTimeMillisecond - | tiledb::DateTimeMicrosecond - | tiledb::DateTimeNanosecond - | tiledb::TimeSecond - | tiledb::TimeMillisecond - | tiledb::TimeMicrosecond - | tiledb::TimeNanosecond - | tiledb::StringAscii - | tiledb::StringUtf8 - | tiledb::Blob => unreachable!("Strict conversion failed"), +impl TypeConversion { + pub fn into_inner(self) -> Output { + match self { + Self::PhysicalMatch(dt) => dt, + Self::LogicalMatch(dt) => dt, } } } -pub struct FromArrowConverter { - mode: ConversionMode, -} - -impl FromArrowConverter { - pub fn strict() -> Self { - Self { - mode: ConversionMode::Strict, - } - } - - pub fn relaxed() -> Self { - Self { - mode: ConversionMode::Relaxed, - } - } - - pub fn convert_datatype( - &self, - arrow_type: adt::DataType, - ) -> Result<(Datatype, CellValNum, Option)> { - use adt::DataType as arrow; - use Datatype as tiledb; - - let single = CellValNum::single(); - let var = CellValNum::Var; - - match arrow_type { - arrow::Null => Ok((tiledb::Any, single, None)), - arrow::Boolean => Ok((tiledb::Boolean, single, None)), - - arrow::Int8 => Ok((tiledb::Int8, single, None)), - arrow::Int16 => Ok((tiledb::Int16, single, None)), - arrow::Int32 => Ok((tiledb::Int32, single, None)), - arrow::Int64 => Ok((tiledb::Int64, single, None)), - arrow::UInt8 => Ok((tiledb::UInt8, single, None)), - arrow::UInt16 => Ok((tiledb::UInt16, single, None)), - arrow::UInt32 => Ok((tiledb::UInt32, single, None)), - arrow::UInt64 => Ok((tiledb::UInt64, single, None)), - arrow::Float32 => Ok((tiledb::Float32, single, None)), - arrow::Float64 => Ok((tiledb::Float64, single, None)), - - arrow::Timestamp(adt::TimeUnit::Second, None) => { - Ok((tiledb::DateTimeSecond, single, None)) - } - arrow::Timestamp(adt::TimeUnit::Millisecond, None) => { - Ok((tiledb::DateTimeMillisecond, single, None)) - } - arrow::Timestamp(adt::TimeUnit::Microsecond, None) => { - Ok((tiledb::DateTimeMicrosecond, single, None)) - } - arrow::Timestamp(adt::TimeUnit::Nanosecond, None) => { - Ok((tiledb::DateTimeNanosecond, single, None)) - } - arrow::Timestamp(_, Some(_)) => Err(Error::TimeZonesNotSupported), - - arrow::Time64(adt::TimeUnit::Second) => { - Ok((tiledb::TimeSecond, single, None)) - } - arrow::Time64(adt::TimeUnit::Millisecond) => { - Ok((tiledb::TimeMillisecond, single, None)) - } - arrow::Time64(adt::TimeUnit::Microsecond) => { - Ok((tiledb::TimeMicrosecond, single, None)) - } - arrow::Time64(adt::TimeUnit::Nanosecond) => { - Ok((tiledb::TimeNanosecond, single, None)) - } - - arrow::Utf8 => Ok((tiledb::StringUtf8, var, None)), - arrow::LargeUtf8 => Ok((tiledb::StringUtf8, var, None)), - arrow::Binary => Ok((tiledb::Blob, var, None)), - arrow::FixedSizeBinary(cvn) => { - if cvn < 1 { - return Err(Error::InvalidFixedSize(cvn)); - } - let cvn = if cvn == 1 { - CellValNum::single() - } else { - CellValNum::try_from(cvn as u32).unwrap() - }; - Ok((tiledb::Blob, cvn, None)) - } - arrow::LargeBinary => Ok((tiledb::Blob, var, None)), - - arrow::List(field) | arrow::LargeList(field) => { - let dtype = field.data_type(); - if !dtype.is_primitive() { - return Err(Error::UnsupportedListElementType( - dtype.clone(), - )); - } - - let (tdb_type, _, _) = - self.convert_datatype(dtype.clone()).map_err(|e| { - Error::ListElementTypeConversionFailed(Box::new(e)) - })?; - - Ok((tdb_type, var, Some(field.is_nullable()))) - } - - arrow::FixedSizeList(field, cvn) => { - let dtype = field.data_type(); - if !dtype.is_primitive() { - return Err(Error::UnsupportedListElementType( - dtype.clone(), - )); - } - - let (tdb_type, _, _) = - self.convert_datatype(dtype.clone()).map_err(|e| { - Error::ListElementTypeConversionFailed(Box::new(e)) - })?; - - Ok(( - tdb_type, - CellValNum::try_from(cvn as u32).unwrap(), - Some(field.is_nullable()), - )) - } - - // A few relaxed conversions for accepting Arrow types that don't - // line up directly with TileDB. - arrow::Date32 if self.is_relaxed() => { - Ok((tiledb::Int32, single, None)) - } - - arrow::Date64 if self.is_relaxed() => { - Ok((tiledb::Int64, single, None)) - } - - arrow::Time32(_) if self.is_relaxed() => { - Ok((tiledb::Int32, single, None)) - } - - // Notes on other possible relaxed conversions: - // - // Duration and some intervals are likely supportable, but - // leaving them off for now as the docs aren't clear. - // - // Views are also likely supportable, but will likely require - // separate buffer allocations since individual values are not - // contiguous. - // - // Struct and Union are never supportable (given current core) - // - // Dictionary is, but they should be handled higher up the stack - // to ensure that things line up with enumerations. - // - // Decimal128 and Decimal256 might be supportable using Float64 - // and 2 or 4 fixed length cell val num. Though it'd be fairly - // hacky. - // - // Map isn't supported in TileDB (given current core) - // - // RunEndEncoded is probably supportable, but like views will - // require separate buffer allocations so leaving for now. - arrow::Float16 - | arrow::Date32 - | arrow::Date64 - | arrow::Time32(_) - | arrow::Duration(_) - | arrow::Interval(_) - | arrow::BinaryView - | arrow::Utf8View - | arrow::ListView(_) - | arrow::LargeListView(_) - | arrow::Struct(_) - | arrow::Union(_, _) - | arrow::Dictionary(_, _) - | arrow::Decimal128(_, _) - | arrow::Decimal256(_, _) - | arrow::Map(_, _) - | arrow::RunEndEncoded(_, _) => { - Err(Error::UnsupportedArrowDataType(arrow_type)) - } - } - } - - fn is_relaxed(&self) -> bool { - matches!(self.mode, ConversionMode::Relaxed) - } -} +pub use self::compatibility::is_physically_compatible; +pub use self::default_to::{ + default_arrow_type, NoMatchDetail as DefaultArrowTypeError, +}; #[cfg(test)] mod tests { + use tiledb_common::array::CellValNum; + use tiledb_common::Datatype; + use super::*; - /// Test that a datatype is supported as a scalar type - /// if and only if it is also supported as a list element type #[test] - fn list_unsupported_element() { - let conv = ToArrowConverter::strict(); + fn default_compatibility() { for dt in Datatype::iter() { - let single_to_arrow = - conv.convert_datatype(&dt, &CellValNum::single(), false); - let var_to_arrow = - conv.convert_datatype(&dt, &CellValNum::Var, false); - - if let Err(Error::RequiresVarSized(_)) = single_to_arrow { - assert!(var_to_arrow.is_ok()); - } else if single_to_arrow.is_err() { - assert_eq!(single_to_arrow, var_to_arrow); - } - - if var_to_arrow.is_err() { - assert_eq!(var_to_arrow, single_to_arrow); - } + let single = default_arrow_type(dt, CellValNum::single()) + .unwrap() + .into_inner(); + assert!( + is_physically_compatible(&single, dt, CellValNum::single()), + "arrow = {}, tiledb = {}", + single, + dt + ); + + let fixed_cvn = CellValNum::try_from(4).unwrap(); + let fixed = default_arrow_type(dt, fixed_cvn).unwrap().into_inner(); + assert!( + is_physically_compatible(&fixed, dt, fixed_cvn), + "arrow = {}, tiledb = {}", + fixed, + dt + ); + + let var = default_arrow_type(dt, CellValNum::Var) + .unwrap() + .into_inner(); + assert!( + is_physically_compatible(&var, dt, CellValNum::Var), + "arrow = {}, tiledb = {}", + var, + dt + ); } } } From 535cdc21e985a95392931f37f914ee4c3638bd45 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 10:24:40 -0500 Subject: [PATCH 29/42] Change test name --- tiledb/query-core/src/tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiledb/query-core/src/tests.rs b/tiledb/query-core/src/tests.rs index 0ce0527f..b298cbc5 100644 --- a/tiledb/query-core/src/tests.rs +++ b/tiledb/query-core/src/tests.rs @@ -17,7 +17,7 @@ use uri::TestArrayUri; use super::*; #[test] -fn query_roundtrip() -> anyhow::Result<()> { +fn query_roundtrip_once() -> anyhow::Result<()> { let ctx = Context::new()?; let schema_req = query_write_schema_requirements(None); From d05c1c1919d7e42c6d9a2f5c47afb67044dfec83 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 10:24:53 -0500 Subject: [PATCH 30/42] Eq for Mode --- tiledb/common/src/array/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tiledb/common/src/array/mod.rs b/tiledb/common/src/array/mod.rs index 1e856874..c416024b 100644 --- a/tiledb/common/src/array/mod.rs +++ b/tiledb/common/src/array/mod.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; pub use dimension::DimensionConstraints; -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum Mode { Read, Write, From 27043b86c3ab18c986afc82bc79f6a781eea9c1f Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 10:25:39 -0500 Subject: [PATCH 31/42] fixed_offsets for reading/writing var-size tiledb field using FixedSize arrow array --- tiledb/query-core/src/buffers.rs | 81 +++++++++++++++++++++++--------- tiledb/query-core/src/field.rs | 2 +- 2 files changed, 59 insertions(+), 24 deletions(-) diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers.rs index a3fc0635..e4062ccc 100644 --- a/tiledb/query-core/src/buffers.rs +++ b/tiledb/query-core/src/buffers.rs @@ -190,9 +190,23 @@ trait NewBufferTraitThing { } #[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum BufferTarget { - Read, - Write(bool), +struct BufferTarget { + query_type: QueryType, + cell_val_num: CellValNum, + is_nullable: bool, +} + +impl BufferTarget { + pub fn new( + query_type: QueryType, + target: QueryField, + ) -> crate::field::Result { + Ok(Self { + query_type, + cell_val_num: target.cell_val_num()?, + is_nullable: target.nullable()?, + }) + } } struct BooleanBuffers { @@ -430,6 +444,8 @@ struct FixedListBuffers { field: Arc, cell_val_num: CellValNum, data: QueryBuffer, + /// Optional offsets buffer when targeting a variable-size field. + fixed_offsets: Option, validity: Option, } @@ -462,6 +478,19 @@ impl FixedListBuffers { )); } + let fixed_offsets = match target.cell_val_num { + CellValNum::Fixed(_) => None, + CellValNum::Var => { + let offsets = abuf::OffsetBuffer::::from_lengths(vec![ + cvn as usize; + num_cells + ]); + // SAFETY: just created a new one, it is definitely not in use + let tdb_offsets = to_tdb_offsets(offsets).unwrap(); + Some(QueryBuffer::new(tdb_offsets)) + } + }; + // SAFETY: We just showed cvn >= 2 && cvn is i32 whicih means // it can't be u32::MAX let cvn = CellValNum::try_from(cvn as u32) @@ -485,6 +514,7 @@ impl FixedListBuffers { field: Arc::clone(&field), cell_val_num: cvn, data: buffers.data, + fixed_offsets, validity, } }) @@ -562,18 +592,26 @@ impl NewBufferTraitThing for FixedListBuffers { 0, vec![data.clone()], vec![], - ) - .map(|data| { - let array: Arc = - Arc::new(aa::FixedSizeListArray::new( - Arc::clone(&field), - u32::from(cell_val_num) as i32, - aa::make_array(data), - validity.clone(), - )); - array - }) { - Ok(arrow) => Ok(arrow), + ) { + Ok(data) => { + // validate fixed offsets (if any) + if let Some(fixed_offsets) = self.fixed_offsets { + let offsets = from_tdb_offsets(fixed_offsets.buffer); + for w in offsets.windows(2) { + if w[1] - w[0] != (cvn as i64) { + todo!() + } + } + } + let array: Arc = + Arc::new(aa::FixedSizeListArray::new( + Arc::clone(&field), + u32::from(cell_val_num) as i32, + aa::make_array(data), + validity.clone(), + )); + Ok(array) + } Err(e) => { let boxed: Box = Box::new(FixedListBuffers { @@ -583,6 +621,7 @@ impl NewBufferTraitThing for FixedListBuffers { buffer: data.into_mutable().unwrap(), size: self.data.size, }, + fixed_offsets: self.fixed_offsets, validity: self.validity, }); @@ -1415,13 +1454,7 @@ fn make_buffer_entry( let query_field = QueryField::get(context, raw, &name)?; let array = request_to_buffers(request, &query_field)?; - let target = match query_type { - QueryType::Read => BufferTarget::Read, - QueryType::Write => BufferTarget::Write(query_field.nullable()?), - QueryType::Delete => unimplemented!(), - QueryType::Update => unimplemented!(), - QueryType::ModifyExclusive => unimplemented!(), - }; + let target = BufferTarget::new(query_type, query_field)?; let entry = BufferEntry::new(target, array); Ok(if let Some(aggregate) = aggregate { @@ -1667,7 +1700,9 @@ fn to_tdb_validity( }); if arrow_validity.is_some() { arrow_validity - } else if matches!(target, BufferTarget::Write(true)) { + } else if target.is_nullable + && matches!(target.query_type, QueryType::Write) + { Some(ArrowBufferMut::from(vec![1u8; num_cells])) } else { None diff --git a/tiledb/query-core/src/field.rs b/tiledb/query-core/src/field.rs index d5e8e64f..0cfae842 100644 --- a/tiledb/query-core/src/field.rs +++ b/tiledb/query-core/src/field.rs @@ -20,7 +20,7 @@ pub enum Error { LibTileDB(#[from] tiledb_api::error::Error), } -type Result = std::result::Result; +pub type Result = std::result::Result; pub enum RawQueryField { Owned(Context, *mut ffi::tiledb_query_field_t), From 2021b2821cc23ccea5c05f9e9d8a67b94db60cf4 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 13:21:10 -0500 Subject: [PATCH 32/42] Steal WithoutReplacement strategy extension from tables --- test-utils/strategy-ext/src/lib.rs | 14 +++++ .../strategy-ext/src/without_replacement.rs | 61 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 test-utils/strategy-ext/src/without_replacement.rs diff --git a/test-utils/strategy-ext/src/lib.rs b/test-utils/strategy-ext/src/lib.rs index 6aa1877d..74ff1e05 100644 --- a/test-utils/strategy-ext/src/lib.rs +++ b/test-utils/strategy-ext/src/lib.rs @@ -2,6 +2,7 @@ pub mod meta; pub mod records; pub mod sequence; pub mod strategy; +pub mod without_replacement; use std::fmt::Debug; @@ -20,6 +21,19 @@ pub trait StrategyExt: Strategy { meta::ValueTreeStrategy(self) } + /// Returns a strategy which produces only test cases it has not previously produced. + /// This requires keeping the previous test cases used, + /// so this should not be used with large values or for too many test cases. + fn prop_without_replacement( + self, + ) -> without_replacement::WithoutReplacement + where + Self: Sized, + Self::Value: Eq + std::hash::Hash, + { + without_replacement::WithoutReplacement::new(self) + } + /// Returns a strategy which produces values transformed by /// the [ValueTree] mapping function `transform`. /// diff --git a/test-utils/strategy-ext/src/without_replacement.rs b/test-utils/strategy-ext/src/without_replacement.rs new file mode 100644 index 00000000..93a3f703 --- /dev/null +++ b/test-utils/strategy-ext/src/without_replacement.rs @@ -0,0 +1,61 @@ +use std::cell::RefCell; +use std::collections::HashSet; +use std::fmt::Debug; +use std::hash::Hash; + +use proptest::strategy::{NewTree, Strategy, ValueTree}; +use proptest::test_runner::TestRunner; + +#[derive(Debug)] +pub struct WithoutReplacement { + source: T, + values: RefCell>, +} + +impl WithoutReplacement { + pub fn new(source: T) -> Self { + Self { + source, + values: RefCell::new(HashSet::new()), + } + } +} + +impl Strategy for WithoutReplacement +where + T::Value: Eq + Hash, +{ + type Tree = T::Tree; + type Value = T::Value; + + fn new_tree(&self, runner: &mut TestRunner) -> NewTree { + loop { + let tree = self.source.new_tree(runner)?; + if self.values.borrow_mut().insert(tree.current()) { + break Ok(tree); + } + runner.reject_local(format!( + "Strategy generated value already: {:?}", + tree.current() + ))?; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proptest::prelude::*; + + use crate::StrategyExt; + + #[test] + fn without_replacement() { + let previous = RefCell::new(HashSet::new()); + + proptest!(|(s in any::().prop_without_replacement())| { + let first_insert = previous.borrow_mut().insert(s); + assert!(first_insert); + }); + } +} From fa29848edbf8d8cae09bbc46615595902eb1060a Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 13:28:42 -0500 Subject: [PATCH 33/42] arrow-proptest-strategies, stolen from tables --- Cargo.lock | 12 + Cargo.toml | 2 + .../arrow-proptest-strategies/Cargo.toml | 13 + .../arrow-proptest-strategies/src/array.rs | 309 ++++++++++++++++++ .../arrow-proptest-strategies/src/lib.rs | 3 + .../src/record_batch.rs | 90 +++++ .../arrow-proptest-strategies/src/schema.rs | 202 ++++++++++++ 7 files changed, 631 insertions(+) create mode 100644 test-utils/arrow-proptest-strategies/Cargo.toml create mode 100644 test-utils/arrow-proptest-strategies/src/array.rs create mode 100644 test-utils/arrow-proptest-strategies/src/lib.rs create mode 100644 test-utils/arrow-proptest-strategies/src/record_batch.rs create mode 100644 test-utils/arrow-proptest-strategies/src/schema.rs diff --git a/Cargo.lock b/Cargo.lock index 99de8292..4ddbdcc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -329,6 +329,18 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-proptest-strategies" +version = "0.1.0" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", + "half", + "proptest", + "strategy-ext", +] + [[package]] name = "arrow-row" version = "52.2.0" diff --git a/Cargo.toml b/Cargo.toml index ff60e5cc..6d52c0bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "tiledb/sys-cfg", "tiledb/sys-defs", "tiledb/utils", + "test-utils/arrow-proptest-strategies", "test-utils/cells", "test-utils/proptest-config", "test-utils/signal", @@ -43,6 +44,7 @@ arrow-schema = { version = "52.0.0" } bindgen = "0.70" cells = { path = "test-utils/cells", version = "0.1.0" } cmake = "0.1" +half = { version = "2.2.1", default-features = false } itertools = "0" num-traits = "0.2" paste = "1.0" diff --git a/test-utils/arrow-proptest-strategies/Cargo.toml b/test-utils/arrow-proptest-strategies/Cargo.toml new file mode 100644 index 00000000..94488856 --- /dev/null +++ b/test-utils/arrow-proptest-strategies/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "arrow-proptest-strategies" +edition.workspace = true +rust-version.workspace = true +version.workspace = true + +[dependencies] +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-schema = { workspace = true } +half = { workspace = true } +proptest = { workspace = true } +strategy-ext = { workspace = true } diff --git a/test-utils/arrow-proptest-strategies/src/array.rs b/test-utils/arrow-proptest-strategies/src/array.rs new file mode 100644 index 00000000..24677eb2 --- /dev/null +++ b/test-utils/arrow-proptest-strategies/src/array.rs @@ -0,0 +1,309 @@ +use std::sync::Arc; + +use arrow_array::builder::FixedSizeBinaryBuilder; +use arrow_array::types::{ + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, +}; +use arrow_array::*; +use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer}; +use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; +use proptest::collection::{vec as strat_vec, SizeRange}; +use proptest::prelude::*; +use proptest::strategy::BoxedStrategy; + +pub const DEFAULT_NONE_PROBABILITY: f64 = 0.0625f64; + +#[derive(Clone, Debug)] +pub struct ColumnParameters { + /// Strategy for choosing the number of rows in the column. + pub num_rows: BoxedStrategy, + /// Strategy for choosing the number of elements in variable-length column elements. + pub num_collection_elements: SizeRange, + /// Whether to allow null values. + pub allow_null_values: bool, + /// Whether to allow elements of collection types such as `DataType::LargeList` + /// to be null. Defaults to `false`. + pub allow_null_collection_element: bool, +} + +pub fn prop_array( + params: ColumnParameters, + field: Arc, +) -> impl Strategy> { + fn to_arc_dyn(array: T) -> Arc + where + T: Array + 'static, + { + Arc::new(array) + } + + macro_rules! strat { + ($nrows:expr, $strat:expr, $arraytype:ident) => {{ + strat!($nrows, $strat, $arraytype::from, $arraytype::from) + }}; + ($nrows:expr, $strat:expr, $makearray:expr) => {{ + strat!($nrows, $strat, $makearray, $makearray) + }}; + ($nrows:expr, $strat:expr, $nullable:expr, $nonnullable:expr) => {{ + if field.is_nullable() && params.allow_null_values { + strat_vec(optional($strat), $nrows) + .prop_map($nullable) + .prop_map(to_arc_dyn) + .boxed() + } else { + strat_vec($strat, $nrows) + .prop_map($nonnullable) + .prop_map(to_arc_dyn) + .boxed() + } + }}; + } + + macro_rules! any { + ($nrows:expr, $datatype:ty, $arraytype:ident) => {{ + strat!($nrows, any::<$datatype>(), $arraytype) + }}; + } + + macro_rules! binary { + ($nrows:expr, $eltlen:expr, $arraytype:ident) => {{ + let strat_element = strat_vec(any::(), $eltlen); + strat!( + $nrows, + strat_element, + |elts| $arraytype::from( + elts.iter() + .map(|e| e.as_ref().map(|e| e.as_ref())) + .collect::>>() + ), + |elts| $arraytype::from( + elts.iter().map(|e| e.as_ref()).collect::>() + ) + ) + }}; + } + + params.num_rows.clone().prop_flat_map(move |num_rows| { + match field.data_type() { + DataType::Null => { + Just(to_arc_dyn(NullArray::new(num_rows))).boxed() + } + DataType::Boolean => any!(num_rows, bool, BooleanArray), + DataType::Int8 => any!(num_rows, i8, Int8Array), + DataType::Int16 => any!(num_rows, i16, Int16Array), + DataType::Int32 => any!(num_rows, i32, Int32Array), + DataType::Int64 => any!(num_rows, i64, Int64Array), + DataType::UInt8 => any!(num_rows, u8, UInt8Array), + DataType::UInt16 => any!(num_rows, u16, UInt16Array), + DataType::UInt32 => any!(num_rows, u32, UInt32Array), + DataType::UInt64 => any!(num_rows, u64, UInt64Array), + DataType::Float16 => { + strat!( + num_rows, + any::().prop_map(half::f16::from_f32), + Float16Array + ) + } + DataType::Float32 => any!(num_rows, f32, Float32Array), + DataType::Float64 => any!(num_rows, f64, Float64Array), + DataType::Timestamp(TimeUnit::Second, _) => { + any!(num_rows, i64, TimestampSecondArray) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + any!(num_rows, i64, TimestampMillisecondArray) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + any!(num_rows, i64, TimestampMicrosecondArray) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + any!(num_rows, i64, TimestampNanosecondArray) + } + DataType::Date32 => any!(num_rows, i32, Date32Array), + DataType::Date64 => any!(num_rows, i64, Date64Array), + DataType::Time32(TimeUnit::Second) => { + any!(num_rows, i32, Time32SecondArray) + } + DataType::Time32(TimeUnit::Millisecond) => { + any!(num_rows, i32, Time32MillisecondArray) + } + DataType::Time32(_) => { + panic!("Invalid data type: {:?}", field.data_type()) + } + DataType::Time64(TimeUnit::Microsecond) => { + any!(num_rows, i64, Time64MicrosecondArray) + } + DataType::Time64(TimeUnit::Nanosecond) => { + any!(num_rows, i64, Time64NanosecondArray) + } + DataType::Duration(TimeUnit::Second) => { + any!(num_rows, i64, DurationSecondArray) + } + DataType::Duration(TimeUnit::Millisecond) => { + any!(num_rows, i64, DurationMillisecondArray) + } + DataType::Duration(TimeUnit::Microsecond) => { + any!(num_rows, i64, DurationMicrosecondArray) + } + DataType::Duration(TimeUnit::Nanosecond) => { + any!(num_rows, i64, DurationNanosecondArray) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let strat_element = any::().prop_map(|val| { + IntervalYearMonthType::make_value(val / 12, val % 12) + }); + strat!(num_rows, strat_element, IntervalYearMonthArray) + } + DataType::Interval(IntervalUnit::DayTime) => { + let strat_element = + (any::(), any::()).prop_map(|(days, millis)| { + IntervalDayTimeType::make_value(days, millis) + }); + strat!(num_rows, strat_element, IntervalDayTimeArray) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let strat_element = (any::(), any::(), any::()) + .prop_map(|(months, days, nanos)| { + IntervalMonthDayNanoType::make_value( + months, days, nanos, + ) + }); + strat!(num_rows, strat_element, IntervalMonthDayNanoArray) + } + DataType::Binary => binary!( + num_rows, + params.num_collection_elements.clone(), + BinaryArray + ), + DataType::FixedSizeBinary(flen) => { + let flen = *flen; + let strat_element = strat_vec(any::(), flen as usize); + strat!( + num_rows, + strat_element, + move |elts| { + let mut values = FixedSizeBinaryBuilder::with_capacity( + elts.len(), + flen, + ); + elts.into_iter().for_each(|elt| { + if let Some(elt) = elt { + values.append_value(elt).unwrap(); + } else { + values.append_null(); + } + }); + values.finish() + }, + move |elts| FixedSizeBinaryArray::new( + flen, + elts.into_iter().flatten().collect::(), + None + ) + ) + } + DataType::LargeBinary => { + binary!( + num_rows, + params.num_collection_elements.clone(), + LargeBinaryArray + ) + } + DataType::Utf8 => any!(num_rows, String, StringArray), + DataType::LargeUtf8 => any!(num_rows, String, LargeStringArray), + DataType::Decimal128(p, s) => { + let (p, s) = (*p, *s); + strat!(num_rows, any::(), move |values| { + Decimal128Array::from(values) + .with_precision_and_scale(p, s) + .expect("Invalid precision and scale") + }) + } + DataType::Decimal256(p, s) => { + let (p, s) = (*p, *s); + strat!( + num_rows, + any::<[u8; 32]>().prop_map(i256::from_le_bytes), + move |values| Decimal256Array::from(values) + .with_precision_and_scale(p, s) + .expect("Invalid precision and scale") + ) + } + DataType::LargeList(element) => { + let field = Arc::clone(&field); + let element = Arc::clone(element); + let (min_values, max_values) = { + let r = params.num_collection_elements.clone(); + (num_rows * r.start(), num_rows * r.end_incl()) + }; + let values_parameters = ColumnParameters { + num_rows: (min_values..=max_values).boxed(), + allow_null_values: params.allow_null_collection_element, + ..params.clone() + }; + prop_array(values_parameters, Arc::clone(&element)) + .prop_flat_map(move |values| { + let num_values = values.len(); + ( + Just(values), + strat_list_subdivisions( + num_values, + field.is_nullable(), + ), + ) + }) + .prop_map(move |(values, (offsets, nulls))| { + GenericListArray::new( + Arc::clone(&element), + offsets, + values, + nulls, + ) + }) + .prop_map(to_arc_dyn) + .boxed() + } + _ => unreachable!( + "Not implemented in schema strategy: {}", + field.data_type() + ), + } + }) +} + +fn optional(strat: T) -> impl Strategy> { + proptest::option::weighted(1.0 - DEFAULT_NONE_PROBABILITY, strat) +} + +fn strat_list_subdivisions( + num_elements: usize, + is_nullable: bool, +) -> impl Strategy, Option)> { + let strat_offset = 0..=num_elements; + let strat_num_lists = 0..=num_elements; + + strat_vec((strat_offset, any::()), strat_num_lists).prop_map( + move |mut rows| { + rows.sort(); + rows.push((num_elements, false)); + rows[0].0 = 0; + + let offsets = OffsetBuffer::new( + rows.iter() + .map(|(o, _)| *o as i64) + .collect::>() + .into(), + ); + let nulls = if is_nullable { + Some( + rows.iter() + .map(|(_, b)| *b) + .take(rows.len() - 1) + .collect::(), + ) + } else { + None + }; + (offsets, nulls) + }, + ) +} diff --git a/test-utils/arrow-proptest-strategies/src/lib.rs b/test-utils/arrow-proptest-strategies/src/lib.rs new file mode 100644 index 00000000..58900775 --- /dev/null +++ b/test-utils/arrow-proptest-strategies/src/lib.rs @@ -0,0 +1,3 @@ +pub mod array; +pub mod record_batch; +pub mod schema; diff --git a/test-utils/arrow-proptest-strategies/src/record_batch.rs b/test-utils/arrow-proptest-strategies/src/record_batch.rs new file mode 100644 index 00000000..9e92831e --- /dev/null +++ b/test-utils/arrow-proptest-strategies/src/record_batch.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use arrow_array::{Array, RecordBatch}; +use arrow_schema::{Schema, SchemaRef}; +use proptest::collection::SizeRange; +use proptest::prelude::*; + +use crate::array::{prop_array, ColumnParameters}; + +#[derive(Clone, Debug)] +pub struct RecordBatchParameters { + /// Strategy for choosing a number of rows. + pub num_rows: BoxedStrategy, + /// Strategy for choosing the number of elements in variable-length column elements. + pub num_collection_elements: SizeRange, + /// Whether to allow elements of collection types such as `DataType::LargeList` + /// to be null. Defaults to `true`. + pub allow_null_collection_element: bool, +} + +impl RecordBatchParameters { + pub fn column_parameters(&self) -> ColumnParameters { + ColumnParameters { + num_rows: self.num_rows.clone(), + num_collection_elements: self.num_collection_elements.clone(), + allow_null_values: true, + allow_null_collection_element: self.allow_null_collection_element, + } + } +} + +impl Default for RecordBatchParameters { + fn default() -> Self { + Self { + num_rows: (0..=4usize).boxed(), + num_collection_elements: (0..=4usize).into(), + allow_null_collection_element: true, + } + } +} + +pub fn prop_record_batch( + schema: BoxedStrategy, + params: RecordBatchParameters, +) -> impl Strategy { + schema.prop_flat_map(move |s| { + prop_record_batch_for_schema(params.clone(), Arc::new(s)) + }) +} + +pub fn prop_record_batch_for_schema( + params: RecordBatchParameters, + schema: SchemaRef, +) -> impl Strategy { + params + .num_rows + .clone() + .prop_flat_map(move |num_rows| { + let column_params = ColumnParameters { + num_rows: Just(num_rows).boxed(), + ..params.column_parameters() + }; + let columns = schema + .fields + .iter() + .map(move |field| { + prop_array(column_params.clone(), Arc::clone(field)).boxed() + }) + .collect::>>>(); + + (Just(Arc::clone(&schema)), columns) + }) + .prop_map(|(schema, columns)| { + RecordBatch::try_new(schema, columns).unwrap() + }) +} +#[cfg(test)] +mod tests { + use super::*; + + proptest! { + #[test] + fn strategy_validity(_ in prop_record_batch( + crate::schema::prop_arrow_schema(Default::default()).boxed(), + Default::default()) + ) { + // NB: empty, this just checks that we produce correct record batches + } + } +} diff --git a/test-utils/arrow-proptest-strategies/src/schema.rs b/test-utils/arrow-proptest-strategies/src/schema.rs new file mode 100644 index 00000000..957ac001 --- /dev/null +++ b/test-utils/arrow-proptest-strategies/src/schema.rs @@ -0,0 +1,202 @@ +use std::rc::Rc; + +use arrow_schema::{self, DataType, Field, IntervalUnit, Schema, TimeUnit}; +use proptest::prelude::*; +use strategy_ext::StrategyExt; + +#[derive(Clone, Debug)] +pub struct SchemaParameters { + pub num_fields: proptest::collection::SizeRange, + pub field_names: BoxedStrategy, + pub field_type: BoxedStrategy, +} + +impl Default for SchemaParameters { + fn default() -> Self { + SchemaParameters { + num_fields: (1..32).into(), + field_names: proptest::string::string_regex("[a-zA-Z0-9_]*") + .unwrap() + .prop_without_replacement() + .boxed(), + field_type: prop_arrow_datatype(Default::default()).boxed(), + } + } +} + +#[derive(Clone, Debug)] +pub struct DataTypeParameters { + pub fixed_binary_len: BoxedStrategy, +} + +impl Default for DataTypeParameters { + fn default() -> Self { + DataTypeParameters { + fixed_binary_len: (1..(128 * 1024)).boxed(), + } + } +} + +pub fn prop_arrow_schema( + params: SchemaParameters, +) -> impl Strategy { + let strat_num_fields = params.num_fields.clone(); + proptest::collection::vec( + prop_arrow_field(Rc::new(params)), + strat_num_fields, + ) + .prop_map(Schema::new) +} + +pub fn prop_arrow_field( + params: Rc, +) -> impl Strategy { + let strat_name = params.field_names.clone(); + + (strat_name, params.field_type.clone(), any::()).prop_map( + |(name, datatype, nullable)| Field::new(name, datatype, nullable), + ) +} + +pub fn prop_arrow_datatype( + params: Rc, +) -> impl Strategy { + let leaf = prop_oneof![ + Just(DataType::Null), + Just(DataType::Int8), + Just(DataType::Int16), + Just(DataType::Int32), + Just(DataType::Int64), + Just(DataType::UInt8), + Just(DataType::UInt16), + Just(DataType::UInt32), + Just(DataType::UInt64), + Just(DataType::Float16), + Just(DataType::Float32), + Just(DataType::Float64), + Just(DataType::Timestamp(TimeUnit::Second, None)), + Just(DataType::Timestamp(TimeUnit::Millisecond, None)), + Just(DataType::Timestamp(TimeUnit::Microsecond, None)), + Just(DataType::Timestamp(TimeUnit::Nanosecond, None)), + Just(DataType::Date32), + Just(DataType::Date64), + Just(DataType::Time32(TimeUnit::Second)), + Just(DataType::Time32(TimeUnit::Millisecond)), + Just(DataType::Time64(TimeUnit::Microsecond)), + Just(DataType::Time64(TimeUnit::Nanosecond)), + Just(DataType::Duration(TimeUnit::Second)), + Just(DataType::Duration(TimeUnit::Millisecond)), + Just(DataType::Duration(TimeUnit::Nanosecond)), + Just(DataType::Interval(IntervalUnit::YearMonth)), + Just(DataType::Interval(IntervalUnit::DayTime)), + Just(DataType::Interval(IntervalUnit::MonthDayNano)), + Just(DataType::Binary), + params + .fixed_binary_len + .clone() + .prop_map(DataType::FixedSizeBinary), + Just(DataType::LargeBinary), + Just(DataType::Utf8), + Just(DataType::LargeUtf8), + prop_arrow_type_decimal128(), + prop_arrow_type_decimal256(), + ]; + + // TODO: use `prop_recursive` for other datatypes + leaf +} + +/// Returns a strategy which produces `DataType`s for which +/// `DataType::is_numeric()` returns `true`. +pub fn prop_arrow_datatype_numeric() -> impl Strategy { + prop_oneof![ + Just(DataType::Int8), + Just(DataType::Int16), + Just(DataType::Int32), + Just(DataType::Int64), + Just(DataType::UInt8), + Just(DataType::UInt16), + Just(DataType::UInt32), + Just(DataType::UInt64), + Just(DataType::Float16), + Just(DataType::Float32), + Just(DataType::Float64), + prop_arrow_type_decimal128(), + prop_arrow_type_decimal256(), + ] +} + +/// Returns a strategy which produces `DataType`s for which +/// `DataType::is_integer()` returns `true`. +pub fn prop_arrow_datatype_integer() -> impl Strategy { + prop_oneof![ + Just(DataType::Int8), + Just(DataType::Int16), + Just(DataType::Int32), + Just(DataType::Int64), + Just(DataType::UInt8), + Just(DataType::UInt16), + Just(DataType::UInt32), + Just(DataType::UInt64), + ] +} + +fn prop_arrow_type_decimal128() -> impl Strategy { + (1..=arrow_schema::DECIMAL128_MAX_PRECISION).prop_flat_map(|precision| { + ( + Just(precision), + (0..precision.clamp(0, arrow_schema::DECIMAL128_MAX_SCALE as u8)), + ) + .prop_map(|(precision, scale)| { + DataType::Decimal128(precision, scale as i8) + }) + }) +} + +fn prop_arrow_type_decimal256() -> impl Strategy { + (1..=arrow_schema::DECIMAL256_MAX_PRECISION).prop_flat_map(|precision| { + ( + Just(precision), + (0..precision.clamp(0, arrow_schema::DECIMAL256_MAX_SCALE as u8)), + ) + .prop_map(|(precision, scale)| { + DataType::Decimal256(precision, scale as i8) + }) + }) +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + fn do_unique_names(schema: Schema) { + let field_names = + schema.fields.iter().map(|f| f.name()).collect::>(); + let unique_names = field_names.iter().collect::>(); + assert_eq!( + field_names.len(), + unique_names.len(), + "field_names = {:?}", + field_names + ); + } + + proptest! { + #[test] + fn unique_names(schema in prop_arrow_schema(Default::default())) { + do_unique_names(schema); + } + + #[test] + fn is_numeric(datatype in prop_arrow_datatype_numeric()) { + assert!(datatype.is_numeric()); + } + + #[test] + fn is_integer(datatype in prop_arrow_datatype_integer()) { + assert!(datatype.is_integer()); + } + } +} From 2f7cf50dd0da53d373a886476c37e236f36aa89c Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 14:24:38 -0500 Subject: [PATCH 34/42] arrow-proptest-strategies tweaks --- .../arrow-proptest-strategies/src/array.rs | 17 ++++++++++++++--- test-utils/arrow-proptest-strategies/src/lib.rs | 2 ++ .../src/record_batch.rs | 8 ++++---- .../src/{buffers.rs => buffers/mod.rs} | 0 4 files changed, 20 insertions(+), 7 deletions(-) rename tiledb/query-core/src/{buffers.rs => buffers/mod.rs} (100%) diff --git a/test-utils/arrow-proptest-strategies/src/array.rs b/test-utils/arrow-proptest-strategies/src/array.rs index 24677eb2..f3954bb4 100644 --- a/test-utils/arrow-proptest-strategies/src/array.rs +++ b/test-utils/arrow-proptest-strategies/src/array.rs @@ -14,7 +14,7 @@ use proptest::strategy::BoxedStrategy; pub const DEFAULT_NONE_PROBABILITY: f64 = 0.0625f64; #[derive(Clone, Debug)] -pub struct ColumnParameters { +pub struct ArrayParameters { /// Strategy for choosing the number of rows in the column. pub num_rows: BoxedStrategy, /// Strategy for choosing the number of elements in variable-length column elements. @@ -26,8 +26,19 @@ pub struct ColumnParameters { pub allow_null_collection_element: bool, } +impl Default for ArrayParameters { + fn default() -> Self { + Self { + num_rows: (0..=8usize).boxed(), + num_collection_elements: (0..=8).into(), + allow_null_values: true, + allow_null_collection_element: false, + } + } +} + pub fn prop_array( - params: ColumnParameters, + params: ArrayParameters, field: Arc, ) -> impl Strategy> { fn to_arc_dyn(array: T) -> Arc @@ -235,7 +246,7 @@ pub fn prop_array( let r = params.num_collection_elements.clone(); (num_rows * r.start(), num_rows * r.end_incl()) }; - let values_parameters = ColumnParameters { + let values_parameters = ArrayParameters { num_rows: (min_values..=max_values).boxed(), allow_null_values: params.allow_null_collection_element, ..params.clone() diff --git a/test-utils/arrow-proptest-strategies/src/lib.rs b/test-utils/arrow-proptest-strategies/src/lib.rs index 58900775..87a4cf98 100644 --- a/test-utils/arrow-proptest-strategies/src/lib.rs +++ b/test-utils/arrow-proptest-strategies/src/lib.rs @@ -1,3 +1,5 @@ pub mod array; pub mod record_batch; pub mod schema; + +pub use array::prop_array; diff --git a/test-utils/arrow-proptest-strategies/src/record_batch.rs b/test-utils/arrow-proptest-strategies/src/record_batch.rs index 9e92831e..c6ac7d55 100644 --- a/test-utils/arrow-proptest-strategies/src/record_batch.rs +++ b/test-utils/arrow-proptest-strategies/src/record_batch.rs @@ -5,7 +5,7 @@ use arrow_schema::{Schema, SchemaRef}; use proptest::collection::SizeRange; use proptest::prelude::*; -use crate::array::{prop_array, ColumnParameters}; +use crate::array::{prop_array, ArrayParameters}; #[derive(Clone, Debug)] pub struct RecordBatchParameters { @@ -19,8 +19,8 @@ pub struct RecordBatchParameters { } impl RecordBatchParameters { - pub fn column_parameters(&self) -> ColumnParameters { - ColumnParameters { + pub fn column_parameters(&self) -> ArrayParameters { + ArrayParameters { num_rows: self.num_rows.clone(), num_collection_elements: self.num_collection_elements.clone(), allow_null_values: true, @@ -56,7 +56,7 @@ pub fn prop_record_batch_for_schema( .num_rows .clone() .prop_flat_map(move |num_rows| { - let column_params = ColumnParameters { + let column_params = ArrayParameters { num_rows: Just(num_rows).boxed(), ..params.column_parameters() }; diff --git a/tiledb/query-core/src/buffers.rs b/tiledb/query-core/src/buffers/mod.rs similarity index 100% rename from tiledb/query-core/src/buffers.rs rename to tiledb/query-core/src/buffers/mod.rs From 8d0fe421fbb74d9a55beb4e5a58b6f1f6444e7de Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 14:51:46 -0500 Subject: [PATCH 35/42] proptest_list_buffers_roundtrip_var --- Cargo.lock | 1 + Cargo.toml | 1 + tiledb/query-core/Cargo.toml | 1 + tiledb/query-core/src/buffers/mod.rs | 274 +++++++++++++++++-------- tiledb/query-core/src/buffers/tests.rs | 131 ++++++++++++ 5 files changed, 323 insertions(+), 85 deletions(-) create mode 100644 tiledb/query-core/src/buffers/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 4ddbdcc3..ed681226 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1900,6 +1900,7 @@ version = "0.1.0" dependencies = [ "anyhow", "arrow", + "arrow-proptest-strategies", "cells", "itertools 0.12.1", "proptest", diff --git a/Cargo.toml b/Cargo.toml index 6d52c0bd..bc82287b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ armerge = "2" arrow = { version = "52.0.0", features = ["prettyprint"] } arrow-array = { version = "52.0.0" } arrow-buffer = { version = "52.0.0" } +arrow-proptest-strategies = { path = "test-utils/arrow-proptest-strategies" } arrow-schema = { version = "52.0.0" } bindgen = "0.70" cells = { path = "test-utils/cells", version = "0.1.0" } diff --git a/tiledb/query-core/Cargo.toml b/tiledb/query-core/Cargo.toml index 5a04cf2c..a16ef8dd 100644 --- a/tiledb/query-core/Cargo.toml +++ b/tiledb/query-core/Cargo.toml @@ -13,6 +13,7 @@ tiledb-sys = { workspace = true } [dev-dependencies] anyhow = { workspace = true } +arrow-proptest-strategies = { workspace = true } cells = { workspace = true, features = ["arrow", "proptest-strategies"] } itertools = { workspace = true } proptest = { workspace = true } diff --git a/tiledb/query-core/src/buffers/mod.rs b/tiledb/query-core/src/buffers/mod.rs index e4062ccc..2d0f9c28 100644 --- a/tiledb/query-core/src/buffers/mod.rs +++ b/tiledb/query-core/src/buffers/mod.rs @@ -3,9 +3,10 @@ use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; -use arrow::array as aa; +use arrow::array::{self as aa, Array}; use arrow::buffer::{ self as abuf, Buffer as ArrowBuffer, MutableBuffer as ArrowBufferMut, + OffsetBuffer, }; use arrow::datatypes as adt; use arrow::error::ArrowError; @@ -86,6 +87,8 @@ pub enum UnsupportedArrowArrayError { "TileDB only supports fixed size lists of primitive types, not {0}" )] UnsupportedFixedSizeListType(adt::DataType), + #[error("Offsets do not match target query field: expected fixed-size values of length {0}, found {1} at cell {2}")] + FixedOffsets(usize, usize, usize), #[error("Invalid data type for bytes array: {0}")] InvalidBytesType(adt::DataType), #[error("Invalid data type for primitive data: {0}")] @@ -582,27 +585,32 @@ impl NewBufferTraitThing for FixedListBuffers { let cvn = u32::from(cell_val_num) as i32; let len = num_values / cvn as usize; - // N.B., data/validity clones are cheap. They are not cloning the - // underlying data buffers. We have to clone so that we can put ourself - // back together if the array conversion failes. - match aa::ArrayData::try_new( - field.data_type().clone(), - len, - validity.clone().map(|v| v.into_inner().into_inner()), - 0, - vec![data.clone()], - vec![], - ) { + let into_arrow = self + .fixed_offsets + .as_ref() + .map(|fixed_offsets| { + check_fixed_offset_compatibility( + cvn as usize, + fixed_offsets.buffer.typed_data::(), + ) + }) + .unwrap_or(Ok(())) + .and_then(|_| { + // N.B., data/validity clones are cheap. They are not cloning the + // underlying data buffers. We have to clone so that we can put ourself + // back together if the array conversion failes. + aa::ArrayData::try_new( + field.data_type().clone(), + len, + validity.clone().map(|v| v.into_inner().into_inner()), + 0, + vec![data.clone()], + vec![], + ) + .map_err(UnsupportedArrowArrayError::ArrayCreationFailed) + }); + match into_arrow { Ok(data) => { - // validate fixed offsets (if any) - if let Some(fixed_offsets) = self.fixed_offsets { - let offsets = from_tdb_offsets(fixed_offsets.buffer); - for w in offsets.windows(2) { - if w[1] - w[0] != (cvn as i64) { - todo!() - } - } - } let array: Arc = Arc::new(aa::FixedSizeListArray::new( Arc::clone(&field), @@ -625,7 +633,7 @@ impl NewBufferTraitThing for FixedListBuffers { validity: self.validity, }); - Err((boxed, UnsupportedArrowArrayError::ArrayCreationFailed(e))) + Err((boxed, e)) } } } @@ -688,52 +696,117 @@ impl QueryBuffer { } } +/// Buffers representing an arrow `LargeList` array. +/// +/// These may be used to target a tiledb field with any [CellValNum]. +/// For [CellValNum::Var] the mapping is obvious. +/// For [CellValNum::Fixed]: +/// * Read queries will generate offsets using the fixed size of each cell. +/// * Write queries will validate that the offsets are of fixed size +/// and return `Err` if they are not. struct ListBuffers { field: Arc, data: QueryBuffer, - offsets: QueryBuffer, + /// Offsets are optional in case the target field is fixed size + offsets: ListBuffersOffsets, validity: Option, } +enum ListBuffersOffsets { + ArrowOnly(OffsetBuffer), + Shared(QueryBuffer), +} + impl ListBuffers { pub fn try_new( target: &BufferTarget, - array: Arc, + array: aa::LargeListArray, ) -> FromArrowResult { - assert!(matches!(array.data_type(), adt::DataType::LargeList(_))); + let num_cells = array.len(); - let array: aa::LargeListArray = downcast_consume(array); - let (field, offsets, array, nulls) = array.into_parts(); + let (field, offsets, values, nulls) = array.into_parts(); + + // NB: this function has to be very careful with map/map_err + // as moving arrays or offsets into closures affects refcounts. + + // validation + let valid = { + if field.is_nullable() { + // FIXME: this depends on read/write + Err(UnsupportedArrowArrayError::InvalidNullableListElements) + } else if !field.data_type().is_primitive() { + Err(UnsupportedArrowArrayError::UnsupportedFixedSizeListType( + field.data_type().clone(), + )) + } else if let CellValNum::Fixed(nz) = target.cell_val_num { + if matches!(target.query_type, QueryType::Write) { + // validate that offsets match + check_fixed_offset_compatibility( + nz.get() as usize, + &offsets, + ) + } else { + Ok(()) + } + } else { + Ok(()) + } + }; - if field.is_nullable() { - return Err(( - array, - UnsupportedArrowArrayError::InvalidNullableListElements, - )); + if let Err(e) = valid { + let array = aa::LargeListArray::try_new( + field, + offsets, + values, + nulls.clone(), + ) + .unwrap(); + return Err((Arc::new(array), e.into())); } - let dtype = field.data_type().clone(); - if !dtype.is_primitive() { - return Err(( - array, - UnsupportedArrowArrayError::UnsupportedFixedSizeListType(dtype), - )); - } + let offsets = if let CellValNum::Fixed(_) = target.cell_val_num { + // NB: validated to match above + ListBuffersOffsets::ArrowOnly(offsets) + } else { + match offsets.into_inner().into_inner().into_mutable() { + Ok(offsets) => { + ListBuffersOffsets::Shared(QueryBuffer::new(offsets)) + } + Err(e) => { + let offsets_buffer = e; + let offsets = abuf::OffsetBuffer::new( + abuf::ScalarBuffer::::from(offsets_buffer), + ); + let array: Arc = Arc::new( + aa::LargeListArray::try_new( + Arc::clone(&field), + offsets, + values, + nulls.clone(), + ) + .unwrap(), + ); + return Err((array, ArrayInUseError::Array.into())); + } + } + }; - let num_cells = array.len(); + // NB: by default the offsets are not arrow-shaped. + // However we use the configuration options to make them so. - // N.B., I really, really tried to make this a fancy map/map_err - // cascade like all of the others. But it turns out that keeping the - // proper refcounts on either array or offsets turns into a bit of - // an issue when passing things through multiple closures. - let result = PrimitiveBuffers::try_new(target, array); + let result = PrimitiveBuffers::try_new(target, values); if result.is_err() { - let (array, err) = result.err().unwrap(); + let (values, err) = result.err().unwrap(); let array: Arc = Arc::new( aa::LargeListArray::try_new( Arc::clone(&field), - offsets, - array, + match offsets { + ListBuffersOffsets::ArrowOnly(offsets) => offsets, + ListBuffersOffsets::Shared(qb) => { + from_tdb_offsets(qb.buffer) + } + }, + values, nulls.clone(), ) .unwrap(), @@ -742,39 +815,13 @@ impl ListBuffers { } let data = result.ok().unwrap(); - - let offsets = match offsets.into_inner().into_inner().into_mutable() { - Ok(offsets) => offsets, - Err(e) => { - let offsets_buffer = e; - let offsets = abuf::OffsetBuffer::new( - abuf::ScalarBuffer::::from(offsets_buffer), - ); - let array: Arc = Arc::new( - aa::LargeListArray::try_new( - Arc::clone(&field), - offsets, - // Safety: We just turned this into a mutable buffer, so - // the inversion should never fail. - Box::new(data).into_arrow().ok().unwrap(), - nulls.clone(), - ) - .unwrap(), - ); - return Err((array, ArrayInUseError::Array.into())); - } - }; - - // NB: by default the offsets are not arrow-shaped. - // However we use the configuration options to make them so. - let validity = to_tdb_validity(target, num_cells, nulls).map(QueryBuffer::new); Ok(ListBuffers { field, data: data.data, - offsets: QueryBuffer::new(offsets), + offsets, validity, }) } @@ -786,7 +833,10 @@ impl NewBufferTraitThing for ListBuffers { } fn len(&self) -> usize { - self.offsets.num_var_cells() + match &self.offsets { + ListBuffersOffsets::ArrowOnly(offsets) => offsets.len() - 1, + ListBuffersOffsets::Shared(qb) => qb.num_var_cells(), + } } fn data(&mut self) -> &mut QueryBuffer { @@ -794,7 +844,10 @@ impl NewBufferTraitThing for ListBuffers { } fn offsets(&mut self) -> Option<&mut QueryBuffer> { - Some(&mut self.offsets) + match self.offsets { + ListBuffersOffsets::ArrowOnly(_) => None, + ListBuffersOffsets::Shared(ref mut qb) => Some(qb), + } } fn validity(&mut self) -> Option<&mut QueryBuffer> { @@ -818,13 +871,37 @@ impl NewBufferTraitThing for ListBuffers { assert!(field.data_type().is_primitive()); - let num_cells = self.offsets.num_var_cells(); + let num_values = match self.offsets { + ListBuffersOffsets::ArrowOnly(ref offsets) => { + // SAFETY: [OffsetBuffer] is always non-empty + *offsets.last().unwrap() as usize + } + ListBuffersOffsets::Shared(ref qb) => { + let num_offsets = qb.num_var_cells() + 1; + // SAFETY: offsets came from arrow and are thus non-NULL, + // and are owned by `self.offsets` + let offsets = unsafe { + std::slice::from_raw_parts( + qb.buffer.as_ptr() as *const i64, + num_offsets, + ) + }; + // SAFETY: offsets is always non-empty + // TODO: this might not be true for tiledb actually, though it is for arrow + *offsets.last().unwrap() as usize + } + }; // NB: by default the offsets are not arrow-shaped. // However we use the configuration options to make them so. let data = ArrowBuffer::from(self.data.buffer); - let offsets = from_tdb_offsets(self.offsets.buffer); + let (offsets, offsets_size) = match self.offsets { + ListBuffersOffsets::ArrowOnly(offsets) => (offsets, None), + ListBuffersOffsets::Shared(qb) => { + (from_tdb_offsets(qb.buffer), Some(qb.size)) + } + }; let validity = from_tdb_validity(&self.validity); // N.B., the calls to cloning the data/offsets/validity are as cheap @@ -832,7 +909,7 @@ impl NewBufferTraitThing for ListBuffers { // the underlying allocated data. match aa::ArrayData::try_new( field.data_type().clone(), - num_cells, + num_values, None, 0, vec![data.clone()], @@ -859,9 +936,14 @@ impl NewBufferTraitThing for ListBuffers { buffer: data.into_mutable().unwrap(), size: self.data.size, }, - offsets: QueryBuffer { - buffer: to_tdb_offsets(offsets).unwrap(), - size: self.offsets.size, + offsets: if let Some(offsets_size) = offsets_size { + // SAFETY: this was constructed via `from_tdb_offsets` + ListBuffersOffsets::Shared(QueryBuffer { + buffer: to_tdb_offsets(offsets).unwrap(), + size: offsets_size, + }) + } else { + ListBuffersOffsets::ArrowOnly(offsets) }, validity: self.validity, }); @@ -1057,6 +1139,7 @@ fn to_target_buffers( }) } adt::DataType::LargeList(_) => { + let array = downcast_consume::(array); ListBuffers::try_new(target, array).map(|buffers| { let boxed: Box = Box::new(buffers); boxed @@ -1709,8 +1792,11 @@ fn to_tdb_validity( } } -fn from_tdb_offsets(offsets: ArrowBufferMut) -> abuf::OffsetBuffer { - let buffer = abuf::ScalarBuffer::::from(offsets); +fn from_tdb_offsets(offsets: B) -> abuf::OffsetBuffer +where + B: Into, +{ + let buffer = abuf::ScalarBuffer::::from(offsets.into()); abuf::OffsetBuffer::new(buffer) } @@ -1723,3 +1809,21 @@ fn from_tdb_validity( ) }) } + +fn check_fixed_offset_compatibility( + fixed_cvn: usize, + offsets: &[i64], +) -> UnsupportedArrowArrayResult<()> { + for (i, w) in offsets.windows(2).enumerate() { + let length = (w[1] - w[0]) as usize; + if length != fixed_cvn { + return Err(UnsupportedArrowArrayError::FixedOffsets( + fixed_cvn, length, i, + )); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests; diff --git a/tiledb/query-core/src/buffers/tests.rs b/tiledb/query-core/src/buffers/tests.rs new file mode 100644 index 00000000..bbc06c7c --- /dev/null +++ b/tiledb/query-core/src/buffers/tests.rs @@ -0,0 +1,131 @@ +use arrow_proptest_strategies::prop_array; +use arrow_proptest_strategies::schema::prop_arrow_field; +use proptest::prelude::*; +use tiledb_common::Datatype; + +use super::*; + +fn copy_buffer(buffer: &abuf::Buffer) -> abuf::Buffer { + abuf::Buffer::from(buffer.as_slice().to_vec()) +} + +fn copy_array_data(array_data: &aa::ArrayData) -> aa::ArrayData { + let nulls = array_data + .nulls() + .map(|n| abuf::Buffer::from(n.validity().to_vec())); + let buffers = array_data + .buffers() + .iter() + .map(copy_buffer) + .collect::>(); + let child_data = array_data + .child_data() + .iter() + .map(copy_array_data) + .collect::>(); + + aa::ArrayData::try_new( + array_data.data_type().clone(), + array_data.len(), + nulls, + array_data.offset(), + buffers, + child_data, + ) + .expect("Error copying array data") +} + +fn copy_array(array: &dyn aa::Array) -> Arc { + let data_ref = array.to_data(); + let data_copy = copy_array_data(&data_ref); + + aa::make_array(data_copy) +} + +fn instance_copy_array(array_in: &dyn aa::Array) { + let array_out = copy_array(array_in); + assert_eq!(array_in, array_out.as_ref()); +} + +proptest! { + #[test] + fn proptest_copy_array( + array in prop_arrow_field(Default::default()) + .prop_flat_map(|field| prop_array(Default::default(), Arc::new(field))) + ) { + instance_copy_array(&array) + } +} + +fn instance_list_buffers_roundtrip_var(array_in: aa::LargeListArray) { + let target = BufferTarget { + query_type: QueryType::Write, + cell_val_num: CellValNum::Var, + is_nullable: true, + }; + let lb = Box::new({ + let array_in = + downcast_consume::(copy_array(&array_in)); + ListBuffers::try_new(&target, array_in).unwrap() + }); + let array_out = match lb.into_arrow() { + Ok(array) => array, + Err((_, e)) => panic!( + "For array of type {}, unexpected error in `into_arrow`: {}", + array_in.data_type(), + e + ), + }; + + assert_eq!(&array_in as &dyn Array, array_out.as_ref()); +} + +fn instance_list_buffers_roundtrip_fixed( + cell_val_num: CellValNum, + array_in: aa::LargeListArray, +) { + let target = BufferTarget { + query_type: QueryType::Write, + cell_val_num, + is_nullable: true, + }; + + let lb = Box::new(ListBuffers::try_new(&target, array_in.clone()).unwrap()); + let array_out = match lb.into_arrow() { + Ok(array) => array, + Err((_, e)) => panic!("Unexpected error in `into_arrow`: {}", e), + }; + + assert_eq!(&array_in as &dyn Array, array_out.as_ref()); +} + +fn strat_list_buffers_roundtrip_var( +) -> impl Strategy { + any::() + .prop_map(|dt| { + crate::datatype::default_arrow_type(dt, CellValNum::single()) + .unwrap() + .into_inner() + }) + .prop_flat_map(|adt| { + let field = adt::Field::new( + "unused", + adt::DataType::LargeList(Arc::new(adt::Field::new_list_field( + adt, false, + ))), + true, + ); + arrow_proptest_strategies::prop_array( + Default::default(), + Arc::new(field), + ) + }) + .prop_map(|array| downcast_consume::(array)) +} + +proptest! { + #[test] + fn proptest_list_buffers_roundtrip_var(array in strat_list_buffers_roundtrip_var()) { + instance_list_buffers_roundtrip_var(array) + } +} From 8c51932de05f634fa0df5ac3456a3ca2d3733d61 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 15:13:08 -0500 Subject: [PATCH 36/42] Fix PrimitiveBuffers for all primitive types --- tiledb/query-core/src/buffers/mod.rs | 117 +++++++++++---------------- 1 file changed, 49 insertions(+), 68 deletions(-) diff --git a/tiledb/query-core/src/buffers/mod.rs b/tiledb/query-core/src/buffers/mod.rs index 2d0f9c28..18120a21 100644 --- a/tiledb/query-core/src/buffers/mod.rs +++ b/tiledb/query-core/src/buffers/mod.rs @@ -960,42 +960,45 @@ struct PrimitiveBuffers { validity: Option, } -macro_rules! to_primitive { - ($TARGET:expr, $ARRAY:expr, $ARROW_DT:ty) => {{ - let target = $TARGET; - let array: $ARROW_DT = downcast_consume($ARRAY); - let num_cells = array.len(); - let (dtype, buffer, nulls) = array.into_parts(); +fn to_primitive_buffers( + target: &BufferTarget, + array: Arc, +) -> FromArrowResult +where + T: aa::ArrowPrimitiveType, +{ + let array = downcast_consume::>(array); - buffer - .into_inner() - .into_mutable() - .map(|data| { - let validity = - to_tdb_validity(target, num_cells, nulls.clone()) - .map(QueryBuffer::new); - PrimitiveBuffers { - dtype: dtype.clone(), - data: QueryBuffer::new(data), - validity, - } - }) - .map_err(|buffer| { - // Safety: We just broke an array open to get these so - // unless someone did something unsafe they should go - // right back together again. Sorry, Humpty. - let data = aa::ArrayData::try_new( - dtype, - num_cells, - nulls.map(|n| n.into_inner().into_inner()), - 0, - vec![buffer], - vec![], - ) - .unwrap(); - (aa::make_array(data), ArrayInUseError::Array.into()) - }) - }}; + let num_cells = array.len(); + let (dtype, buffer, nulls) = array.into_parts(); + + buffer + .into_inner() + .into_mutable() + .map(|data| { + let validity = to_tdb_validity(target, num_cells, nulls.clone()) + .map(QueryBuffer::new); + PrimitiveBuffers { + dtype: dtype.clone(), + data: QueryBuffer::new(data), + validity, + } + }) + .map_err(|buffer| { + // SAFETY: We just broke an array open to get these so + // unless someone did something unsafe they should go + // right back together again. Sorry, Humpty. + let data = aa::ArrayData::try_new( + dtype, + num_cells, + nulls.map(|n| n.into_inner().into_inner()), + 0, + vec![buffer], + vec![], + ) + .unwrap(); + (aa::make_array(data), ArrayInUseError::Array.into()) + }) } impl PrimitiveBuffers { @@ -1005,40 +1008,18 @@ impl PrimitiveBuffers { ) -> FromArrowResult { assert!(array.data_type().is_primitive()); - match array.data_type().clone() { - adt::DataType::Int8 => to_primitive!(target, array, aa::Int8Array), - adt::DataType::Int16 => { - to_primitive!(target, array, aa::Int16Array) - } - adt::DataType::Int32 => { - to_primitive!(target, array, aa::Int32Array) - } - adt::DataType::Int64 => { - to_primitive!(target, array, aa::Int64Array) - } - adt::DataType::UInt8 => { - to_primitive!(target, array, aa::UInt8Array) - } - adt::DataType::UInt16 => { - to_primitive!(target, array, aa::UInt16Array) - } - adt::DataType::UInt32 => { - to_primitive!(target, array, aa::UInt32Array) - } - adt::DataType::UInt64 => { - to_primitive!(target, array, aa::UInt64Array) - } - adt::DataType::Float32 => { - to_primitive!(target, array, aa::Float32Array) - } - adt::DataType::Float64 => { - to_primitive!(target, array, aa::Float64Array) - } - t => Err(( - array, - UnsupportedArrowArrayError::InvalidPrimitiveType(t), - )), + macro_rules! go { + ($ARROWTYPE:ty) => { + to_primitive_buffers::<$ARROWTYPE>(target, array) + }; } + + use adt as arrow_schema; + let dtype = array.data_type(); + aa::downcast_primitive!( + dtype => (go), + _ => todo!() + ) } } From 351a1f85c2d243f426672e9d92235c44279c8855 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 15:37:51 -0500 Subject: [PATCH 37/42] Fix a few datatype cases --- tiledb/query-core/src/buffers/tests.rs | 3 +++ tiledb/query-core/src/datatype/default_to.rs | 6 ++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tiledb/query-core/src/buffers/tests.rs b/tiledb/query-core/src/buffers/tests.rs index bbc06c7c..57862e99 100644 --- a/tiledb/query-core/src/buffers/tests.rs +++ b/tiledb/query-core/src/buffers/tests.rs @@ -102,6 +102,9 @@ fn instance_list_buffers_roundtrip_fixed( fn strat_list_buffers_roundtrip_var( ) -> impl Strategy { any::() + .prop_filter("Boolean in list needs special handling", |dt| { + *dt != Datatype::Boolean + }) .prop_map(|dt| { crate::datatype::default_arrow_type(dt, CellValNum::single()) .unwrap() diff --git a/tiledb/query-core/src/datatype/default_to.rs b/tiledb/query-core/src/datatype/default_to.rs index 14bfd5d5..44005ee1 100644 --- a/tiledb/query-core/src/datatype/default_to.rs +++ b/tiledb/query-core/src/datatype/default_to.rs @@ -126,10 +126,8 @@ fn single_valued_type(tiledb: Datatype) -> TypeConversion { tiledb::DateTimeNanosecond => { LogicalMatch(arrow::Timestamp(TimeUnit::Nanosecond, None)) } - tiledb::TimeSecond => LogicalMatch(arrow::Time64(TimeUnit::Second)), - tiledb::TimeMillisecond => { - LogicalMatch(arrow::Time64(TimeUnit::Millisecond)) - } + tiledb::TimeSecond => PhysicalMatch(arrow::Int64), + tiledb::TimeMillisecond => PhysicalMatch(arrow::Int64), tiledb::TimeMicrosecond => { LogicalMatch(arrow::Time64(TimeUnit::Microsecond)) } From e28cbe52843bf733095afcd62e9b2d0152b4601b Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 16:41:43 -0500 Subject: [PATCH 38/42] Fill in arrow array FixedSizeList case --- .../arrow-proptest-strategies/src/array.rs | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test-utils/arrow-proptest-strategies/src/array.rs b/test-utils/arrow-proptest-strategies/src/array.rs index f3954bb4..cbbcfec9 100644 --- a/test-utils/arrow-proptest-strategies/src/array.rs +++ b/test-utils/arrow-proptest-strategies/src/array.rs @@ -239,6 +239,43 @@ pub fn prop_array( .expect("Invalid precision and scale") ) } + DataType::FixedSizeList(element, flen) => { + let flen = *flen; + let element = Arc::clone(&element); + + let values_parameters = ArrayParameters { + num_rows: Just(num_rows * (flen as usize)).boxed(), + allow_null_values: params.allow_null_collection_element, + ..params.clone() + }; + + ( + prop_array(values_parameters, Arc::clone(&element)), + if field.is_nullable() { + strat_vec( + proptest::bool::weighted( + 1.0 - DEFAULT_NONE_PROBABILITY, + ), + num_rows, + ) + .prop_map(Some) + .boxed() + } else { + Just(None).boxed() + }, + ) + .prop_map(move |(values, nulls)| { + FixedSizeListArray::new( + Arc::clone(&element), + flen, + values, + nulls + .map(|n| n.into_iter().collect::()), + ) + }) + .prop_map(to_arc_dyn) + .boxed() + } DataType::LargeList(element) => { let field = Arc::clone(&field); let element = Arc::clone(element); From 9aaf2d50e7d79e9a907fc522a95e1ff6245594f7 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Wed, 27 Nov 2024 16:42:01 -0500 Subject: [PATCH 39/42] proptest_list_buffers_roundtrip_fixed --- tiledb/query-core/src/buffers/tests.rs | 61 +++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/tiledb/query-core/src/buffers/tests.rs b/tiledb/query-core/src/buffers/tests.rs index 57862e99..0a23e374 100644 --- a/tiledb/query-core/src/buffers/tests.rs +++ b/tiledb/query-core/src/buffers/tests.rs @@ -1,4 +1,4 @@ -use arrow_proptest_strategies::prop_array; +use arrow_proptest_strategies::array::{prop_array, ArrayParameters}; use arrow_proptest_strategies::schema::prop_arrow_field; use proptest::prelude::*; use tiledb_common::Datatype; @@ -90,7 +90,11 @@ fn instance_list_buffers_roundtrip_fixed( is_nullable: true, }; - let lb = Box::new(ListBuffers::try_new(&target, array_in.clone()).unwrap()); + let lb = Box::new({ + let array_in = + downcast_consume::(copy_array(&array_in)); + ListBuffers::try_new(&target, array_in).unwrap() + }); let array_out = match lb.into_arrow() { Ok(array) => array, Err((_, e)) => panic!("Unexpected error in `into_arrow`: {}", e), @@ -126,9 +130,62 @@ fn strat_list_buffers_roundtrip_var( .prop_map(|array| downcast_consume::(array)) } +fn strat_list_buffers_roundtrip_fixed( +) -> impl Strategy { + (any::(), 1..=256i32) + .prop_filter( + "FIXME: Boolean in list needs special handling", + |(dt, _)| *dt != Datatype::Boolean, + ) + .prop_map(|(dt, fl)| { + ( + fl, + crate::datatype::default_arrow_type(dt, CellValNum::single()) + .unwrap() + .into_inner(), + ) + }) + .prop_flat_map(|(fl, adt)| { + let params = ArrayParameters { + num_collection_elements: (fl as usize).into(), + ..Default::default() + }; + let field = adt::Field::new( + "unused", + adt::DataType::FixedSizeList( + Arc::new(adt::Field::new_list_field(adt, false)), + fl, + ), + true, + ); + prop_array(params, Arc::new(field)) + }) + .prop_map(|array| { + let num_lists = array.len(); + let fl = downcast_consume::(dbg!(array)); + let (field, fl, values, nulls) = fl.into_parts(); + let array = aa::LargeListArray::try_new( + field, + abuf::OffsetBuffer::::from_lengths(vec![ + fl as usize; + num_lists + ]), + values, + nulls, + ) + .unwrap(); + (CellValNum::try_from(fl as u32).unwrap(), array) + }) +} + proptest! { #[test] fn proptest_list_buffers_roundtrip_var(array in strat_list_buffers_roundtrip_var()) { instance_list_buffers_roundtrip_var(array) } + + #[test] + fn proptest_list_buffers_roundtrip_fixed((cvn, array) in strat_list_buffers_roundtrip_fixed()) { + instance_list_buffers_roundtrip_fixed(cvn, array) + } } From 358c6747cde7240a7c285c62247736f3dd4197dd Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Mon, 2 Dec 2024 13:03:45 -0500 Subject: [PATCH 40/42] enum Capacity --- tiledb/query-core/src/buffers/mod.rs | 139 +++-------- tiledb/query-core/src/fields.rs | 344 ++++++++++++++++++++++++++- 2 files changed, 365 insertions(+), 118 deletions(-) diff --git a/tiledb/query-core/src/buffers/mod.rs b/tiledb/query-core/src/buffers/mod.rs index 18120a21..de79caf4 100644 --- a/tiledb/query-core/src/buffers/mod.rs +++ b/tiledb/query-core/src/buffers/mod.rs @@ -18,12 +18,10 @@ use tiledb_api::{Context, ContextBound}; use tiledb_common::array::CellValNum; use super::field::QueryField; -use super::fields::{QueryField as RequestField, QueryFields}; +use super::fields::{Capacity, QueryField as RequestField, QueryFields}; use super::QueryType; use super::RawQuery; -const AVERAGE_STRING_LENGTH: usize = 64; - #[derive(Debug, Error)] pub enum Error { #[error("Failed to convert Arrow Array for field '{0}': {1}")] @@ -69,6 +67,8 @@ pub enum FieldError { TargetTypeRequired(crate::datatype::DefaultArrowTypeError), #[error("Failed to allocate buffer: {0}")] BufferAllocation(ArrowError), + #[error("Cannot calculate number of cells: {0}")] + Capacity(#[from] super::fields::CapacityNumCellsError), #[error("Unsupported arrow array: {0}")] UnsupportedArrowArray(#[from] UnsupportedArrowArrayError), } @@ -1554,29 +1554,33 @@ fn request_to_buffers( .into_inner() }; - // SAFETY: I dunno this unwrap looks sketchy - alloc_array(arrow_type, tdb_nullable, field.capacity().unwrap()) + // SAFETY: `QueryField::capacity` returns `None` only if there is no buffer + // already, which is ruled out by the time we get here + let capacity = field.capacity().unwrap(); + + alloc_array(arrow_type, tdb_nullable, capacity) } pub type SharedBuffers = HashMap>; fn alloc_array( - dtype: adt::DataType, + target_type: adt::DataType, nullable: bool, - capacity: usize, + capacity: Capacity, ) -> FieldResult> { - let num_cells = calculate_num_cells(dtype.clone(), nullable, capacity)?; - - match dtype { + let num_cells = capacity.num_cells(&target_type, nullable)?; + let num_values = capacity.num_values(&target_type, nullable)?; + match target_type { adt::DataType::Boolean => { Ok(Arc::new(aa::BooleanArray::new_null(num_cells))) } adt::DataType::LargeList(field) => { let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); - let value_capacity = - capacity - (num_cells * std::mem::size_of::()); - let values = - alloc_array(field.data_type().clone(), false, value_capacity)?; + let values = alloc_array( + field.data_type().clone(), + false, + Capacity::Values(num_values), + )?; let nulls = if nullable { Some(abuf::NullBuffer::new_null(num_cells)) } else { @@ -1593,8 +1597,12 @@ fn alloc_array( } else { None }; - let values = - alloc_array(field.data_type().clone(), false, capacity)?; + let num_values = num_cells * (cvn as usize); + let values = alloc_array( + field.data_type().clone(), + false, + Capacity::Values(num_values), + )?; Ok(Arc::new( aa::FixedSizeListArray::try_new(field, cvn, values, nulls) .map_err(FieldError::BufferAllocation)?, @@ -1602,9 +1610,7 @@ fn alloc_array( } adt::DataType::LargeUtf8 => { let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); - let values = ArrowBufferMut::from_len_zeroed( - capacity - (num_cells * std::mem::size_of::()), - ); + let values = ArrowBufferMut::from_len_zeroed(num_values); let nulls = if nullable { Some(abuf::NullBuffer::new_null(num_cells)) } else { @@ -1617,9 +1623,7 @@ fn alloc_array( } adt::DataType::LargeBinary => { let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); - let values = ArrowBufferMut::from_len_zeroed( - capacity - (num_cells * std::mem::size_of::()), - ); + let values = ArrowBufferMut::from_len_zeroed(num_values); let nulls = if nullable { Some(abuf::NullBuffer::new_null(num_cells)) } else { @@ -1630,10 +1634,11 @@ fn alloc_array( .map_err(FieldError::BufferAllocation)?, )) } - _ if dtype.is_primitive() => { - let data = ArrowBufferMut::from_len_zeroed( - num_cells * dtype.primitive_width().unwrap(), - ); + _ if target_type.is_primitive() => { + // SAFETY: we know that `target_type.is_primitive()` per match arm + let width = target_type.primitive_width().unwrap(); + + let data = ArrowBufferMut::from_len_zeroed(num_cells * width); let nulls = if nullable { Some(ArrowBufferMut::from_len_zeroed(num_cells).into()) @@ -1642,7 +1647,7 @@ fn alloc_array( }; let data = aa::ArrayData::try_new( - dtype, + target_type, num_cells, nulls, 0, @@ -1657,86 +1662,6 @@ fn alloc_array( } } -fn calculate_num_cells( - dtype: adt::DataType, - nullable: bool, - capacity: usize, -) -> FieldResult { - match dtype { - adt::DataType::Boolean => { - if nullable { - Ok(capacity * 8 / 2) - } else { - Ok(capacity * 8) - } - } - adt::DataType::LargeList(ref field) => { - if !field.data_type().is_primitive() { - return Err(UnsupportedArrowArrayError::UnsupportedArrowType( - dtype.clone(), - ) - .into()); - } - - // Todo: Figure out a better way to approximate values to offsets ratios - // based on whatever Python does or some such. - // - // For now, I'll pull a guess at of the ether and assume on average a - // var sized primitive array averages two values per cell. Becuase why - // not? - let width = field.data_type().primitive_width().unwrap(); - let bytes_per_cell = (width * 2) - + std::mem::size_of::() - + if nullable { 1 } else { 0 }; - Ok(capacity / bytes_per_cell) - } - adt::DataType::FixedSizeList(ref field, cvn) => { - if !field.data_type().is_primitive() { - return Err(UnsupportedArrowArrayError::UnsupportedArrowType( - dtype, - ) - .into()); - } - - if cvn < 2 { - return Err( - UnsupportedArrowArrayError::InvalidFixedSizeListLength(cvn) - .into(), - ); - } - - let cvn = cvn as usize; - let width = field.data_type().primitive_width().unwrap(); - let bytes_per_cell = capacity / (width * cvn); - let bytes_per_cell = if nullable { - bytes_per_cell + 1 - } else { - bytes_per_cell - }; - Ok(capacity / bytes_per_cell) - } - adt::DataType::LargeUtf8 | adt::DataType::LargeBinary => { - let bytes_per_cell = - AVERAGE_STRING_LENGTH + std::mem::size_of::(); - let bytes_per_cell = if nullable { - bytes_per_cell + 1 - } else { - bytes_per_cell - }; - Ok(capacity / bytes_per_cell) - } - _ if dtype.is_primitive() => { - let width = dtype.primitive_width().unwrap(); - let bytes_per_cell = width + if nullable { 1 } else { 0 }; - Ok(capacity / bytes_per_cell) - } - _ => Err(UnsupportedArrowArrayError::UnsupportedArrowType( - dtype.clone(), - ) - .into()), - } -} - // Private utility functions fn to_tdb_offsets( diff --git a/tiledb/query-core/src/fields.rs b/tiledb/query-core/src/fields.rs index 994a69ec..2f643e94 100644 --- a/tiledb/query-core/src/fields.rs +++ b/tiledb/query-core/src/fields.rs @@ -7,26 +7,344 @@ use tiledb_api::query::read::aggregate::AggregateFunction; use super::QueryBuilder; -/// Default field capacity is 10MiB -const DEFAULT_CAPACITY: usize = 1024 * 1024 * 10; +#[derive(Debug, thiserror::Error)] +pub enum CapacityNumCellsError { + #[error("")] + InvalidFixedSize(i32), + #[error("")] + UnsupportedArrowType(adt::DataType), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Capacity { + /// Request a maximum number of cells of the target field. + /// + /// The amount of memory allocated for fixed-length query + /// fields is the exact amount needed to hold the requested number + /// of values. + /// + /// The amount of space allocated for variable-length query + /// fields is determined by estimating the size of each variable-length cell. + Cells(usize), + /// Request a maximum number of total value of the target field. + /// + /// The amount of memory allocated for fixed-length query + /// fields is the exact amount needed to hold the requested number + /// of values. This behavior is identical to that of [Self::Cells]. + /// + /// The amount of memory allocated for variable-length query + /// fields is the exact amount needed to hold the requested number + /// of values, plus an additional amount needed to hold an estimated + /// number of cell offsets. + Values(usize), + /// Request whatever fits within a fixed memory limit. + /// + /// For variable-length query fields, the fixed memory is apportioned + /// among the cell values and cell offsets using an estimate for + /// the average cell size. + Memory(usize), +} + +impl Capacity { + /// Returns a number of cells of `target_type` which can be held by this capacity. + /// + /// For fixed-length target types, the result is exact. + /// + /// For variable-length target types, the result is an estimate using an estimated + /// average number of values per cell. + /// + /// Returns `Err` if `target_type` is not a supported [DataType]. + pub fn num_cells( + &self, + target_type: &adt::DataType, + nullable: bool, + ) -> Result { + match self { + Self::Cells(num_cells) => Ok(*num_cells), + Self::Values(num_values) => { + calculate_num_cells_by_values(*num_values, target_type) + } + Self::Memory(memory_limit) => { + calculate_by_memory(*memory_limit, target_type, nullable) + .map(|(num_cells, _)| num_cells) + } + } + } + + pub fn num_values( + &self, + target_type: &adt::DataType, + nullable: bool, + ) -> Result { + match self { + Self::Cells(num_cells) => { + calculate_num_values_by_cells(*num_cells, target_type) + } + Self::Values(num_values) => Ok(*num_values), + Self::Memory(memory_limit) => { + calculate_by_memory(*memory_limit, target_type, nullable) + .map(|(_, num_values)| num_values) + } + } + } +} + +fn calculate_num_cells_by_values( + num_values: usize, + target_type: &adt::DataType, +) -> Result { + match target_type { + adt::DataType::FixedSizeBinary(fl) => { + if *fl < 1 { + Err(CapacityNumCellsError::InvalidFixedSize(*fl)) + } else { + Ok(num_values / (*fl as usize)) + } + } + adt::DataType::FixedSizeList(ref field, fl) => { + if *fl < 1 { + Err(CapacityNumCellsError::InvalidFixedSize(*fl)) + } else { + let num_elements = num_values / (*fl as usize); + calculate_num_cells_by_values(num_elements, field.data_type()) + } + } + adt::DataType::LargeUtf8 + | adt::DataType::LargeBinary + | adt::DataType::LargeList(_) => { + Ok(num_values + / estimate_average_variable_length_values(target_type)) + } + _ if target_type.is_primitive() => Ok(num_values), + _ => todo!(), + } +} + +fn calculate_num_values_by_cells( + num_cells: usize, + target_type: &adt::DataType, +) -> Result { + match target_type { + adt::DataType::FixedSizeBinary(fl) => { + if *fl < 1 { + Err(CapacityNumCellsError::InvalidFixedSize(*fl)) + } else { + Ok(num_cells * (*fl as usize)) + } + } + adt::DataType::FixedSizeList(ref field, fl) => { + if *fl < 1 { + Err(CapacityNumCellsError::InvalidFixedSize(*fl)) + } else { + let num_elements = num_cells * (*fl as usize); + calculate_num_cells_by_values(num_elements, field.data_type()) + } + } + adt::DataType::LargeUtf8 + | adt::DataType::LargeBinary + | adt::DataType::LargeList(_) => Ok( + num_cells * estimate_average_variable_length_values(target_type) + ), + _ if target_type.is_primitive() => Ok(num_cells), + _ => todo!(), + } +} + +fn calculate_by_memory( + memory_limit: usize, + target_type: &adt::DataType, + nullable: bool, +) -> Result<(usize, usize), CapacityNumCellsError> { + match target_type { + adt::DataType::Boolean => { + let num_cells = if nullable { + memory_limit * 8 / 2 + } else { + memory_limit * 8 + }; + Ok((num_cells, num_cells)) + } + adt::DataType::LargeList(ref field) => { + if !field.data_type().is_primitive() { + return Err(CapacityNumCellsError::UnsupportedArrowType( + target_type.clone(), + )); + } + + let estimate_values_per_cell = + estimate_average_variable_length_values(target_type); + + // TODO: Figure out a better way to approximate values to offsets ratios + // based on whatever Python does or some such. + // + // For now, I'll pull a guess at of the ether and assume on average a + // var sized primitive array averages two values per cell. Becuase why + // not? + let width = field.data_type().primitive_width().unwrap(); + let bytes_per_cell = (width * estimate_values_per_cell) + + std::mem::size_of::() + + if nullable { 1 } else { 0 }; + + let num_cells = memory_limit / bytes_per_cell; + let num_values = num_cells * estimate_values_per_cell; + assert!( + num_cells + * (std::mem::size_of::() + + if nullable { 1 } else { 0 }) + + num_values * width + <= memory_limit + ); + + Ok((num_cells, num_values)) + } + adt::DataType::FixedSizeList(ref field, cvn) => { + if !field.data_type().is_primitive() { + return Err(CapacityNumCellsError::UnsupportedArrowType( + target_type.clone(), + )); + } + + if *cvn < 1 { + return Err(CapacityNumCellsError::InvalidFixedSize(*cvn)); + } + + let cvn = *cvn as usize; + let width = field.data_type().primitive_width().unwrap(); + let bytes_per_cell = memory_limit / (width * cvn); + let bytes_per_cell = if nullable { + bytes_per_cell + 1 + } else { + bytes_per_cell + }; + + let num_cells = memory_limit / bytes_per_cell; + let num_values = num_cells * cvn; + assert!( + num_cells * if nullable { 1 } else { 0 } + num_values * width + <= memory_limit + ); + Ok((num_cells, num_values)) + } + adt::DataType::LargeUtf8 | adt::DataType::LargeBinary => { + let values_per_cell = + estimate_average_variable_length_values(target_type); + let bytes_per_cell = values_per_cell + + std::mem::size_of::() + + if nullable { 1 } else { 0 }; + + let num_cells = memory_limit / bytes_per_cell; + let num_values = num_cells * values_per_cell; + assert!( + num_cells + * (std::mem::size_of::() + + if nullable { 1 } else { 0 }) + + num_values + <= memory_limit + ); + Ok((num_cells, num_values)) + } + _ if target_type.is_primitive() => { + let width = target_type.primitive_width().unwrap(); + let bytes_per_cell = width + if nullable { 1 } else { 0 }; + let num_cells = memory_limit / bytes_per_cell; + let num_values = num_cells; + Ok((num_cells, num_values)) + } + _ => Err(CapacityNumCellsError::UnsupportedArrowType( + target_type.clone(), + ) + .into()), + } +} + +/// Returns a guess for how many variable-length values the average cell has. +fn estimate_average_variable_length_values( + target_type: &adt::DataType, +) -> usize { + // A bad value here will lead to poor memory utilization. + // - if this estimate is too small then the results will fill up the variable-length + // data buffers quickly, and the fixed-size data buffers will be under-utilized. + // - if this estimate is too large, then the results will fill up the fixed-length + // data buffers quickly, and the variable-size data buffers will be under-utilized. + // + // Some ideas core could implement to improve this estimate: + // - keep a histogram of average cell length in fragment metadata + // - register a single buffer for all variable-length data values + // - write offsets and variable-length data into a single buffer, writing the fixed-size offsets + // in order from the front and the variable-size values in reverse cell order from the back + // (the result buffer is full when the two would meet in the middle) + // - produce results in row-major order and write the variable-length parts + // for all query fields in reverse from the end of the buffer, similar to the above + match target_type { + adt::DataType::LargeUtf8 => { + // https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ + // "German" strings have a buffer for 16 bytes which optimizes access for strings + // which are 12 bytes and shorter. + // + // https://www.vldb.org/pvldb/vol17/p148-zeng.pdf + // claims that in real-world datasets 99% of strings are of length 128 or less. + // + // But of course what makes sense is domain-specific. + // A username is probably short, an email is probably longer than this. + 16 + } + adt::DataType::LargeBinary => { + // this can be literally anything, so go with 1KiB? + 1024 + } + adt::DataType::LargeList(_) => { + // also pulling a number out of thin air + 4 + } + _ => unreachable!(), + } +} + +/// The default capacity for a field is 10MiB. +/// +/// Use of this default is not recommended for queries which request +/// multiple fields of different fixed sizes. Except where aggregates +/// are concerned, queries return the same number of cells for each +/// target field. This means that the number of cells returned by a +/// query is bounded by the number of cells which fit in the buffer +/// allocated for the largest field. If the buffers for each field are +/// the same size, then buffers for smaller fields will not be fully +/// utilized. +/// +/// For example, a query to a `Datatype::Int32` field and a +/// `Datatype::Int64` field which writes to a 12MiB buffer per field +/// can write only up to 1.5M cells per submit. This fully utilizes +/// the `Datatype::Int64` buffer but only utilizes 50% of the +/// `Datatype::Int32` buffer. A better strategy would be to allocate +/// twice as much memory for the `Datatype::Int64` field as for +/// the `Datatype::Int32` field, such as by using [Self::Cells]. +/// +/// Note that [Self::Values] is not the default for similar reasons, +/// and [Self::Cells] is not the default to avoid large fields +/// from using unexpectedly large amounts of memory. +impl Default for Capacity { + fn default() -> Self { + Self::Memory(1024 * 1024 * 10) + } +} #[derive(Debug, Default)] pub enum QueryField { #[default] Default, - WithCapacity(usize), - WithCapacityAndType(usize, adt::DataType), + WithCapacity(Capacity), + WithCapacityAndType(Capacity, adt::DataType), WithType(adt::DataType), Buffer(Arc), } impl QueryField { - pub fn capacity(&self) -> Option { + pub fn capacity(&self) -> Option { match self { - Self::Default => Some(DEFAULT_CAPACITY), + Self::Default => Some(Default::default()), Self::WithCapacity(capacity) => Some(*capacity), Self::WithCapacityAndType(capacity, _) => Some(*capacity), - Self::WithType(_) => Some(DEFAULT_CAPACITY), + Self::WithType(_) => Some(Default::default()), Self::Buffer(_) => None, } } @@ -83,7 +401,11 @@ impl QueryFieldsBuilder { self } - pub fn field_with_capacity(mut self, name: &str, capacity: usize) -> Self { + pub fn field_with_capacity( + mut self, + name: &str, + capacity: Capacity, + ) -> Self { self.fields.insert(name, QueryField::WithCapacity(capacity)); self } @@ -91,7 +413,7 @@ impl QueryFieldsBuilder { pub fn field_with_capacity_and_type( mut self, name: &str, - capacity: usize, + capacity: Capacity, dtype: adt::DataType, ) -> Self { self.fields @@ -151,7 +473,7 @@ impl QueryFieldsBuilderForQuery { } } - pub fn field_with_capacity(self, name: &str, capacity: usize) -> Self { + pub fn field_with_capacity(self, name: &str, capacity: Capacity) -> Self { Self { fields_builder: self .fields_builder @@ -163,7 +485,7 @@ impl QueryFieldsBuilderForQuery { pub fn field_with_capacity_and_type( self, name: &str, - capacity: usize, + capacity: Capacity, dtype: adt::DataType, ) -> Self { Self { From 15cd3ae126155286389e5ac05f1c2fd3eebebbf1 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 13 Dec 2024 10:22:22 -0500 Subject: [PATCH 41/42] Add stdx-binary-search crate for searching Range --- Cargo.lock | 8 ++ Cargo.toml | 1 + stdx/binary-search/Cargo.toml | 11 +++ stdx/binary-search/src/lib.rs | 153 ++++++++++++++++++++++++++++++++++ 4 files changed, 173 insertions(+) create mode 100644 stdx/binary-search/Cargo.toml create mode 100644 stdx/binary-search/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index ed681226..9b75d333 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1679,6 +1679,14 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stdx-binary-search" +version = "0.1.0" +dependencies = [ + "num-traits", + "proptest", +] + [[package]] name = "strategy-ext" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index bc82287b..dc8df2a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "stdx/binary-search", "tiledb/api", "tiledb/common", "tiledb/pod", diff --git a/stdx/binary-search/Cargo.toml b/stdx/binary-search/Cargo.toml new file mode 100644 index 00000000..62b6c32e --- /dev/null +++ b/stdx/binary-search/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "stdx-binary-search" +edition.workspace = true +rust-version.workspace = true +version.workspace = true + +[dependencies] +num-traits = { workspace = true } + +[dev-dependencies] +proptest = { workspace = true } diff --git a/stdx/binary-search/src/lib.rs b/stdx/binary-search/src/lib.rs new file mode 100644 index 00000000..638a9425 --- /dev/null +++ b/stdx/binary-search/src/lib.rs @@ -0,0 +1,153 @@ +use std::ops::Range; + +use num_traits::{FromPrimitive, ToPrimitive}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Bisect { + NeverTrue, + UpperBound(T), + AlwaysTrue, +} + +/// A type which represents a searchable space of values. +pub trait Search { + type Item; + + /// Performs an efficient search over the items of `self` to find the upper bound + /// where `property` is true. + /// + /// `property` is some function which bisects the search space, + /// returning `true` on each value in the first segment and `false` on each value in the second. + fn upper_bound(&self, property: F) -> Bisect + where + F: Fn(&Self::Item) -> bool; +} + +macro_rules! binary_search_impl { + ($($ITYPE:ty),+) => { + $( + impl Search for Range<$ITYPE> { + type Item = $ITYPE; + + fn upper_bound(&self, property: F) -> Bisect + where + F: Fn(&Self::Item) -> bool, + { + if self.is_empty() { + return Bisect::AlwaysTrue + } else if self.start + 1 == self.end { + return if property(&self.start) { + Bisect::AlwaysTrue + } else { + Bisect::NeverTrue + } + } + let mut search = self.clone(); + while search.start + 1 < search.end { + let midpoint = midpoint(&search); + if property(&midpoint) { + search.start = midpoint; + } else { + search.end = midpoint; + } + } + if search.end == self.end { + Bisect::AlwaysTrue + } else if property(&search.start) { + Bisect::UpperBound(search.start) + } else { + Bisect::NeverTrue + } + } + } + )+ + }; +} + +fn midpoint(range: &Range) -> T +where + T: Copy + FromPrimitive + ToPrimitive, +{ + T::from_i128( + (range.start.to_i128().unwrap() + range.end.to_i128().unwrap()) / 2, + ) + .unwrap() +} + +binary_search_impl!( + u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize +); + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + /// Performs a linear search to return the maximum value in the range + /// for which a `property` is true + fn linear_search(range: Range, property: F) -> Bisect + where + Range: Iterator, + F: Fn(&T) -> bool, + { + let mut prev = None; + for i in range { + if property(&i) { + prev = Some(i); + } else if let Some(prev) = prev { + return Bisect::UpperBound(prev); + } else { + return Bisect::NeverTrue; + } + } + Bisect::AlwaysTrue + } + + fn search_results( + range: Range, + property: F, + ) -> (Bisect, Bisect) + where + Range: Iterator + Search, + T: Clone + PartialEq, + F: Clone + Fn(&T) -> bool, + { + let linear_search_result = + linear_search(range.clone(), property.clone()); + let binary_search_result = range.upper_bound(property.clone()); + + (linear_search_result, binary_search_result) + } + + #[test] + fn example_simple_less_than() { + let cmp = |value: &usize| *value < 5; + + for i in 0..10 { + for j in i..10 { + let (linear, binary) = search_results(i..j, &cmp); + assert_eq!(linear, binary, "i..j = {:?}", i..j) + } + } + } + + proptest! { + #[test] + fn proptest_simple_less_than(target in any::(), range in any::>()) { + match range.upper_bound(|value: &usize| *value < target) { + Bisect::AlwaysTrue => assert!(range.end <= target), + Bisect::NeverTrue => assert!(target < range.start), + Bisect::UpperBound(bound) => { + assert_eq!(target - 1, bound); + } + } + } + + #[test] + fn proptest_search_compare(target in any::(), range in any::>()) { + let (linear, binary) = search_results(range, |value: &u8| *value < target); + assert_eq!(linear, binary); + } + } +} From d93fb8164e4befbd89e0309dbf041e45e76cd628 Mon Sep 17 00:00:00 2001 From: Ryan Roelke Date: Fri, 13 Dec 2024 16:36:46 -0500 Subject: [PATCH 42/42] proptest_capacity_limits passes --- Cargo.lock | 1 + Cargo.toml | 3 +- tiledb/query-core/Cargo.toml | 1 + tiledb/query-core/src/buffers/mod.rs | 41 +- tiledb/query-core/src/buffers/tests.rs | 63 +++- tiledb/query-core/src/datatype/default_to.rs | 5 + .../src/{fields.rs => fields/mod.rs} | 262 ++++++++----- tiledb/query-core/src/fields/tests.rs | 353 ++++++++++++++++++ 8 files changed, 639 insertions(+), 90 deletions(-) rename tiledb/query-core/src/{fields.rs => fields/mod.rs} (69%) create mode 100644 tiledb/query-core/src/fields/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 9b75d333..527e079a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1912,6 +1912,7 @@ dependencies = [ "cells", "itertools 0.12.1", "proptest", + "stdx-binary-search", "thiserror", "tiledb-api", "tiledb-common", diff --git a/Cargo.toml b/Cargo.toml index dc8df2a2..bb5ebccd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ members = [ "test-utils/signal", "test-utils/strategy-ext", "test-utils/uri", - "tools/api-coverage", + "tools/api-coverage" ] default-members = [ "tiledb/api", @@ -55,6 +55,7 @@ regex = "1" serde = { version = "1", features = ["derive"] } serde_json = { version = "1", features = ["float_roundtrip"] } signal = { path = "test-utils/signal", version = "0.1.0" } +stdx-binary-search = { path = "stdx/binary-search", version = "0.1.0" } strategy-ext = { path = "test-utils/strategy-ext", version = "0.1.0" } tempfile = { version = "3" } thiserror = { version = "1" } diff --git a/tiledb/query-core/Cargo.toml b/tiledb/query-core/Cargo.toml index a16ef8dd..9f146b05 100644 --- a/tiledb/query-core/Cargo.toml +++ b/tiledb/query-core/Cargo.toml @@ -6,6 +6,7 @@ version.workspace = true [dependencies] arrow = { workspace = true } +stdx-binary-search = { workspace = true } thiserror = { workspace = true } tiledb-api = { workspace = true, features = ["raw"] } tiledb-common = { workspace = true } diff --git a/tiledb/query-core/src/buffers/mod.rs b/tiledb/query-core/src/buffers/mod.rs index de79caf4..409ceb64 100644 --- a/tiledb/query-core/src/buffers/mod.rs +++ b/tiledb/query-core/src/buffers/mod.rs @@ -1,5 +1,6 @@ use std::any::Any; use std::collections::HashMap; +use std::fmt::Debug; use std::pin::Pin; use std::sync::Arc; @@ -140,7 +141,7 @@ type FromArrowError = (Arc, UnsupportedArrowArrayError); type FromArrowResult = std::result::Result; /// An interface to our mutable buffer implementations. -trait NewBufferTraitThing { +trait NewBufferTraitThing: Debug { /// Return this trait object as any for downcasting. fn as_any(&self) -> &dyn Any; @@ -212,6 +213,7 @@ impl BufferTarget { } } +#[derive(Debug)] struct BooleanBuffers { data: QueryBuffer, validity: Option, @@ -276,6 +278,7 @@ impl NewBufferTraitThing for BooleanBuffers { } } +#[derive(Debug)] struct ByteBuffers { dtype: adt::DataType, data: QueryBuffer, @@ -443,6 +446,7 @@ impl NewBufferTraitThing for ByteBuffers { } } +#[derive(Debug)] struct FixedListBuffers { field: Arc, cell_val_num: CellValNum, @@ -639,6 +643,7 @@ impl NewBufferTraitThing for FixedListBuffers { } } +#[derive(Debug)] pub struct QueryBuffer { buffer: ArrowBufferMut, size: Pin>, @@ -704,6 +709,7 @@ impl QueryBuffer { /// * Read queries will generate offsets using the fixed size of each cell. /// * Write queries will validate that the offsets are of fixed size /// and return `Err` if they are not. +#[derive(Debug)] struct ListBuffers { field: Arc, data: QueryBuffer, @@ -712,6 +718,7 @@ struct ListBuffers { validity: Option, } +#[derive(Debug)] enum ListBuffersOffsets { ArrowOnly(OffsetBuffer), Shared(QueryBuffer), @@ -954,6 +961,7 @@ impl NewBufferTraitThing for ListBuffers { } } +#[derive(Debug)] struct PrimitiveBuffers { dtype: adt::DataType, data: QueryBuffer, @@ -1563,7 +1571,7 @@ fn request_to_buffers( pub type SharedBuffers = HashMap>; -fn alloc_array( +pub fn alloc_array( target_type: adt::DataType, nullable: bool, capacity: Capacity, @@ -1572,7 +1580,15 @@ fn alloc_array( let num_values = capacity.num_values(&target_type, nullable)?; match target_type { adt::DataType::Boolean => { - Ok(Arc::new(aa::BooleanArray::new_null(num_cells))) + if nullable { + Ok(Arc::new(aa::BooleanArray::new_null(num_cells))) + } else { + Ok(Arc::new( + std::iter::repeat(Some(false)) + .take(num_cells) + .collect::(), + )) + } } adt::DataType::LargeList(field) => { let offsets = abuf::OffsetBuffer::::new_zeroed(num_cells); @@ -1591,6 +1607,18 @@ fn alloc_array( .map_err(FieldError::BufferAllocation)?, )) } + adt::DataType::FixedSizeBinary(cvn) => { + let nulls = if nullable { + Some(abuf::NullBuffer::new_null(num_cells)) + } else { + None + }; + + let values = + abuf::Buffer::from_vec(vec![0u8; num_cells * (cvn as usize)]); + + Ok(Arc::new(aa::FixedSizeBinaryArray::new(cvn, values, nulls))) + } adt::DataType::FixedSizeList(field, cvn) => { let nulls = if nullable { Some(abuf::NullBuffer::new_null(num_cells)) @@ -1641,7 +1669,12 @@ fn alloc_array( let data = ArrowBufferMut::from_len_zeroed(num_cells * width); let nulls = if nullable { - Some(ArrowBufferMut::from_len_zeroed(num_cells).into()) + Some( + abuf::NullBuffer::new_null(num_cells) + .into_inner() + .into_inner() + .into(), + ) } else { None }; diff --git a/tiledb/query-core/src/buffers/tests.rs b/tiledb/query-core/src/buffers/tests.rs index 0a23e374..42bdb284 100644 --- a/tiledb/query-core/src/buffers/tests.rs +++ b/tiledb/query-core/src/buffers/tests.rs @@ -1,14 +1,37 @@ use arrow_proptest_strategies::array::{prop_array, ArrayParameters}; -use arrow_proptest_strategies::schema::prop_arrow_field; +use arrow_proptest_strategies::schema::{ + prop_arrow_datatype, prop_arrow_field, +}; use proptest::prelude::*; use tiledb_common::Datatype; use super::*; +impl Arbitrary for BufferTarget { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + let query_type = + prop_oneof![Just(QueryType::Read), Just(QueryType::Write)]; + (query_type, any::(), any::()) + .prop_map(|(query_type, cell_val_num, is_nullable)| Self { + query_type, + cell_val_num, + is_nullable, + }) + .boxed() + } +} + +/// Returns a deep copy of `buffer`. fn copy_buffer(buffer: &abuf::Buffer) -> abuf::Buffer { abuf::Buffer::from(buffer.as_slice().to_vec()) } +/// Returns a deep copy of `array_data`. +/// +/// The returned [ArrayData] does not share any buffers with `array_data`. fn copy_array_data(array_data: &aa::ArrayData) -> aa::ArrayData { let nulls = array_data .nulls() @@ -35,6 +58,9 @@ fn copy_array_data(array_data: &aa::ArrayData) -> aa::ArrayData { .expect("Error copying array data") } +/// Returns a deep copy of `array`. +/// +/// The returned [Array] does not share any buffers with `array`. fn copy_array(array: &dyn aa::Array) -> Arc { let data_ref = array.to_data(); let data_copy = copy_array_data(&data_ref); @@ -189,3 +215,38 @@ proptest! { instance_list_buffers_roundtrip_fixed(cvn, array) } } + +/// Test that if a data type can be used to alloc an array then it also +/// can be converted to a mutable buffer +fn instance_make_mut( + target_type: adt::DataType, + capacity: Capacity, + target: BufferTarget, +) { + let Ok(array) = alloc_array(target_type, target.is_nullable, capacity) + else { + return; + }; + + let array_expect = copy_array(&array); + + let entry_mut = to_target_buffers(&target, array) + .expect("alloc_array succeeded but to_target_buffers failed"); + + let array_out = entry_mut + .into_arrow() + .expect("to_target_buffers succeeded but into_array failed"); + + assert_eq!(array_expect.as_ref(), array_out.as_ref()); +} + +proptest! { + #[test] + fn proptest_make_mut( + target_type in prop_arrow_datatype(Default::default()), + capacity in any::(), + target in any::() + ) { + instance_make_mut(target_type, capacity, target) + } +} diff --git a/tiledb/query-core/src/datatype/default_to.rs b/tiledb/query-core/src/datatype/default_to.rs index 44005ee1..b995612a 100644 --- a/tiledb/query-core/src/datatype/default_to.rs +++ b/tiledb/query-core/src/datatype/default_to.rs @@ -42,6 +42,11 @@ pub fn default_arrow_type( LogicalMatch(ArrowDataType::LargeBinary) } + (Datatype::Boolean, cvn) if !cvn.is_single_valued() => { + // FIXME: not implemented in capacity/alloc + default_arrow_type(Datatype::UInt8, CellValNum::Var)? + } + // then the general cases (_, CellValNum::Fixed(nz)) if nz.get() == 1 => { single_valued_type(dtype) diff --git a/tiledb/query-core/src/fields.rs b/tiledb/query-core/src/fields/mod.rs similarity index 69% rename from tiledb/query-core/src/fields.rs rename to tiledb/query-core/src/fields/mod.rs index 2f643e94..9aa427f1 100644 --- a/tiledb/query-core/src/fields.rs +++ b/tiledb/query-core/src/fields/mod.rs @@ -3,15 +3,16 @@ use std::sync::Arc; use arrow::array as aa; use arrow::datatypes as adt; +use stdx_binary_search::{Bisect, Search}; use tiledb_api::query::read::aggregate::AggregateFunction; use super::QueryBuilder; #[derive(Debug, thiserror::Error)] pub enum CapacityNumCellsError { - #[error("")] + #[error("Invalid fixed size data type with fixed size '{0}'")] InvalidFixedSize(i32), - #[error("")] + #[error("Unsupported arrow data type: {0}")] UnsupportedArrowType(adt::DataType), } @@ -26,16 +27,19 @@ pub enum Capacity { /// The amount of space allocated for variable-length query /// fields is determined by estimating the size of each variable-length cell. Cells(usize), - /// Request a maximum number of total value of the target field. + /// Request a maximum number of total values of the target field. /// - /// The amount of memory allocated for fixed-length query - /// fields is the exact amount needed to hold the requested number - /// of values. This behavior is identical to that of [Self::Cells]. + /// For fixed-length query fields, the requested number of values + /// is an upper bound on the number of values which will fit + /// in the allocated space. For single-valued fields, this behavior + /// is identical to that of [Self::Cells]. For multi-valued fixed-length + /// fields, space is allocated for the maximum number of cells + /// which fit within the upper bound. /// - /// The amount of memory allocated for variable-length query - /// fields is the exact amount needed to hold the requested number - /// of values, plus an additional amount needed to hold an estimated - /// number of cell offsets. + /// For variable-length query fields, the exact amount of memory + /// needed to hold the requested number of values is allocated, + /// and additional memory is allocated to hold an estimated number + /// of cell offsets. Values(usize), /// Request whatever fits within a fixed memory limit. /// @@ -94,6 +98,7 @@ fn calculate_num_cells_by_values( target_type: &adt::DataType, ) -> Result { match target_type { + adt::DataType::Boolean => Ok(num_values), adt::DataType::FixedSizeBinary(fl) => { if *fl < 1 { Err(CapacityNumCellsError::InvalidFixedSize(*fl)) @@ -112,8 +117,12 @@ fn calculate_num_cells_by_values( adt::DataType::LargeUtf8 | adt::DataType::LargeBinary | adt::DataType::LargeList(_) => { - Ok(num_values - / estimate_average_variable_length_values(target_type)) + // SAFETY: obvious from definition of `estimate_average_variable_length_values` + let est = + estimate_average_variable_length_values(target_type).unwrap(); + + // NB: round up + Ok((num_values + est - 1) / est) } _ if target_type.is_primitive() => Ok(num_values), _ => todo!(), @@ -125,6 +134,7 @@ fn calculate_num_values_by_cells( target_type: &adt::DataType, ) -> Result { match target_type { + adt::DataType::Boolean => Ok(num_cells), adt::DataType::FixedSizeBinary(fl) => { if *fl < 1 { Err(CapacityNumCellsError::InvalidFixedSize(*fl)) @@ -142,9 +152,12 @@ fn calculate_num_values_by_cells( } adt::DataType::LargeUtf8 | adt::DataType::LargeBinary - | adt::DataType::LargeList(_) => Ok( - num_cells * estimate_average_variable_length_values(target_type) - ), + | adt::DataType::LargeList(_) => { + // SAFETY: obvious from definition of `estimate_average_variable_length_values` + let est = + estimate_average_variable_length_values(target_type).unwrap(); + Ok(num_cells * est) + } _ if target_type.is_primitive() => Ok(num_cells), _ => todo!(), } @@ -157,13 +170,34 @@ fn calculate_by_memory( ) -> Result<(usize, usize), CapacityNumCellsError> { match target_type { adt::DataType::Boolean => { - let num_cells = if nullable { - memory_limit * 8 / 2 - } else { - memory_limit * 8 + // need space for the translation buffer as well as the bit-packed arrow array + let num_cells_memory = |num_cells: usize| -> usize { + let mem_values = num_cells + (num_cells + 7) / 8; + let mem_nulls = if nullable { mem_values } else { 0 }; + mem_values + mem_nulls }; + let num_cells = + match (1..u32::MAX as usize).upper_bound(|num_cells: &usize| { + num_cells_memory(*num_cells) <= memory_limit + }) { + Bisect::NeverTrue => 0, + Bisect::AlwaysTrue => u32::MAX as usize, + Bisect::UpperBound(max_num_cells) => max_num_cells, + }; Ok((num_cells, num_cells)) } + adt::DataType::FixedSizeBinary(cvn) => { + if *cvn < 1 { + return Err(CapacityNumCellsError::InvalidFixedSize(*cvn)); + } + + Ok(calculate_by_memory_fixed_length( + memory_limit, + nullable, + *cvn as usize, + 1, + )) + } adt::DataType::LargeList(ref field) => { if !field.data_type().is_primitive() { return Err(CapacityNumCellsError::UnsupportedArrowType( @@ -171,31 +205,17 @@ fn calculate_by_memory( )); } + // SAFETY: obvious from definition of `estimate_average_variable_length_values` let estimate_values_per_cell = - estimate_average_variable_length_values(target_type); - - // TODO: Figure out a better way to approximate values to offsets ratios - // based on whatever Python does or some such. - // - // For now, I'll pull a guess at of the ether and assume on average a - // var sized primitive array averages two values per cell. Becuase why - // not? - let width = field.data_type().primitive_width().unwrap(); - let bytes_per_cell = (width * estimate_values_per_cell) - + std::mem::size_of::() - + if nullable { 1 } else { 0 }; - - let num_cells = memory_limit / bytes_per_cell; - let num_values = num_cells * estimate_values_per_cell; - assert!( - num_cells - * (std::mem::size_of::() - + if nullable { 1 } else { 0 }) - + num_values * width - <= memory_limit - ); - - Ok((num_cells, num_values)) + estimate_average_variable_length_values(target_type).unwrap(); + let value_width = field.data_type().primitive_width().unwrap(); + + Ok(calculate_by_memory_var_length( + memory_limit, + nullable, + estimate_values_per_cell, + value_width, + )) } adt::DataType::FixedSizeList(ref field, cvn) => { if !field.data_type().is_primitive() { @@ -208,47 +228,36 @@ fn calculate_by_memory( return Err(CapacityNumCellsError::InvalidFixedSize(*cvn)); } - let cvn = *cvn as usize; - let width = field.data_type().primitive_width().unwrap(); - let bytes_per_cell = memory_limit / (width * cvn); - let bytes_per_cell = if nullable { - bytes_per_cell + 1 - } else { - bytes_per_cell - }; + let value_width = field.data_type().primitive_width().unwrap(); - let num_cells = memory_limit / bytes_per_cell; - let num_values = num_cells * cvn; - assert!( - num_cells * if nullable { 1 } else { 0 } + num_values * width - <= memory_limit - ); - Ok((num_cells, num_values)) + Ok(calculate_by_memory_fixed_length( + memory_limit, + nullable, + *cvn as usize, + value_width, + )) } adt::DataType::LargeUtf8 | adt::DataType::LargeBinary => { - let values_per_cell = - estimate_average_variable_length_values(target_type); - let bytes_per_cell = values_per_cell - + std::mem::size_of::() - + if nullable { 1 } else { 0 }; - - let num_cells = memory_limit / bytes_per_cell; - let num_values = num_cells * values_per_cell; - assert!( - num_cells - * (std::mem::size_of::() - + if nullable { 1 } else { 0 }) - + num_values - <= memory_limit - ); - Ok((num_cells, num_values)) + // SAFETY: obvious from definition of `estimate_average_variable_length_values` + let estimate_values_per_cell = + estimate_average_variable_length_values(target_type).unwrap(); + let value_width = 1; + + Ok(calculate_by_memory_var_length( + memory_limit, + nullable, + estimate_values_per_cell, + value_width, + )) } _ if target_type.is_primitive() => { - let width = target_type.primitive_width().unwrap(); - let bytes_per_cell = width + if nullable { 1 } else { 0 }; - let num_cells = memory_limit / bytes_per_cell; - let num_values = num_cells; - Ok((num_cells, num_values)) + let value_width = target_type.primitive_width().unwrap(); + Ok(calculate_by_memory_fixed_length( + memory_limit, + nullable, + 1, + value_width, + )) } _ => Err(CapacityNumCellsError::UnsupportedArrowType( target_type.clone(), @@ -257,10 +266,92 @@ fn calculate_by_memory( } } +fn estimate_num_cells_memory( + num_cells: usize, + nullable: bool, + per_cell_overhead: usize, + cell_value_memory: usize, +) -> usize { + // NB: arrow bit-packs the null values but tiledb does not, which + // requires a translation buffer between them. + let mem_nulls = if nullable { + num_cells + (num_cells + 7) / 8 + } else { + 0 + }; + + let mem_overhead = num_cells * per_cell_overhead; + + let mem_values = num_cells * cell_value_memory; + + mem_values + mem_overhead + mem_nulls +} + +fn calculate_by_memory_fixed_length( + memory_limit: usize, + nullable: bool, + cvn: usize, + value_width: usize, +) -> (usize, usize) { + let estimate_memory = |num_cells: usize| { + estimate_num_cells_memory(num_cells, nullable, 0, cvn * value_width) + }; + + let num_cells = + match (1..u32::MAX as usize).upper_bound(|num_cells: &usize| { + estimate_memory(*num_cells) <= memory_limit + }) { + Bisect::NeverTrue => 0, + Bisect::AlwaysTrue => u32::MAX as usize, + Bisect::UpperBound(max_num_cells) => max_num_cells, + }; + let num_values = cvn * num_cells; + + (num_cells, num_values) +} + +fn calculate_by_memory_var_length( + memory_limit: usize, + nullable: bool, + estimate_values_per_cell: usize, + value_width: usize, +) -> (usize, usize) { + let estimate_memory = |num_cells: usize| { + estimate_num_cells_memory( + num_cells, + nullable, + std::mem::size_of::(), + estimate_values_per_cell * value_width, + ) + }; + + let num_cells = + match (1..u32::MAX as usize).upper_bound(|num_cells: &usize| { + estimate_memory(*num_cells) + <= memory_limit - std::mem::size_of::() + }) { + Bisect::NeverTrue => 0, + Bisect::AlwaysTrue => u32::MAX as usize, + Bisect::UpperBound(max_num_cells) => max_num_cells, + }; + + let cell_overhead_memory = (num_cells + 1) * std::mem::size_of::(); + let null_overhead_memory = if nullable { + num_cells + (num_cells + 7) / 8 + } else { + 0 + }; + + let num_values = + (memory_limit - cell_overhead_memory - null_overhead_memory) + / value_width; + (num_cells, num_values) +} + /// Returns a guess for how many variable-length values the average cell has. fn estimate_average_variable_length_values( target_type: &adt::DataType, -) -> usize { +) -> Option { // A bad value here will lead to poor memory utilization. // - if this estimate is too small then the results will fill up the variable-length // data buffers quickly, and the fixed-size data buffers will be under-utilized. @@ -286,17 +377,17 @@ fn estimate_average_variable_length_values( // // But of course what makes sense is domain-specific. // A username is probably short, an email is probably longer than this. - 16 + Some(16) } adt::DataType::LargeBinary => { // this can be literally anything, so go with 1KiB? - 1024 + Some(1024) } adt::DataType::LargeList(_) => { // also pulling a number out of thin air - 4 + Some(4) } - _ => unreachable!(), + _ => None, } } @@ -517,3 +608,6 @@ impl QueryFieldsBuilderForQuery { } } } + +#[cfg(test)] +mod tests; diff --git a/tiledb/query-core/src/fields/tests.rs b/tiledb/query-core/src/fields/tests.rs new file mode 100644 index 00000000..f677a7f6 --- /dev/null +++ b/tiledb/query-core/src/fields/tests.rs @@ -0,0 +1,353 @@ +use arrow::array::AsArray; +use proptest::prelude::*; +use tiledb_common::array::CellValNum; +use tiledb_common::Datatype; + +use super::*; +use crate::buffers::alloc_array; + +impl Arbitrary for Capacity { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + let min_memory_limit = std::mem::size_of::(); + prop_oneof![ + (1usize..=1024).prop_map(Capacity::Cells), + (1usize..=(1024 * 16)).prop_map(Capacity::Values), + (min_memory_limit..=(10 * 1024 * 1024)).prop_map(Capacity::Memory) + ] + .boxed() + } +} + +/// Instance of a capacity limits test. +fn instance_capacity_limits( + capacity: Capacity, + target_type: adt::DataType, + nullable: bool, +) -> anyhow::Result<()> { + let a1 = alloc_array(target_type.clone(), nullable, capacity)?; + assert_eq!(nullable, a1.nulls().is_some()); + + // NB: boolean requires special handling because it is bit-packed in arrow + // but not in tiledb, so there is an extra buffer allocated to translate + match target_type { + adt::DataType::Boolean => { + return instance_capacity_limits_boolean(capacity, nullable, a1) + } + adt::DataType::FixedSizeList(ref elt, fl) + if matches!(elt.data_type(), adt::DataType::Boolean) => + { + return instance_capacity_limits_boolean_fl( + capacity, + nullable, + a1, + fl as usize, + ) + } + adt::DataType::LargeList(ref elt) + if matches!(elt.data_type(), adt::DataType::Boolean) => + { + return instance_capacity_limits_boolean_vl(capacity, nullable, a1) + } + _ => { + // not a supported boolean thing, fall through + } + } + + let (actual_num_cells, actual_num_values, cell_width, value_width) = { + use arrow::array::Array; + use arrow::datatypes as arrow_schema; + + let a1 = a1.as_ref(); + aa::downcast_primitive_array!( + a1 => { + let value_width = a1.data_type().primitive_width().unwrap(); + (a1.len(), a1.len(), Some(value_width), value_width) + }, + adt::DataType::FixedSizeBinary(fl) => { + (a1.len(), a1.as_fixed_size_binary().values().len(), Some(*fl as usize), 1) + }, + adt::DataType::FixedSizeList(_, fl) => { + let elt = a1.as_fixed_size_list().values(); + let value_size = elt.data_type().primitive_width().unwrap(); + (a1.len(), elt.len(), Some(*fl as usize * value_size), value_size) + }, + adt::DataType::LargeUtf8 => { + (a1.len(), a1.as_string::().values().len(), None, 1) + }, + adt::DataType::LargeBinary => { + (a1.len(), a1.as_binary::().values().len(), None, 1) + }, + adt::DataType::LargeList(_) => { + let elt = a1.as_list::().values(); + (a1.len(), elt.len(), None, elt.data_type().primitive_width().unwrap()) + }, + dt => unreachable!("Unexpected data type {} in downcast", dt) + ) + }; + + match capacity { + Capacity::Cells(num_cells) => { + assert_eq!(num_cells, actual_num_cells); + + // adding another cell should increase the number of values + let num_values = capacity.num_values(&target_type, nullable)?; + let next_values = Capacity::Cells(num_cells + 1) + .num_values(&target_type, nullable)?; + assert!( + next_values > num_values, + "num_cells = {:?}, num_values = {:?}, next_values = {:?}", + num_cells, + num_values, + next_values + ); + } + Capacity::Values(num_values) => { + let a1 = a1.as_ref(); + + if let Some(cell_width) = cell_width { + // we should hold the largest amount of integral cells + assert_eq!(0, cell_width % value_width); + let num_values_per_cell = cell_width / value_width; + assert_eq!(num_values_per_cell * a1.len(), actual_num_values); + assert!(num_values_per_cell * a1.len() <= num_values); + assert!(num_values_per_cell * (a1.len() + 1) > num_values); + } else { + // whichever buffer holds the values must be big enough + assert_eq!(num_values, actual_num_values); + } + + let num_cells = capacity.num_cells(&target_type, nullable)?; + + // there is a threshold over which adding values to the request + // should increase the number of cells + if let Some(est_value) = + estimate_average_variable_length_values(&target_type) + { + for delta in 0..est_value { + let delta_num_cells = Capacity::Values(num_values + delta) + .num_cells(&target_type, nullable)?; + assert!(delta_num_cells >= num_cells); + } + let delta_cells = Capacity::Values(num_values + est_value) + .num_cells(&target_type, nullable)?; + + assert!(delta_cells > num_cells); + } else if let adt::DataType::FixedSizeBinary(ref fl) + | adt::DataType::FixedSizeList(_, ref fl) = target_type + { + let fl = *fl as usize; + let delta_cells = Capacity::Values(num_values + fl) + .num_cells(&target_type, nullable)?; + assert!(delta_cells > num_cells); + } else { + let delta_cells = Capacity::Values(num_values + 1) + .num_cells(&target_type, nullable)?; + assert!(delta_cells > num_cells); + } + } + Capacity::Memory(memory_limit) => { + let null_translation_buffer_overhead = + if nullable { a1.len() } else { 0 }; + + let a1_memory = a1.get_buffer_memory_size(); + assert!( + a1_memory + null_translation_buffer_overhead <= memory_limit, + "a1_memory = {:?}, memory_limit = {:?}", + a1_memory, + memory_limit + ); + + let num_cells = capacity.num_cells(&target_type, nullable)?; + + // there should be no room for another full cell within the memory limit + let a2 = alloc_array( + target_type.clone(), + nullable, + Capacity::Cells(num_cells), + )?; + let a2_memory = a2.get_buffer_memory_size(); + + let null_overhead = if nullable { + num_cells + 1 + (num_cells + 8) / 8 + } else { + 0 + }; + + if let Some(est_values) = + estimate_average_variable_length_values(&target_type) + { + // the memory limit should have no room for another estimated cell + let cell_overhead = std::mem::size_of::(); + assert!( + a1_memory + + value_width * est_values + + cell_overhead + + null_overhead + > memory_limit + ); + } else { + assert_eq!(a1_memory, a2_memory); + + let cell_width = cell_width.unwrap(); + + // the memory limit should have no room for another cell + assert!( + a1_memory + cell_width + null_overhead > memory_limit, + "memory_limit = {}, a1_memory = {}, cell_width = {}, null_overhead = {}", + memory_limit, a1_memory, cell_width, null_overhead + ); + } + } + } + + Ok(()) +} + +/// Instance of a capacity limits test when the target type is Boolean. +/// +/// Boolean data is bit packed in arrow but not in tiledb. When there is a memory +/// limit the buffer which unpacks the bits must be accounted for. +fn instance_capacity_limits_boolean( + capacity: Capacity, + nullable: bool, + a1: Arc, +) -> anyhow::Result<()> { + match capacity { + Capacity::Cells(num_cells) => { + assert_eq!(num_cells, a1.len()); + assert_eq!( + a1.as_ref(), + alloc_array( + adt::DataType::Boolean, + nullable, + Capacity::Values(num_cells) + )? + .as_ref() + ); + } + Capacity::Values(num_values) => { + assert_eq!(num_values, a1.len()); + assert_eq!( + a1.as_ref(), + alloc_array( + adt::DataType::Boolean, + nullable, + Capacity::Cells(num_values) + )? + .as_ref() + ); + } + Capacity::Memory(memory_limit) => { + let num_cells = a1.len(); + + let a1_memory = a1.get_buffer_memory_size(); + let a1_translation_memory = a1.len(); + let a1_null_translation_memory = + if nullable { a1.len() } else { 0 }; + + let translation_bytes_per_cell = if nullable { 2 } else { 1 }; + + // we may not always be able to use the full memory limit because + // advancing to the next unpacked byte requires 2 (+1 if nullable) + // additional bytes of memory + assert!( + memory_limit - translation_bytes_per_cell + <= a1_memory + + a1_translation_memory + + a1_null_translation_memory + ); + assert!( + a1_memory + a1_translation_memory + a1_null_translation_memory + <= memory_limit + ); + + // there should be no room for another full cell within the memory limit + let a2 = alloc_array( + adt::DataType::Boolean, + nullable, + Capacity::Cells(num_cells + 1), + )?; + let a2_memory = a2.get_buffer_memory_size(); + let a2_translation_memory = a2.len(); + let a2_null_translation_memory = + if nullable { a2.len() } else { 0 }; + + assert!( + a2_memory + a2_translation_memory + a2_null_translation_memory + > memory_limit + ); + + if nullable { + assert!( + a2_memory + + a2_translation_memory + + a2_null_translation_memory + > memory_limit + ) + /* + if num_cells % 4 == 0 { + // a1 saturated a byte, a2 needs an extra + assert!(a2_memory + a2_translation_memory > memory_limit) + } else { + assert!(a2_memory + a2_translation_memory >= memory_limit); + } + */ + } else { + assert!(a2_memory + a2_translation_memory > memory_limit) + /* + if num_cells % 8 == 0 { + // a1 saturated a byte, a2 needs an extra + assert!(a2_memory + a2_translation_memory > memory_limit) + } else { + assert!(a2_memory + a2_translation_memory >= memory_limit); + } + */ + } + } + } + Ok(()) +} + +fn instance_capacity_limits_boolean_fl( + _capacity: Capacity, + _nullable: bool, + _a1: Arc, + _fl: usize, +) -> anyhow::Result<()> { + // FIXME: gonna do it later + Ok(()) +} + +fn instance_capacity_limits_boolean_vl( + _capacity: Capacity, + _nullable: bool, + _a1: Arc, +) -> anyhow::Result<()> { + // FIXME: gonna do it later + Ok(()) +} + +fn strat_capacity_limits( +) -> impl Strategy { + let strat_datatype = (any::(), any::()).prop_map( + |(dt, cell_val_num)| { + crate::datatype::default_arrow_type(dt, cell_val_num) + .unwrap() + .into_inner() + }, + ); + + (any::(), strat_datatype, any::()) +} + +proptest! { + #[test] + fn proptest_capacity_limits( + (capacity, target_type, nullable) in strat_capacity_limits() + ) { + instance_capacity_limits(capacity, target_type, nullable).unwrap() + } +}