Skip to content

Commit

Permalink
feat: add support for accepting Substrait ExtendedExpression messages…
Browse files Browse the repository at this point in the history
… as filters (#1863)
  • Loading branch information
westonpace authored Jan 30, 2024
1 parent a150a4b commit e34fc4d
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 5 deletions.
36 changes: 34 additions & 2 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ def __init__(self, ds: LanceDataset):
self.ds = ds
self._limit = 0
self._filter = None
self._substrait_filter = None
self._prefilter = None
self._offset = None
self._columns = None
Expand Down Expand Up @@ -1607,8 +1608,38 @@ def columns(self, cols: Optional[list[str]] = None) -> ScannerBuilder:

def filter(self, filter: Union[str, pa.compute.Expression]) -> ScannerBuilder:
if isinstance(filter, pa.compute.Expression):
filter = str(filter)
self._filter = filter
try:
from pyarrow.substrait import serialize_expressions

fields_without_lists = []
counter = 0
# Pyarrow cannot handle fixed size lists when converting
# types to Substrait. So we can't use those in our filter,
# which is ok for now but we need to replace them with some
# kind of placeholder because Substrait is going to use
# ordinal field references and we want to make sure those are
# correct.
for field in self.ds.schema:
if pa.types.is_fixed_size_list(field.type):
pos = counter
counter += 1
fields_without_lists.append(
pa.field(f"__unlikely_name_placeholder_{pos}", pa.int8())
)
else:
fields_without_lists.append(field)
# Serialize the pyarrow compute expression toSubstrait and use
# that as a filter.
scalar_schema = pa.schema(fields_without_lists)
self._substrait_filter = serialize_expressions(
[filter], ["my_filter"], scalar_schema
)
except ImportError:
# serialize_expressions was introduced in pyarrow 14. Fallback to
# stringifying the expression if pyarrow is too old
self._filter = str(filter)
else:
self._filter = filter
return self

def prefilter(self, prefilter: bool) -> ScannerBuilder:
Expand Down Expand Up @@ -1709,6 +1740,7 @@ def to_scanner(self) -> LanceScanner:
self._fragments,
self._with_row_id,
self._use_stats,
self._substrait_filter,
)
return LanceScanner(scanner, self.ds)

Expand Down
31 changes: 30 additions & 1 deletion python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@ def test_pickle(tmp_path: Path):


def test_polar_scan(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
some_structs = [{"x": counter, "y": counter} for counter in range(100)]
table = pa.Table.from_pydict(
{"a": range(100), "b": range(100), "struct": some_structs}
)
base_dir = tmp_path / "test"
lance.write_dataset(table, base_dir)

Expand All @@ -413,6 +416,32 @@ def test_polar_scan(tmp_path: Path):
df = dataset.to_table().to_pandas()
tm.assert_frame_equal(polars_df.collect().to_pandas(), df)

# Note, this doesn't verify that the filter is actually pushed down.
# It only checks that, if the filter is pushed down, we interpret it
# correctly.
def check_pushdown_filt(pl_filt, sql_filt):
polars_df = pl.scan_pyarrow_dataset(dataset).filter(pl_filt)
df = dataset.to_table(filter=sql_filt).to_pandas()
tm.assert_frame_equal(polars_df.collect().to_pandas(), df)

# These three should push down (but we don't verify)
check_pushdown_filt(pl.col("a") > 50, "a > 50")
check_pushdown_filt(~(pl.col("a") > 50), "a <= 50")
check_pushdown_filt(pl.col("a").is_in([50, 51, 52]), "a IN (50, 51, 52)")
# At the current moment it seems polars cannot pushdown this
# kind of filter
check_pushdown_filt((pl.col("a") + 3) < 100, "(a + 3) < 100")

# I can't seem to get struct["x"] to work in Lance but maybe there is
# a way. For now, let's compare it directly to the pyarrow compute version

# Doesn't yet work today :( due to upstream issue (datafusion's substrait parser
# doesn't yet handle nested refs)
# if pa.cpp_version_info.major >= 14:
# polars_df = pl.scan_pyarrow_dataset(dataset).filter(pl.col("struct.x") < 10)
# df = dataset.to_table(filter=pc.field("struct", "x") < 10).to_pandas()
# tm.assert_frame_equal(polars_df.collect().to_pandas(), df)


def test_count_fragments(tmp_path: Path):
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
Expand Down
5 changes: 5 additions & 0 deletions python/python/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def test_simple_predicates(dataset):
pc.field("float") >= 30.0,
pc.field("str") != "aa",
pc.field("str") == "aa",
(pc.field("int") >= 50) & (pc.field("int") < 200),
pc.invert(pc.field("int") >= 50),
pc.is_null(pc.field("int")),
pc.field("int") + 3 >= 50,
pc.is_valid(pc.field("int")),
]
# test simple
for expr in predicates:
Expand Down
11 changes: 11 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ impl Dataset {
fragments: Option<Vec<FileFragment>>,
with_row_id: Option<bool>,
use_stats: Option<bool>,
substrait_filter: Option<Vec<u8>>,
) -> PyResult<Scanner> {
let mut scanner: LanceScanner = self_.ds.scan();
if let Some(c) = columns {
Expand All @@ -297,10 +298,20 @@ impl Dataset {
.map_err(|err| PyValueError::new_err(err.to_string()))?;
}
if let Some(f) = filter {
if substrait_filter.is_some() {
return Err(PyValueError::new_err(
"cannot specify both a string filter and a substrait filter",
));
}
scanner
.filter(f.as_str())
.map_err(|err| PyValueError::new_err(err.to_string()))?;
}
if let Some(f) = substrait_filter {
RT.runtime
.block_on(scanner.filter_substrait(f.as_slice()))
.map_err(|err| PyIOError::new_err(err.to_string()))?;
}
if let Some(prefilter) = prefilter {
scanner.prefilter(prefilter);
}
Expand Down
3 changes: 3 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ datafusion-common = "35.0"
datafusion-sql = "35.0"
datafusion-expr = "35.0"
datafusion-physical-expr = "35.0"
datafusion-substrait = "35.0"
either = "1.0"
futures = "0.3"
http = "0.2.9"
Expand All @@ -109,6 +110,8 @@ serde = { version = "^1" }
serde_json = { version = "1" }
shellexpand = "3.0"
snafu = "0.7.4"
substrait = "0.22.1"
substrait-expr = "0.2.0"
tempfile = "3"
tokio = { version = "1.23", features = [
"rt-multi-thread",
Expand Down
7 changes: 7 additions & 0 deletions rust/lance-datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ async-trait.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-physical-expr.workspace = true
datafusion-substrait.workspace = true
futures.workspace = true
lance-arrow.workspace = true
lance-core = { workspace = true, features = ["datafusion"] }
prost.workspace = true
snafu.workspace = true
substrait.workspace = true
tokio.workspace = true

[dev-dependencies]
substrait-expr.workspace = true
192 changes: 190 additions & 2 deletions rust/lance-datafusion/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,24 @@ use std::sync::Arc;

use arrow::compute::cast;
use arrow_array::{cast::AsArray, ArrayRef};
use arrow_schema::DataType;
use datafusion_common::ScalarValue;
use arrow_schema::{DataType, Schema};
use datafusion::{
datasource::empty::EmptyTable, execution::context::SessionContext, logical_expr::Expr,
};
use datafusion_common::{
tree_node::{Transformed, TreeNode},
Column, DataFusionError, ScalarValue, TableReference,
};
use prost::Message;
use snafu::{location, Location};

use lance_core::{Error, Result};
use substrait::proto::{
expression_reference::ExprType,
plan_rel::RelType,
read_rel::{NamedTable, ReadType},
rel, ExtendedExpression, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot,
};

// This is slightly tedious but when we convert expressions from SQL strings to logical
// datafusion expressions there is no type coercion that happens. In other words "x = 7"
Expand Down Expand Up @@ -284,3 +300,175 @@ pub fn safe_coerce_scalar(value: &ScalarValue, ty: &DataType) -> Option<ScalarVa
_ => None,
}
}

/// Convert a Substrait ExtendedExpressions message into a DF Expr
///
/// The ExtendedExpressions message must contain a single scalar expression
pub async fn parse_substrait(expr: &[u8], input_schema: Arc<Schema>) -> Result<Expr> {
let envelope = ExtendedExpression::decode(expr)?;
if envelope.referred_expr.is_empty() {
return Err(Error::InvalidInput {
source: "the provided substrait expression is empty (contains no expressions)".into(),
location: location!(),
});
}
if envelope.referred_expr.len() > 1 {
return Err(Error::InvalidInput {
source: format!(
"the provided substrait expression had {} expressions when only 1 was expected",
envelope.referred_expr.len()
)
.into(),
location: location!(),
});
}
let expr = match &envelope.referred_expr[0].expr_type {
None => Err(Error::InvalidInput {
source: "the provided substrait had an expression but was missing an expr_type".into(),
location: location!(),
}),
Some(ExprType::Expression(expr)) => Ok(expr.clone()),
_ => Err(Error::InvalidInput {
source: "the provided substrait was not a scalar expression".into(),
location: location!(),
}),
}?;

// Datafusion's substrait consumer only supports Plan (not ExtendedExpression) and so
// we need to create a dummy plan with a single project node
let plan = Plan {
version: None,
extensions: envelope.extensions.clone(),
advanced_extensions: envelope.advanced_extensions.clone(),
expected_type_urls: envelope.expected_type_urls.clone(),
extension_uris: envelope.extension_uris.clone(),
relations: vec![PlanRel {
rel_type: Some(RelType::Root(RelRoot {
input: Some(Rel {
rel_type: Some(rel::RelType::Project(Box::new(ProjectRel {
common: None,
input: Some(Box::new(Rel {
rel_type: Some(rel::RelType::Read(Box::new(ReadRel {
common: None,
base_schema: envelope.base_schema.clone(),
filter: None,
best_effort_filter: None,
projection: None,
advanced_extension: None,
read_type: Some(ReadType::NamedTable(NamedTable {
names: vec!["dummy".to_string()],
advanced_extension: None,
})),
}))),
})),
expressions: vec![expr],
advanced_extension: None,
}))),
}),
// Not technically accurate but pretty sure DF ignores this
names: vec![],
})),
}],
};

let session_context = SessionContext::new();
let dummy_table = Arc::new(EmptyTable::new(input_schema));
session_context.register_table(
TableReference::Bare {
table: "dummy".into(),
},
dummy_table,
)?;
let df_plan =
datafusion_substrait::logical_plan::consumer::from_substrait_plan(&session_context, &plan)
.await?;

let expr = df_plan.expressions().pop().unwrap();

// When DF parses the above plan it turns column references into qualified references
// into `dummy` (e.g. we get `WHERE dummy.x < 0` instead of `WHERE x < 0`) We want
// these to be unqualified references instead and so we need a quick trasnformation pass

let expr = expr.transform(&|node| match node {
Expr::Column(column) => {
if let Some(relation) = column.relation {
match relation {
TableReference::Bare { table } => {
if table == "dummy" {
Ok(Transformed::Yes(Expr::Column(Column {
relation: None,
name: column.name,
})))
} else {
// This should not be possible
Err(DataFusionError::Substrait(format!(
"Unexpected reference to table {} found when parsing filter",
table
)))
}
}
// This should not be possible
_ => Err(DataFusionError::Substrait("Unexpected partially or fully qualified table reference encountered when parsing filter".into()))
}
} else {
Ok(Transformed::No(Expr::Column(column)))
}
}
_ => Ok(Transformed::No(node)),
})?;
Ok(expr)
}

#[cfg(test)]
mod tests {
use super::*;

use arrow_schema::Field;
use datafusion::logical_expr::{BinaryExpr, Operator};
use datafusion_common::Column;
use prost::Message;
use substrait_expr::{
builder::{schema::SchemaBuildersExt, BuilderParams, ExpressionsBuilder},
functions::functions_comparison::FunctionsComparisonExt,
helpers::{literals::literal, schema::SchemaInfo},
};

#[tokio::test]
async fn test_substrait_conversion() {
let schema = SchemaInfo::new_full()
.field("x", substrait_expr::helpers::types::i32(true))
.build();
let expr_builder = ExpressionsBuilder::new(schema, BuilderParams::default());
expr_builder
.add_expression(
"filter_mask",
expr_builder
.functions()
.lt(
expr_builder.fields().resolve_by_name("x").unwrap(),
literal(0_i32),
)
.build()
.unwrap(),
)
.unwrap();
let expr = expr_builder.build();
let expr_bytes = expr.encode_to_vec();

let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)]));

let df_expr = parse_substrait(expr_bytes.as_slice(), schema)
.await
.unwrap();

let expected = Expr::BinaryExpr(BinaryExpr {
left: Box::new(Expr::Column(Column {
relation: None,
name: "x".to_string(),
})),
op: Operator::Lt,
right: Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))),
});
assert_eq!(df_expr, expected);
}
}
Loading

0 comments on commit e34fc4d

Please sign in to comment.