Skip to content

Commit

Permalink
bug(column): incorrect logical not (#11)
Browse files Browse the repository at this point in the history
- fix issue with logical not operator setup
- add unit test for functionality
  • Loading branch information
sjrusso8 authored Apr 8, 2024
1 parent 6650707 commit 89c5d85
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
10 changes: 9 additions & 1 deletion src/column.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! [Column] represents a column in a DataFrame that holds a [spark::Expression]
use std::convert::From;
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Rem, Sub};
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Not, Rem, Sub};

use crate::spark;

Expand Down Expand Up @@ -397,3 +397,11 @@ impl BitXor for Column {
invoke_func("^", vec![self, other])
}
}

impl Not for Column {
type Output = Self;

fn not(self) -> Self::Output {
invoke_func("not", vec![self])
}
}
69 changes: 66 additions & 3 deletions src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ pub fn negate<T: expressions::ToExpr>(col: T) -> Column
where
Vec<T>: expressions::ToVecExpr,
{
invoke_func("not", vec![col])
invoke_func("negative", vec![col])
}

pub fn pow<T: ToExpr>(col1: T, col2: T) -> Column
Expand Down Expand Up @@ -813,9 +813,9 @@ mod tests {

assert_eq!(expected, res);

// negate isin
// Logical NOT for column ISIN
let res = df
.filter(negate(col("name").isin(vec!["Tom", "Bob"])))
.filter(!col("name").isin(vec!["Tom", "Bob"]))
.select("name")
.collect()
.await?;
Expand All @@ -828,4 +828,67 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_func_col_expr() -> Result<(), SparkError> {
let spark = setup().await;

let name: ArrayRef = Arc::new(StringArray::from(vec!["Alice", "Bob"]));

let data = RecordBatch::try_from_iter(vec![("name", name.clone())])?;

let df = spark.createDataFrame(&data)?;

let res = df
.select([col("name"), expr("length(name)")])
.collect()
.await?;

let length: ArrayRef = Arc::new(Int32Array::from(vec![5, 3]));

let expected = RecordBatch::try_from_iter(vec![("name", name), ("length(name)", length)])?;

assert_eq!(expected, res);
Ok(())
}

#[tokio::test]
async fn test_func_greatest() -> Result<(), SparkError> {
let spark = setup().await;

let a: ArrayRef = Arc::new(Int64Array::from(vec![1]));
let b: ArrayRef = Arc::new(Int64Array::from(vec![4]));
let c: ArrayRef = Arc::new(Int64Array::from(vec![4]));

let data = RecordBatch::try_from_iter(vec![("a", a), ("b", b.clone()), ("c", c)])?;

let df = spark.createDataFrame(&data)?;

let res = df.select(greatest(["a", "b", "c"])).collect().await?;

let expected = RecordBatch::try_from_iter(vec![("greatest(a, b, c)", b)])?;

assert_eq!(expected, res);
Ok(())
}

#[tokio::test]
async fn test_func_least() -> Result<(), SparkError> {
let spark = setup().await;

let a: ArrayRef = Arc::new(Int64Array::from(vec![1]));
let b: ArrayRef = Arc::new(Int64Array::from(vec![4]));
let c: ArrayRef = Arc::new(Int64Array::from(vec![4]));

let data = RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b), ("c", c)])?;

let df = spark.createDataFrame(&data)?;

let res = df.select(least(["a", "b", "c"])).collect().await?;

let expected = RecordBatch::try_from_iter(vec![("least(a, b, c)", a)])?;

assert_eq!(expected, res);
Ok(())
}
}

0 comments on commit 89c5d85

Please sign in to comment.