diff --git a/Cargo.toml b/Cargo.toml index 98e665b3..064da6cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "fnck_sql" -version = "0.0.8" +version = "0.0.9" edition = "2021" authors = ["Kould ", "Xwg "] description = "SQL as a Function for Rust" @@ -41,6 +41,7 @@ chrono = { version = "0.4" } comfy-table = { version = "7" } csv = { version = "1" } dirs = { version = "5" } +fixedbitset = { version = "0.4" } itertools = { version = "0.12" } ordered-float = { version = "4" } paste = { version = "1" } diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 93a9e967..db84d16a 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,7 +1,6 @@ use crate::catalog::TableName; use crate::errors::DatabaseError; use crate::expression::ScalarExpression; -use crate::types::tuple::EMPTY_TUPLE; use crate::types::value::DataValue; use crate::types::{ColumnId, LogicalType}; use fnck_sql_serde_macros::ReferenceSerialization; @@ -170,7 +169,7 @@ impl ColumnCatalog { self.desc .default .as_ref() - .map(|expr| expr.eval(&EMPTY_TUPLE, &[])) + .map(|expr| expr.eval(None)) .transpose() } diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index ba0ae31b..c028aead 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -96,7 +96,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Update { } for (i, column) in input_schema.iter().enumerate() { if let Some(expr) = exprs_map.get(&column.id()) { - tuple.values[i] = throw!(expr.eval(&tuple, &input_schema)); + tuple.values[i] = throw!(expr.eval(Some((&tuple, &input_schema)))); } } tuple.clear_id(); diff --git a/src/execution/dql/aggregate/hash_agg.rs b/src/execution/dql/aggregate/hash_agg.rs index 39e2abe0..9debaee2 100644 --- a/src/execution/dql/aggregate/hash_agg.rs +++ b/src/execution/dql/aggregate/hash_agg.rs @@ -69,14 +69,14 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashAggExecutor { if args.len() > 1 { throw!(Err(DatabaseError::UnsupportedStmt("currently aggregate functions only support a single Column as a parameter".to_string()))) } - values.push(throw!(args[0].eval(&tuple, &schema_ref))); + values.push(throw!(args[0].eval(Some((&tuple, &schema_ref))))); } else { unreachable!() } } let group_keys: Vec = throw!(groupby_exprs .iter() - .map(|expr| expr.eval(&tuple, &schema_ref)) + .map(|expr| expr.eval(Some((&tuple, &schema_ref)))) .try_collect()); let entry = match group_hash_accs.entry(group_keys) { diff --git a/src/execution/dql/aggregate/simple_agg.rs b/src/execution/dql/aggregate/simple_agg.rs index 2fb13dcd..d6063911 100644 --- a/src/execution/dql/aggregate/simple_agg.rs +++ b/src/execution/dql/aggregate/simple_agg.rs @@ -50,7 +50,8 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for SimpleAggExecutor { let values: Vec = throw!(agg_calls .iter() .map(|expr| match expr { - ScalarExpression::AggCall { args, .. } => args[0].eval(&tuple, &schema), + ScalarExpression::AggCall { args, .. } => + args[0].eval(Some((&tuple, &schema))), _ => unreachable!(), }) .try_collect()); diff --git a/src/execution/dql/filter.rs b/src/execution/dql/filter.rs index 57d3cc65..21ce815a 100644 --- a/src/execution/dql/filter.rs +++ b/src/execution/dql/filter.rs @@ -40,7 +40,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for Filter { while let CoroutineState::Yielded(tuple) = Pin::new(&mut coroutine).resume(()) { let tuple = throw!(tuple); - if throw!(throw!(predicate.eval(&tuple, &schema)).is_true()) { + if throw!(throw!(predicate.eval(Some((&tuple, &schema)))).is_true()) { yield Ok(tuple); } } diff --git a/src/execution/dql/join/hash_join.rs b/src/execution/dql/join/hash_join.rs index 510b99cf..b943e7b1 100644 --- a/src/execution/dql/join/hash_join.rs +++ b/src/execution/dql/join/hash_join.rs @@ -9,8 +9,8 @@ use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::throw; use crate::types::tuple::{Schema, Tuple}; use crate::types::value::{DataValue, NULL_VALUE}; -use crate::utils::bit_vector::BitVector; use ahash::{HashMap, HashMapExt}; +use fixedbitset::FixedBitSet; use itertools::Itertools; use std::ops::Coroutine; use std::ops::CoroutineState; @@ -49,7 +49,7 @@ impl HashJoin { let mut values = Vec::with_capacity(on_keys.len()); for expr in on_keys { - values.push(expr.eval(tuple, schema)?); + values.push(expr.eval(Some((tuple, schema)))?); } Ok(values) } @@ -62,7 +62,7 @@ impl HashJoin { left_schema_len: usize, ) -> Result, DatabaseError> { if let (Some(expr), false) = (filter, matches!(join_ty, JoinType::Full | JoinType::Cross)) { - match &expr.eval(&tuple, schema)? { + match &expr.eval(Some((&tuple, schema)))? { DataValue::Boolean(Some(false) | None) => { let full_schema_len = schema.len(); @@ -193,7 +193,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashJoin { if *is_filtered { continue; } else { - bits_option = Some(BitVector::new(tuples.len())); + bits_option = Some(FixedBitSet::with_capacity(tuples.len())); } } JoinType::LeftAnti => continue, @@ -214,7 +214,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashJoin { left_schema_len )) { if let Some(bits) = bits_option.as_mut() { - bits.set_bit(i, true); + bits.insert(i); } else { yield Ok(tuple); } @@ -223,7 +223,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for HashJoin { if let Some(bits) = bits_option { let mut cnt = 0; tuples.retain(|_| { - let res = bits.get_bit(cnt); + let res = bits.contains(cnt); cnt += 1; res }); diff --git a/src/execution/dql/join/nested_loop_join.rs b/src/execution/dql/join/nested_loop_join.rs index c84b44e8..7df5df1e 100644 --- a/src/execution/dql/join/nested_loop_join.rs +++ b/src/execution/dql/join/nested_loop_join.rs @@ -13,7 +13,7 @@ use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::throw; use crate::types::tuple::{Schema, SchemaRef, Tuple}; use crate::types::value::{DataValue, NULL_VALUE}; -use crate::utils::bit_vector::BitVector; +use fixedbitset::FixedBitSet; use itertools::Itertools; use std::ops::Coroutine; use std::ops::CoroutineState; @@ -146,7 +146,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { let right_schema_len = eq_cond.right_schema.len(); let mut left_coroutine = build_read(left_input, cache, transaction); - let mut bitmap: Option = None; + let mut bitmap: Option = None; let mut first_matches = Vec::new(); while let CoroutineState::Yielded(left_tuple) = @@ -177,7 +177,8 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { } (Some(filter), true) => { let new_tuple = Self::merge_tuple(&left_tuple, &right_tuple, &ty); - let value = throw!(filter.eval(&new_tuple, &output_schema_ref)); + let value = + throw!(filter.eval(Some((&new_tuple, &output_schema_ref)))); match &value { DataValue::Boolean(Some(true)) => { let tuple = match ty { @@ -215,7 +216,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { break; } if let Some(bits) = bitmap.as_mut() { - bits.set_bit(right_idx, true); + bits.insert(right_idx); } else if matches!(ty, JoinType::Full) { first_matches.push(right_idx); } @@ -227,7 +228,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { } if matches!(self.ty, JoinType::Full) && bitmap.is_none() { - bitmap = Some(BitVector::new(right_idx)); + bitmap = Some(FixedBitSet::with_capacity(right_idx)); } // handle no matched tuple case @@ -256,7 +257,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { if matches!(ty, JoinType::Full) { for idx in first_matches.into_iter() { - bitmap.as_mut().unwrap().set_bit(idx, true); + bitmap.as_mut().unwrap().insert(idx); } let mut right_coroutine = build_read(right_input.clone(), cache, transaction); @@ -264,7 +265,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for NestedLoopJoin { while let CoroutineState::Yielded(right_tuple) = Pin::new(&mut right_coroutine).resume(()) { - if !bitmap.as_ref().unwrap().get_bit(idx) { + if !bitmap.as_ref().unwrap().contains(idx) { let mut right_tuple: Tuple = throw!(right_tuple); let mut values = vec![NULL_VALUE.clone(); right_schema_len]; values.append(&mut right_tuple.values); diff --git a/src/execution/dql/projection.rs b/src/execution/dql/projection.rs index 9d3bf75e..3584912f 100644 --- a/src/execution/dql/projection.rs +++ b/src/execution/dql/projection.rs @@ -58,7 +58,7 @@ impl Projection { let mut values = Vec::with_capacity(exprs.len()); for expr in exprs.iter() { - values.push(expr.eval(tuple, schmea)?); + values.push(expr.eval(Some((tuple, schmea)))?); } Ok(values) } diff --git a/src/execution/dql/sort.rs b/src/execution/dql/sort.rs index 84a7ce17..8bc40fd1 100644 --- a/src/execution/dql/sort.rs +++ b/src/execution/dql/sort.rs @@ -133,7 +133,8 @@ impl SortBy { let mut key = BumpBytes::new_in(arena); let tuple = tuple.as_ref().map(|(_, tuple)| tuple).unwrap(); - expr.eval(tuple, schema)?.memcomparable_encode(&mut key)?; + expr.eval(Some((tuple, schema)))? + .memcomparable_encode(&mut key)?; if !asc { for byte in key.iter_mut() { *byte ^= 0xFF; @@ -169,7 +170,7 @@ impl SortBy { debug_assert!(tuple.is_some()); let (_, tuple) = tuple.as_ref().unwrap(); - eval_values[x].push(expr.eval(tuple, schema)?); + eval_values[x].push(expr.eval(Some((tuple, schema)))?); } } diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 747a6edb..bcf77587 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -11,17 +11,10 @@ use regex::Regex; use sqlparser::ast::{CharLengthUnits, TrimWhereField}; use std::cmp; use std::cmp::Ordering; -use std::sync::LazyLock; - -static NULL_VALUE: LazyLock = LazyLock::new(|| DataValue::Null); macro_rules! eval_to_num { - ($num_expr:expr, $tuple:expr, $schema:expr) => { - if let Some(num_i32) = $num_expr - .eval($tuple, $schema)? - .cast(&LogicalType::Integer)? - .i32() - { + ($num_expr:expr, $tuple:expr) => { + if let Some(num_i32) = $num_expr.eval($tuple)?.cast(&LogicalType::Integer)?.i32() { num_i32 } else { return Ok(DataValue::Utf8 { @@ -34,7 +27,7 @@ macro_rules! eval_to_num { } impl ScalarExpression { - pub fn eval(&self, tuple: &Tuple, schema: &[ColumnRef]) -> Result { + pub fn eval(&self, tuple: Option<(&Tuple, &[ColumnRef])>) -> Result { let check_cast = |value: DataValue, return_type: &LogicalType| { if value.logical_type() != *return_type { return value.cast(return_type); @@ -45,16 +38,21 @@ impl ScalarExpression { match self { ScalarExpression::Constant(val) => Ok(val.clone()), ScalarExpression::ColumnRef(col) => { + let Some((tuple, schema)) = tuple else { + return Ok(DataValue::Null); + }; let value = schema .iter() .find_position(|tul_col| tul_col.summary() == col.summary()) - .map(|(i, _)| &tuple.values[i]) - .unwrap_or(&NULL_VALUE) - .clone(); + .map(|(i, _)| tuple.values[i].clone()) + .unwrap_or(DataValue::Null); Ok(value) } ScalarExpression::Alias { expr, alias } => { + let Some((tuple, schema)) = tuple else { + return Ok(DataValue::Null); + }; if let Some(value) = schema .iter() .find_position(|tul_col| match alias { @@ -65,24 +63,22 @@ impl ScalarExpression { alias_expr.output_column().summary() == tul_col.summary() } }) - .map(|(i, _)| &tuple.values[i]) + .map(|(i, _)| tuple.values[i].clone()) { return Ok(value.clone()); } - expr.eval(tuple, schema) - } - ScalarExpression::TypeCast { expr, ty, .. } => { - Ok(expr.eval(tuple, schema)?.cast(ty)?) + expr.eval(Some((tuple, schema))) } + ScalarExpression::TypeCast { expr, ty, .. } => Ok(expr.eval(tuple)?.cast(ty)?), ScalarExpression::Binary { left_expr, right_expr, evaluator, .. } => { - let left = left_expr.eval(tuple, schema)?; - let right = right_expr.eval(tuple, schema)?; + let left = left_expr.eval(tuple)?; + let right = right_expr.eval(tuple)?; Ok(evaluator .as_ref() @@ -91,7 +87,7 @@ impl ScalarExpression { .binary_eval(&left, &right)) } ScalarExpression::IsNull { expr, negated } => { - let mut is_null = expr.eval(tuple, schema)?.is_null(); + let mut is_null = expr.eval(tuple)?.is_null(); if *negated { is_null = !is_null; } @@ -102,13 +98,13 @@ impl ScalarExpression { args, negated, } => { - let value = expr.eval(tuple, schema)?; + let value = expr.eval(tuple)?; if value.is_null() { return Ok(DataValue::Boolean(None)); } let mut is_in = false; for arg in args { - let arg_value = arg.eval(tuple, schema)?; + let arg_value = arg.eval(tuple)?; if arg_value.is_null() { return Ok(DataValue::Boolean(None)); @@ -126,7 +122,7 @@ impl ScalarExpression { ScalarExpression::Unary { expr, evaluator, .. } => { - let value = expr.eval(tuple, schema)?; + let value = expr.eval(tuple)?; Ok(evaluator .as_ref() @@ -143,9 +139,9 @@ impl ScalarExpression { right_expr, negated, } => { - let value = expr.eval(tuple, schema)?; - let left = left_expr.eval(tuple, schema)?; - let right = right_expr.eval(tuple, schema)?; + let value = expr.eval(tuple)?; + let left = left_expr.eval(tuple)?; + let right = right_expr.eval(tuple)?; let mut is_between = match ( value.partial_cmp(&left).map(Ordering::is_ge), @@ -166,12 +162,12 @@ impl ScalarExpression { from_expr, } => { if let Some(mut string) = expr - .eval(tuple, schema)? + .eval(tuple)? .cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))? .utf8() { if let Some(from_expr) = from_expr { - let mut from = eval_to_num!(from_expr, tuple, schema).saturating_sub(1); + let mut from = eval_to_num!(from_expr, tuple).saturating_sub(1); let len_i = string.len() as i32; while from < 0 { @@ -187,8 +183,7 @@ impl ScalarExpression { string = string.split_off(from as usize); } if let Some(for_expr) = for_expr { - let for_i = - cmp::min(eval_to_num!(for_expr, tuple, schema) as usize, string.len()); + let for_i = cmp::min(eval_to_num!(for_expr, tuple) as usize, string.len()); let _ = string.split_off(for_i); } @@ -208,7 +203,7 @@ impl ScalarExpression { ScalarExpression::Position { expr, in_expr } => { let unpack = |expr: &ScalarExpression| -> Result { Ok(expr - .eval(tuple, schema)? + .eval(tuple)? .cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))? .utf8() .unwrap_or("".to_owned())) @@ -226,14 +221,14 @@ impl ScalarExpression { } => { let mut value = None; if let Some(string) = expr - .eval(tuple, schema)? + .eval(tuple)? .cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))? .utf8() { let mut trim_what = String::from(" "); if let Some(trim_what_expr) = trim_what_expr { trim_what = trim_what_expr - .eval(tuple, schema)? + .eval(tuple)? .cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))? .utf8() .unwrap_or_default(); @@ -263,23 +258,24 @@ impl ScalarExpression { unit: CharLengthUnits::Characters, }) } - ScalarExpression::Reference { pos, .. } => Ok(tuple - .values - .get(*pos) - .unwrap_or_else(|| &NULL_VALUE) - .clone()), + ScalarExpression::Reference { pos, .. } => { + let Some((tuple, _)) = tuple else { + return Ok(DataValue::Null); + }; + Ok(tuple.values.get(*pos).cloned().unwrap_or(DataValue::Null)) + } ScalarExpression::Tuple(exprs) => { let mut values = Vec::with_capacity(exprs.len()); for expr in exprs { - values.push(expr.eval(tuple, schema)?); + values.push(expr.eval(tuple)?); } Ok(DataValue::Tuple( (!values.is_empty()).then_some((values, false)), )) } ScalarExpression::ScalaFunction(ScalarFunction { inner, args, .. }) => { - inner.eval(args, tuple, schema)?.cast(inner.return_type()) + inner.eval(args, tuple)?.cast(inner.return_type()) } ScalarExpression::Empty => unreachable!(), ScalarExpression::If { @@ -288,10 +284,10 @@ impl ScalarExpression { right_expr, ty, } => { - if condition.eval(tuple, schema)?.is_true()? { - check_cast(left_expr.eval(tuple, schema)?, ty) + if condition.eval(tuple)?.is_true()? { + check_cast(left_expr.eval(tuple)?, ty) } else { - check_cast(right_expr.eval(tuple, schema)?, ty) + check_cast(right_expr.eval(tuple)?, ty) } } ScalarExpression::IfNull { @@ -299,10 +295,10 @@ impl ScalarExpression { right_expr, ty, } => { - let mut value = left_expr.eval(tuple, schema)?; + let mut value = left_expr.eval(tuple)?; if value.is_null() { - value = right_expr.eval(tuple, schema)?; + value = right_expr.eval(tuple)?; } check_cast(value, ty) } @@ -311,10 +307,10 @@ impl ScalarExpression { right_expr, ty, } => { - let mut value = left_expr.eval(tuple, schema)?; + let mut value = left_expr.eval(tuple)?; - if right_expr.eval(tuple, schema)? == value { - value = NULL_VALUE.clone(); + if right_expr.eval(tuple)? == value { + value = DataValue::Null; } check_cast(value, ty) } @@ -322,14 +318,14 @@ impl ScalarExpression { let mut value = None; for expr in exprs { - let temp = expr.eval(tuple, schema)?; + let temp = expr.eval(tuple)?; if !temp.is_null() { value = Some(temp); break; } } - check_cast(value.unwrap_or_else(|| NULL_VALUE.clone()), ty) + check_cast(value.unwrap_or(DataValue::Null), ty) } ScalarExpression::CaseWhen { operand_expr, @@ -341,10 +337,10 @@ impl ScalarExpression { let mut result = None; if let Some(expr) = operand_expr { - operand_value = Some(expr.eval(tuple, schema)?); + operand_value = Some(expr.eval(tuple)?); } for (when_expr, result_expr) in expr_pairs { - let mut when_value = when_expr.eval(tuple, schema)?; + let mut when_value = when_expr.eval(tuple)?; let is_true = if let Some(operand_value) = &operand_value { let ty = operand_value.logical_type(); let evaluator = @@ -361,16 +357,16 @@ impl ScalarExpression { when_value.is_true()? }; if is_true { - result = Some(result_expr.eval(tuple, schema)?); + result = Some(result_expr.eval(tuple)?); break; } } if result.is_none() { if let Some(expr) = else_expr { - result = Some(expr.eval(tuple, schema)?); + result = Some(expr.eval(tuple)?); } } - check_cast(result.unwrap_or_else(|| NULL_VALUE.clone()), ty) + check_cast(result.unwrap_or(DataValue::Null), ty) } ScalarExpression::TableFunction(_) => unreachable!(), } diff --git a/src/expression/function/scala.rs b/src/expression/function/scala.rs index 60409b69..f351ad24 100644 --- a/src/expression/function/scala.rs +++ b/src/expression/function/scala.rs @@ -54,8 +54,7 @@ pub trait ScalarFunctionImpl: Debug + Send + Sync { fn eval( &self, args: &[ScalarExpression], - tuple: &Tuple, - schema: &[ColumnRef], + tuple: Option<(&Tuple, &[ColumnRef])>, ) -> Result; // TODO: Exploiting monotonicity when optimizing `ScalarFunctionImpl::monotonicity()` diff --git a/src/function/char_length.rs b/src/function/char_length.rs index d4eeb834..dd5b7693 100644 --- a/src/function/char_length.rs +++ b/src/function/char_length.rs @@ -35,11 +35,9 @@ impl ScalarFunctionImpl for CharLength { fn eval( &self, exprs: &[ScalarExpression], - tuples: &Tuple, - columns: &[ColumnRef], + tuples: Option<(&Tuple, &[ColumnRef])>, ) -> Result { - let value = exprs[0].eval(tuples, columns)?; - let mut value = DataValue::clone(&value); + let mut value = exprs[0].eval(tuples)?; if !matches!(value.logical_type(), LogicalType::Varchar(_, _)) { value = value.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?; } diff --git a/src/function/current_date.rs b/src/function/current_date.rs index 7e7a4d85..f1519a93 100644 --- a/src/function/current_date.rs +++ b/src/function/current_date.rs @@ -37,8 +37,7 @@ impl ScalarFunctionImpl for CurrentDate { fn eval( &self, _: &[ScalarExpression], - _: &Tuple, - _: &[ColumnRef], + _: Option<(&Tuple, &[ColumnRef])>, ) -> Result { Ok(DataValue::Date32(Some(Local::now().num_days_from_ce()))) } diff --git a/src/function/lower.rs b/src/function/lower.rs index 655c7506..57aa83fc 100644 --- a/src/function/lower.rs +++ b/src/function/lower.rs @@ -37,11 +37,9 @@ impl ScalarFunctionImpl for Lower { fn eval( &self, exprs: &[ScalarExpression], - tuples: &Tuple, - columns: &[ColumnRef], + tuples: Option<(&Tuple, &[ColumnRef])>, ) -> Result { - let value = exprs[0].eval(tuples, columns)?; - let mut value = DataValue::clone(&value); + let mut value = exprs[0].eval(tuples)?; if !matches!(value.logical_type(), LogicalType::Varchar(_, _)) { value = value.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?; } diff --git a/src/function/numbers.rs b/src/function/numbers.rs index adbda58f..d5e7cfd5 100644 --- a/src/function/numbers.rs +++ b/src/function/numbers.rs @@ -5,8 +5,8 @@ use crate::errors::DatabaseError; use crate::expression::function::table::TableFunctionImpl; use crate::expression::function::FunctionSummary; use crate::expression::ScalarExpression; +use crate::types::tuple::SchemaRef; use crate::types::tuple::Tuple; -use crate::types::tuple::{SchemaRef, EMPTY_TUPLE}; use crate::types::value::DataValue; use crate::types::LogicalType; use serde::Deserialize; @@ -52,7 +52,7 @@ impl TableFunctionImpl for Numbers { &self, args: &[ScalarExpression], ) -> Result>>, DatabaseError> { - let mut value = args[0].eval(&EMPTY_TUPLE, &[])?; + let mut value = args[0].eval(None)?; if value.logical_type() != LogicalType::Integer { value = value.cast(&LogicalType::Integer)?; diff --git a/src/function/upper.rs b/src/function/upper.rs index 531cc9b0..29a3e9b5 100644 --- a/src/function/upper.rs +++ b/src/function/upper.rs @@ -37,11 +37,9 @@ impl ScalarFunctionImpl for Upper { fn eval( &self, exprs: &[ScalarExpression], - tuples: &Tuple, - columns: &[ColumnRef], + tuples: Option<(&Tuple, &[ColumnRef])>, ) -> Result { - let value = exprs[0].eval(tuples, columns)?; - let mut value = DataValue::clone(&value); + let mut value = exprs[0].eval(tuples)?; if !matches!(value.logical_type(), LogicalType::Varchar(_, _)) { value = value.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?; } diff --git a/src/macros/mod.rs b/src/macros/mod.rs index b3f7499e..c3abfc50 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -93,11 +93,11 @@ macro_rules! scala_function { #[typetag::serde] impl ::fnck_sql::expression::function::scala::ScalarFunctionImpl for $struct_name { #[allow(unused_variables, clippy::redundant_closure_call)] - fn eval(&self, args: &[::fnck_sql::expression::ScalarExpression], tuple: &::fnck_sql::types::tuple::Tuple, schema: &[::fnck_sql::catalog::column::ColumnRef]) -> Result<::fnck_sql::types::value::DataValue, ::fnck_sql::errors::DatabaseError> { + fn eval(&self, args: &[::fnck_sql::expression::ScalarExpression], tuple: Option<(&::fnck_sql::types::tuple::Tuple, &[::fnck_sql::catalog::column::ColumnRef])>) -> Result<::fnck_sql::types::value::DataValue, ::fnck_sql::errors::DatabaseError> { let mut _index = 0; $closure($({ - let mut value = args[_index].eval(tuple, schema)?; + let mut value = args[_index].eval(tuple)?; _index += 1; if value.logical_type() != $arg_ty { @@ -184,7 +184,7 @@ macro_rules! table_function { let mut _index = 0; $closure($({ - let mut value = args[_index].eval(&::fnck_sql::types::tuple::EMPTY_TUPLE, &[])?; + let mut value = args[_index].eval(None)?; _index += 1; if value.logical_type() != $arg_ty { diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index d06dde5c..82aa1327 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -2,7 +2,6 @@ use crate::errors::DatabaseError; use crate::optimizer::core::pattern::PatternMatcher; use crate::optimizer::core::rule::{ImplementationRule, MatchPattern}; use crate::optimizer::core::statistics_meta::StatisticMetaLoader; -use crate::optimizer::heuristic::batch::HepMatchOrder; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::heuristic::matcher::HepMatcher; use crate::optimizer::rule::implementation::ImplementationRuleImpl; @@ -47,7 +46,7 @@ impl Memo { return Err(DatabaseError::EmptyPlan); } - for node_id in graph.nodes_iter(HepMatchOrder::BottomUp, None) { + for node_id in graph.nodes_iter(None) { for rule in implementations { if HepMatcher::new(rule.pattern(), node_id, graph).match_opt_expr() { let op = graph.operator(node_id); diff --git a/src/optimizer/heuristic/batch.rs b/src/optimizer/heuristic/batch.rs index 84fbc763..85b92f17 100644 --- a/src/optimizer/heuristic/batch.rs +++ b/src/optimizer/heuristic/batch.rs @@ -23,40 +23,28 @@ impl HepBatch { } } -#[derive(Clone)] -pub struct HepBatchStrategy { +#[derive(Clone, Copy)] +pub enum HepBatchStrategy { /// An execution_ap strategy for rules that indicates the maximum number of executions. If the /// execution_ap reaches fix point (i.e. converge) before maxIterations, it will stop. /// /// Fix Point means that plan tree not changed after applying all rules. - pub max_iteration: usize, - /// An order to traverse the plan tree nodes. - pub match_order: HepMatchOrder, + MaxTimes(usize), + #[allow(dead_code)] + LoopIfApplied, } impl HepBatchStrategy { pub fn once_topdown() -> Self { - HepBatchStrategy { - max_iteration: 1, - match_order: HepMatchOrder::TopDown, - } + HepBatchStrategy::MaxTimes(1) } pub fn fix_point_topdown(max_iteration: usize) -> Self { - HepBatchStrategy { - max_iteration, - match_order: HepMatchOrder::TopDown, - } + HepBatchStrategy::MaxTimes(max_iteration) } -} -#[derive(Clone, Copy)] -pub enum HepMatchOrder { - /// Match from root down. A match attempt at an ancestor always precedes all match attempts at - /// its descendants. - TopDown, - /// Match from leaves up. A match attempt at a descendant precedes all match attempts at its - /// ancestors. #[allow(dead_code)] - BottomUp, + pub fn loop_if_applied() -> Self { + HepBatchStrategy::LoopIfApplied + } } diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 841d6d74..f6de8c61 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -1,7 +1,7 @@ use crate::optimizer::core::memo::Memo; -use crate::optimizer::heuristic::batch::HepMatchOrder; use crate::planner::operator::Operator; use crate::planner::{Childrens, LogicalPlan}; +use fixedbitset::FixedBitSet; use itertools::Itertools; use petgraph::stable_graph::{NodeIndex, StableDiGraph}; use petgraph::visit::{Bfs, EdgeRef}; @@ -136,29 +136,6 @@ impl HepGraph { self.graph.remove_node(source_id) } - /// Traverse the graph in BFS order. - fn bfs(&self, start: HepNodeId) -> Vec { - let mut ids = Vec::with_capacity(self.graph.node_count()); - let mut iter = Bfs::new(&self.graph, start); - while let Some(node_id) = iter.next(&self.graph) { - ids.push(node_id); - } - ids - } - - /// Use bfs to traverse the graph and return node ids - pub fn nodes_iter( - &self, - order: HepMatchOrder, - start_option: Option, - ) -> Box> { - let ids = self.bfs(start_option.unwrap_or(self.root_index)); - match order { - HepMatchOrder::TopDown => Box::new(ids.into_iter()), - HepMatchOrder::BottomUp => Box::new(ids.into_iter().rev()), - } - } - #[allow(dead_code)] pub fn node(&self, node_id: HepNodeId) -> Option<&Operator> { self.graph.node_weight(node_id) @@ -200,6 +177,15 @@ impl HepGraph { self.build_childrens(self.root_index, memo) } + /// Use bfs to traverse the graph and return node ids + pub fn nodes_iter(&self, start_option: Option) -> HepGraphIter { + let inner = Bfs::new(&self.graph, start_option.unwrap_or(self.root_index)); + HepGraphIter { + inner, + graph: &self.graph, + } + } + fn build_childrens(&mut self, start: HepNodeId, memo: Option<&Memo>) -> Option { let physical_option = memo.and_then(|memo| memo.cheapest_physical_option(&start)); @@ -230,6 +216,19 @@ impl HepGraph { } } +pub struct HepGraphIter<'a> { + inner: Bfs, + graph: &'a StableDiGraph, +} + +impl Iterator for HepGraphIter<'_> { + type Item = HepNodeId; + + fn next(&mut self) -> Option { + self.inner.next(self.graph) + } +} + #[cfg(test)] mod tests { use crate::binder::test::build_t1_table; diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 1e3d0493..c8195dd7 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -1,5 +1,4 @@ use crate::optimizer::core::pattern::{Pattern, PatternChildrenPredicate, PatternMatcher}; -use crate::optimizer::heuristic::batch::HepMatchOrder; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; /// Use pattern to determines which rule can be applied @@ -30,10 +29,7 @@ impl PatternMatcher for HepMatcher<'_, '_> { match &self.pattern.children { PatternChildrenPredicate::Recursive => { // check - for node_id in self - .graph - .nodes_iter(HepMatchOrder::TopDown, Some(self.start_id)) - { + for node_id in self.graph.nodes_iter(Some(self.start_id)) { if !(self.pattern.predicate)(self.graph.operator(node_id)) { return false; } diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index 16d5f4c5..116d310e 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -47,14 +47,16 @@ impl HepOptimizer { loader: Option<&StatisticMetaLoader<'_, T>>, ) -> Result { for ref batch in self.batches { - let mut batch_over = false; - let mut iteration = 1usize; - - while iteration <= batch.strategy.max_iteration && !batch_over { - if Self::apply_batch(&mut self.graph, batch)? { - iteration += 1; - } else { - batch_over = true + match batch.strategy { + HepBatchStrategy::MaxTimes(max_iteration) => { + for _ in 0..max_iteration { + if !Self::apply_batch(&mut self.graph, batch)? { + break; + } + } + } + HepBatchStrategy::LoopIfApplied => { + while Self::apply_batch(&mut self.graph, batch)? {} } } } @@ -73,22 +75,21 @@ impl HepOptimizer { } fn apply_batch( - graph: &mut HepGraph, - HepBatch { - rules, strategy, .. - }: &HepBatch, + graph: *mut HepGraph, + HepBatch { rules, .. }: &HepBatch, ) -> Result { - let before_version = graph.version; + let before_version = unsafe { &*graph }.version; for rule in rules { - for node_id in graph.nodes_iter(strategy.match_order, None) { - if Self::apply_rule(graph, rule, node_id)? { + // SAFETY: after successfully modifying the graph, the iterator is no longer used. + for node_id in unsafe { &*graph }.nodes_iter(None) { + if Self::apply_rule(unsafe { &mut *graph }, rule, node_id)? { break; } } } - Ok(before_version != graph.version) + Ok(before_version != unsafe { &*graph }.version) } fn apply_rule( diff --git a/src/optimizer/rule/normalization/pushdown_limit.rs b/src/optimizer/rule/normalization/pushdown_limit.rs index a32efd61..fd6497a0 100644 --- a/src/optimizer/rule/normalization/pushdown_limit.rs +++ b/src/optimizer/rule/normalization/pushdown_limit.rs @@ -107,13 +107,15 @@ impl NormalizationRule for PushLimitIntoScan { fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { if let Operator::Limit(limit_op) = graph.operator(node_id) { if let Some(child_index) = graph.eldest_child_at(node_id) { - if let Operator::TableScan(scan_op) = graph.operator(child_index) { - let mut new_scan_op = scan_op.clone(); - - new_scan_op.limit = (limit_op.offset, limit_op.limit); + let mut is_apply = false; + let limit = (limit_op.offset, limit_op.limit); + if let Operator::TableScan(scan_op) = graph.operator_mut(child_index) { + scan_op.limit = limit; + is_apply = true; + } + if is_apply { graph.remove_node(node_id, false); - graph.replace_node(child_index, Operator::TableScan(new_scan_op)); } } } diff --git a/src/optimizer/rule/normalization/pushdown_predicates.rs b/src/optimizer/rule/normalization/pushdown_predicates.rs index 6778d817..d2d1a51a 100644 --- a/src/optimizer/rule/normalization/pushdown_predicates.rs +++ b/src/optimizer/rule/normalization/pushdown_predicates.rs @@ -72,6 +72,7 @@ fn reduce_filters(filters: Vec, having: bool) -> Option Result<(), DatabaseError> { - if let Operator::Filter(mut filter_op) = graph.operator(node_id).clone() { + let mut is_optimized = false; + if let Operator::Filter(filter_op) = graph.operator_mut(node_id) { + if filter_op.is_optimized { + return Ok(()); + } filter_op.predicate.simplify()?; filter_op.predicate.constant_calculation()?; - - graph.replace_node(node_id, Operator::Filter(filter_op)) + filter_op.is_optimized = true; + is_optimized = true; + } + if is_optimized { + graph.version += 1; } Ok(()) diff --git a/src/planner/operator/filter.rs b/src/planner/operator/filter.rs index 69700a6c..c8a9fd2f 100644 --- a/src/planner/operator/filter.rs +++ b/src/planner/operator/filter.rs @@ -9,13 +9,18 @@ use super::Operator; #[derive(Debug, PartialEq, Eq, Clone, Hash, ReferenceSerialization)] pub struct FilterOperator { pub predicate: ScalarExpression, + pub is_optimized: bool, pub having: bool, } impl FilterOperator { pub fn build(predicate: ScalarExpression, children: LogicalPlan, having: bool) -> LogicalPlan { LogicalPlan::new( - Operator::Filter(FilterOperator { predicate, having }), + Operator::Filter(FilterOperator { + predicate, + is_optimized: false, + having, + }), Childrens::Only(children), ) } diff --git a/src/types/tuple.rs b/src/types/tuple.rs index a57a5cdf..3b7ceaf4 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -9,13 +9,6 @@ use comfy_table::{Cell, Table}; use itertools::Itertools; use std::io::Cursor; use std::sync::Arc; -use std::sync::LazyLock; - -pub static EMPTY_TUPLE: LazyLock = LazyLock::new(|| Tuple { - pk_indices: None, - values: vec![], - id_buf: None, -}); const BITS_MAX_INDEX: usize = 8; @@ -80,32 +73,32 @@ impl Tuple { bits & (1 << (7 - i)) > 0 } - let values_len = table_types.len(); - let mut tuple_values = Vec::with_capacity(values_len); - let bits_len = (values_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; + let types_len = table_types.len(); + let bits_len = (types_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; + let mut values = vec![DataValue::Null; projections.len()]; let mut projection_i = 0; let mut cursor = Cursor::new(&bytes[bits_len..]); for (i, logic_type) in table_types.iter().enumerate() { - if projection_i >= values_len || projection_i > projections.len() - 1 { + if projections.len() <= projection_i { break; } + debug_assert!(projection_i < types_len); if is_none(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { - if projections[projection_i] == i { - tuple_values.push(DataValue::none(logic_type)); - projection_i += 1; - } - } else if let Some(value) = + projection_i += 1; + continue; + } + if let Some(value) = DataValue::from_raw(&mut cursor, logic_type, projections[projection_i] == i)? { - tuple_values.push(value); + values[projection_i] = value; projection_i += 1; } } Ok(Tuple { pk_indices: Some(pk_indices.clone()), - values: tuple_values, + values, id_buf: None, }) } diff --git a/src/utils/bit_vector.rs b/src/utils/bit_vector.rs deleted file mode 100644 index 4a0fd674..00000000 --- a/src/utils/bit_vector.rs +++ /dev/null @@ -1,96 +0,0 @@ -use itertools::Itertools; - -#[derive(Debug, Default)] -pub struct BitVector { - #[allow(dead_code)] - len: u64, - bit_groups: Vec, -} - -impl BitVector { - pub fn new(len: usize) -> BitVector { - BitVector { - len: len as u64, - bit_groups: vec![0; (len + 7) / 8], - } - } - - pub fn set_bit(&mut self, index: usize, value: bool) { - let byte_index = index / 8; - let bit_index = index % 8; - - if value { - self.bit_groups[byte_index] |= 1 << bit_index; - } else { - self.bit_groups[byte_index] &= !(1 << bit_index); - } - } - - pub fn get_bit(&self, index: usize) -> bool { - self.bit_groups[index / 8] >> (index % 8) & 1 != 0 - } - - #[allow(dead_code)] - pub fn len(&self) -> usize { - self.len as usize - } - - #[allow(dead_code)] - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - #[allow(dead_code)] - pub fn to_raw(&self, bytes: &mut Vec) { - bytes.extend(self.len.to_le_bytes()); - - for bits in &self.bit_groups { - bytes.extend(bits.to_le_bytes()); - } - } - - #[allow(dead_code)] - pub fn from_raw(bytes: &[u8]) -> Self { - let len = u64::from_le_bytes(bytes[0..8].try_into().unwrap()); - let bit_groups = bytes[8..] - .iter() - .map(|bit| i8::from_le_bytes([*bit])) - .collect_vec(); - - BitVector { len, bit_groups } - } -} - -#[cfg(test)] -mod tests { - use crate::utils::bit_vector::BitVector; - - #[test] - fn bit_vector_serialization() { - let mut vector = BitVector::new(100); - - vector.set_bit(99, true); - - let mut bytes = Vec::new(); - - vector.to_raw(&mut bytes); - let vector = BitVector::from_raw(&bytes); - - for i in 0..98 { - assert!(!vector.get_bit(i)); - } - assert!(vector.get_bit(99)); - } - - #[test] - fn bit_vector_simple() { - let mut vector = BitVector::new(100); - - vector.set_bit(99, true); - - for i in 0..98 { - assert!(!vector.get_bit(i)); - } - assert!(vector.get_bit(99)); - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 81efc1b4..dde0d096 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,2 +1 @@ -pub(crate) mod bit_vector; pub(crate) mod lru; diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs index 33e246a7..1db6317d 100644 --- a/tests/macros-test/src/main.rs +++ b/tests/macros-test/src/main.rs @@ -109,8 +109,7 @@ mod test { unit: CharLengthUnits::Characters, }), ], - &Tuple::new(None, vec![]), - &vec![], + None, )?; println!("{:?}", function); diff --git a/tpcc/Cargo.toml b/tpcc/Cargo.toml index 2a39d5bb..7bba2487 100644 --- a/tpcc/Cargo.toml +++ b/tpcc/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] clap = { version = "4", features = ["derive"] } chrono = { version = "0.4" } -fnck_sql = { version = "0.0.8", path = "..", package = "fnck_sql" } +fnck_sql = { version = "0.0.9", path = "..", package = "fnck_sql" } indicatif = { version = "0.17" } ordered-float = { version = "4" } rand = { version = "0.8" }