Skip to content

Commit

Permalink
Call func
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewmturner committed Jan 30, 2025
1 parent b7f1fb4 commit de8022c
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions src/execution/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,35 @@ use std::sync::Arc;

use color_eyre::{eyre::eyre, Result};
use datafusion::{
arrow::{array::ArrayRef, datatypes::DataType},
arrow::{
array::{Array, ArrayRef, PrimitiveArray},
datatypes::{self, ArrowPrimitiveType, DataType},
},
logical_expr::{ColumnarValue, ScalarUDF, Volatility},
prelude::create_udf,
};
use datafusion_common::{DataFusionError, Result as DFResult};
use log::{error, info};
use wasmtime::{Instance, Module, Store};
use wasmtime::{Instance, Module, Store, Val};

use crate::config::{WasmFuncDetails, WasmUdfConfig};

fn get_arrow_value<T>(args: &[ArrayRef], row_ix: usize, col_ix: usize) -> DFResult<T::Native>
where
T: ArrowPrimitiveType,
{
args.get(col_ix)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Error casting column {col_ix:?} to array of primitive values"
))
})
.map(|arr| arr.value(row_ix))
}

pub fn udf_signature_from_func_details(
func_details: &WasmFuncDetails,
) -> Result<(Vec<DataType>, DataType)> {
Expand All @@ -44,7 +63,14 @@ pub fn udf_signature_from_func_details(
Ok((input_types?, return_type))
}

fn validate_args(args: &[ColumnarValue], input_types: &[DataType]) -> DFResult<()> {
fn validate_func(args: &[ColumnarValue], input_types: &[DataType]) -> DFResult<()> {
// Check that there is at least one `ColumnarValue`. Strictly speaking this might not be
// needed, but for the immediate future I believe it will always be the case
if args.len() == 0 {
return Err(DataFusionError::Execution(
"There must be at least one argument".to_string(),
));
}
// First check that the defined input_types and args have same number of columns
if args.len() != input_types.len() {
return Err(DataFusionError::Execution(
Expand All @@ -55,23 +81,68 @@ fn validate_args(args: &[ColumnarValue], input_types: &[DataType]) -> DFResult<(
Ok(())
}

/// Extract the relevant row and column indices and convert to WASM values that can be passed to
/// WASM func
fn create_wasm_func_params(vals: &[Arc<dyn Array>], row_idx: usize) -> DFResult<Vec<Val>> {
(0..vals.len())
.map(|col_idx| match vals[col_idx].data_type() {
DataType::Int32 => {
let arrow_val = get_arrow_value::<datatypes::Int32Type>(vals, row_idx, col_idx)?;
Ok(Val::I32(arrow_val))
}
DataType::Int64 => {
let arrow_val = get_arrow_value::<datatypes::Int64Type>(vals, row_idx, col_idx)?;
Ok(Val::I64(arrow_val))
}
DataType::Float32 => {
let arrow_val = get_arrow_value::<datatypes::Float32Type>(vals, row_idx, col_idx)?;
Ok(Val::F32(arrow_val.to_bits()))
}
DataType::Float64 => {
let arrow_val = get_arrow_value::<datatypes::Float64Type>(vals, row_idx, col_idx)?;
Ok(Val::F64(arrow_val.to_bits()))
}

_ => Err(DataFusionError::Execution(
"Unsupported column type for WASM scalar function".to_string(),
)),
})
.collect()
}

fn create_wasm_udf_impl(
module_bytes: Vec<u8>,
func_name: String,
input_types: Vec<DataType>,
return_type: DataType,
) -> impl Fn(&[ColumnarValue]) -> DFResult<ColumnarValue> {
move |args: &[ColumnarValue]| {
// First validate the arguments
validate_args(args, &input_types)?;
validate_func(args, &input_types)?;
// Load the function again
let mut store = Store::<()>::default();
let module = Module::from_binary(store.engine(), &module_bytes)
.map_err(|e| DataFusionError::Internal(format!("Error loading module: {e:?}")))?;
let instance = Instance::new(&mut store, &module, &[])
.map_err(|e| DataFusionError::Internal(format!("Error instantiating module: {e:?}")))?;
let func = instance.get_func(&mut store, &func_name).ok_or_else(|| {
DataFusionError::Execution(format!("Unable to access function {func_name}"))
})?;

let vals = ColumnarValue::values_to_arrays(args)?;
let first = &vals[0];
let val_count = vals.first().unwrap().len();
let mut results: Vec<Val> = Vec::with_capacity(val_count);
for row_idx in 0..val_count {
let params = create_wasm_func_params(&vals, row_idx)?;
func.call(&mut store, &params, &mut results[row_idx..row_idx + 1])
.map_err(|e| {
DataFusionError::Execution(format!(
"Error executing function {func_name:?}: {e:?}"
))
})?;
}

let first = vals.first().unwrap();
Ok(ColumnarValue::Array(Arc::clone(first)))
}
}
Expand All @@ -94,6 +165,7 @@ pub fn create_wasm_udfs(wasm_udf_config: &WasmUdfConfig) -> Result<Vec<ScalarUDF
} else {
let udf_impl = create_wasm_udf_impl(
module_bytes.to_owned(),
func_details.name.to_string(),
input_types.clone(),
return_type.clone(),
);
Expand Down

0 comments on commit de8022c

Please sign in to comment.