From 67dd97d8df574bb7d2c82515a75cc7acf324342c Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 8 Feb 2023 17:51:51 -0800 Subject: [PATCH 1/4] Planner: Add implementation LogicalPlan::Values --- dask_planner/src/sql/logical.rs | 16 ++++++--- dask_planner/src/sql/logical/values.rs | 45 ++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 dask_planner/src/sql/logical/values.rs diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 823dc08bb..3b8e2855e 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -33,6 +33,7 @@ pub mod sort; pub mod subquery_alias; pub mod table_scan; pub mod use_schema; +pub mod values; pub mod window; use datafusion_common::{DFSchemaRef, DataFusionError}; @@ -141,16 +142,23 @@ impl PyLogicalPlan { to_py_plan(self.current_node.as_ref()) } - /// LogicalPlan::Window as PyWindow - pub fn window(&self) -> PyResult { + /// LogicalPlan::TableScan as PyTableScan + pub fn table_scan(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } - /// LogicalPlan::TableScan as PyTableScan - pub fn table_scan(&self) -> PyResult { + /// LogicalPlan::Values as PyValues + pub fn values(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) } + /// LogicalPlan::Window as PyWindow + pub fn window(&self) -> PyResult { + to_py_plan(self.current_node.as_ref()) + } + + // Custom LogicalPlan Nodes + /// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable pub fn create_memory_table(&self) -> PyResult { to_py_plan(self.current_node.as_ref()) diff --git a/dask_planner/src/sql/logical/values.rs b/dask_planner/src/sql/logical/values.rs new file mode 100644 index 000000000..31d71acac --- /dev/null +++ b/dask_planner/src/sql/logical/values.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use datafusion_expr::{logical_plan::Values, LogicalPlan}; +use pyo3::prelude::*; + +use crate::{ + expression::{py_expr_list, PyExpr}, + sql::exceptions::py_type_err, +}; + +#[pyclass(name = "Values", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PyValues { + values: Values, + plan: Arc, +} + +#[pymethods] +impl PyValues { + /// Creating a model requires that a subquery be passed to the CREATE MODEL + /// statement to be used to gather the dataset which should be used for the + /// model. This function returns that portion of the statement. + #[pyo3(name = "getValues")] + fn get_values(&self) -> PyResult>> { + self.values + .values + .iter() + .map(|e| py_expr_list(&self.plan, e)) + .collect() + } +} + +impl TryFrom for PyValues { + type Error = PyErr; + + fn try_from(logical_plan: LogicalPlan) -> Result { + match logical_plan { + LogicalPlan::Values(values) => Ok(PyValues { + plan: Arc::new(LogicalPlan::Values(values.clone())), + values, + }), + _ => Err(py_type_err("unexpected plan")), + } + } +} From 64f2a839b4504ce84e8c0a942175676cb1cb2c3a Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 8 Feb 2023 17:52:37 -0800 Subject: [PATCH 2/4] Update values to use new bindings --- dask_sql/physical/rel/logical/values.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/dask_sql/physical/rel/logical/values.py b/dask_sql/physical/rel/logical/values.py index ca95375c9..56d49afba 100644 --- a/dask_sql/physical/rel/logical/values.py +++ b/dask_sql/physical/rel/logical/values.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: import dask_sql - from dask_sql.java import org + from dask_planner.rust import LogicalPlan class DaskValuesPlugin(BaseRelPlugin): @@ -26,15 +26,12 @@ class DaskValuesPlugin(BaseRelPlugin): data samples. """ - class_name = "com.dask.sql.nodes.DaskValues" + class_name = "Values" - def convert( - self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" - ) -> DataContainer: + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: # There should not be any input. This is the first step. self.assert_inputs(rel, 0) - - rex_expression_rows = list(rel.getTuples()) + rex_expression_rows = rel.values().getValues() rows = [] for rex_expression_row in rex_expression_rows: # We convert each of the cells in the row @@ -44,7 +41,7 @@ def convert( # their index. rows.append( { - str(i): RexConverter.convert(rex_cell, None, context=context) + str(i): RexConverter.convert(rel, rex_cell, None, context=context) for i, rex_cell in enumerate(rex_expression_row) } ) From 338c3a1b212139467934d1d3f19199cfbfa2eab0 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 26 Apr 2023 15:08:52 -0700 Subject: [PATCH 3/4] Add tests --- tests/integration/test_values.py | 52 ++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/integration/test_values.py diff --git a/tests/integration/test_values.py b/tests/integration/test_values.py new file mode 100644 index 000000000..ef6138b90 --- /dev/null +++ b/tests/integration/test_values.py @@ -0,0 +1,52 @@ +import pandas as pd +import pytest + +from tests.utils import assert_eq + + +def test_values(c): + result_df = c.sql( + """ + SELECT * FROM (VALUES (1, 2), (1, 3)) as tbl(column1, column2) + """ + ) + expected_df = pd.DataFrame({"column1": [1, 1], "column2": [2, 3]}) + assert_eq(result_df, expected_df, check_index=False) + + +def test_values_join(c): + result_df = c.sql( + """ + SELECT * FROM df_simple, (VALUES (1, 2), (1, 3)) as tbl(aa, bb) + WHERE a = aa + """ + ) + expected_df = pd.DataFrame( + {"a": [1, 1], "b": [1.1, 1.1], "aa": [1, 1], "bb": [2, 3]} + ) + assert_eq(result_df, expected_df, check_index=False) + + +@pytest.mark.xfail(reason="Datafusion doesn't handle values relations cleanly") +def test_values_join_alias(c): + result_df = c.sql( + """ + SELECT * FROM df_simple, (VALUES (1, 2), (1, 3)) as tbl(aa, bb) + WHERE a = tbl.aa + """ + ) + expected_df = pd.DataFrame( + {"a": [1, 1], "b": [1.1, 1.1], "aa": [1, 1], "bb": [2, 3]} + ) + assert_eq(result_df, expected_df, check_index=False) + + result_df = c.sql( + """ + SELECT * FROM df_simple t1, (VALUES (1, 2), (1, 3)) as t2(a, b) + WHERE t1.a = t2.a + """ + ) + expected_df = pd.DataFrame( + {"t1.a": [1, 1], "t1.b": [1.1, 1.1], "t2.aa": [1, 1], "t2.bb": [2, 3]} + ) + assert_eq(result_df, expected_df, check_index=False) From fe72445c6cc3d6c122e270be924b73464e6b3a40 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Thu, 4 May 2023 11:42:59 -0700 Subject: [PATCH 4/4] Update crate namespace --- dask_planner/src/sql/logical/values.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_planner/src/sql/logical/values.rs b/dask_planner/src/sql/logical/values.rs index 31d71acac..4de881056 100644 --- a/dask_planner/src/sql/logical/values.rs +++ b/dask_planner/src/sql/logical/values.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use datafusion_expr::{logical_plan::Values, LogicalPlan}; +use datafusion_python::datafusion_expr::{logical_plan::Values, LogicalPlan}; use pyo3::prelude::*; use crate::{