Skip to content

Commit

Permalink
feat: add [not_]starts_with and [not_]in arrow predicate pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
sdd committed Jun 13, 2024
1 parent 070576b commit 07bef71
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 20 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ arrow-array = { version = "52" }
arrow-ord = { version = "52" }
arrow-schema = { version = "52" }
arrow-select = { version = "52" }
arrow-string = { version = "52" }
async-stream = "0.3.5"
async-trait = "0.1"
aws-config = "1.1.8"
Expand Down
1 change: 1 addition & 0 deletions crates/iceberg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ arrow-array = { workspace = true }
arrow-ord = { workspace = true }
arrow-schema = { workspace = true }
arrow-select = { workspace = true }
arrow-string = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
bimap = { workspace = true }
Expand Down
89 changes: 73 additions & 16 deletions crates/iceberg/src/arrow/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow_arith::boolean::{and, is_not_null, is_null, not, or};
use arrow_array::{ArrayRef, BooleanArray, RecordBatch};
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef};
use arrow_string::like::starts_with;
use async_stream::try_stream;
use bytes::Bytes;
use fnv::FnvHashSet;
Expand Down Expand Up @@ -741,42 +742,98 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> {

fn starts_with(
&mut self,
_reference: &BoundReference,
_literal: &Datum,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
// TODO: Implement starts_with
self.build_always_true()
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;
starts_with(&left, literal.as_ref())
}))
} else {
// A missing column, treating it as null.
self.build_always_false()
}
}

fn not_starts_with(
&mut self,
_reference: &BoundReference,
_literal: &Datum,
reference: &BoundReference,
literal: &Datum,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
// TODO: Implement not_starts_with
self.build_always_true()
if let Some(idx) = self.bound_reference(reference)? {
let literal = get_arrow_datum(literal)?;

Ok(Box::new(move |batch| {
let left = project_column(&batch, idx)?;

// update here if arrow ever adds a native not_starts_with
not(&starts_with(&left, literal.as_ref())?)
}))
} else {
// A missing column, treating it as null.
self.build_always_false()
}
}

fn r#in(
&mut self,
_reference: &BoundReference,
_literals: &FnvHashSet<Datum>,
reference: &BoundReference,
literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
// TODO: Implement in
self.build_always_true()
if let Some(idx) = self.bound_reference(reference)? {
let literals: Vec<_> = literals
.iter()
.map(|lit| get_arrow_datum(lit).unwrap())
.collect();

Ok(Box::new(move |batch| {
// update this if arrow ever adds a native is_in kernel
let left = project_column(&batch, idx)?;
let mut acc = BooleanArray::from(vec![false; batch.num_rows()]);
for literal in &literals {
acc = or(&acc, &eq(&left, literal.as_ref())?)?
}

Ok(acc)
}))
} else {
// A missing column, treating it as null.
self.build_always_false()
}
}

fn not_in(
&mut self,
_reference: &BoundReference,
_literals: &FnvHashSet<Datum>,
reference: &BoundReference,
literals: &FnvHashSet<Datum>,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
// TODO: Implement not_in
self.build_always_true()
if let Some(idx) = self.bound_reference(reference)? {
let literals: Vec<_> = literals
.iter()
.map(|lit| get_arrow_datum(lit).unwrap())
.collect();

Ok(Box::new(move |batch| {
// update this if arrow ever adds a native not_in kernel
let left = project_column(&batch, idx)?;
let mut acc = BooleanArray::from(vec![true; batch.num_rows()]);
for literal in &literals {
acc = and(&acc, &neq(&left, literal.as_ref())?)?
}

Ok(acc)
}))
} else {
// A missing column, treating it as null.
self.build_always_false()
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/iceberg/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::{Error, ErrorKind};
use arrow_array::types::{validate_decimal_precision_and_scale, Decimal128Type};
use arrow_array::{
BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array,
StringArray,
};
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
use bitvec::macros::internal::funty::Fundamental;
Expand Down Expand Up @@ -605,6 +606,7 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send
PrimitiveLiteral::Long(value) => Ok(Box::new(Int64Array::new_scalar(*value))),
PrimitiveLiteral::Float(value) => Ok(Box::new(Float32Array::new_scalar(value.as_f32()))),
PrimitiveLiteral::Double(value) => Ok(Box::new(Float64Array::new_scalar(value.as_f64()))),
PrimitiveLiteral::String(value) => Ok(Box::new(StringArray::new_scalar(value.as_str()))),
l => Err(Error::new(
ErrorKind::FeatureUnsupported,
format!(
Expand Down
108 changes: 105 additions & 3 deletions crates/iceberg/src/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ mod tests {
};
use crate::table::Table;
use crate::TableIdent;
use arrow_array::{ArrayRef, Int64Array, RecordBatch};
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use futures::TryStreamExt;
use parquet::arrow::{ArrowWriter, PARQUET_FIELD_ID_META_KEY};
use parquet::basic::Compression;
Expand Down Expand Up @@ -705,10 +705,15 @@ mod tests {
PARQUET_FIELD_ID_META_KEY.to_string(),
"3".to_string(),
)])),
arrow_schema::Field::new("a", arrow_schema::DataType::Utf8, false)
.with_metadata(HashMap::from([(
PARQUET_FIELD_ID_META_KEY.to_string(),
"4".to_string(),
)])),
];
Arc::new(arrow_schema::Schema::new(fields))
};
// 3 columns:
// 4 columns:
// x: [1, 1, 1, 1, ...]
let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef;

Expand All @@ -725,7 +730,14 @@ mod tests {

// z: [3, 3, 3, 3, ..., 4, 4, 4, 4]
let col3 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef;
let to_write = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap();

// a: ["Apache", "Apache", "Apache", ..., "Iceberg", "Iceberg", "Iceberg"]
let mut values = vec!["Apache"; 512];
values.append(vec!["Iceberg"; 512].as_mut());
let col4 = Arc::new(StringArray::from_iter_values(values)) as ArrayRef;

let to_write =
RecordBatch::try_new(schema.clone(), vec![col1, col2, col3, col4]).unwrap();

// Write the Parquet files
let props = WriterProperties::builder()
Expand Down Expand Up @@ -1040,4 +1052,94 @@ mod tests {
let expected_z = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef;
assert_eq!(col, &expected_z);
}

#[tokio::test]
async fn test_filter_on_arrow_startswith() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: a STARTSWITH "Ice"
let mut builder = fixture.table.scan();
let predicate = Reference::new("a").starts_with(Datum::string("Ice"));
builder = builder.filter(predicate);
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches[0].num_rows(), 512);

let col = batches[0].column_by_name("a").unwrap();
let string_arr = col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_arr.value(0), "Iceberg");
}

#[tokio::test]
async fn test_filter_on_arrow_not_startswith() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: a NOT STARTSWITH "Ice"
let mut builder = fixture.table.scan();
let predicate = Reference::new("a").not_starts_with(Datum::string("Ice"));
builder = builder.filter(predicate);
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches[0].num_rows(), 512);

let col = batches[0].column_by_name("a").unwrap();
let string_arr = col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_arr.value(0), "Apache");
}

#[tokio::test]
async fn test_filter_on_arrow_in() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: a IN ("Sioux", "Iceberg")
let mut builder = fixture.table.scan();
let predicate =
Reference::new("a").is_in([Datum::string("Sioux"), Datum::string("Iceberg")]);
builder = builder.filter(predicate);
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches[0].num_rows(), 512);

let col = batches[0].column_by_name("a").unwrap();
let string_arr = col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_arr.value(0), "Iceberg");
}

#[tokio::test]
async fn test_filter_on_arrow_not_in() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: a NOT IN ("Sioux", "Iceberg")
let mut builder = fixture.table.scan();
let predicate =
Reference::new("a").is_not_in([Datum::string("Sioux"), Datum::string("Iceberg")]);
builder = builder.filter(predicate);
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches[0].num_rows(), 512);

let col = batches[0].column_by_name("a").unwrap();
let string_arr = col.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_arr.value(0), "Apache");
}
}
3 changes: 2 additions & 1 deletion crates/iceberg/testdata/example_table_metadata_v2.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"fields": [
{"id": 1, "name": "x", "required": true, "type": "long"},
{"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"},
{"id": 3, "name": "z", "required": true, "type": "long"}
{"id": 3, "name": "z", "required": true, "type": "long"},
{"id": 4, "name": "a", "required": true, "type": "string"}
]
}
],
Expand Down

0 comments on commit 07bef71

Please sign in to comment.