Skip to content

Commit

Permalink
fix(flow): fix call df func bug&sqlness test (#4165)
Browse files Browse the repository at this point in the history
* tests: flow sqlness tests

* tests: WIP df func test

* fix: use schema before expand for transform expr

* tests: some basic flow tests

* tests: unit test

* chore: dep use rev not patch

* fix: wired sqlness error?

* refactor: per review

* fix: temp sqlness bug

* fix: use fixed sqlness

* fix: impl drop as async shutdown

* refactor: per bot's review

* tests: drop worker handler both sync/async

* docs: add rationale for test

* refactor: per review

* chore: fmt
  • Loading branch information
discord9 authored Jun 24, 2024
1 parent 0139a70 commit 5179174
Show file tree
Hide file tree
Showing 14 changed files with 932 additions and 198 deletions.
30 changes: 28 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions src/flow/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl FlownodeManager {
schema
.get_name(*i)
.clone()
.unwrap_or_else(|| format!("Col_{i}"))
.unwrap_or_else(|| format!("col_{i}"))
})
.collect_vec()
})
Expand All @@ -344,7 +344,7 @@ impl FlownodeManager {
.get(idx)
.cloned()
.flatten()
.unwrap_or(format!("Col_{}", idx));
.unwrap_or(format!("col_{}", idx));
let ret = ColumnSchema::new(name, typ.scalar_type, typ.nullable);
if schema.typ().time_index == Some(idx) {
ret.with_time_index(true)
Expand Down
2 changes: 2 additions & 0 deletions src/flow/src/adapter/table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ impl TableSource {
column_types,
keys,
time_index,
// by default table schema's column are all non-auto
auto_columns: vec![],
},
names: col_names,
},
Expand Down
26 changes: 22 additions & 4 deletions src/flow/src/adapter/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,15 @@ impl WorkerHandle {

impl Drop for WorkerHandle {
fn drop(&mut self) {
if let Err(err) = self.shutdown_blocking() {
common_telemetry::error!("Fail to shutdown worker: {:?}", err)
let ret = futures::executor::block_on(async { self.shutdown().await });
if let Err(ret) = ret {
common_telemetry::error!(
ret;
"While dropping Worker Handle, failed to shutdown worker, worker might be in inconsistent state."
);
} else {
info!("Flow Worker shutdown due to Worker Handle dropped.")
}
info!("Flow Worker shutdown due to Worker Handle dropped.")
}
}

Expand Down Expand Up @@ -496,6 +501,19 @@ mod test {
use crate::plan::Plan;
use crate::repr::{RelationType, Row};

#[test]
fn drop_handle() {
let (tx, rx) = oneshot::channel();
let worker_thread_handle = std::thread::spawn(move || {
let (handle, mut worker) = create_worker();
tx.send(handle).unwrap();
worker.run();
});
let handle = rx.blocking_recv().unwrap();
drop(handle);
worker_thread_handle.join().unwrap();
}

#[tokio::test]
pub async fn test_simple_get_with_worker_and_handle() {
let (tx, rx) = oneshot::channel();
Expand Down Expand Up @@ -532,7 +550,7 @@ mod test {
tx.send((Row::empty(), 0, 0)).unwrap();
handle.run_available(0).await.unwrap();
assert_eq!(sink_rx.recv().await.unwrap().0, Row::empty());
handle.shutdown().await.unwrap();
drop(handle);
worker_thread_handle.join().unwrap();
}
}
76 changes: 61 additions & 15 deletions src/flow/src/expr/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use bytes::BytesMut;
use common_error::ext::BoxedError;
use common_recordbatch::DfRecordBatch;
use datafusion_physical_expr::PhysicalExpr;
use datatypes::arrow_array;
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::Value;
use datatypes::{arrow_array, value};
use prost::Message;
use serde::{Deserialize, Serialize};
use snafu::{ensure, ResultExt};
Expand Down Expand Up @@ -155,8 +155,10 @@ pub enum ScalarExpr {
exprs: Vec<ScalarExpr>,
},
CallDf {
// TODO(discord9): support shuffle
/// invariant: the input args set inside this [`DfScalarFunction`] is
/// always col(0) to col(n-1) where n is the length of `expr`
df_scalar_fn: DfScalarFunction,
exprs: Vec<ScalarExpr>,
},
/// Conditionally evaluated expressions.
///
Expand Down Expand Up @@ -189,8 +191,27 @@ impl DfScalarFunction {
})
}

pub fn try_from_raw_fn(raw_fn: RawDfScalarFn) -> Result<Self, Error> {
Ok(Self {
fn_impl: raw_fn.get_fn_impl()?,
df_schema: Arc::new(raw_fn.input_schema.to_df_schema()?),
raw_fn,
})
}

/// eval a list of expressions using input values
fn eval_args(values: &[Value], exprs: &[ScalarExpr]) -> Result<Vec<Value>, EvalError> {
exprs
.iter()
.map(|expr| expr.eval(values))
.collect::<Result<_, _>>()
}

// TODO(discord9): add RecordBatch support
pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result<Value, EvalError> {
// first eval exprs to construct values to feed to datafusion
let values: Vec<_> = Self::eval_args(values, exprs)?;

if values.is_empty() {
return InvalidArgumentSnafu {
reason: "values is empty".to_string(),
Expand Down Expand Up @@ -259,16 +280,18 @@ impl<'de> serde::de::Deserialize<'de> for DfScalarFunction {
D: serde::de::Deserializer<'de>,
{
let raw_fn = RawDfScalarFn::deserialize(deserializer)?;
let fn_impl = raw_fn.get_fn_impl().map_err(serde::de::Error::custom)?;
DfScalarFunction::new(raw_fn, fn_impl).map_err(serde::de::Error::custom)
DfScalarFunction::try_from_raw_fn(raw_fn).map_err(serde::de::Error::custom)
}
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RawDfScalarFn {
f: bytes::BytesMut,
input_schema: RelationDesc,
extensions: FunctionExtensions,
/// The raw bytes encoded datafusion scalar function
pub(crate) f: bytes::BytesMut,
/// The input schema of the function
pub(crate) input_schema: RelationDesc,
/// Extension contains mapping from function reference to function name
pub(crate) extensions: FunctionExtensions,
}

impl RawDfScalarFn {
Expand Down Expand Up @@ -354,7 +377,7 @@ impl ScalarExpr {
Ok(ColumnType::new_nullable(func.signature().output))
}
ScalarExpr::If { then, .. } => then.typ(context),
ScalarExpr::CallDf { df_scalar_fn } => {
ScalarExpr::CallDf { df_scalar_fn, .. } => {
let arrow_typ = df_scalar_fn
.fn_impl
// TODO(discord9): get scheme from args instead?
Expand Down Expand Up @@ -445,7 +468,10 @@ impl ScalarExpr {
}
.fail(),
},
ScalarExpr::CallDf { df_scalar_fn } => df_scalar_fn.eval(values),
ScalarExpr::CallDf {
df_scalar_fn,
exprs,
} => df_scalar_fn.eval(values, exprs),
}
}

Expand Down Expand Up @@ -614,7 +640,15 @@ impl ScalarExpr {
f(then)?;
f(els)
}
_ => Ok(()),
ScalarExpr::CallDf {
df_scalar_fn: _,
exprs,
} => {
for expr in exprs {
f(expr)?;
}
Ok(())
}
}
}

Expand Down Expand Up @@ -650,7 +684,15 @@ impl ScalarExpr {
f(then)?;
f(els)
}
_ => Ok(()),
ScalarExpr::CallDf {
df_scalar_fn: _,
exprs,
} => {
for expr in exprs {
f(expr)?;
}
Ok(())
}
}
}
}
Expand Down Expand Up @@ -852,11 +894,15 @@ mod test {
.unwrap();
let extensions = FunctionExtensions::from_iter(vec![(0, "abs")]);
let raw_fn = RawDfScalarFn::from_proto(&raw_scalar_func, input_schema, extensions).unwrap();
let fn_impl = raw_fn.get_fn_impl().unwrap();
let df_func = DfScalarFunction::new(raw_fn, fn_impl).unwrap();
let df_func = DfScalarFunction::try_from_raw_fn(raw_fn).unwrap();
let as_str = serde_json::to_string(&df_func).unwrap();
let from_str: DfScalarFunction = serde_json::from_str(&as_str).unwrap();
assert_eq!(df_func, from_str);
assert_eq!(df_func.eval(&[Value::Null]).unwrap(), Value::Int64(1));
assert_eq!(
df_func
.eval(&[Value::Null], &[ScalarExpr::Column(0)])
.unwrap(),
Value::Int64(1)
);
}
}
29 changes: 27 additions & 2 deletions src/flow/src/repr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use itertools::Itertools;
use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt, ResultExt};

use crate::adapter::error::{DatafusionSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu};
use crate::adapter::error::{
DatafusionSnafu, InternalSnafu, InvalidQuerySnafu, Result, UnexpectedSnafu,
};
use crate::expr::{MapFilterProject, SafeMfpPlan, ScalarExpr};

/// a set of column indices that are "keys" for the collection.
Expand Down Expand Up @@ -93,13 +95,19 @@ pub struct RelationType {
///
/// A collection can contain multiple sets of keys, although it is common to
/// have either zero or one sets of key indices.
#[serde(default)]
pub keys: Vec<Key>,
/// optionally indicate the column that is TIME INDEX
pub time_index: Option<usize>,
/// mark all the columns that are added automatically by flow, but are not present in original sql
pub auto_columns: Vec<usize>,
}

impl RelationType {
pub fn with_autos(mut self, auto_cols: &[usize]) -> Self {
self.auto_columns = auto_cols.to_vec();
self
}

/// Trying to apply a mpf on current types, will return a new RelationType
/// with the new types, will also try to preserve keys&time index information
/// if the old key&time index columns are preserve in given mfp
Expand Down Expand Up @@ -155,10 +163,16 @@ impl RelationType {
let time_index = self
.time_index
.and_then(|old| old_to_new_col.get(&old).cloned());
let auto_columns = self
.auto_columns
.iter()
.filter_map(|old| old_to_new_col.get(old).cloned())
.collect_vec();
Ok(Self {
column_types: mfp_out_types,
keys,
time_index,
auto_columns,
})
}
/// Constructs a `RelationType` representing the relation with no columns and
Expand All @@ -175,6 +189,7 @@ impl RelationType {
column_types,
keys: Vec::new(),
time_index: None,
auto_columns: vec![],
}
}

Expand Down Expand Up @@ -340,6 +355,16 @@ pub struct RelationDesc {
}

impl RelationDesc {
pub fn len(&self) -> Result<usize> {
ensure!(
self.typ.column_types.len() == self.names.len(),
InternalSnafu {
reason: "Expect typ and names field to be of same length"
}
);
Ok(self.names.len())
}

pub fn to_df_schema(&self) -> Result<DFSchema> {
let fields: Vec<_> = self
.iter()
Expand Down
Loading

0 comments on commit 5179174

Please sign in to comment.