Skip to content

Commit

Permalink
tests: unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
discord9 committed Jun 19, 2024
1 parent 96d8379 commit 98c1270
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 39 deletions.
23 changes: 16 additions & 7 deletions src/flow/src/expr/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ impl DfScalarFunction {
})
}

pub fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
Ok(Self {
fn_impl: raw_fn.get_fn_impl()?,
df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
raw_fn,
})
}

// TODO(discord9): add RecordBatch support
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// first eval exprs to construct values to feed to datafusion
Expand Down Expand Up @@ -267,16 +275,18 @@ impl<'de> serde::de::Deserialize<'de> for DfScalarFunction {
D: serde::de::Deserializer<'de>,
{
let raw_fn = RawDfScalarFn::deserialize(deserializer)?;
let fn_impl = raw_fn.get_fn_impl().map_err(serde::de::Error::custom)?;
DfScalarFunction::new(raw_fn, fn_impl).map_err(serde::de::Error::custom)
DfScalarFunction::try_from_raw_fn(raw_fn).map_err(serde::de::Error::custom)
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RawDfScalarFn {
f: bytes::BytesMut,
input_schema: RelationDesc,
extensions: FunctionExtensions,
/// The raw bytes encoded datafusion scalar function
pub(crate) f: bytes::BytesMut,
/// The input schema of the function
pub(crate) input_schema: RelationDesc,
/// Extension contains mapping from function reference to function name
pub(crate) extensions: FunctionExtensions,
}

impl RawDfScalarFn {
Expand Down Expand Up @@ -879,8 +889,7 @@ mod test {
.unwrap();
let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
let fn_impl = raw_fn.get_fn_impl().unwrap();
let df_func = DfScalarFunction::new(raw_fn, fn_impl).unwrap();
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).unwrap();
let as_str = serde_json::to_string(&df_func).unwrap();
let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap();
assert_eq!(df_func, from_str);
Expand Down
284 changes: 279 additions & 5 deletions src/flow/src/transform/aggr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,24 +460,302 @@ impl TypedPlan {

#[cfg(test)]
mod test {
use bytes::BytesMut;
use common_time::{DateTime, Interval};
use datatypes::prelude::ConcreteDataType;
use pretty_assertions::{assert_eq, assert_ne};

use super::*;
use crate::expr::{DfScalarFunction, RawDfScalarFn};
use crate::plan::{Plan, TypedPlan};
use crate::repr::{self, ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
/// TODO(discord9): add more illegal sql tests
#[tokio::test]
async fn tes_missing_key_check() {
async fn test_missing_key_check() {
let engine = create_test_query_engine();
let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
let plan = sql_to_substrait(engine.clone(), sql).await;

let mut ctx = create_test_ctx();
assert!(TypedPlan::from_substrait_plan(&mut ctx, &plan).is_err());
}

#[tokio::test]
async fn test_df_func_basic() {
let engine = create_test_query_engine();
let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
let plan = sql_to_substrait(engine.clone(), sql).await;

let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();

let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
])
.with_key(vec![2])
.with_time_index(Some(1))
.with_autos(&[2])
.into_named(vec![
None,
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(ConcreteDataType::datetime_datatype(), false),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Interval::from_month_day_nano(
0,
0,
1_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Interval::from_month_day_nano(
0,
0,
1_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.map(vec![ScalarExpr::CallDf {
df_scalar_fn: DfScalarFunction::try_from_raw_fn(
RawDfScalarFn {
f: BytesMut::from(
b"\x08\x01\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
.as_ref(),
),
input_schema: RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint32_datatype(),
false,
)])
.into_unnamed(),
extensions: FunctionExtensions {
anchor_to_name: BTreeMap::from([
(0, "tumble".to_string()),
(1, "abs".to_string()),
(2, "sum".to_string()),
]),
},
},
)
.unwrap(),
exprs: vec![ScalarExpr::Column(0)],
}])
.unwrap()
.project(vec![2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0))
.with_autos(&[1])
.into_named(vec![
Some("window_start".to_string()),
Some("window_end".to_string()),
None,
]),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::Column(3),
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![4, 5, 6])
.unwrap(),
},
};
assert_eq!(expected, flow_plan);
}

#[tokio::test]
async fn test_df_func_expr_tree() {
let engine = create_test_query_engine();
let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
let plan = sql_to_substrait(engine.clone(), sql).await;

let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).unwrap();

let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt32,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
])
.with_key(vec![2])
.with_time_index(Some(1))
.with_autos(&[2])
.into_named(vec![
None,
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(ConcreteDataType::datetime_datatype(), false),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Interval::from_month_day_nano(
0,
0,
1_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Interval::from_month_day_nano(
0,
0,
1_000_000_000,
),
start_time: Some(DateTime::new(1625097600000)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.project(vec![0, 1])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::datetime_datatype(), false), // window start
ColumnType::new(CDT::datetime_datatype(), false), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0))
.with_autos(&[1])
.into_named(vec![
Some("window_start".to_string()),
Some("window_end".to_string()),
None,
]),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::CallDf {
df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
input_schema: RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint64_datatype(),
true,
)])
.into_unnamed(),
extensions: FunctionExtensions {
anchor_to_name: BTreeMap::from([
(0, "abs".to_string()),
(1, "tumble".to_string()),
(2, "sum".to_string()),
]),
},
})
.unwrap(),
exprs: vec![ScalarExpr::Column(3)],
},
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![4, 5, 6])
.unwrap(),
},
};
assert_eq!(expected, flow_plan);
}

/// TODO(discord9): add more illegal sql tests
#[tokio::test]
async fn test_tumble_composite() {
Expand Down Expand Up @@ -655,10 +933,6 @@ mod test {
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
// TODO(discord9): mfp indirectly ref to key columns
/*
.with_key(vec![1])
.with_time_index(Some(0)),*/
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
Expand Down
Loading

0 comments on commit 98c1270

Please sign in to comment.