Skip to content

Commit

Permalink
expand the Expr::Literal's that can be used in PyArrowFilterExpression
Browse files Browse the repository at this point in the history
Closes apache#703
  • Loading branch information
Michael-J-Ward committed Jun 18, 2024
1 parent df69c6a commit 5dd2b88
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
35 changes: 35 additions & 0 deletions python/datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import gzip
import os
import datetime as dt

import pyarrow as pa
import pyarrow.dataset as ds
Expand Down Expand Up @@ -322,6 +323,40 @@ def test_pyarrow_predicate_pushdown_is_null(ctx, capfd):
assert result[0].column(0) == pa.array([2])


def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd):
"""Ensure that pyarrow filter gets pushed down for timestamp"""
# Ref: https://github.com/apache/datafusion-python/issues/703

# create pyarrow dataset with no actual files
col_type = pa.timestamp("ns", "+00:00")
nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type)
pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem())
pa_dataset_format = pa.dataset.ParquetFileFormat()
pa_dataset_partition = pa.dataset.field("a") <= nyd_2000
fragments = [
# NOTE: we never actually make this file.
# Working predicate pushdown means it never gets accessed
pa_dataset_format.make_fragment(
"1.parquet",
filesystem=pa_dataset_fs,
partition_expression=pa_dataset_partition,
)
]
pa_dataset = pa.dataset.FileSystemDataset(
fragments,
pa.schema([pa.field("a", col_type)]),
pa_dataset_format,
pa_dataset_fs,
)

ctx.register_dataset("t", pa_dataset)

# the partition for our only fragment is for a < 2000-01-01.
# so querying for a > 2024-01-01 should not touch any files
df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'")
assert df.collect() == []


def test_dataset_filter_nested_data(ctx):
# create Arrow StructArrays to test nested data types
data = pa.StructArray.from_arrays(
Expand Down
20 changes: 3 additions & 17 deletions src/pyarrow_filter_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use pyo3::prelude::*;
use std::convert::TryFrom;
use std::result::Result;

use arrow::pyarrow::ToPyArrow;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::{expr::InList, Between, BinaryExpr, Expr, Operator};

Expand Down Expand Up @@ -56,6 +57,7 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result<Vec<PyObject>, Data
let ret: Result<Vec<PyObject>, DataFusionError> = exprs
.iter()
.map(|expr| match expr {
// TODO: should we also leverage `ScalarValue::to_pyarrow` here?
Expr::Literal(v) => match v {
ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)),
ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)),
Expand Down Expand Up @@ -100,23 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
let op_module = Python::import_bound(py, "operator")?;
let pc_expr: Result<Bound<'_, PyAny>, DataFusionError> = match expr {
Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?),
Expr::Literal(v) => match v {
ScalarValue::Boolean(Some(b)) => Ok(pc.getattr("scalar")?.call1((*b,))?),
ScalarValue::Int8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::Int16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::Int32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::Int64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::UInt8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::UInt16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::UInt32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::UInt64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?),
ScalarValue::Float32(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?),
ScalarValue::Float64(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?),
ScalarValue::Utf8(Some(s)) => Ok(pc.getattr("scalar")?.call1((s,))?),
_ => Err(DataFusionError::Common(format!(
"PyArrow can't handle ScalarValue: {v:?}"
))),
},
Expr::Literal(scalar) => Ok(scalar.to_pyarrow(py)?.into_bound(py)),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let operator = operator_to_py(op, &op_module)?;
let left = PyArrowFilterExpression::try_from(left.as_ref())?.0;
Expand Down

0 comments on commit 5dd2b88

Please sign in to comment.