diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index 4e9db0881..abc324db8 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -16,6 +16,7 @@ # under the License. import gzip import os +import datetime as dt import pyarrow as pa import pyarrow.dataset as ds @@ -322,6 +323,40 @@ def test_pyarrow_predicate_pushdown_is_null(ctx, capfd): assert result[0].column(0) == pa.array([2]) +def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd): + """Ensure that pyarrow filter gets pushed down for timestamp""" + # Ref: https://github.com/apache/datafusion-python/issues/703 + + # create pyarrow dataset with no actual files + col_type = pa.timestamp("ns", "+00:00") + nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type) + pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem()) + pa_dataset_format = pa.dataset.ParquetFileFormat() + pa_dataset_partition = pa.dataset.field("a") <= nyd_2000 + fragments = [ + # NOTE: we never actually make this file. + # Working predicate pushdown means it never gets accessed + pa_dataset_format.make_fragment( + "1.parquet", + filesystem=pa_dataset_fs, + partition_expression=pa_dataset_partition, + ) + ] + pa_dataset = pa.dataset.FileSystemDataset( + fragments, + pa.schema([pa.field("a", col_type)]), + pa_dataset_format, + pa_dataset_fs, + ) + + ctx.register_dataset("t", pa_dataset) + + # the partition for our only fragment is for a < 2000-01-01. + # so querying for a > 2024-01-01 should not touch any files + df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'") + assert df.collect() == [] + + def test_dataset_filter_nested_data(ctx): # create Arrow StructArrays to test nested data types data = pa.StructArray.from_arrays( diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 5f2c9592d..ff447e1ab 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -21,6 +21,7 @@ use pyo3::prelude::*; use std::convert::TryFrom; use std::result::Result; +use arrow::pyarrow::ToPyArrow; use datafusion_common::{Column, ScalarValue}; use datafusion_expr::{expr::InList, Between, BinaryExpr, Expr, Operator}; @@ -56,6 +57,7 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result, Data let ret: Result, DataFusionError> = exprs .iter() .map(|expr| match expr { + // TODO: should we also leverage `ScalarValue::to_pyarrow` here? Expr::Literal(v) => match v { ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)), ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)), @@ -100,23 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let op_module = Python::import_bound(py, "operator")?; let pc_expr: Result, DataFusionError> = match expr { Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), - Expr::Literal(v) => match v { - ScalarValue::Boolean(Some(b)) => Ok(pc.getattr("scalar")?.call1((*b,))?), - ScalarValue::Int8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Int16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Int32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Int64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt8(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt16(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt32(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::UInt64(Some(i)) => Ok(pc.getattr("scalar")?.call1((*i,))?), - ScalarValue::Float32(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?), - ScalarValue::Float64(Some(f)) => Ok(pc.getattr("scalar")?.call1((*f,))?), - ScalarValue::Utf8(Some(s)) => Ok(pc.getattr("scalar")?.call1((s,))?), - _ => Err(DataFusionError::Common(format!( - "PyArrow can't handle ScalarValue: {v:?}" - ))), - }, + Expr::Literal(scalar) => Ok(scalar.to_pyarrow(py)?.into_bound(py)), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let operator = operator_to_py(op, &op_module)?; let left = PyArrowFilterExpression::try_from(left.as_ref())?.0;