Skip to content

Commit

Permalink
Add more unit tests for collect rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
phillipleblanc committed Jan 13, 2025
1 parent e670031 commit 4a9ea34
Showing 1 changed file with 186 additions and 0 deletions.
186 changes: 186 additions & 0 deletions sources/sql/src/rewrite/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1676,3 +1676,189 @@ mod tests {
Ok(())
}
}

#[cfg(test)]
mod collect_rewrites_tests {
use crate::{SQLExecutor, SQLFederationProvider};

use super::*;
use async_trait::async_trait;
use datafusion::{
arrow::datatypes::{DataType, Field, Schema, SchemaRef},
common::DFSchema,
datasource::DefaultTableSource,
execution::SendableRecordBatchStream,
sql::unparser::dialect::{DefaultDialect, Dialect},
};
use datafusion_federation::FederatedTableProviderAdaptor;

struct TestSQLExecutor {}

#[async_trait]
impl SQLExecutor for TestSQLExecutor {
fn name(&self) -> &str {
"test_sql_table_source"
}

fn compute_context(&self) -> Option<String> {
None
}

fn dialect(&self) -> Arc<dyn Dialect> {
Arc::new(DefaultDialect {})
}

fn execute(&self, _query: &str, _schema: SchemaRef) -> Result<SendableRecordBatchStream> {
Err(DataFusionError::NotImplemented(
"execute not implemented".to_string(),
))
}

async fn table_names(&self) -> Result<Vec<String>> {
Err(DataFusionError::NotImplemented(
"table inference not implemented".to_string(),
))
}

async fn get_table_schema(&self, _table_name: &str) -> Result<SchemaRef> {
Err(DataFusionError::NotImplemented(
"table inference not implemented".to_string(),
))
}
}

fn create_test_table_scan() -> LogicalPlan {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, false),
]));

let sql_federation_provider =
Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {})));
let table_source = Arc::new(
SQLTableSource::new_with_schema(
sql_federation_provider,
"remote_table".to_string(),
schema.clone(),
)
.expect("to have a valid SQLTableSource"),
);
let source = Arc::new(DefaultTableSource::new(Arc::new(
FederatedTableProviderAdaptor::new(table_source),
)));

let df_schema =
DFSchema::try_from(schema.as_ref().clone()).expect("to have a valid DFSchema");

LogicalPlan::TableScan(logical_expr::TableScan {
table_name: TableReference::from("foo.df_table"),
source,
projection: None,
projected_schema: df_schema.into(),
filters: vec![],
fetch: None,
})
}

#[test]
fn test_collect_from_table_scan() -> Result<()> {
let plan = create_test_table_scan();
let mut known_rewrites = HashMap::new();

collect_known_rewrites_from_plan(&plan, &mut known_rewrites)?;

assert_eq!(known_rewrites.len(), 1);
assert_eq!(
known_rewrites.get(&TableReference::from("foo.df_table")),
Some(&MultiPartTableReference::TableReference(
TableReference::from("remote_table")
))
);
Ok(())
}

#[test]
fn test_collect_from_scalar_subquery() -> Result<()> {
let table_scan = create_test_table_scan();
let subquery = Expr::ScalarSubquery(Subquery {
subquery: Arc::new(table_scan),
outer_ref_columns: vec![],
});

let mut known_rewrites = HashMap::new();
collect_known_rewrites_from_expr(subquery, &mut known_rewrites)?;

assert_eq!(known_rewrites.len(), 1);
assert_eq!(
known_rewrites.get(&TableReference::from("foo.df_table")),
Some(&MultiPartTableReference::TableReference(
TableReference::from("remote_table")
))
);
Ok(())
}

#[test]
fn test_collect_from_binary_expr() -> Result<()> {
let left = Expr::Column(Column::from_qualified_name("foo.df_table.a"));
let right = Expr::Column(Column::from_qualified_name("foo.df_table.b"));
let binary = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
datafusion::logical_expr::Operator::Eq,
Box::new(right),
));

let mut known_rewrites = HashMap::new();
collect_known_rewrites_from_expr(binary, &mut known_rewrites)?;

// Column expressions don't generate rewrites on their own
assert_eq!(known_rewrites.len(), 0);
Ok(())
}

#[test]
fn test_collect_from_case_expression() -> Result<()> {
let col = Expr::Column(Column::from_qualified_name("foo.df_table.a"));
let case = Expr::Case(Case::new(
Some(Box::new(col.clone())),
vec![(
Box::new(Expr::Literal(datafusion::scalar::ScalarValue::Int64(Some(
1,
)))),
Box::new(col.clone()),
)],
Some(Box::new(col)),
));

let mut known_rewrites = HashMap::new();
collect_known_rewrites_from_expr(case, &mut known_rewrites)?;

// Column expressions don't generate rewrites on their own
assert_eq!(known_rewrites.len(), 0);
Ok(())
}

#[test]
fn test_collect_from_exists_subquery() -> Result<()> {
let table_scan = create_test_table_scan();
let exists = Expr::Exists(Exists::new(
Subquery {
subquery: Arc::new(table_scan),
outer_ref_columns: vec![],
},
false,
));

let mut known_rewrites = HashMap::new();
collect_known_rewrites_from_expr(exists, &mut known_rewrites)?;

assert_eq!(known_rewrites.len(), 1);
assert_eq!(
known_rewrites.get(&TableReference::from("foo.df_table")),
Some(&MultiPartTableReference::TableReference(
TableReference::from("remote_table")
))
);
Ok(())
}
}

0 comments on commit 4a9ea34

Please sign in to comment.