Skip to content

Commit

Permalink
feat(cube): Add cube_ext and string_boolean_coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
paveltiunov committed Nov 25, 2024
1 parent a43ce8b commit fc9ab6b
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 0 deletions.
8 changes: 8 additions & 0 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,14 @@ pub mod variable {
pub use datafusion_expr::var_provider::{VarProvider, VarType};
}

pub mod cube_ext {
pub use datafusion_physical_plan::cube_ext::*;
}

pub mod dfschema {
pub use datafusion_common::*;
}

#[cfg(test)]
pub mod test;
pub mod test_util;
Expand Down
14 changes: 14 additions & 0 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
.or_else(|| list_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
.or_else(|| string_numeric_coercion(lhs_type, rhs_type))
.or_else(|| string_boolean_coercion(lhs_type, rhs_type))
.or_else(|| string_temporal_coercion(lhs_type, rhs_type))
.or_else(|| binary_coercion(lhs_type, rhs_type))
.or_else(|| struct_coercion(lhs_type, rhs_type))
Expand Down Expand Up @@ -536,6 +537,19 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
}
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is boolean and one is `Utf8`/`LargeUtf8`.
fn string_boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, Boolean) => Some(Utf8),
(LargeUtf8, Boolean) => Some(LargeUtf8),
(Boolean, Utf8) => Some(Utf8),
(Boolean, LargeUtf8) => Some(LargeUtf8),
_ => None,
}
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`.
///
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ parking_lot = { workspace = true }
pin-project-lite = "^0.2.7"
rand = { workspace = true }
tokio = { workspace = true }
serde = { version = "1.0.214", features = ["derive"] }
tracing = "0.1.25"
tracing-futures = { version = "0.2.5" }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
124 changes: 124 additions & 0 deletions datafusion/physical-plan/src/cube_ext/catch_unwind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::error::ArrowError;
use futures::future::FutureExt;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::panic::{catch_unwind, AssertUnwindSafe};
use datafusion_common::DataFusionError;

#[derive(PartialEq, Debug)]
pub struct PanicError {
pub msg: String,
}

impl PanicError {
pub fn new(msg: String) -> PanicError {
PanicError { msg }
}
}

impl Display for PanicError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Panic: {}", self.msg)
}
}

impl From<PanicError> for ArrowError {
fn from(error: PanicError) -> Self {
ArrowError::ComputeError(format!("Panic: {}", error.msg))
}
}

impl From<PanicError> for DataFusionError {
fn from(error: PanicError) -> Self {
DataFusionError::Internal(error.msg)
}
}

pub fn try_with_catch_unwind<F, R>(f: F) -> Result<R, PanicError>
where
F: FnOnce() -> R,
{
let result = catch_unwind(AssertUnwindSafe(f));
match result {
Ok(x) => Ok(x),
Err(e) => match e.downcast::<String>() {
Ok(s) => Err(PanicError::new(*s)),
Err(e) => match e.downcast::<&str>() {
Ok(m1) => Err(PanicError::new(m1.to_string())),
Err(_) => Err(PanicError::new("unknown cause".to_string())),
},
},
}
}

pub async fn async_try_with_catch_unwind<F, R>(future: F) -> Result<R, PanicError>
where
F: Future<Output = R>,
{
let result = AssertUnwindSafe(future).catch_unwind().await;
match result {
Ok(x) => Ok(x),
Err(e) => match e.downcast::<String>() {
Ok(s) => Err(PanicError::new(*s)),
Err(e) => match e.downcast::<&str>() {
Ok(m1) => Err(PanicError::new(m1.to_string())),
Err(_) => Err(PanicError::new("unknown cause".to_string())),
},
},
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::panic;

#[test]
fn test_try_with_catch_unwind() {
assert_eq!(
try_with_catch_unwind(|| "ok".to_string()),
Ok("ok".to_string())
);
assert_eq!(
try_with_catch_unwind(|| panic!("oops")),
Err(PanicError::new("oops".to_string()))
);
assert_eq!(
try_with_catch_unwind(|| panic!("oops{}", "ie")),
Err(PanicError::new("oopsie".to_string()))
);
}

#[tokio::test]
async fn test_async_try_with_catch_unwind() {
assert_eq!(
async_try_with_catch_unwind(async { "ok".to_string() }).await,
Ok("ok".to_string())
);
assert_eq!(
async_try_with_catch_unwind(async { panic!("oops") }).await,
Err(PanicError::new("oops".to_string()))
);
assert_eq!(
async_try_with_catch_unwind(async { panic!("oops{}", "ie") }).await,
Err(PanicError::new("oopsie".to_string()))
);
}
}
21 changes: 21 additions & 0 deletions datafusion/physical-plan/src/cube_ext/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

pub mod catch_unwind;

mod spawn;
pub use spawn::*;
149 changes: 149 additions & 0 deletions datafusion/physical-plan/src/cube_ext/spawn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::future::Future;
use crate::cube_ext::catch_unwind::{
async_try_with_catch_unwind, try_with_catch_unwind, PanicError,
};
use futures::sink::SinkExt;
use tokio::task::JoinHandle;
use tracing_futures::Instrument;

/// Calls [tokio::spawn] and additionally enables tracing of the spawned task as part of the current
/// computation. This is CubeStore approach to tracing, so all code must use this function instead
/// of replace [tokio::spawn].
pub fn spawn<T>(task: T) -> JoinHandle<T::Output>
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
if let Some(s) = new_subtask_span() {
tokio::spawn(async move {
let _p = s.parent; // ensure parent stays alive.
task.instrument(s.child).await
})
} else {
tokio::spawn(task)
}
}

/// Propagates current span to blocking operation. See [spawn] for details.
pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if let Some(s) = new_subtask_span() {
tokio::task::spawn_blocking(move || {
let _p = s.parent; // ensure parent stays alive.
s.child.in_scope(f)
})
} else {
tokio::task::spawn_blocking(f)
}
}

struct SpawnSpans {
parent: tracing::Span,
child: tracing::Span,
}

fn new_subtask_span() -> Option<SpawnSpans> {
let parent = tracing::Span::current();
if parent.is_disabled() {
return None;
}
// TODO: ensure this is always enabled.
let child = tracing::info_span!(parent: &parent, "subtask");
Some(SpawnSpans { parent, child })
}

/// Executes future [f] in a new tokio thread. Catches panics.
pub fn spawn_with_catch_unwind<F, T, E>(f: F) -> JoinHandle<Result<T, E>>
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: Send + 'static,
E: From<PanicError> + Send + 'static,
{
let task = async move {
match async_try_with_catch_unwind(f).await {
Ok(result) => result,
Err(panic) => Err(E::from(panic)),
}
};
spawn(task)
}

/// Executes future [f] in a new tokio thread. Feeds the result into [tx] oneshot channel. Catches panics.
pub fn spawn_oneshot_with_catch_unwind<F, T, E>(
f: F,
tx: futures::channel::oneshot::Sender<Result<T, E>>,
) -> JoinHandle<Result<(), Result<T, E>>>
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: Send + 'static,
E: From<PanicError> + Send + 'static,
{
let task = async move {
match async_try_with_catch_unwind(f).await {
Ok(result) => tx.send(result),
Err(panic) => tx.send(Err(E::from(panic))),
}
};
spawn(task)
}

/// Executes future [f] in a new tokio thread. Catches panics and feeds them into a [tx] mpsc channel
pub fn spawn_mpsc_with_catch_unwind<F, T, E>(
f: F,
mut tx: futures::channel::mpsc::Sender<Result<T, E>>,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
T: Send + 'static,
E: From<PanicError> + Send + 'static,
{
let task = async move {
match async_try_with_catch_unwind(f).await {
Ok(_) => (),
Err(panic) => {
tx.send(Err(E::from(panic))).await.ok();
}
}
};
spawn(task)
}

/// Executes fn [f] in a new tokio thread. Catches panics and feeds them into a [tx] mpsc channel.
pub fn spawn_blocking_mpsc_with_catch_unwind<F, R, T, E>(
f: F,
tx: tokio::sync::mpsc::Sender<Result<T, E>>,
) -> JoinHandle<()>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
T: Send + 'static,
E: From<PanicError> + Send + 'static,
{
let task = move || match try_with_catch_unwind(f) {
Ok(_) => (),
Err(panic) => {
tx.blocking_send(Err(E::from(panic))).ok();
}
};
spawn_blocking(task)
}
1 change: 1 addition & 0 deletions datafusion/physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub mod unnest;
pub mod values;
pub mod windows;
pub mod work_table;
pub mod cube_ext;

pub mod udaf {
pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
Expand Down

0 comments on commit fc9ab6b

Please sign in to comment.