From bbaa21fd4608ae9c5a07191fc52aabf571d7bbf2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E4=BA=8C=E6=89=8B=E6=8E=89=E5=8C=85=E5=B7=A5=E7=A8=8B?=
 =?UTF-8?q?=E5=B8=88?= <rustin.liu@gmail.com>
Date: Fri, 7 Jul 2023 20:39:46 +0800
Subject: [PATCH] feat: support InList clause in streaming SQL (#1694)

* feat: support InList clause in streaming SQL

Signed-off-by: hi-rustin <rustin.liu@gmail.com>

* feat: support not in list

Signed-off-by: hi-rustin <rustin.liu@gmail.com>

* fix: use users

Signed-off-by: hi-rustin <rustin.liu@gmail.com>

* Remove broken tests

Signed-off-by: hi-rustin <rustin.liu@gmail.com>

* Fix fmt

Signed-off-by: hi-rustin <rustin.liu@gmail.com>

* Better test code

Signed-off-by: hi-rustin <rustin.liu@gmail.com>

---------

Signed-off-by: hi-rustin <rustin.liu@gmail.com>
Co-authored-by: Dario Pizzamiglio <1150680+mediuminvader@users.noreply.github.com>
---
 dozer-sql/src/pipeline/expression/builder.rs  | 28 ++++++
 .../src/pipeline/expression/execution.rs      | 37 ++++++++
 dozer-sql/src/pipeline/expression/in_list.rs  | 27 ++++++
 dozer-sql/src/pipeline/expression/mod.rs      |  1 +
 .../src/pipeline/expression/tests/in_list.rs  | 86 +++++++++++++++++++
 .../src/pipeline/expression/tests/mod.rs      |  1 +
 6 files changed, 180 insertions(+)
 create mode 100644 dozer-sql/src/pipeline/expression/in_list.rs
 create mode 100644 dozer-sql/src/pipeline/expression/tests/in_list.rs

diff --git a/dozer-sql/src/pipeline/expression/builder.rs b/dozer-sql/src/pipeline/expression/builder.rs
index da143bcf68..59927c27aa 100644
--- a/dozer-sql/src/pipeline/expression/builder.rs
+++ b/dozer-sql/src/pipeline/expression/builder.rs
@@ -108,6 +108,12 @@ impl ExpressionBuilder {
                 escape_char,
                 schema,
             ),
+            SqlExpr::InList {
+                expr,
+                list,
+                negated,
+            } => self.parse_sql_in_list_operator(parse_aggregations, expr, list, *negated, schema),
+
             SqlExpr::Cast { expr, data_type } => {
                 self.parse_sql_cast_operator(parse_aggregations, expr, data_type, schema)
             }
@@ -764,6 +770,28 @@ impl ExpressionBuilder {
             return_type,
         })
     }
+
+    fn parse_sql_in_list_operator(
+        &mut self,
+        parse_aggregations: bool,
+        expr: &Expr,
+        list: &[Expr],
+        negated: bool,
+        schema: &Schema,
+    ) -> Result<Expression, PipelineError> {
+        let expr = self.parse_sql_expression(parse_aggregations, expr, schema)?;
+        let list = list
+            .iter()
+            .map(|expr| self.parse_sql_expression(parse_aggregations, expr, schema))
+            .collect::<Result<Vec<_>, PipelineError>>()?;
+        let in_list_expression = Expression::InList {
+            expr: Box::new(expr),
+            list,
+            negated,
+        };
+
+        Ok(in_list_expression)
+    }
 }
 
 #[derive(Debug, Clone, Hash, PartialEq, Eq)]
diff --git a/dozer-sql/src/pipeline/expression/execution.rs b/dozer-sql/src/pipeline/expression/execution.rs
index 946b20a370..e88c2b616c 100644
--- a/dozer-sql/src/pipeline/expression/execution.rs
+++ b/dozer-sql/src/pipeline/expression/execution.rs
@@ -22,6 +22,7 @@ use uuid::Uuid;
 
 use super::aggregate::AggregateFunctionType;
 use super::cast::CastOperatorType;
+use super::in_list::evaluate_in_list;
 use super::scalar::string::{evaluate_like, get_like_operator_type};
 
 #[derive(Clone, Debug, PartialEq)]
@@ -73,6 +74,11 @@ pub enum Expression {
         pattern: Box<Expression>,
         escape: Option<char>,
     },
+    InList {
+        expr: Box<Expression>,
+        list: Vec<Expression>,
+        negated: bool,
+    },
     Now {
         fun: DateTimeFunctionType,
     },
@@ -219,6 +225,22 @@ impl Expression {
                 pattern,
                 escape: _,
             } => arg.to_string(schema) + " LIKE " + pattern.to_string(schema).as_str(),
+            Expression::InList {
+                expr,
+                list,
+                negated,
+            } => {
+                expr.to_string(schema)
+                    + if *negated { " NOT" } else { "" }
+                    + " IN ("
+                    + list
+                        .iter()
+                        .map(|e| e.to_string(schema))
+                        .collect::<Vec<String>>()
+                        .join(",")
+                        .as_str()
+                    + ")"
+            }
             Expression::GeoFunction { fun, args } => {
                 fun.to_string()
                     + "("
@@ -318,6 +340,11 @@ impl ExpressionExecutor for Expression {
                 pattern,
                 escape,
             } => evaluate_like(schema, arg, pattern, *escape, record),
+            Expression::InList {
+                expr,
+                list,
+                negated,
+            } => evaluate_in_list(schema, expr, list, *negated, record),
             Expression::Cast { arg, typ } => typ.evaluate(schema, arg, record),
             Expression::GeoFunction { fun, args } => fun.evaluate(schema, args, record),
             Expression::ConditionalExpression { fun, args } => fun.evaluate(schema, args, record),
@@ -384,6 +411,16 @@ impl ExpressionExecutor for Expression {
                 pattern,
                 escape: _,
             } => get_like_operator_type(arg, pattern, schema),
+            Expression::InList {
+                expr: _,
+                list: _,
+                negated: _,
+            } => Ok(ExpressionType::new(
+                FieldType::Boolean,
+                false,
+                SourceDefinition::Dynamic,
+                false,
+            )),
             Expression::Cast { arg, typ } => typ.get_return_type(schema, arg),
             Expression::GeoFunction { fun, args } => get_geo_function_type(fun, args, schema),
             Expression::DateTimeFunction { fun, arg } => {
diff --git a/dozer-sql/src/pipeline/expression/in_list.rs b/dozer-sql/src/pipeline/expression/in_list.rs
new file mode 100644
index 0000000000..99fd93c43b
--- /dev/null
+++ b/dozer-sql/src/pipeline/expression/in_list.rs
@@ -0,0 +1,27 @@
+use dozer_types::types::{Field, Record, Schema};
+
+use crate::pipeline::errors::PipelineError;
+use crate::pipeline::expression::execution::{Expression, ExpressionExecutor};
+
+pub(crate) fn evaluate_in_list(
+    schema: &Schema,
+    expr: &Expression,
+    list: &[Expression],
+    negated: bool,
+    record: &Record,
+) -> Result<Field, PipelineError> {
+    let field = expr.evaluate(record, schema)?;
+    let mut result = false;
+    for item in list {
+        let item = item.evaluate(record, schema)?;
+        if field == item {
+            result = true;
+            break;
+        }
+    }
+    // Negate the result if the IN list was negated.
+    if negated {
+        result = !result;
+    }
+    Ok(Field::Boolean(result))
+}
diff --git a/dozer-sql/src/pipeline/expression/mod.rs b/dozer-sql/src/pipeline/expression/mod.rs
index accad562aa..3ec8b05c08 100644
--- a/dozer-sql/src/pipeline/expression/mod.rs
+++ b/dozer-sql/src/pipeline/expression/mod.rs
@@ -8,6 +8,7 @@ pub mod conditional;
 mod datetime;
 pub mod execution;
 pub mod geo;
+pub mod in_list;
 mod json_functions;
 pub mod logical;
 pub mod mathematical;
diff --git a/dozer-sql/src/pipeline/expression/tests/in_list.rs b/dozer-sql/src/pipeline/expression/tests/in_list.rs
new file mode 100644
index 0000000000..819cc25da5
--- /dev/null
+++ b/dozer-sql/src/pipeline/expression/tests/in_list.rs
@@ -0,0 +1,86 @@
+use crate::pipeline::expression::tests::test_common::run_fct;
+use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition};
+
+#[test]
+fn test_in_list() {
+    let f = run_fct(
+        "SELECT 42 IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)",
+        Schema::empty(),
+        vec![],
+    );
+    assert_eq!(f, Field::Boolean(false));
+
+    let f = run_fct(
+        "SELECT 42 IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42)",
+        Schema::empty(),
+        vec![],
+    );
+    assert_eq!(f, Field::Boolean(true));
+
+    let schema = Schema::empty()
+        .field(
+            FieldDefinition::new(
+                String::from("age"),
+                FieldType::Int,
+                false,
+                SourceDefinition::Dynamic,
+            ),
+            false,
+        )
+        .clone();
+    let f = run_fct(
+        "SELECT age IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) FROM users",
+        schema.clone(),
+        vec![Field::Int(42)],
+    );
+    assert_eq!(f, Field::Boolean(false));
+
+    let f = run_fct(
+        "SELECT age IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42) FROM users",
+        schema.clone(),
+        vec![Field::Int(42)],
+    );
+    assert_eq!(f, Field::Boolean(true));
+}
+
+#[test]
+fn test_not_in_list() {
+    let f = run_fct(
+        "SELECT 42 NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)",
+        Schema::empty(),
+        vec![],
+    );
+    assert_eq!(f, Field::Boolean(true));
+
+    let f = run_fct(
+        "SELECT 42 NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42)",
+        Schema::empty(),
+        vec![],
+    );
+    assert_eq!(f, Field::Boolean(false));
+
+    let schema = Schema::empty()
+        .field(
+            FieldDefinition::new(
+                String::from("age"),
+                FieldType::Int,
+                false,
+                SourceDefinition::Dynamic,
+            ),
+            false,
+        )
+        .clone();
+    let f = run_fct(
+        "SELECT age NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) FROM users",
+        schema.clone(),
+        vec![Field::Int(42)],
+    );
+    assert_eq!(f, Field::Boolean(true));
+
+    let f = run_fct(
+        "SELECT age NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42) FROM users",
+        schema.clone(),
+        vec![Field::Int(42)],
+    );
+    assert_eq!(f, Field::Boolean(false));
+}
diff --git a/dozer-sql/src/pipeline/expression/tests/mod.rs b/dozer-sql/src/pipeline/expression/tests/mod.rs
index 8c7b40f852..fb3368fadf 100644
--- a/dozer-sql/src/pipeline/expression/tests/mod.rs
+++ b/dozer-sql/src/pipeline/expression/tests/mod.rs
@@ -15,6 +15,7 @@ mod conditional;
 mod datetime;
 #[cfg(test)]
 mod distance;
+mod in_list;
 #[cfg(test)]
 mod json_functions;
 #[cfg(test)]