diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 45ff6c5984ca..dd9978f5c26d 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -275,6 +275,7 @@ message Statistics { int64 num_rows = 1; int64 total_byte_size = 2; repeated ColumnStats column_stats = 3; + bool is_exact = 4; } message PartitionedFile { diff --git a/ballista/rust/core/src/datasource.rs b/ballista/rust/core/src/datasource.rs index b774b8d39b9d..3310a4a75a46 100644 --- a/ballista/rust/core/src/datasource.rs +++ b/ballista/rust/core/src/datasource.rs @@ -20,7 +20,7 @@ use std::{any::Any, sync::Arc}; use datafusion::arrow::datatypes::SchemaRef; use datafusion::error::Result as DFResult; use datafusion::{ - datasource::{datasource::Statistics, TableProvider}, + datasource::TableProvider, logical_plan::{Expr, LogicalPlan}, physical_plan::ExecutionPlan, }; @@ -61,12 +61,4 @@ impl TableProvider for DfTableAdapter { ) -> DFResult> { Ok(self.plan.clone()) } - - fn statistics(&self) -> Statistics { - Statistics { - num_rows: None, - total_byte_size: None, - column_statistics: None, - } - } } diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index 7793ad9e9244..bebc98f08cc4 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -35,7 +35,7 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, + SendableRecordBatchStream, Statistics, }; use async_trait::async_trait; @@ -203,6 +203,13 @@ impl ExecutionPlan for DistributedQueryExec { } } } + + fn statistics(&self) -> Statistics { + // This execution plan sends the logical plan to the scheduler without + // performing the node by node conversion to a full physical plan. + // This implies that we cannot infer the statistics at this stage. + Statistics::default() + } } async fn fetch_partition( diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index bc5dbc175c7a..6cdd8cc7665a 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -21,7 +21,7 @@ use std::{any::Any, pin::Pin}; use crate::client::BallistaClient; use crate::memory_stream::MemoryStream; -use crate::serde::scheduler::PartitionLocation; +use crate::serde::scheduler::{PartitionLocation, PartitionStats}; use crate::utils::WrappedStream; use async_trait::async_trait; @@ -31,7 +31,9 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::metrics::{ ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use datafusion::physical_plan::{DisplayFormatType, ExecutionPlan, Metric, Partitioning}; +use datafusion::physical_plan::{ + DisplayFormatType, ExecutionPlan, Metric, Partitioning, Statistics, +}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::RecordBatchStream, @@ -156,6 +158,38 @@ impl ExecutionPlan for ShuffleReaderExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn statistics(&self) -> Statistics { + stats_for_partitions( + self.partition + .iter() + .flatten() + .map(|loc| loc.partition_stats), + ) + } +} + +fn stats_for_partitions( + partition_stats: impl Iterator, +) -> Statistics { + // TODO stats: add column statistics to PartitionStats + partition_stats.fold( + Statistics { + is_exact: true, + num_rows: Some(0), + total_byte_size: Some(0), + column_statistics: None, + }, + |mut acc, part| { + // if any statistic is unkown it makes the entire statistic unkown + acc.num_rows = acc.num_rows.zip(part.num_rows).map(|(a, b)| a + b as usize); + acc.total_byte_size = acc + .total_byte_size + .zip(part.num_bytes) + .map(|(a, b)| a + b as usize); + acc + }, + ) } async fn fetch_partition( @@ -177,3 +211,76 @@ async fn fetch_partition( .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?) } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_stats_for_partitions_empty() { + let result = stats_for_partitions(std::iter::empty()); + + let exptected = Statistics { + is_exact: true, + num_rows: Some(0), + total_byte_size: Some(0), + column_statistics: None, + }; + + assert_eq!(result, exptected); + } + + #[tokio::test] + async fn test_stats_for_partitions_full() { + let part_stats = vec![ + PartitionStats { + num_rows: Some(10), + num_bytes: Some(84), + num_batches: Some(1), + }, + PartitionStats { + num_rows: Some(4), + num_bytes: Some(65), + num_batches: None, + }, + ]; + + let result = stats_for_partitions(part_stats.into_iter()); + + let exptected = Statistics { + is_exact: true, + num_rows: Some(14), + total_byte_size: Some(149), + column_statistics: None, + }; + + assert_eq!(result, exptected); + } + + #[tokio::test] + async fn test_stats_for_partitions_missing() { + let part_stats = vec![ + PartitionStats { + num_rows: Some(10), + num_bytes: Some(84), + num_batches: Some(1), + }, + PartitionStats { + num_rows: None, + num_bytes: None, + num_batches: None, + }, + ]; + + let result = stats_for_partitions(part_stats.into_iter()); + + let exptected = Statistics { + is_exact: true, + num_rows: None, + total_byte_size: None, + column_statistics: None, + }; + + assert_eq!(result, exptected); + } +} diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 36e445bc4ead..6884720501fa 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -51,7 +51,7 @@ use datafusion::physical_plan::metrics::{ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::Partitioning::RoundRobinBatch; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Metric, Partitioning, RecordBatchStream, + DisplayFormatType, ExecutionPlan, Metric, Partitioning, RecordBatchStream, Statistics, }; use futures::StreamExt; use hashbrown::HashMap; @@ -417,6 +417,10 @@ impl ExecutionPlan for ShuffleWriterExec { } } } + + fn statistics(&self) -> Statistics { + self.plan.statistics() + } } fn result_schema() -> SchemaRef { diff --git a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs index 3111b5a41be3..6290add4e2b4 100644 --- a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs @@ -23,7 +23,9 @@ use crate::serde::scheduler::PartitionLocation; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; +use datafusion::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, Statistics, +}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::RecordBatchStream, @@ -117,4 +119,10 @@ impl ExecutionPlan for UnresolvedShuffleExec { } } } + + fn statistics(&self) -> Statistics { + // The full statistics are computed in the `ShuffleReaderExec` node + // that replaces this one once the previous stage is completed. + Statistics::default() + } } diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index fc4ac2c9076c..38de341ed01d 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -359,6 +359,7 @@ impl TryInto for &protobuf::Statistics { num_rows: Some(self.num_rows as usize), total_byte_size: Some(self.total_byte_size as usize), column_statistics: Some(column_statistics), + is_exact: self.is_exact, }) } } @@ -1177,8 +1178,7 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; -use datafusion::datasource::datasource::{ColumnStatistics, Statistics}; -use datafusion::physical_plan::{aggregates, windows}; +use datafusion::physical_plan::{aggregates, windows, ColumnStatistics, Statistics}; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper, diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index aa7a973dd340..10bc63e4807b 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -25,7 +25,6 @@ use crate::serde::{protobuf, BallistaError}; use datafusion::arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, }; -use datafusion::datasource::datasource::{ColumnStatistics, Statistics}; use datafusion::datasource::{CsvFile, PartitionedFile, TableDescriptor}; use datafusion::logical_plan::{ window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, @@ -36,6 +35,7 @@ use datafusion::physical_plan::functions::BuiltinScalarFunction; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; +use datafusion::physical_plan::{ColumnStatistics, Statistics}; use datafusion::{datasource::parquet::ParquetTable, logical_plan::exprlist_to_fields}; use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType, @@ -278,6 +278,7 @@ impl From<&Statistics> for protobuf::Statistics { num_rows: s.num_rows.map(|n| n as i64).unwrap_or(none_value), total_byte_size: s.total_byte_size.map(|n| n as i64).unwrap_or(none_value), column_stats, + is_exact: s.is_exact, } } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 6aa0fa111921..3cd8cf3871cf 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -34,7 +34,6 @@ use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, }; -use datafusion::datasource::datasource::Statistics; use datafusion::datasource::object_store::ObjectStoreRegistry; use datafusion::datasource::FilePartition; use datafusion::execution::context::{ @@ -74,7 +73,9 @@ use datafusion::physical_plan::{ sort::{SortExec, SortOptions}, Partitioning, }; -use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; +use datafusion::physical_plan::{ + AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, +}; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::physical_expr_node::ExprType; diff --git a/ballista/rust/executor/src/collect.rs b/ballista/rust/executor/src/collect.rs index e9448c82d861..494bed2c5b7b 100644 --- a/ballista/rust/executor/src/collect.rs +++ b/ballista/rust/executor/src/collect.rs @@ -28,7 +28,7 @@ use datafusion::arrow::{ }; use datafusion::error::DataFusionError; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion::{error::Result, physical_plan::RecordBatchStream}; use futures::stream::SelectAll; @@ -116,6 +116,10 @@ impl ExecutionPlan for CollectExec { } } } + + fn statistics(&self) -> Statistics { + self.plan.statistics() + } } struct MergedRecordBatchStream { diff --git a/datafusion/src/datasource/csv.rs b/datafusion/src/datasource/csv.rs index 987c4fdb079d..971bd91315f9 100644 --- a/datafusion/src/datasource/csv.rs +++ b/datafusion/src/datasource/csv.rs @@ -39,7 +39,6 @@ use std::io::{Read, Seek}; use std::string::String; use std::sync::{Arc, Mutex}; -use crate::datasource::datasource::Statistics; use crate::datasource::{Source, TableProvider}; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Expr; @@ -54,7 +53,6 @@ pub struct CsvFile { has_header: bool, delimiter: u8, file_extension: String, - statistics: Statistics, } impl CsvFile { @@ -82,7 +80,6 @@ impl CsvFile { has_header: options.has_header, delimiter: options.delimiter, file_extension: String::from(options.file_extension), - statistics: Statistics::default(), }) } @@ -105,7 +102,6 @@ impl CsvFile { schema, has_header: options.has_header, delimiter: options.delimiter, - statistics: Statistics::default(), file_extension: String::new(), }) } @@ -133,7 +129,6 @@ impl CsvFile { schema, has_header: options.has_header, delimiter: options.delimiter, - statistics: Statistics::default(), file_extension: String::new(), }) } @@ -210,10 +205,6 @@ impl TableProvider for CsvFile { }; Ok(Arc::new(exec)) } - - fn statistics(&self) -> Statistics { - self.statistics.clone() - } } #[cfg(test)] diff --git a/datafusion/src/datasource/datasource.rs b/datafusion/src/datasource/datasource.rs index e173d6e0d771..3c60255590a1 100644 --- a/datafusion/src/datasource/datasource.rs +++ b/datafusion/src/datasource/datasource.rs @@ -20,34 +20,10 @@ use std::any::Any; use std::sync::Arc; +use crate::arrow::datatypes::SchemaRef; use crate::error::Result; use crate::logical_plan::Expr; use crate::physical_plan::ExecutionPlan; -use crate::{arrow::datatypes::SchemaRef, scalar::ScalarValue}; - -/// This table statistics are estimates. -/// It can not be used directly in the precise compute -#[derive(Debug, Clone, Default, PartialEq)] -pub struct Statistics { - /// The number of table rows - pub num_rows: Option, - /// total byte of the table rows - pub total_byte_size: Option, - /// Statistics on a column level - pub column_statistics: Option>, -} -/// This table statistics are estimates about column -#[derive(Clone, Debug, PartialEq)] -pub struct ColumnStatistics { - /// Number of null values on column - pub null_count: Option, - /// Maximum value of column - pub max_value: Option, - /// Minimum value of column - pub min_value: Option, - /// Number of distinct values - pub distinct_count: Option, -} /// Indicates whether and how a filter expression can be handled by a /// TableProvider for table scans. @@ -104,15 +80,6 @@ pub trait TableProvider: Sync + Send { limit: Option, ) -> Result>; - /// Returns the table Statistics - /// Statistics should be optional because not all data sources can provide statistics. - fn statistics(&self) -> Statistics; - - /// Returns whether statistics provided are exact values or estimates - fn has_exact_statistics(&self) -> bool { - false - } - /// Tests whether the table provider can make use of a filter expression /// to optimise data retrieval. fn supports_filter_pushdown( diff --git a/datafusion/src/datasource/empty.rs b/datafusion/src/datasource/empty.rs index e6140cdb8de6..183db76829c0 100644 --- a/datafusion/src/datasource/empty.rs +++ b/datafusion/src/datasource/empty.rs @@ -22,7 +22,6 @@ use std::sync::Arc; use arrow::datatypes::*; -use crate::datasource::datasource::Statistics; use crate::datasource::TableProvider; use crate::error::Result; use crate::logical_plan::Expr; @@ -69,12 +68,4 @@ impl TableProvider for EmptyTable { ); Ok(Arc::new(EmptyExec::new(false, Arc::new(projected_schema)))) } - - fn statistics(&self) -> Statistics { - Statistics { - num_rows: Some(0), - total_byte_size: Some(0), - column_statistics: None, - } - } } diff --git a/datafusion/src/datasource/json.rs b/datafusion/src/datasource/json.rs index 90fedfd6f528..f4e67828906e 100644 --- a/datafusion/src/datasource/json.rs +++ b/datafusion/src/datasource/json.rs @@ -37,8 +37,6 @@ use crate::{ }; use arrow::{datatypes::SchemaRef, json::reader::infer_json_schema_from_seekable}; -use super::datasource::Statistics; - trait SeekRead: Read + Seek {} impl SeekRead for T {} @@ -48,7 +46,6 @@ pub struct NdJsonFile { source: Source>, schema: SchemaRef, file_extension: String, - statistics: Statistics, } impl NdJsonFile { @@ -77,7 +74,6 @@ impl NdJsonFile { source: Source::Path(path.to_string()), schema, file_extension: options.file_extension.to_string(), - statistics: Statistics::default(), }) } @@ -101,7 +97,6 @@ impl NdJsonFile { Ok(Self { source: Source::Reader(Mutex::new(Some(Box::new(reader)))), schema, - statistics: Statistics::default(), file_extension: String::new(), }) } @@ -154,10 +149,6 @@ impl TableProvider for NdJsonFile { }; Ok(Arc::new(exec)) } - - fn statistics(&self) -> Statistics { - self.statistics.clone() - } } #[cfg(test)] diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index a4dbfd6c4a24..67b0e7b7c030 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -20,11 +20,10 @@ //! repeatedly queried without incurring additional file I/O overhead. use futures::StreamExt; -use log::debug; use std::any::Any; use std::sync::Arc; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use crate::datasource::TableProvider; @@ -33,56 +32,12 @@ use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::ExecutionPlan; -use crate::{ - datasource::datasource::Statistics, - physical_plan::{repartition::RepartitionExec, Partitioning}, -}; - -use super::datasource::ColumnStatistics; +use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; /// In-memory table pub struct MemTable { schema: SchemaRef, batches: Vec>, - statistics: Statistics, -} - -// Calculates statistics based on partitions -fn calculate_statistics( - schema: &SchemaRef, - partitions: &[Vec], -) -> Statistics { - let num_rows: usize = partitions - .iter() - .flat_map(|batches| batches.iter().map(RecordBatch::num_rows)) - .sum(); - - let mut null_count: Vec = vec![0; schema.fields().len()]; - for partition in partitions.iter() { - for batch in partition { - for (i, array) in batch.columns().iter().enumerate() { - null_count[i] += array.null_count(); - } - } - } - - let column_statistics = Some( - null_count - .iter() - .map(|null_count| ColumnStatistics { - null_count: Some(*null_count), - distinct_count: None, - max_value: None, - min_value: None, - }) - .collect(), - ); - - Statistics { - num_rows: Some(num_rows), - total_byte_size: None, - column_statistics, - } } impl MemTable { @@ -93,13 +48,9 @@ impl MemTable { .flatten() .all(|batches| schema.contains(&batches.schema())) { - let statistics = calculate_statistics(&schema, &partitions); - debug!("MemTable statistics: {:?}", statistics); - Ok(Self { schema, batches: partitions, - statistics, }) } else { Err(DataFusionError::Plan( @@ -179,47 +130,12 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { - let columns: Vec = match projection { - Some(p) => p.clone(), - None => { - let l = self.schema.fields().len(); - let mut v = Vec::with_capacity(l); - for i in 0..l { - v.push(i); - } - v - } - }; - - let projected_columns: Result> = columns - .iter() - .map(|i| { - if *i < self.schema.fields().len() { - Ok(self.schema.field(*i).clone()) - } else { - Err(DataFusionError::Internal( - "Projection index out of range".to_string(), - )) - } - }) - .collect(); - - let projected_schema = Arc::new(Schema::new(projected_columns?)); - Ok(Arc::new(MemoryExec::try_new( &self.batches.clone(), - projected_schema, + self.schema(), projection.clone(), )?)) } - - fn statistics(&self) -> Statistics { - self.statistics.clone() - } - - fn has_exact_statistics(&self) -> bool { - true - } } #[cfg(test)] @@ -251,37 +167,6 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; - assert_eq!(provider.statistics().num_rows, Some(3)); - assert_eq!( - provider.statistics().column_statistics, - Some(vec![ - ColumnStatistics { - null_count: Some(0), - max_value: None, - min_value: None, - distinct_count: None, - }, - ColumnStatistics { - null_count: Some(0), - max_value: None, - min_value: None, - distinct_count: None, - }, - ColumnStatistics { - null_count: Some(0), - max_value: None, - min_value: None, - distinct_count: None, - }, - ColumnStatistics { - null_count: Some(2), - max_value: None, - min_value: None, - distinct_count: None, - }, - ]) - ); - // scan with projection let exec = provider.scan(&Some(vec![2, 1]), 1024, &[], None)?; let mut it = exec.execute(0).await?; @@ -469,7 +354,6 @@ mod tests { let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); - assert_eq!(provider.statistics().num_rows, Some(6)); Ok(()) } diff --git a/datafusion/src/datasource/mod.rs b/datafusion/src/datasource/mod.rs index df3328ec81c8..53ba5177a2fc 100644 --- a/datafusion/src/datasource/mod.rs +++ b/datafusion/src/datasource/mod.rs @@ -29,11 +29,10 @@ pub use self::csv::{CsvFile, CsvReadOptions}; pub use self::datasource::{TableProvider, TableType}; pub use self::memory::MemTable; use crate::arrow::datatypes::{Schema, SchemaRef}; -use crate::datasource::datasource::{ColumnStatistics, Statistics}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::common::build_file_list; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; -use crate::physical_plan::Accumulator; +use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; use std::sync::Arc; /// Source for table input data @@ -194,9 +193,11 @@ pub fn get_statistics_with_limit( let mut num_rows = 0; let mut num_files = 0; + let mut is_exact = true; for file in &all_files { num_files += 1; let file_stats = &file.statistics; + is_exact &= file_stats.is_exact; num_rows += file_stats.num_rows.unwrap_or(0); total_byte_size += file_stats.total_byte_size.unwrap_or(0); if let Some(vec) = &file_stats.column_statistics { @@ -231,7 +232,10 @@ pub fn get_statistics_with_limit( break; } } - all_files.truncate(num_files); + if num_files < all_files.len() { + is_exact = false; + all_files.truncate(num_files); + } let column_stats = if has_statistics { Some(get_col_stats( @@ -248,6 +252,7 @@ pub fn get_statistics_with_limit( num_rows: Some(num_rows as usize), total_byte_size: Some(total_byte_size as usize), column_statistics: column_stats, + is_exact, }; (all_files, statistics) } diff --git a/datafusion/src/datasource/parquet.rs b/datafusion/src/datasource/parquet.rs index c11aadea9a64..8dc9bc52213d 100644 --- a/datafusion/src/datasource/parquet.rs +++ b/datafusion/src/datasource/parquet.rs @@ -28,16 +28,15 @@ use parquet::file::statistics::Statistics as ParquetStatistics; use super::datasource::TableProviderFilterPushDown; use crate::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use crate::datasource::datasource::Statistics; use crate::datasource::{ - create_max_min_accs, get_col_stats, get_statistics_with_limit, FileAndSchema, - PartitionedFile, TableDescriptor, TableDescriptorBuilder, TableProvider, + create_max_min_accs, get_col_stats, FileAndSchema, PartitionedFile, TableDescriptor, + TableDescriptorBuilder, TableProvider, }; use crate::error::Result; use crate::logical_plan::{combine_filters, Expr}; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::parquet::ParquetExec; -use crate::physical_plan::{Accumulator, ExecutionPlan}; +use crate::physical_plan::{Accumulator, ExecutionPlan, Statistics}; use crate::scalar::ScalarValue; /// Table-based representation of a `ParquetFile`. @@ -156,14 +155,6 @@ impl TableProvider for ParquetTable { limit, )?)) } - - fn statistics(&self) -> Statistics { - self.desc.statistics() - } - - fn has_exact_statistics(&self) -> bool { - true - } } #[derive(Debug, Clone)] @@ -200,11 +191,6 @@ impl ParquetTableDescriptor { self.descriptor.schema.clone() } - /// Get the summary statistics for all parquet files - pub fn statistics(&self) -> Statistics { - get_statistics_with_limit(&self.descriptor, None).1 - } - fn summarize_min_max( max_values: &mut Vec>, min_values: &mut Vec>, @@ -404,6 +390,7 @@ impl TableDescriptorBuilder for ParquetTableDescriptor { num_rows: Some(num_rows as usize), total_byte_size: Some(total_byte_size as usize), column_statistics: column_stats, + is_exact: true, }; Ok(FileAndSchema { @@ -440,8 +427,8 @@ mod tests { .await; // test metadata - assert_eq!(table.statistics().num_rows, Some(8)); - assert_eq!(table.statistics().total_byte_size, Some(671)); + assert_eq!(exec.statistics().num_rows, Some(8)); + assert_eq!(exec.statistics().total_byte_size, Some(671)); Ok(()) } diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index cbb2e73bb9d1..82947aaee1ba 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -22,11 +22,11 @@ use crate::{ information_schema::CatalogWithInformationSchema, }, logical_plan::{PlanType, ToStringifiedPlan}, - optimizer::{ - aggregate_statistics::AggregateStatistics, eliminate_limit::EliminateLimit, - hash_build_probe_order::HashBuildProbeOrder, + optimizer::eliminate_limit::EliminateLimit, + physical_optimizer::{ + aggregate_statistics::AggregateStatistics, + hash_build_probe_order::HashBuildProbeOrder, optimizer::PhysicalOptimizerRule, }, - physical_optimizer::optimizer::PhysicalOptimizerRule, }; use log::debug; use std::fs; @@ -710,14 +710,14 @@ impl Default for ExecutionConfig { optimizers: vec![ Arc::new(ConstantFolding::new()), Arc::new(EliminateLimit::new()), - Arc::new(AggregateStatistics::new()), Arc::new(ProjectionPushDown::new()), Arc::new(FilterPushDown::new()), Arc::new(SimplifyExpressions::new()), - Arc::new(HashBuildProbeOrder::new()), Arc::new(LimitPushDown::new()), ], physical_optimizers: vec![ + Arc::new(AggregateStatistics::new()), + Arc::new(HashBuildProbeOrder::new()), Arc::new(CoalesceBatches::new()), Arc::new(Repartition::new()), Arc::new(AddCoalescePartitionsExec::new()), @@ -3180,10 +3180,6 @@ mod tests { ) -> Result> { unimplemented!() } - - fn statistics(&self) -> crate::datasource::datasource::Statistics { - unimplemented!() - } } let mut ctx = ExecutionContext::with_config( diff --git a/datafusion/src/optimizer/aggregate_statistics.rs b/datafusion/src/optimizer/aggregate_statistics.rs deleted file mode 100644 index e2d905464201..000000000000 --- a/datafusion/src/optimizer/aggregate_statistics.rs +++ /dev/null @@ -1,489 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utilizing exact statistics from sources to avoid scanning data -use std::collections::HashMap; -use std::{sync::Arc, vec}; - -use crate::{ - execution::context::ExecutionProps, - logical_plan::{col, DFField, DFSchema, Expr, LogicalPlan}, - physical_plan::aggregates::AggregateFunction, - scalar::ScalarValue, -}; - -use super::{optimizer::OptimizerRule, utils}; -use crate::error::Result; - -/// Optimizer that uses available statistics for aggregate functions -pub struct AggregateStatistics {} - -impl AggregateStatistics { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for AggregateStatistics { - fn optimize( - &self, - plan: &LogicalPlan, - execution_props: &ExecutionProps, - ) -> crate::error::Result { - match plan { - // match only select count(*) from table_scan - LogicalPlan::Aggregate { - input, - group_expr, - aggr_expr, - schema, - } if group_expr.is_empty() => { - // aggregations that can not be replaced - // using statistics - let mut agg = vec![]; - let mut max_values = HashMap::new(); - let mut min_values = HashMap::new(); - - // expressions that can be replaced by constants - let mut projections = vec![]; - if let Some(num_rows) = match input.as_ref() { - LogicalPlan::TableScan { - table_name, source, .. - } if source.has_exact_statistics() => { - let schema = source.schema(); - let fields = schema.fields(); - if let Some(column_statistics) = - source.statistics().column_statistics - { - if fields.len() == column_statistics.len() { - for (i, field) in fields.iter().enumerate() { - if let Some(max_value) = - column_statistics[i].max_value.clone() - { - let max_key = - format!("{}.{}", table_name, field.name()); - max_values.insert(max_key, max_value); - } - if let Some(min_value) = - column_statistics[i].min_value.clone() - { - let min_key = - format!("{}.{}", table_name, field.name()); - min_values.insert(min_key, min_value); - } - } - } - } - - source.statistics().num_rows - } - _ => None, - } { - for expr in aggr_expr { - match expr { - Expr::AggregateFunction { - fun: AggregateFunction::Count, - args, - distinct: false, - } if args - == &[Expr::Literal(ScalarValue::UInt8(Some(1)))] => - { - projections.push(Expr::Alias( - Box::new(Expr::Literal(ScalarValue::UInt64(Some( - num_rows as u64, - )))), - "COUNT(Uint8(1))".to_string(), - )); - } - Expr::AggregateFunction { - fun: AggregateFunction::Max, - args, - .. - } => match &args[0] { - Expr::Column(c) => match max_values.get(&c.flat_name()) { - Some(max_value) => { - if !max_value.is_null() { - let name = format!("MAX({})", c.name); - projections.push(Expr::Alias( - Box::new(Expr::Literal( - max_value.clone(), - )), - name, - )); - } else { - agg.push(expr.clone()); - } - } - None => { - agg.push(expr.clone()); - } - }, - _ => { - agg.push(expr.clone()); - } - }, - Expr::AggregateFunction { - fun: AggregateFunction::Min, - args, - .. - } => match &args[0] { - Expr::Column(c) => match min_values.get(&c.flat_name()) { - Some(min_value) => { - if !min_value.is_null() { - let name = format!("MIN({})", c.name); - projections.push(Expr::Alias( - Box::new(Expr::Literal( - min_value.clone(), - )), - name, - )); - } else { - agg.push(expr.clone()); - } - } - None => { - agg.push(expr.clone()); - } - }, - _ => { - agg.push(expr.clone()); - } - }, - _ => { - agg.push(expr.clone()); - } - } - } - - return Ok(if agg.is_empty() { - // table scan can be entirely removed - - LogicalPlan::Projection { - expr: projections, - input: Arc::new(LogicalPlan::EmptyRelation { - produce_one_row: true, - schema: Arc::new(DFSchema::empty()), - }), - schema: schema.clone(), - } - } else if projections.is_empty() { - // no replacements -> return original plan - plan.clone() - } else { - // Split into parts that can be supported and part that should stay in aggregate - let agg_fields = agg - .iter() - .map(|x| x.to_field(input.schema())) - .collect::>>()?; - let agg_schema = DFSchema::new(agg_fields)?; - let cols = agg - .iter() - .map(|e| e.name(&agg_schema)) - .collect::>>()?; - projections.extend(cols.iter().map(|x| col(x))); - LogicalPlan::Projection { - expr: projections, - schema: schema.clone(), - input: Arc::new(LogicalPlan::Aggregate { - input: input.clone(), - group_expr: vec![], - aggr_expr: agg, - schema: Arc::new(agg_schema), - }), - } - }); - } - Ok(plan.clone()) - } - // Rest: recurse and find possible statistics - _ => { - let expr = plan.expressions(); - - // apply the optimization to all inputs of the plan - let inputs = plan.inputs(); - let new_inputs = inputs - .iter() - .map(|plan| self.optimize(plan, execution_props)) - .collect::>>()?; - - utils::from_plan(plan, &expr, &new_inputs) - } - } - } - - fn name(&self) -> &str { - "aggregate_statistics" - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::datatypes::{DataType, Field, Schema}; - - use crate::error::Result; - use crate::execution::context::ExecutionProps; - use crate::logical_plan::LogicalPlan; - use crate::optimizer::aggregate_statistics::AggregateStatistics; - use crate::optimizer::optimizer::OptimizerRule; - use crate::scalar::ScalarValue; - use crate::{ - datasource::{ - datasource::{ColumnStatistics, Statistics}, - TableProvider, - }, - logical_plan::Expr, - }; - - struct TestTableProvider { - num_rows: usize, - column_statistics: Vec, - is_exact: bool, - } - - impl TableProvider for TestTableProvider { - fn as_any(&self) -> &dyn std::any::Any { - unimplemented!() - } - fn schema(&self) -> arrow::datatypes::SchemaRef { - Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])) - } - - fn scan( - &self, - _projection: &Option>, - _batch_size: usize, - _filters: &[Expr], - _limit: Option, - ) -> Result> { - unimplemented!() - } - fn statistics(&self) -> Statistics { - Statistics { - num_rows: Some(self.num_rows), - total_byte_size: None, - column_statistics: Some(self.column_statistics.clone()), - } - } - fn has_exact_statistics(&self) -> bool { - self.is_exact - } - } - - #[test] - fn optimize_count_using_statistics() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics: Vec::new(), - is_exact: true, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("select count(*) from test") - .unwrap(); - let expected = "\ - Projection: #COUNT(UInt8(1))\ - \n Projection: UInt64(100) AS COUNT(Uint8(1))\ - \n EmptyRelation"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - #[test] - fn optimize_count_not_exact() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics: Vec::new(), - is_exact: false, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("select count(*) from test") - .unwrap(); - let expected = "\ - Projection: #COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: test projection=None"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - #[test] - fn optimize_count_sum() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics: Vec::new(), - is_exact: true, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("select sum(a)/count(*) from test") - .unwrap(); - let expected = "\ - Projection: #SUM(test.a) Divide #COUNT(UInt8(1))\ - \n Projection: UInt64(100) AS COUNT(Uint8(1)), #SUM(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(#test.a)]]\ - \n TableScan: test projection=None"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - #[test] - fn optimize_count_group_by() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics: Vec::new(), - is_exact: true, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("SELECT count(*), a FROM test GROUP BY a") - .unwrap(); - let expected = "\ - Projection: #COUNT(UInt8(1)), #test.a\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: test projection=None"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - #[test] - fn optimize_count_filter() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics: Vec::new(), - is_exact: true, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("SELECT count(*) FROM test WHERE a < 5") - .unwrap(); - let expected = "\ - Projection: #COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: #test.a Lt Int64(5)\ - \n TableScan: test projection=None"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - #[test] - fn optimize_max_min_using_statistics() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - - let column_statistic = ColumnStatistics { - null_count: None, - max_value: Some(ScalarValue::from(100_i64)), - min_value: Some(ScalarValue::from(1_i64)), - distinct_count: None, - }; - let column_statistics = vec![column_statistic]; - - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics, - is_exact: true, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("select max(a), min(a) from test") - .unwrap(); - let expected = "\ - Projection: #MAX(test.a), #MIN(test.a)\ - \n Projection: Int64(100) AS MAX(a), Int64(1) AS MIN(a)\ - \n EmptyRelation"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - #[test] - fn optimize_max_min_not_using_statistics() -> Result<()> { - use crate::execution::context::ExecutionContext; - let mut ctx = ExecutionContext::new(); - ctx.register_table( - "test", - Arc::new(TestTableProvider { - num_rows: 100, - column_statistics: Vec::new(), - is_exact: true, - }), - ) - .unwrap(); - - let plan = ctx - .create_logical_plan("select max(a), min(a) from test") - .unwrap(); - let expected = "\ - Projection: #MAX(test.a), #MIN(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(#test.a), MIN(#test.a)]]\ - \n TableScan: test projection=None"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let opt = AggregateStatistics::new(); - let optimized_plan = opt.optimize(plan, &ExecutionProps::new()).unwrap(); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), plan.schema()); - } -} diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index d0990de38dca..a51fbc225724 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -543,7 +543,6 @@ fn rewrite(expr: &Expr, projection: &HashMap) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::datasource::datasource::Statistics; use crate::datasource::TableProvider; use crate::logical_plan::{lit, sum, DFSchema, Expr, LogicalPlanBuilder, Operator}; use crate::physical_plan::ExecutionPlan; @@ -1161,10 +1160,6 @@ mod tests { fn as_any(&self) -> &dyn std::any::Any { self } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } fn table_scan_with_pushdown_provider( diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs deleted file mode 100644 index 209faf49bbe1..000000000000 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ /dev/null @@ -1,317 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License - -//! Optimizer rule to switch build and probe order of hash join -//! based on statistics of a `TableProvider`. If the number of -//! rows of both sources is known, the order can be switched -//! for a faster hash join. - -use std::sync::Arc; - -use crate::logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder}; -use crate::optimizer::optimizer::OptimizerRule; -use crate::{error::Result, prelude::JoinType}; - -use super::utils; -use crate::execution::context::ExecutionProps; - -/// BuildProbeOrder reorders the build and probe phase of -/// hash joins. This uses the amount of rows that a datasource has. -/// The rule optimizes the order such that the left (build) side of the join -/// is the smallest. -/// If the information is not available, the order stays the same, -/// so that it could be optimized manually in a query. -pub struct HashBuildProbeOrder {} - -// Gets exact number of rows, if known by the statistics of the underlying -fn get_num_rows(logical_plan: &LogicalPlan) -> Option { - match logical_plan { - LogicalPlan::TableScan { source, .. } => source.statistics().num_rows, - LogicalPlan::EmptyRelation { - produce_one_row, .. - } => { - if *produce_one_row { - Some(1) - } else { - Some(0) - } - } - LogicalPlan::Limit { n: limit, input } => { - let num_rows_input = get_num_rows(input); - num_rows_input.map(|rows| std::cmp::min(*limit, rows)) - } - LogicalPlan::Window { input, .. } => { - // window functions do not change num of rows - get_num_rows(input) - } - LogicalPlan::Aggregate { .. } => { - // we cannot yet predict how many rows will be produced by an aggregate because - // we do not know the cardinality of the grouping keys - None - } - LogicalPlan::Filter { .. } => { - // we cannot yet predict how many rows will be produced by a filter because - // we don't know how selective it is (how many rows it will filter out) - None - } - LogicalPlan::Join { .. } => { - // we cannot predict the cardinality of the join output - None - } - LogicalPlan::CrossJoin { left, right, .. } => { - // number of rows is equal to num_left * num_right - get_num_rows(left).and_then(|x| get_num_rows(right).map(|y| x * y)) - } - LogicalPlan::Repartition { .. } => { - // we cannot predict how rows will be repartitioned - None - } - LogicalPlan::Analyze { .. } => { - // Analyze produces one row, verbose produces more - // but it should never be used as an input to a Join anyways - None - } - // the following operators are special cases and not querying data - LogicalPlan::CreateExternalTable { .. } => None, - LogicalPlan::Explain { .. } => None, - // we do not support estimating rows with extensions yet - LogicalPlan::Extension { .. } => None, - // the following operators do not modify row count in any way - LogicalPlan::Projection { input, .. } => get_num_rows(input), - LogicalPlan::Sort { input, .. } => get_num_rows(input), - // Add number of rows of below plans - LogicalPlan::Union { inputs, .. } => { - inputs.iter().map(|plan| get_num_rows(plan)).sum() - } - } -} - -// Finds out whether to swap left vs right order based on statistics -fn should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool { - let left_rows = get_num_rows(left); - let right_rows = get_num_rows(right); - - match (left_rows, right_rows) { - (Some(l), Some(r)) => l > r, - _ => false, - } -} - -fn supports_swap(join_type: JoinType) -> bool { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true, - JoinType::Semi | JoinType::Anti => false, - } -} - -impl OptimizerRule for HashBuildProbeOrder { - fn name(&self) -> &str { - "hash_build_probe_order" - } - - fn optimize( - &self, - plan: &LogicalPlan, - execution_props: &ExecutionProps, - ) -> Result { - match plan { - // Main optimization rule, swaps order of left and right - // based on number of rows in each table - LogicalPlan::Join { - left, - right, - on, - join_type, - join_constraint, - schema, - } => { - let left = self.optimize(left, execution_props)?; - let right = self.optimize(right, execution_props)?; - if should_swap_join_order(&left, &right) && supports_swap(*join_type) { - // Swap left and right, change join type and (equi-)join key order - Ok(LogicalPlan::Join { - left: Arc::new(right), - right: Arc::new(left), - on: on.iter().map(|(l, r)| (r.clone(), l.clone())).collect(), - join_type: swap_join_type(*join_type), - join_constraint: *join_constraint, - schema: schema.clone(), - }) - } else { - // Keep join as is - Ok(LogicalPlan::Join { - left: Arc::new(left), - right: Arc::new(right), - on: on.clone(), - join_type: *join_type, - join_constraint: *join_constraint, - schema: schema.clone(), - }) - } - } - LogicalPlan::CrossJoin { - left, - right, - schema, - } => { - let left = self.optimize(left, execution_props)?; - let right = self.optimize(right, execution_props)?; - if should_swap_join_order(&left, &right) { - let swapped = - LogicalPlanBuilder::from(right.clone()).cross_join(&left)?; - // wrap plan with projection to maintain column order - let left_cols = left - .schema() - .fields() - .iter() - .map(|f| Expr::Column(f.qualified_column())); - let right_cols = right - .schema() - .fields() - .iter() - .map(|f| Expr::Column(f.qualified_column())); - swapped.project(left_cols.chain(right_cols))?.build() - } else { - // Keep join as is - Ok(LogicalPlan::CrossJoin { - left: Arc::new(left), - right: Arc::new(right), - schema: schema.clone(), - }) - } - } - // Rest: recurse into plan, apply optimization where possible - LogicalPlan::Projection { .. } - | LogicalPlan::Window { .. } - | LogicalPlan::Aggregate { .. } - | LogicalPlan::TableScan { .. } - | LogicalPlan::Limit { .. } - | LogicalPlan::Filter { .. } - | LogicalPlan::Repartition { .. } - | LogicalPlan::EmptyRelation { .. } - | LogicalPlan::Sort { .. } - | LogicalPlan::CreateExternalTable { .. } - | LogicalPlan::Explain { .. } - | LogicalPlan::Analyze { .. } - | LogicalPlan::Union { .. } - | LogicalPlan::Extension { .. } => { - let expr = plan.expressions(); - - // apply the optimization to all inputs of the plan - let inputs = plan.inputs(); - let new_inputs = inputs - .iter() - .map(|plan| self.optimize(plan, execution_props)) - .collect::>>()?; - - utils::from_plan(plan, &expr, &new_inputs) - } - } - } -} - -impl HashBuildProbeOrder { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn swap_join_type(join_type: JoinType) -> JoinType { - match join_type { - JoinType::Inner => JoinType::Inner, - JoinType::Full => JoinType::Full, - JoinType::Left => JoinType::Right, - JoinType::Right => JoinType::Left, - _ => unreachable!(), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use crate::{ - datasource::{datasource::Statistics, TableProvider}, - logical_plan::{DFSchema, Expr}, - test::*, - }; - - struct TestTableProvider { - num_rows: usize, - } - - impl TableProvider for TestTableProvider { - fn as_any(&self) -> &dyn std::any::Any { - unimplemented!() - } - fn schema(&self) -> arrow::datatypes::SchemaRef { - unimplemented!() - } - - fn scan( - &self, - _projection: &Option>, - _batch_size: usize, - _filters: &[Expr], - _limit: Option, - ) -> Result> { - unimplemented!() - } - fn statistics(&self) -> crate::datasource::datasource::Statistics { - Statistics { - num_rows: Some(self.num_rows), - total_byte_size: None, - column_statistics: None, - } - } - } - - #[test] - fn test_num_rows() -> Result<()> { - let table_scan = test_table_scan()?; - - assert_eq!(get_num_rows(&table_scan), Some(0)); - - Ok(()) - } - - #[test] - fn test_swap_order() { - let lp_left = LogicalPlan::TableScan { - table_name: "left".to_string(), - projection: None, - source: Arc::new(TestTableProvider { num_rows: 1000 }), - projected_schema: Arc::new(DFSchema::empty()), - filters: vec![], - limit: None, - }; - - let lp_right = LogicalPlan::TableScan { - table_name: "right".to_string(), - projection: None, - source: Arc::new(TestTableProvider { num_rows: 100 }), - projected_schema: Arc::new(DFSchema::empty()), - filters: vec![], - limit: None, - }; - - assert!(should_swap_join_order(&lp_left, &lp_right)); - assert!(!should_swap_join_order(&lp_right, &lp_left)); - } -} diff --git a/datafusion/src/optimizer/mod.rs b/datafusion/src/optimizer/mod.rs index 68758474d594..0246f9a12c65 100644 --- a/datafusion/src/optimizer/mod.rs +++ b/datafusion/src/optimizer/mod.rs @@ -18,11 +18,9 @@ //! This module contains a query optimizer that operates against a logical plan and applies //! some simple rules to a logical plan, such as "Projection Push Down" and "Type Coercion". -pub mod aggregate_statistics; pub mod constant_folding; pub mod eliminate_limit; pub mod filter_push_down; -pub mod hash_build_probe_order; pub mod limit_push_down; pub mod optimizer; pub mod projection_push_down; diff --git a/datafusion/src/physical_optimizer/aggregate_statistics.rs b/datafusion/src/physical_optimizer/aggregate_statistics.rs new file mode 100644 index 000000000000..1b361dd54936 --- /dev/null +++ b/datafusion/src/physical_optimizer/aggregate_statistics.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilizing exact statistics from sources to avoid scanning data +use std::sync::Arc; + +use arrow::datatypes::Schema; + +use crate::execution::context::ExecutionConfig; +use crate::physical_plan::empty::EmptyExec; +use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::{ + expressions, AggregateExpr, ColumnStatistics, ExecutionPlan, Statistics, +}; +use crate::scalar::ScalarValue; + +use super::optimizer::PhysicalOptimizerRule; +use super::utils::optimize_children; +use crate::error::Result; + +/// Optimizer that uses available statistics for aggregate functions +pub struct AggregateStatistics {} + +impl AggregateStatistics { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for AggregateStatistics { + fn optimize( + &self, + plan: Arc, + execution_config: &ExecutionConfig, + ) -> Result> { + if let Some(partial_agg_exec) = take_optimizable(&*plan) { + let partial_agg_exec = partial_agg_exec + .as_any() + .downcast_ref::() + .expect("take_optimizable() ensures that this is a HashAggregateExec"); + let stats = partial_agg_exec.input().statistics(); + let mut projections = vec![]; + for expr in partial_agg_exec.aggr_expr() { + if let Some((num_rows, name)) = take_optimizable_count(&**expr, &stats) { + projections.push((expressions::lit(num_rows), name.to_owned())); + } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { + projections.push((expressions::lit(min), name.to_owned())); + } else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) { + projections.push((expressions::lit(max), name.to_owned())); + } else { + // TODO: we need all aggr_expr to be resolved (cf TODO fullres) + break; + } + } + + // TODO fullres: use statistics even if not all aggr_expr could be resolved + if projections.len() == partial_agg_exec.aggr_expr().len() { + // input can be entirely removed + Ok(Arc::new(ProjectionExec::try_new( + projections, + Arc::new(EmptyExec::new(true, Arc::new(Schema::empty()))), + )?)) + } else { + optimize_children(self, plan, execution_config) + } + } else { + optimize_children(self, plan, execution_config) + } + } + + fn name(&self) -> &str { + "aggregate_statistics" + } +} + +/// assert if the node passed as argument is a final `HashAggregateExec` node that can be optimized: +/// - its child (with posssible intermediate layers) is a partial `HashAggregateExec` node +/// - they both have no grouping expression +/// - the statistics are exact +/// If this is the case, return a ref to the partial `HashAggregateExec`, else `None`. +/// We would have prefered to return a casted ref to HashAggregateExec but the recursion requires +/// the `ExecutionPlan.children()` method that returns an owned reference. +fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { + if let Some(final_agg_exec) = node.as_any().downcast_ref::() { + if final_agg_exec.mode() == &AggregateMode::Final + && final_agg_exec.group_expr().is_empty() + { + let mut child = Arc::clone(final_agg_exec.input()); + loop { + if let Some(partial_agg_exec) = + child.as_any().downcast_ref::() + { + if partial_agg_exec.mode() == &AggregateMode::Partial + && partial_agg_exec.group_expr().is_empty() + { + let stats = partial_agg_exec.input().statistics(); + if stats.is_exact { + return Some(child); + } + } + } + if let [ref childrens_child] = child.children().as_slice() { + child = Arc::clone(childrens_child); + } else { + break; + } + } + } + } + None +} + +/// If this agg_expr is a count that is defined in the statistics, return it +fn take_optimizable_count( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, &'static str)> { + if let (Some(num_rows), Some(casted_expr)) = ( + stats.num_rows, + agg_expr.as_any().downcast_ref::(), + ) { + // TODO implementing Eq on PhysicalExpr would help a lot here + if casted_expr.expressions().len() == 1 { + if let Some(lit_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { + return Some(( + ScalarValue::UInt64(Some(num_rows as u64)), + "COUNT(Uint8(1))", + )); + } + } + } + } + None +} + +/// If this agg_expr is a min that is defined in the statistics, return it +fn take_optimizable_min( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, String)> { + if let (Some(col_stats), Some(casted_expr)) = ( + &stats.column_statistics, + agg_expr.as_any().downcast_ref::(), + ) { + if casted_expr.expressions().len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if let ColumnStatistics { + min_value: Some(val), + .. + } = &col_stats[col_expr.index()] + { + return Some((val.clone(), format!("MIN({})", col_expr.name()))); + } + } + } + } + None +} + +/// If this agg_expr is a max that is defined in the statistics, return it +fn take_optimizable_max( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, String)> { + if let (Some(col_stats), Some(casted_expr)) = ( + &stats.column_statistics, + agg_expr.as_any().downcast_ref::(), + ) { + if casted_expr.expressions().len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if let ColumnStatistics { + max_value: Some(val), + .. + } = &col_stats[col_expr.index()] + { + return Some((val.clone(), format!("MAX({})", col_expr.name()))); + } + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::{Int32Array, UInt64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + + use crate::error::Result; + use crate::logical_plan::Operator; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::common; + use crate::physical_plan::expressions::Count; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::hash_aggregate::HashAggregateExec; + use crate::physical_plan::memory::MemoryExec; + + /// Mock data using a MemoryExec which has an exact count statistic + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + /// Checks that the count optimization was applied and we still get the right result + async fn assert_count_optim_success(plan: HashAggregateExec) -> Result<()> { + let conf = ExecutionConfig::new(); + let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; + + assert!(optimized.as_any().is::()); + let result = common::collect(optimized.execute(0).await?).await?; + assert_eq!( + result[0].schema(), + Arc::new(Schema::new(vec![Field::new( + "COUNT(Uint8(1))", + DataType::UInt64, + false + )])) + ); + assert_eq!( + result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values(), + &[3] + ); + Ok(()) + } + + fn count_expr() -> Arc { + Arc::new(Count::new( + expressions::lit(ScalarValue::UInt8(Some(1))), + "my_count_alias", + DataType::UInt64, + )) + } + + #[tokio::test] + async fn test_count_partial_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + + let partial_agg = HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![count_expr()], + source, + Arc::clone(&schema), + )?; + + let final_agg = HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![count_expr()], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + let partial_agg = HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![count_expr()], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![count_expr()], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + expressions::lit(ScalarValue::from(1u32)), + &schema, + )?, + source, + )?); + + let partial_agg = HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![count_expr()], + filter, + Arc::clone(&schema), + )?; + + let final_agg = HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![count_expr()], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ExecutionConfig::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } +} diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs new file mode 100644 index 000000000000..0b87ceb1a4e2 --- /dev/null +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilizing exact statistics from sources to avoid scanning data +use std::sync::Arc; + +use arrow::datatypes::Schema; + +use crate::execution::context::ExecutionConfig; +use crate::logical_plan::JoinType; +use crate::physical_plan::cross_join::CrossJoinExec; +use crate::physical_plan::expressions::Column; +use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; + +use super::optimizer::PhysicalOptimizerRule; +use super::utils::optimize_children; +use crate::error::Result; + +/// BuildProbeOrder reorders the build and probe phase of +/// hash joins. This uses the amount of rows that a datasource has. +/// The rule optimizes the order such that the left (build) side of the join +/// is the smallest. +/// If the information is not available, the order stays the same, +/// so that it could be optimized manually in a query. +pub struct HashBuildProbeOrder {} + +impl HashBuildProbeOrder { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -> bool { + let left_rows = left.statistics().num_rows; + let right_rows = right.statistics().num_rows; + + match (left_rows, right_rows) { + (Some(l), Some(r)) => l > r, + _ => false, + } +} + +fn supports_swap(join_type: JoinType) -> bool { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true, + JoinType::Semi | JoinType::Anti => false, + } +} + +fn swap_join_type(join_type: JoinType) -> JoinType { + match join_type { + JoinType::Inner => JoinType::Inner, + JoinType::Full => JoinType::Full, + JoinType::Left => JoinType::Right, + JoinType::Right => JoinType::Left, + _ => unreachable!(), + } +} + +/// When the order of the join is changed by the optimizer, +/// the columns in the output should not be impacted. +/// This helper creates the expressions that will allow to swap +/// back the values from the original left as first columns and +/// those on the right next +fn swap_reverting_projection( + left_schema: &Schema, + right_schema: &Schema, +) -> Vec<(Arc, String)> { + let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), i)) as Arc, + f.name().to_owned(), + ) + }); + let right_len = right_cols.len(); + let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), right_len + i)) as Arc, + f.name().to_owned(), + ) + }); + + left_cols.chain(right_cols).collect() +} + +impl PhysicalOptimizerRule for HashBuildProbeOrder { + fn optimize( + &self, + plan: Arc, + execution_config: &ExecutionConfig, + ) -> Result> { + let plan = optimize_children(self, plan, execution_config)?; + if let Some(hash_join) = plan.as_any().downcast_ref::() { + let left = hash_join.left(); + let right = hash_join.right(); + if should_swap_join_order(&**left, &**right) + && supports_swap(*hash_join.join_type()) + { + let new_join = HashJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + hash_join + .on() + .iter() + .map(|(l, r)| (r.clone(), l.clone())) + .collect(), + &swap_join_type(*hash_join.join_type()), + *hash_join.partition_mode(), + )?; + let proj = ProjectionExec::try_new( + swap_reverting_projection(&*left.schema(), &*right.schema()), + Arc::new(new_join), + )?; + return Ok(Arc::new(proj)); + } + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right) { + let new_join = + CrossJoinExec::try_new(Arc::clone(right), Arc::clone(left))?; + let proj = ProjectionExec::try_new( + swap_reverting_projection(&*left.schema(), &*right.schema()), + Arc::new(new_join), + )?; + return Ok(Arc::new(proj)); + } + } + Ok(plan) + } + + fn name(&self) -> &str { + "hash_build_probe_order" + } +} + +#[cfg(test)] +mod tests { + use crate::{ + physical_plan::{hash_join::PartitionMode, Statistics}, + test::exec::StatisticsExec, + }; + + use super::*; + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + + fn create_big_and_small() -> (Arc, Arc) { + let big = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Some(100000), + ..Default::default() + }, + Schema::new(vec![Field::new("big_col", DataType::Int32, false)]), + )); + + let small = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Some(10), + ..Default::default() + }, + Schema::new(vec![Field::new("small_col", DataType::Int32, false)]), + )); + (big, small) + } + + #[tokio::test] + async fn test_join_with_swap() { + let (big, small) = create_big_and_small(); + + let join = HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Column::new_with_schema("big_col", &big.schema()).unwrap(), + Column::new_with_schema("small_col", &small.schema()).unwrap(), + )], + &JoinType::Left, + PartitionMode::CollectLeft, + ) + .unwrap(); + + let optimized_join = HashBuildProbeOrder::new() + .optimize(Arc::new(join), &ExecutionConfig::new()) + .unwrap(); + + let swapping_projection = optimized_join + .as_any() + .downcast_ref::() + .expect("A proj is required to swap columns back to their original order"); + + assert_eq!(swapping_projection.expr().len(), 2); + let (col, name) = &swapping_projection.expr()[0]; + assert_eq!(name, "big_col"); + assert_col_expr(col, "big_col", 1); + let (col, name) = &swapping_projection.expr()[1]; + assert_eq!(name, "small_col"); + assert_col_expr(col, "small_col", 0); + + let swapped_join = swapping_projection + .input() + .as_any() + .downcast_ref::() + .expect("The type of the plan should not be changed"); + + assert_eq!(swapped_join.left().statistics().num_rows, Some(10)); + assert_eq!(swapped_join.right().statistics().num_rows, Some(100000)); + } + + #[tokio::test] + async fn test_join_no_swap() { + let (big, small) = create_big_and_small(); + + let join = HashJoinExec::try_new( + Arc::clone(&small), + Arc::clone(&big), + vec![( + Column::new_with_schema("small_col", &small.schema()).unwrap(), + Column::new_with_schema("big_col", &big.schema()).unwrap(), + )], + &JoinType::Left, + PartitionMode::CollectLeft, + ) + .unwrap(); + + let optimized_join = HashBuildProbeOrder::new() + .optimize(Arc::new(join), &ExecutionConfig::new()) + .unwrap(); + + let swapped_join = optimized_join + .as_any() + .downcast_ref::() + .expect("The type of the plan should not be changed"); + + assert_eq!(swapped_join.left().statistics().num_rows, Some(10)); + assert_eq!(swapped_join.right().statistics().num_rows, Some(100000)); + } + + #[tokio::test] + async fn test_swap_reverting_projection() { + let left_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let right_schema = Schema::new(vec![Field::new("c", DataType::Int32, false)]); + + let proj = swap_reverting_projection(&left_schema, &right_schema); + + assert_eq!(proj.len(), 3); + + let (col, name) = &proj[0]; + assert_eq!(name, "a"); + assert_col_expr(col, "a", 1); + + let (col, name) = &proj[1]; + assert_eq!(name, "b"); + assert_col_expr(col, "b", 2); + + let (col, name) = &proj[2]; + assert_eq!(name, "c"); + assert_col_expr(col, "c", 0); + } + + fn assert_col_expr(expr: &Arc, name: &str, index: usize) { + let col = expr + .as_any() + .downcast_ref::() + .expect("Projection items should be Column expression"); + assert_eq!(col.name(), name); + assert_eq!(col.index(), index); + } +} diff --git a/datafusion/src/physical_optimizer/mod.rs b/datafusion/src/physical_optimizer/mod.rs index 8e79fe932874..ed45057784ae 100644 --- a/datafusion/src/physical_optimizer/mod.rs +++ b/datafusion/src/physical_optimizer/mod.rs @@ -18,8 +18,11 @@ //! This module contains a query optimizer that operates against a physical plan and applies //! rules to a physical plan, such as "Repartition". +pub mod aggregate_statistics; pub mod coalesce_batches; +pub mod hash_build_probe_order; pub mod merge_exec; pub mod optimizer; pub mod pruning; pub mod repartition; +mod utils; diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index fd8650411d71..61266e442c98 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -109,11 +109,11 @@ mod tests { use arrow::datatypes::Schema; use super::*; - use crate::datasource::datasource::Statistics; use crate::datasource::PartitionedFile; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; use crate::physical_plan::parquet::{ParquetExec, ParquetPartition}; use crate::physical_plan::projection::ProjectionExec; + use crate::physical_plan::Statistics; #[test] fn added_repartition_to_single_partition() -> Result<()> { diff --git a/datafusion/src/physical_optimizer/utils.rs b/datafusion/src/physical_optimizer/utils.rs new file mode 100644 index 000000000000..962b8ce14557 --- /dev/null +++ b/datafusion/src/physical_optimizer/utils.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Collection of utility functions that are leveraged by the query optimizer rules + +use super::optimizer::PhysicalOptimizerRule; +use crate::execution::context::ExecutionConfig; + +use crate::error::Result; +use crate::physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// Convenience rule for writing optimizers: recursively invoke +/// optimize on plan's children and then return a node of the same +/// type. Useful for optimizer rules which want to leave the type +/// of plan unchanged but still apply to the children. +pub fn optimize_children( + optimizer: &impl PhysicalOptimizerRule, + plan: Arc, + execution_config: &ExecutionConfig, +) -> Result> { + let children = plan + .children() + .iter() + .map(|child| optimizer.optimize(Arc::clone(child), execution_config)) + .collect::>>()?; + + if children.is_empty() { + Ok(Arc::clone(&plan)) + } else { + plan.with_new_children(children) + } +} diff --git a/datafusion/src/physical_plan/analyze.rs b/datafusion/src/physical_plan/analyze.rs index d0125579ace2..e68acc5fab2e 100644 --- a/datafusion/src/physical_plan/analyze.rs +++ b/datafusion/src/physical_plan/analyze.rs @@ -22,8 +22,10 @@ use std::{any::Any, time::Instant}; use crate::{ error::{DataFusionError, Result}, - physical_plan::{display::DisplayableExecutionPlan, Partitioning}, - physical_plan::{DisplayFormatType, ExecutionPlan}, + physical_plan::{ + display::DisplayableExecutionPlan, DisplayFormatType, ExecutionPlan, + Partitioning, Statistics, + }, }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use futures::StreamExt; @@ -206,4 +208,9 @@ impl ExecutionPlan for AnalyzeExec { } } } + + fn statistics(&self) -> Statistics { + // Statistics an an ANALYZE plan are not relevant + Statistics::default() + } } diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index e25412d9d6b8..aee9aea71ddb 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -37,6 +37,8 @@ use async_trait::async_trait; use futures::stream::{Stream, StreamExt}; use log::debug; +use super::Statistics; + /// CoalesceBatchesExec combines small batches into larger batches for more efficient use of /// vectorized processing by upstream operators. #[derive(Debug)] @@ -131,6 +133,10 @@ impl ExecutionPlan for CoalesceBatchesExec { } } } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } } struct CoalesceBatchesStream { diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs b/datafusion/src/physical_plan/coalesce_partitions.rs index 8781a3d3ad75..0d2cc899ebf6 100644 --- a/datafusion/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/src/physical_plan/coalesce_partitions.rs @@ -31,7 +31,7 @@ use arrow::record_batch::RecordBatch; use arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use super::RecordBatchStream; +use super::{RecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; @@ -153,6 +153,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } } pin_project! { diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index 2482bfc0872c..d0b7a07f3b79 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -19,9 +19,9 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::ExecutionPlan; +use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::compute::concat; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -169,6 +169,48 @@ pub(crate) fn spawn_execution( }) } +/// Computes the statistics for an in-memory RecordBatch +/// +/// Only computes statistics that are in arrows metadata (num rows, byte size and nulls) +/// and does not apply any kernel on the actual data. +pub fn compute_record_batch_statistics( + batches: &[Vec], + schema: &Schema, + projection: Option>, +) -> Statistics { + let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); + + let total_byte_size = batches + .iter() + .flatten() + .flat_map(RecordBatch::columns) + .map(|a| a.get_array_memory_size()) + .sum(); + + let projection = match projection { + Some(p) => p, + None => (0..schema.fields().len()).collect(), + }; + + let mut column_statistics = vec![ColumnStatistics::default(); projection.len()]; + + for partition in batches.iter() { + for batch in partition { + for (stat_index, col_index) in projection.iter().enumerate() { + *column_statistics[stat_index].null_count.get_or_insert(0) += + batch.column(*col_index).null_count(); + } + } + } + + Statistics { + num_rows: Some(nb_rows), + total_byte_size: Some(total_byte_size), + column_statistics: Some(column_statistics), + is_exact: true, + } +} + #[cfg(test)] mod tests { use super::*; @@ -217,4 +259,58 @@ mod tests { assert_eq!(batch_count * batch_size, result.num_rows()); Ok(()) } + + #[test] + fn test_compute_record_batch_statistics_empty() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + let stats = compute_record_batch_statistics(&[], &schema, Some(vec![0, 1])); + + assert_eq!(stats.num_rows, Some(0)); + assert!(stats.is_exact); + assert_eq!(stats.total_byte_size, Some(0)); + Ok(()) + } + + #[test] + fn test_compute_record_batch_statistics() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Float32Array::from(vec![1., 2., 3.])), + Arc::new(Float64Array::from(vec![9., 8., 7.])), + ], + )?; + let result = + compute_record_batch_statistics(&[vec![batch]], &schema, Some(vec![0, 1])); + + let expected = Statistics { + is_exact: true, + num_rows: Some(3), + total_byte_size: Some(416), // this might change a bit if the way we compute the size changes + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: None, + max_value: None, + min_value: None, + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: None, + max_value: None, + min_value: None, + null_count: Some(0), + }, + ]), + }; + + assert_eq!(result, expected); + Ok(()) + } } diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 98ad3440aa4a..1575958b5375 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -29,6 +29,7 @@ use futures::{Stream, TryStreamExt}; use super::{ coalesce_partitions::CoalescePartitionsExec, hash_utils::check_join_is_valid, + ColumnStatistics, Statistics, }; use crate::{ error::{DataFusionError, Result}, @@ -207,6 +208,79 @@ impl ExecutionPlan for CrossJoinExec { } } } + + fn statistics(&self) -> Statistics { + stats_cartesian_product( + self.left.statistics(), + self.left.schema().fields().len(), + self.right.statistics(), + self.right.schema().fields().len(), + ) + } +} + +/// [left/right]_col_count are required in case the column statistics are None +fn stats_cartesian_product( + left_stats: Statistics, + left_col_count: usize, + right_stats: Statistics, + right_col_count: usize, +) -> Statistics { + let left_row_count = left_stats.num_rows; + let right_row_count = right_stats.num_rows; + + // calculate global stats + let is_exact = left_stats.is_exact && right_stats.is_exact; + let num_rows = left_stats + .num_rows + .zip(right_stats.num_rows) + .map(|(a, b)| a * b); + // the result size is two times a*b because you have the columns of both left and right + let total_byte_size = left_stats + .total_byte_size + .zip(right_stats.total_byte_size) + .map(|(a, b)| 2 * a * b); + + // calculate column stats + let column_statistics = + // complete the column statistics if they are missing only on one side + match (left_stats.column_statistics, right_stats.column_statistics) { + (None, None) => None, + (None, Some(right_col_stat)) => Some(( + vec![ColumnStatistics::default(); left_col_count], + right_col_stat, + )), + (Some(left_col_stat), None) => Some(( + left_col_stat, + vec![ColumnStatistics::default(); right_col_count], + )), + (Some(left_col_stat), Some(right_col_stat)) => { + Some((left_col_stat, right_col_stat)) + } + } + .map(|(left_col_stats, right_col_stats)| { + // the null counts must be multiplied by the row counts of the other side (if defined) + // Min, max and distinct_count on the other hand are invariants. + left_col_stats.into_iter().map(|s| ColumnStatistics{ + null_count: s.null_count.zip(right_row_count).map(|(a, b)| a * b), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + }).chain( + right_col_stats.into_iter().map(|s| ColumnStatistics{ + null_count: s.null_count.zip(left_row_count).map(|(a, b)| a * b), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + })).collect() + }); + + Statistics { + is_exact, + num_rows, + total_byte_size, + column_statistics, + } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. @@ -331,3 +405,145 @@ impl Stream for CrossJoinStream { }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_stats_cartesian_product() { + let left_row_count = 11; + let left_bytes = 23; + let right_row_count = 7; + let right_bytes = 27; + + let left = Statistics { + is_exact: true, + num_rows: Some(left_row_count), + total_byte_size: Some(left_bytes), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: Some(3), + }, + ]), + }; + + let right = Statistics { + is_exact: true, + num_rows: Some(right_row_count), + total_byte_size: Some(right_bytes), + column_statistics: Some(vec![ColumnStatistics { + distinct_count: Some(3), + max_value: Some(ScalarValue::Int64(Some(12))), + min_value: Some(ScalarValue::Int64(Some(0))), + null_count: Some(2), + }]), + }; + + let result = stats_cartesian_product(left, 3, right, 2); + + let expected = Statistics { + is_exact: true, + num_rows: Some(left_row_count * right_row_count), + total_byte_size: Some(2 * left_bytes * right_bytes), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: Some(3 * right_row_count), + }, + ColumnStatistics { + distinct_count: Some(3), + max_value: Some(ScalarValue::Int64(Some(12))), + min_value: Some(ScalarValue::Int64(Some(0))), + null_count: Some(2 * left_row_count), + }, + ]), + }; + + assert_eq!(result, expected); + } + + #[tokio::test] + async fn test_stats_cartesian_product_with_unknwon_size() { + let left_row_count = 11; + + let left = Statistics { + is_exact: true, + num_rows: Some(left_row_count), + total_byte_size: Some(23), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: Some(3), + }, + ]), + }; + + let right = Statistics { + is_exact: true, + num_rows: None, // not defined! + total_byte_size: None, // not defined! + column_statistics: Some(vec![ColumnStatistics { + distinct_count: Some(3), + max_value: Some(ScalarValue::Int64(Some(12))), + min_value: Some(ScalarValue::Int64(Some(0))), + null_count: Some(2), + }]), + }; + + let result = stats_cartesian_product(left, 3, right, 2); + + let expected = Statistics { + is_exact: true, + num_rows: None, + total_byte_size: None, + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: None, // we don't know the row count on the right + }, + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: None, // we don't know the row count on the right + }, + ColumnStatistics { + distinct_count: Some(3), + max_value: Some(ScalarValue::Int64(Some(12))), + min_value: Some(ScalarValue::Int64(Some(0))), + null_count: Some(2 * left_row_count), + }, + ]), + }; + + assert_eq!(result, expected); + } +} diff --git a/datafusion/src/physical_plan/csv.rs b/datafusion/src/physical_plan/csv.rs index 544f98cba0c6..35bd2247bfbc 100644 --- a/datafusion/src/physical_plan/csv.rs +++ b/datafusion/src/physical_plan/csv.rs @@ -33,7 +33,9 @@ use std::sync::Arc; use std::sync::Mutex; use std::task::{Context, Poll}; -use super::{DisplayFormatType, RecordBatchStream, SendableRecordBatchStream}; +use super::{ + DisplayFormatType, RecordBatchStream, SendableRecordBatchStream, Statistics, +}; use async_trait::async_trait; /// CSV file read option @@ -363,6 +365,11 @@ impl ExecutionPlan for CsvExec { } } } + + fn statistics(&self) -> Statistics { + // TODO stats: handle statistics + Statistics::default() + } } /// Iterator over batches diff --git a/datafusion/src/physical_plan/empty.rs b/datafusion/src/physical_plan/empty.rs index 391a695f4501..430beaf592e2 100644 --- a/datafusion/src/physical_plan/empty.rs +++ b/datafusion/src/physical_plan/empty.rs @@ -28,7 +28,7 @@ use arrow::array::NullArray; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use super::SendableRecordBatchStream; +use super::{common, SendableRecordBatchStream, Statistics}; use async_trait::async_trait; @@ -54,6 +54,23 @@ impl EmptyExec { pub fn produce_one_row(&self) -> bool { self.produce_one_row } + + fn data(&self) -> Result> { + let batch = if self.produce_one_row { + vec![RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "placeholder", + DataType::Null, + true, + )])), + vec![Arc::new(NullArray::new(1))], + )?] + } else { + vec![] + }; + + Ok(batch) + } } #[async_trait] @@ -101,22 +118,8 @@ impl ExecutionPlan for EmptyExec { ))); } - // Makes a stream only contains one null element if needed - let data = if self.produce_one_row { - vec![RecordBatch::try_new( - Arc::new(Schema::new(vec![Field::new( - "placeholder", - DataType::Null, - true, - )])), - vec![Arc::new(NullArray::new(1))], - )?] - } else { - vec![] - }; - Ok(Box::pin(MemoryStream::try_new( - data, + self.data()?, self.schema.clone(), None, )?)) @@ -133,6 +136,13 @@ impl ExecutionPlan for EmptyExec { } } } + + fn statistics(&self) -> Statistics { + let batch = self + .data() + .expect("Create empty RecordBatch should not fail"); + common::compute_record_batch_statistics(&[batch], &self.schema, None) + } } #[cfg(test)] diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index a6a34f5d0b0c..74093259aaf6 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -23,8 +23,10 @@ use std::sync::Arc; use crate::{ error::{DataFusionError, Result}, logical_plan::StringifiedPlan, - physical_plan::Partitioning, - physical_plan::{common::SizedRecordBatchStream, DisplayFormatType, ExecutionPlan}, + physical_plan::{ + common::SizedRecordBatchStream, DisplayFormatType, ExecutionPlan, Partitioning, + Statistics, + }, }; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; @@ -156,6 +158,11 @@ impl ExecutionPlan for ExplainExec { } } } + + fn statistics(&self) -> Statistics { + // Statistics an EXPLAIN plan are not relevant + Statistics::default() + } } /// If this plan should be shown, given the previous plan that was diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 21cf95d6d626..97486680f2e0 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -305,12 +305,12 @@ macro_rules! min_max { } /// the minimum of two scalar values -fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { +pub fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { min_max!(lhs, rhs, min) } /// the maximum of two scalar values -fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { +pub fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { min_max!(lhs, rhs, max) } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index d60a871baa80..5a5a1189af05 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -38,6 +38,7 @@ mod is_not_null; mod is_null; mod lead_lag; mod literal; +#[macro_use] mod min_max; mod negative; mod not; @@ -48,6 +49,11 @@ mod row_number; mod sum; mod try_cast; +/// Module with some convenient methods used in expression building +pub mod helpers { + pub use super::min_max::{max, min}; +} + pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index 52017c63b253..8acfd1b92e6b 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -23,7 +23,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, @@ -144,6 +144,11 @@ impl ExecutionPlan for FilterExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + /// The output statistics of a filtering operation are unknown + fn statistics(&self) -> Statistics { + Statistics::default() + } } /// The FilterExec streams wraps the input iterator and applies the predicate expression to diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index e21cc311ccf0..adeeb0bf8eab 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -53,6 +53,7 @@ use async_trait::async_trait; use super::metrics::{ self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; +use super::Statistics; use super::{expressions::Column, RecordBatchStream, SendableRecordBatchStream}; /// Hash aggregate modes @@ -285,6 +286,26 @@ impl ExecutionPlan for HashAggregateExec { } Ok(()) } + + fn statistics(&self) -> Statistics { + // TODO stats: group expressions: + // - once expressions will be able to compute their own stats, use it here + // - case where we group by on a column for which with have the `distinct` stat + // TODO stats: aggr expression: + // - aggregations somtimes also preserve invariants such as min, max... + match self.mode { + AggregateMode::Final | AggregateMode::FinalPartitioned + if self.group_expr.is_empty() => + { + Statistics { + num_rows: Some(1), + is_exact: true, + ..Default::default() + } + } + _ => Statistics::default(), + } + } } /* @@ -1145,6 +1166,11 @@ mod tests { } Ok(Box::pin(stream)) } + + fn statistics(&self) -> Statistics { + let (_, batches) = some_data(); + common::compute_record_batch_statistics(&[batches], &self.schema(), None) + } } /// A stream using the demo data. If inited as new, it will first yield to runtime before returning records diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index e189f94085ad..f2ce88fddad4 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -50,7 +50,6 @@ use arrow::array::{ use hashbrown::raw::RawTable; -use super::hash_utils::create_hashes; use super::{ coalesce_partitions::CoalescePartitionsExec, hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, @@ -59,6 +58,7 @@ use super::{ expressions::Column, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, }; +use super::{hash_utils::create_hashes, Statistics}; use crate::error::{DataFusionError, Result}; use crate::logical_plan::JoinType; @@ -462,6 +462,13 @@ impl ExecutionPlan for HashJoinExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + Statistics::default() + } } /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, diff --git a/datafusion/src/physical_plan/json.rs b/datafusion/src/physical_plan/json.rs index 24631c57739e..675d88ec3bfa 100644 --- a/datafusion/src/physical_plan/json.rs +++ b/datafusion/src/physical_plan/json.rs @@ -20,7 +20,9 @@ use async_trait::async_trait; use futures::Stream; use super::DisplayFormatType; -use super::{common, source::Source, ExecutionPlan, Partitioning, RecordBatchStream}; +use super::{ + common, source::Source, ExecutionPlan, Partitioning, RecordBatchStream, Statistics, +}; use crate::error::{DataFusionError, Result}; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow::{ @@ -324,6 +326,11 @@ impl ExecutionPlan for NdJsonExec { } } } + + fn statistics(&self) -> Statistics { + // TODO stats: handle statistics + Statistics::default() + } } struct NdJsonStream { diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index 9f4744291c49..792b8f50e1d6 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -35,7 +35,7 @@ use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use async_trait::async_trait; @@ -135,6 +135,27 @@ impl ExecutionPlan for GlobalLimitExec { } } } + + fn statistics(&self) -> Statistics { + let input_stats = self.input.statistics(); + match input_stats { + // if the input does not reach the limit globally, return input stats + Statistics { + num_rows: Some(nr), .. + } if nr <= self.limit => input_stats, + // if the input is greater than the limit, the num_row will be the limit + // but we won't be able to predict the other statistics + Statistics { + num_rows: Some(nr), .. + } if nr > self.limit => Statistics { + num_rows: Some(self.limit), + is_exact: input_stats.is_exact, + ..Default::default() + }, + // if we don't know the input size, we can't predict the limit's behaviour + _ => Statistics::default(), + } + } } /// LocalLimitExec applies a limit to a single partition @@ -213,6 +234,30 @@ impl ExecutionPlan for LocalLimitExec { } } } + + fn statistics(&self) -> Statistics { + let input_stats = self.input.statistics(); + match input_stats { + // if the input does not reach the limit globally, return input stats + Statistics { + num_rows: Some(nr), .. + } if nr <= self.limit => input_stats, + // if the input is greater than the limit, the num_row will be greater + // than the limit because the partitions will be limited separatly + // the statistic + Statistics { + num_rows: Some(nr), .. + } if nr > self.limit => Statistics { + num_rows: Some(self.limit), + // this is not actually exact, but will be when GlobalLimit is applied + // TODO stats: find a more explicit way to vehiculate this information + is_exact: input_stats.is_exact, + ..Default::default() + }, + // if we don't know the input size, we can't predict the limit's behaviour + _ => Statistics::default(), + } + } } /// Truncate a RecordBatch to maximum of n rows diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index 85d8aeef073c..e2e6221cada6 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -23,11 +23,11 @@ use std::sync::Arc; use std::task::{Context, Poll}; use super::{ - DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, + common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; use crate::error::{DataFusionError, Result}; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -38,8 +38,10 @@ use futures::Stream; pub struct MemoryExec { /// The partitions to query partitions: Vec>, - /// Schema representing the data after the optional projection is applied + /// Schema representing the data before projection schema: SchemaRef, + /// Schema representing the data after the optional projection is applied + projected_schema: SchemaRef, /// Optional projection projection: Option>, } @@ -47,7 +49,7 @@ pub struct MemoryExec { impl fmt::Debug for MemoryExec { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "partitions: [...]")?; - write!(f, "schema: {:?}", self.schema)?; + write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection) } } @@ -61,7 +63,7 @@ impl ExecutionPlan for MemoryExec { /// Get the schema for this execution plan fn schema(&self) -> SchemaRef { - self.schema.clone() + self.projected_schema.clone() } fn children(&self) -> Vec> { @@ -87,7 +89,7 @@ impl ExecutionPlan for MemoryExec { async fn execute(&self, partition: usize) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), - self.schema.clone(), + self.projected_schema.clone(), self.projection.clone(), )?)) } @@ -110,18 +112,47 @@ impl ExecutionPlan for MemoryExec { } } } + + /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so + fn statistics(&self) -> Statistics { + common::compute_record_batch_statistics( + &self.partitions, + &self.schema, + self.projection.clone(), + ) + } } impl MemoryExec { /// Create a new execution plan for reading in-memory record batches + /// The provided `schema` should not have the projection applied. pub fn try_new( partitions: &[Vec], schema: SchemaRef, projection: Option>, ) -> Result { + let projected_schema = match &projection { + Some(columns) => { + let fields: Result> = columns + .iter() + .map(|i| { + if *i < schema.fields().len() { + Ok(schema.field(*i).clone()) + } else { + Err(DataFusionError::Internal( + "Projection index out of range".to_string(), + )) + } + }) + .collect(); + Arc::new(Schema::new(fields?)) + } + None => Arc::clone(&schema), + }; Ok(Self { partitions: partitions.to_vec(), schema, + projected_schema, projection, }) } @@ -189,3 +220,116 @@ impl RecordBatchStream for MemoryStream { self.schema.clone() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::ColumnStatistics; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use futures::StreamExt; + + fn mock_data() -> Result<(SchemaRef, RecordBatch)> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![None, None, Some(9)])), + Arc::new(Int32Array::from(vec![7, 8, 9])), + ], + )?; + + Ok((schema, batch)) + } + + #[tokio::test] + async fn test_with_projection() -> Result<()> { + let (schema, batch) = mock_data()?; + + let executor = MemoryExec::try_new(&[vec![batch]], schema, Some(vec![2, 1]))?; + let statistics = executor.statistics(); + + assert_eq!(statistics.num_rows, Some(3)); + assert_eq!( + statistics.column_statistics, + Some(vec![ + ColumnStatistics { + null_count: Some(2), + max_value: None, + min_value: None, + distinct_count: None, + }, + ColumnStatistics { + null_count: Some(0), + max_value: None, + min_value: None, + distinct_count: None, + }, + ]) + ); + + // scan with projection + let mut it = executor.execute(0).await?; + let batch2 = it.next().await.unwrap()?; + assert_eq!(2, batch2.schema().fields().len()); + assert_eq!("c", batch2.schema().field(0).name()); + assert_eq!("b", batch2.schema().field(1).name()); + assert_eq!(2, batch2.num_columns()); + + Ok(()) + } + + #[tokio::test] + async fn test_without_projection() -> Result<()> { + let (schema, batch) = mock_data()?; + + let executor = MemoryExec::try_new(&[vec![batch]], schema, None)?; + let statistics = executor.statistics(); + + assert_eq!(statistics.num_rows, Some(3)); + assert_eq!( + statistics.column_statistics, + Some(vec![ + ColumnStatistics { + null_count: Some(0), + max_value: None, + min_value: None, + distinct_count: None, + }, + ColumnStatistics { + null_count: Some(0), + max_value: None, + min_value: None, + distinct_count: None, + }, + ColumnStatistics { + null_count: Some(2), + max_value: None, + min_value: None, + distinct_count: None, + }, + ColumnStatistics { + null_count: Some(0), + max_value: None, + min_value: None, + distinct_count: None, + }, + ]) + ); + + let mut it = executor.execute(0).await?; + let batch1 = it.next().await.unwrap()?; + assert_eq!(4, batch1.schema().fields().len()); + assert_eq!(4, batch1.num_columns()); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index b7a31fa7f25e..af868871abb8 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -89,6 +89,36 @@ impl Stream for EmptyRecordBatchStream { /// Physical planner interface pub use self::planner::PhysicalPlanner; +/// Statistics for a physical plan node +/// Fields are optional and can be inexact because the sources +/// sometimes provide approximate estimates for performance reasons +/// and the transformations output are not always predictable. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct Statistics { + /// The number of table rows + pub num_rows: Option, + /// total byte of the table rows + pub total_byte_size: Option, + /// Statistics on a column level + pub column_statistics: Option>, + /// If true, any field that is `Some(..)` is the actual value in the data provided by the operator (it is not + /// an estimate). Any or all other fields might still be None, in which case no information is known. + /// if false, any field that is `Some(..)` may contain an inexact estimate and may not be the actual value. + pub is_exact: bool, +} +/// This table statistics are estimates about column +#[derive(Clone, Debug, Default, PartialEq)] +pub struct ColumnStatistics { + /// Number of null values on column + pub null_count: Option, + /// Maximum value of column + pub max_value: Option, + /// Minimum value of column + pub min_value: Option, + /// Number of distinct values + pub distinct_count: Option, +} + /// `ExecutionPlan` represent nodes in the DataFusion Physical Plan. /// /// Each `ExecutionPlan` is Partition-aware and is responsible for @@ -150,6 +180,9 @@ pub trait ExecutionPlan: Debug + Send + Sync { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "ExecutionPlan(PlaceHolder)") } + + /// Returns the global output statistics for this `ExecutionPlan` node. + fn statistics(&self) -> Statistics; } /// Return a [wrapper](DisplayableExecutionPlan) around an diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs index eb8f927fc2ad..feed181ca83d 100644 --- a/datafusion/src/physical_plan/parquet.rs +++ b/datafusion/src/physical_plan/parquet.rs @@ -32,6 +32,8 @@ use crate::{ scalar::ScalarValue, }; +use super::Statistics; + use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, @@ -53,7 +55,6 @@ use tokio::{ task, }; -use crate::datasource::datasource::Statistics; use async_trait::async_trait; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; @@ -224,6 +225,7 @@ impl ParquetExec { num_rows: statistics.num_rows, total_byte_size: statistics.total_byte_size, column_statistics: new_column_statistics, + is_exact: statistics.is_exact, }; Self { @@ -252,11 +254,6 @@ impl ParquetExec { pub fn batch_size(&self) -> usize { self.batch_size } - - /// Statistics for the data set (sum of statistics for all partitions) - pub fn statistics(&self) -> &Statistics { - &self.statistics - } } impl ParquetPartition { @@ -390,6 +387,10 @@ impl ExecutionPlan for ParquetExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn statistics(&self) -> Statistics { + self.statistics.clone() + } } fn send_result( diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d4991746f9a0..0ff595817e7c 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -1393,8 +1393,9 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { mod tests { use super::*; use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; - use crate::physical_plan::DisplayFormatType; - use crate::physical_plan::{csv::CsvReadOptions, expressions, Partitioning}; + use crate::physical_plan::{ + csv::CsvReadOptions, expressions, DisplayFormatType, Partitioning, Statistics, + }; use crate::scalar::ScalarValue; use crate::{ logical_plan::{col, lit, sum, LogicalPlanBuilder}, @@ -1815,6 +1816,10 @@ mod tests { } } } + + fn statistics(&self) -> Statistics { + unimplemented!("NoOpExecutionPlan::statistics"); + } } // Produces an execution plan where the schema is mismatched from diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index 5110e5b5a879..97ff83edd2fc 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -27,13 +27,14 @@ use std::task::{Context, Poll}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, }; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::expressions::Column; +use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use async_trait::async_trait; use futures::stream::Stream; @@ -157,6 +158,40 @@ impl ExecutionPlan for ProjectionExec { } } } + + fn statistics(&self) -> Statistics { + stats_projection( + self.input.statistics(), + self.expr.iter().map(|(e, _)| Arc::clone(e)), + ) + } +} + +fn stats_projection( + stats: Statistics, + exprs: impl Iterator>, +) -> Statistics { + let column_statistics = stats.column_statistics.map(|input_col_stats| { + exprs + .map(|e| { + if let Some(col) = e.as_any().downcast_ref::() { + input_col_stats[col.index()].clone() + } else { + // TODO stats: estimate more statistics from expressions + // (expressions should compute their statistics themselves) + ColumnStatistics::default() + } + }) + .collect() + }); + + Statistics { + is_exact: stats.is_exact, + num_rows: stats.num_rows, + column_statistics, + // TODO stats: knowing the type of the new columns we can guess the output size + total_byte_size: None, + } } fn batch_project( @@ -213,7 +248,8 @@ mod tests { use super::*; use crate::physical_plan::csv::{CsvExec, CsvReadOptions}; - use crate::physical_plan::expressions::col; + use crate::physical_plan::expressions::{self, col}; + use crate::scalar::ScalarValue; use crate::test; use futures::future; @@ -258,4 +294,62 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_stats_projection_columns_only() { + let source = Statistics { + is_exact: true, + num_rows: Some(5), + total_byte_size: Some(23), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: Some(3), + }, + ColumnStatistics { + distinct_count: None, + max_value: Some(ScalarValue::Float32(Some(1.1))), + min_value: Some(ScalarValue::Float32(Some(0.1))), + null_count: None, + }, + ]), + }; + + let exprs: Vec> = vec![ + Arc::new(expressions::Column::new("col1", 1)), + Arc::new(expressions::Column::new("col0", 0)), + ]; + + let result = stats_projection(source, exprs.into_iter()); + + let expected = Statistics { + is_exact: true, + num_rows: Some(5), + total_byte_size: None, + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: Some(3), + }, + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(0), + }, + ]), + }; + + assert_eq!(result, expected); + } } diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 8ba9a4f3ad47..56de364cc995 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -25,7 +25,7 @@ use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::hash_utils::create_hashes; -use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; use arrow::record_batch::RecordBatch; use arrow::{array::Array, error::Result as ArrowResult}; use arrow::{compute::take, datatypes::SchemaRef}; @@ -228,6 +228,10 @@ impl ExecutionPlan for RepartitionExec { } } } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } } impl RepartitionExec { diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 5a47931f96e8..b732797c1d26 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -20,7 +20,7 @@ use super::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ @@ -179,6 +179,10 @@ impl ExecutionPlan for SortExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } } fn sort_batch( diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs index 1bcdd63886b6..f63695057a7d 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sort_preserving_merge.rs @@ -43,7 +43,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, + SendableRecordBatchStream, Statistics, }; /// Sort preserving merge execution plan @@ -187,6 +187,10 @@ impl ExecutionPlan for SortPreservingMergeExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } } /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 932bd5c5c0f5..f30cd575f997 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -25,8 +25,11 @@ use std::{any::Any, sync::Arc}; use arrow::datatypes::SchemaRef; -use super::{DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream}; -use crate::error::Result; +use super::{ + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, +}; +use crate::{error::Result, physical_plan::expressions}; use async_trait::async_trait; /// UNION ALL execution plan @@ -106,16 +109,68 @@ impl ExecutionPlan for UnionExec { } } } + + fn statistics(&self) -> Statistics { + self.inputs + .iter() + .map(|ep| ep.statistics()) + .reduce(stats_union) + .unwrap_or_default() + } +} + +fn col_stats_union( + mut left: ColumnStatistics, + right: ColumnStatistics, +) -> ColumnStatistics { + left.distinct_count = None; + left.min_value = left + .min_value + .zip(right.min_value) + .map(|(a, b)| expressions::helpers::min(&a, &b)) + .map(Result::ok) + .flatten(); + left.max_value = left + .max_value + .zip(right.max_value) + .map(|(a, b)| expressions::helpers::max(&a, &b)) + .map(Result::ok) + .flatten(); + left.null_count = left.null_count.zip(right.null_count).map(|(a, b)| a + b); + + left +} + +fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { + left.is_exact = left.is_exact && right.is_exact; + left.num_rows = left.num_rows.zip(right.num_rows).map(|(a, b)| a + b); + left.total_byte_size = left + .total_byte_size + .zip(right.total_byte_size) + .map(|(a, b)| a + b); + left.column_statistics = + left.column_statistics + .zip(right.column_statistics) + .map(|(a, b)| { + a.into_iter() + .zip(b) + .map(|(ca, cb)| col_stats_union(ca, cb)) + .collect() + }); + left } #[cfg(test)] mod tests { use super::*; - use crate::physical_plan::{ - collect, - csv::{CsvExec, CsvReadOptions}, - }; use crate::test; + use crate::{ + physical_plan::{ + collect, + csv::{CsvExec, CsvReadOptions}, + }, + scalar::ScalarValue, + }; use arrow::record_batch::RecordBatch; #[tokio::test] @@ -152,4 +207,88 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_stats_union() { + let left = Statistics { + is_exact: true, + num_rows: Some(5), + total_byte_size: Some(23), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(5), + max_value: Some(ScalarValue::Int64(Some(21))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: Some(1), + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: Some(3), + }, + ColumnStatistics { + distinct_count: None, + max_value: Some(ScalarValue::Float32(Some(1.1))), + min_value: Some(ScalarValue::Float32(Some(0.1))), + null_count: None, + }, + ]), + }; + + let right = Statistics { + is_exact: true, + num_rows: Some(7), + total_byte_size: Some(29), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(3), + max_value: Some(ScalarValue::Int64(Some(34))), + min_value: Some(ScalarValue::Int64(Some(1))), + null_count: Some(1), + }, + ColumnStatistics { + distinct_count: None, + max_value: Some(ScalarValue::Utf8(Some(String::from("c")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("b")))), + null_count: None, + }, + ColumnStatistics { + distinct_count: None, + max_value: None, + min_value: None, + null_count: None, + }, + ]), + }; + + let result = stats_union(left, right); + let expected = Statistics { + is_exact: true, + num_rows: Some(12), + total_byte_size: Some(52), + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: None, + max_value: Some(ScalarValue::Int64(Some(34))), + min_value: Some(ScalarValue::Int64(Some(-4))), + null_count: Some(1), + }, + ColumnStatistics { + distinct_count: None, + max_value: Some(ScalarValue::Utf8(Some(String::from("x")))), + min_value: Some(ScalarValue::Utf8(Some(String::from("a")))), + null_count: None, + }, + ColumnStatistics { + distinct_count: None, + max_value: None, + min_value: None, + null_count: None, + }, + ]), + }; + + assert_eq!(result, expected); + } } diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs b/datafusion/src/physical_plan/windows/window_agg_exec.rs index c7466477ce79..0524adc7073a 100644 --- a/datafusion/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs @@ -19,8 +19,8 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, WindowExpr, + common, ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, + Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; use arrow::{ array::ArrayRef, @@ -162,6 +162,26 @@ impl ExecutionPlan for WindowAggExec { } Ok(()) } + + fn statistics(&self) -> Statistics { + let input_stat = self.input.statistics(); + let win_cols = self.window_expr.len(); + let input_cols = self.input_schema.fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = vec![ColumnStatistics::default(); win_cols]; + if let Some(input_col_stats) = input_stat.column_statistics { + column_statistics.extend(input_col_stats); + } else { + column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + } + Statistics { + is_exact: input_stat.is_exact, + num_rows: input_stat.num_rows, + column_statistics: Some(column_statistics), + // TODO stats: knowing the type of the new columns we can guess the output size + total_byte_size: None, + } + } } fn create_schema( diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index fa1f36c230f9..688cff838be6 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -33,8 +33,8 @@ use arrow::{ use futures::Stream; use crate::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, + common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; use crate::{ error::{DataFusionError, Result}, @@ -203,6 +203,22 @@ impl ExecutionPlan for MockExec { } } } + + // Panics if one of the batches is an error + fn statistics(&self) -> Statistics { + let data: ArrowResult> = self + .data + .iter() + .map(|r| match r { + Ok(batch) => Ok(batch.clone()), + Err(e) => Err(clone_error(e)), + }) + .collect(); + + let data = data.unwrap(); + + common::compute_record_batch_statistics(&[data], &self.schema, None) + } } fn clone_error(e: &ArrowError) -> ArrowError { @@ -306,6 +322,10 @@ impl ExecutionPlan for BarrierExec { } } } + + fn statistics(&self) -> Statistics { + common::compute_record_batch_statistics(&self.data, &self.schema, None) + } } /// A mock execution plan that errors on a call to execute @@ -368,4 +388,87 @@ impl ExecutionPlan for ErrorExec { } } } + + fn statistics(&self) -> Statistics { + Statistics::default() + } +} + +/// A mock execution plan that simply returns the provided statistics +#[derive(Debug, Clone)] +pub struct StatisticsExec { + stats: Statistics, + schema: Arc, +} +impl StatisticsExec { + pub fn new(stats: Statistics, schema: Schema) -> Self { + assert!( + stats + .column_statistics + .as_ref() + .map(|cols| cols.len() == schema.fields().len()) + .unwrap_or(true), + "if defined, the column statistics vector length should be the number of fields" + ); + Self { + stats, + schema: Arc::new(schema), + } + } +} +#[async_trait] +impl ExecutionPlan for StatisticsExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(2) + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(Arc::new(self.clone())) + } else { + Err(DataFusionError::Internal( + "Children cannot be replaced in CustomExecutionPlan".to_owned(), + )) + } + } + + async fn execute(&self, _partition: usize) -> Result { + unimplemented!("This plan only serves for testing statistics") + } + + fn statistics(&self) -> Statistics { + self.stats.clone() + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "StatisticsExec: col_count={}, row_count={:?}", + self.schema.fields().len(), + self.stats.num_rows, + ) + } + } + } } diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index 36adbea1be0e..31551a91d904 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Int32Array; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::{Int32Array, PrimitiveArray, UInt64Array}; +use arrow::compute::kernels::aggregate; +use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use datafusion::{ - datasource::{datasource::Statistics, TableProvider}, - physical_plan::collect, -}; +use datafusion::physical_plan::empty::EmptyExec; +use datafusion::scalar::ScalarValue; +use datafusion::{datasource::TableProvider, physical_plan::collect}; use datafusion::{ error::{DataFusionError, Result}, physical_plan::DisplayFormatType, @@ -34,7 +34,8 @@ use datafusion::logical_plan::{ col, Expr, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE, }; use datafusion::physical_plan::{ - ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, + ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; use futures::stream::Stream; @@ -145,6 +146,40 @@ impl ExecutionPlan for CustomExecutionPlan { } } } + + fn statistics(&self) -> Statistics { + let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); + Statistics { + is_exact: true, + num_rows: Some(batch.num_rows()), + total_byte_size: None, + column_statistics: Some( + self.projection + .clone() + .unwrap_or_else(|| (0..batch.columns().len()).collect()) + .iter() + .map(|i| ColumnStatistics { + null_count: Some(batch.column(*i).null_count()), + min_value: Some(ScalarValue::Int32(aggregate::min( + batch + .column(*i) + .as_any() + .downcast_ref::>() + .unwrap(), + ))), + max_value: Some(ScalarValue::Int32(aggregate::max( + batch + .column(*i) + .as_any() + .downcast_ref::>() + .unwrap(), + ))), + ..Default::default() + }) + .collect(), + ), + } + } } impl TableProvider for CustomTableProvider { @@ -167,10 +202,6 @@ impl TableProvider for CustomTableProvider { projection: projection.clone(), })) } - - fn statistics(&self) -> Statistics { - Statistics::default() - } } #[tokio::test] @@ -218,3 +249,52 @@ async fn custom_source_dataframe() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn optimizers_catch_all_statistics() { + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Arc::new(CustomTableProvider)) + .unwrap(); + + let df = ctx + .sql("SELECT count(*), min(c1), max(c1) from test") + .unwrap(); + + let physical_plan = ctx.create_physical_plan(&df.to_logical_plan()).unwrap(); + + // when the optimization kicks in, the source is replaced by an EmptyExec + assert!( + contains_empty_exec(Arc::clone(&physical_plan)), + "Expected aggregate_statistics optimizations missing: {:?}", + physical_plan + ); + + let expected = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("COUNT(UInt8(1))", DataType::UInt64, false), + Field::new("MIN(test.c1)", DataType::Int32, false), + Field::new("MAX(test.c1)", DataType::Int32, false), + ])), + vec![ + Arc::new(UInt64Array::from(vec![4])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![100])), + ], + ) + .unwrap(); + + let actual = collect(physical_plan).await.unwrap(); + + assert_eq!(actual.len(), 1); + assert_eq!(format!("{:?}", actual[0]), format!("{:?}", expected)); +} + +fn contains_empty_exec(plan: Arc) -> bool { + if plan.as_any().is::() { + true + } else if plan.children().len() != 1 { + false + } else { + contains_empty_exec(Arc::clone(&plan.children()[0])) + } +} diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index 07b0eb2bb2ce..e0102c4f1bcc 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -19,15 +19,13 @@ use arrow::array::{as_primitive_array, Int32Builder, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion::datasource::datasource::{ - Statistics, TableProvider, TableProviderFilterPushDown, -}; +use datafusion::datasource::datasource::{TableProvider, TableProviderFilterPushDown}; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; @@ -98,6 +96,12 @@ impl ExecutionPlan for CustomPlan { } } } + + fn statistics(&self) -> Statistics { + // here we could provide more accurate statistics + // but we want to test the filter pushdown not the CBOs + Statistics::default() + } } #[derive(Clone)] @@ -145,10 +149,6 @@ impl TableProvider for CustomProvider { } } - fn statistics(&self) -> Statistics { - Statistics::default() - } - fn supports_filter_pushdown(&self, _: &Expr) -> Result { Ok(TableProviderFilterPushDown::Exact) } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index ccff292e0d07..eaf988fba0cf 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1896,6 +1896,28 @@ async fn left_join() -> Result<()> { Ok(()) } +#[tokio::test] +async fn left_join_unbalanced() -> Result<()> { + // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in + let mut ctx = create_join_context_unbalanced("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + let expected = vec![ + vec!["11", "a", "z"], + vec!["22", "b", "y"], + vec!["33", "c", "NULL"], + vec!["44", "d", "x"], + vec!["77", "e", "NULL"], + ]; + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } + Ok(()) +} + #[tokio::test] async fn right_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; @@ -2069,6 +2091,43 @@ async fn cross_join() { assert_eq!(4 * 4 * 2, actual.len()); } +#[tokio::test] +async fn cross_join_unbalanced() { + // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in + let mut ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); + + // the order of the values is not determinisitic, so we need to sort to check the values + let sql = + "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name"; + let actual = execute(&mut ctx, sql).await; + + assert_eq!( + actual, + [ + ["11", "a", "z"], + ["11", "a", "y"], + ["11", "a", "x"], + ["11", "a", "w"], + ["22", "b", "z"], + ["22", "b", "y"], + ["22", "b", "x"], + ["22", "b", "w"], + ["33", "c", "z"], + ["33", "c", "y"], + ["33", "c", "x"], + ["33", "c", "w"], + ["44", "d", "z"], + ["44", "d", "y"], + ["44", "d", "x"], + ["44", "d", "w"], + ["77", "e", "z"], + ["77", "e", "y"], + ["77", "e", "x"], + ["77", "e", "w"] + ] + ); +} + fn create_join_context( column_left: &str, column_right: &str, @@ -2154,6 +2213,55 @@ fn create_join_context_qualified() -> Result { Ok(ctx) } +/// the table column_left has more rows than the table column_right +fn create_join_context_unbalanced( + column_left: &str, + column_right: &str, +) -> Result { + let mut ctx = ExecutionContext::new(); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new(column_left, DataType::UInt32, true), + Field::new("t1_name", DataType::Utf8, true), + ])); + let t1_data = RecordBatch::try_new( + t1_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 33, 44, 77])), + Arc::new(StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ])), + ], + )?; + let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?; + ctx.register_table("t1", Arc::new(t1_table))?; + + let t2_schema = Arc::new(Schema::new(vec![ + Field::new(column_right, DataType::UInt32, true), + Field::new("t2_name", DataType::Utf8, true), + ])); + let t2_data = RecordBatch::try_new( + t2_schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![11, 22, 44, 55])), + Arc::new(StringArray::from(vec![ + Some("z"), + Some("y"), + Some("x"), + Some("w"), + ])), + ], + )?; + let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?; + ctx.register_table("t2", Arc::new(t2_table))?; + + Ok(ctx) +} + #[tokio::test] async fn csv_explain() { // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs new file mode 100644 index 000000000000..a2375ad282c0 --- /dev/null +++ b/datafusion/tests/statistics.rs @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains end to end tests of statistics propagation + +use std::{any::Any, sync::Arc}; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::{ + datasource::TableProvider, + error::{DataFusionError, Result}, + logical_plan::Expr, + physical_plan::{ + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, + }, + prelude::ExecutionContext, + scalar::ScalarValue, +}; + +use async_trait::async_trait; + +/// This is a testing structure for statistics +/// It will act both as a table provider and execution plan +#[derive(Debug, Clone)] +struct StatisticsValidation { + stats: Statistics, + schema: Arc, +} + +impl StatisticsValidation { + fn new(stats: Statistics, schema: Schema) -> Self { + assert!( + stats + .column_statistics + .as_ref() + .map(|cols| cols.len() == schema.fields().len()) + .unwrap_or(true), + "if defined, the column statistics vector length should be the number of fields" + ); + Self { + stats, + schema: Arc::new(schema), + } + } +} + +impl TableProvider for StatisticsValidation { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn scan( + &self, + projection: &Option>, + _batch_size: usize, + filters: &[Expr], + // limit is ignored because it is not mandatory for a `TableProvider` to honor it + _limit: Option, + ) -> Result> { + // Filters should not be pushed down as they are marked as unsupported by default. + assert_eq!( + 0, + filters.len(), + "Unsupported expressions should not be pushed down" + ); + let projection = match projection.clone() { + Some(p) => p, + None => (0..self.schema.fields().len()).collect(), + }; + let projected_schema = Schema::new( + projection + .iter() + .map(|i| self.schema.field(*i).clone()) + .collect(), + ); + + let current_stat = self.stats.clone(); + + let proj_col_stats = current_stat + .column_statistics + .map(|col_stat| projection.iter().map(|i| col_stat[*i].clone()).collect()); + + Ok(Arc::new(Self::new( + Statistics { + is_exact: current_stat.is_exact, + num_rows: current_stat.num_rows, + column_statistics: proj_col_stats, + // TODO stats: knowing the type of the new columns we can guess the output size + total_byte_size: None, + }, + projected_schema, + ))) + } +} + +#[async_trait] +impl ExecutionPlan for StatisticsValidation { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(2) + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + if children.is_empty() { + Ok(Arc::new(self.clone())) + } else { + Err(DataFusionError::Internal( + "Children cannot be replaced in CustomExecutionPlan".to_owned(), + )) + } + } + + async fn execute(&self, _partition: usize) -> Result { + unimplemented!("This plan only serves for testing statistics") + } + + fn statistics(&self) -> Statistics { + self.stats.clone() + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "StatisticsValidation: col_count={}, row_count={:?}", + self.schema.fields().len(), + self.stats.num_rows, + ) + } + } + } +} + +fn init_ctx(stats: Statistics, schema: Schema) -> Result { + let mut ctx = ExecutionContext::new(); + let provider: Arc = + Arc::new(StatisticsValidation::new(stats, schema)); + ctx.register_table("stats_table", provider)?; + Ok(ctx) +} + +fn fully_defined() -> (Statistics, Schema) { + ( + Statistics { + num_rows: Some(13), + is_exact: true, + total_byte_size: None, // ignore byte size for now + column_statistics: Some(vec![ + ColumnStatistics { + distinct_count: Some(2), + max_value: Some(ScalarValue::Int32(Some(1023))), + min_value: Some(ScalarValue::Int32(Some(-24))), + null_count: Some(0), + }, + ColumnStatistics { + distinct_count: Some(13), + max_value: Some(ScalarValue::Int64(Some(5486))), + min_value: Some(ScalarValue::Int64(Some(-6783))), + null_count: Some(5), + }, + ]), + }, + Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int64, false), + ]), + ) +} + +#[tokio::test] +async fn sql_basic() -> Result<()> { + let (stats, schema) = fully_defined(); + let mut ctx = init_ctx(stats.clone(), schema)?; + + let df = ctx.sql("SELECT * from stats_table").unwrap(); + + let physical_plan = ctx.create_physical_plan(&df.to_logical_plan()).unwrap(); + + // the statistics should be those of the source + assert_eq!(stats, physical_plan.statistics()); + + Ok(()) +} + +#[tokio::test] +async fn sql_filter() -> Result<()> { + let (stats, schema) = fully_defined(); + let mut ctx = init_ctx(stats, schema)?; + + let df = ctx.sql("SELECT * FROM stats_table WHERE c1 = 5").unwrap(); + + let physical_plan = ctx.create_physical_plan(&df.to_logical_plan()).unwrap(); + + // with a filtering condition we loose all knowledge about the statistics + assert_eq!(Statistics::default(), physical_plan.statistics()); + + Ok(()) +} + +#[tokio::test] +async fn sql_limit() -> Result<()> { + let (stats, schema) = fully_defined(); + let mut ctx = init_ctx(stats.clone(), schema)?; + + let df = ctx.sql("SELECT * FROM stats_table LIMIT 5").unwrap(); + let physical_plan = ctx.create_physical_plan(&df.to_logical_plan()).unwrap(); + // when the limit is smaller than the original number of lines + // we loose all statistics except the for number of rows which becomes the limit + assert_eq!( + Statistics { + num_rows: Some(5), + is_exact: true, + ..Default::default() + }, + physical_plan.statistics() + ); + + let df = ctx.sql("SELECT * FROM stats_table LIMIT 100").unwrap(); + let physical_plan = ctx.create_physical_plan(&df.to_logical_plan()).unwrap(); + // when the limit is larger than the original number of lines, statistics remain unchanged + assert_eq!(stats, physical_plan.statistics()); + + Ok(()) +} + +#[tokio::test] +async fn sql_window() -> Result<()> { + let (stats, schema) = fully_defined(); + let mut ctx = init_ctx(stats.clone(), schema)?; + + let df = ctx + .sql("SELECT c2, sum(c1) over (partition by c2) FROM stats_table") + .unwrap(); + + let physical_plan = ctx.create_physical_plan(&df.to_logical_plan()).unwrap(); + + let result = physical_plan.statistics(); + + assert_eq!(stats.num_rows, result.num_rows); + assert!(result.column_statistics.is_some()); + let col_stats = result.column_statistics.unwrap(); + assert_eq!(2, col_stats.len()); + assert_eq!(stats.column_statistics.unwrap()[1], col_stats[0]); + + Ok(()) +} diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index dfcdcf55221b..1a0bbe0174a0 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -76,7 +76,7 @@ use datafusion::{ physical_plan::{ planner::{DefaultPhysicalPlanner, ExtensionPlanner}, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalPlanner, - RecordBatchStream, SendableRecordBatchStream, + RecordBatchStream, SendableRecordBatchStream, Statistics, }, prelude::{ExecutionConfig, ExecutionContext}, }; @@ -163,9 +163,9 @@ async fn topk_plan() -> Result<()> { let mut ctx = setup_table(make_topk_context()).await?; let expected = vec![ - "| logical_plan after topk | TopK: k=3 |", - "| | Projection: #sales.customer_id, #sales.revenue |", - "| | TableScan: sales projection=Some([0, 1]) |", + "| logical_plan after topk | TopK: k=3 |", + "| | Projection: #sales.customer_id, #sales.revenue |", + "| | TableScan: sales projection=Some([0, 1]) |", ].join("\n"); let explain_query = format!("EXPLAIN VERBOSE {}", QUERY); @@ -423,6 +423,12 @@ impl ExecutionPlan for TopKExec { } } } + + fn statistics(&self) -> Statistics { + // to improve the optimizability of this plan + // better statistics inference could be provided + Statistics::default() + } } // A very specialized TopK implementation