Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
In support (#3)
Browse files Browse the repository at this point in the history
* Adding IN feature in new Engine

* IN: fixing some CSV strange behavior that in the legacy engine COUNT is returning float, and some tests for the new engine were disabled for that
  • Loading branch information
FreCap authored May 22, 2021
1 parent b021028 commit 4a5f712
Show file tree
Hide file tree
Showing 19 changed files with 371 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ public static Literal timestampLiteral(String value) {
return literal(value, DataType.TIMESTAMP);
}

public static <T> Literal arrayLiteral(List<T> value) {
return literal(value, DataType.ARRAY);
}

public static Literal doubleLiteral(Double value) {
return literal(value, DataType.DOUBLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public enum DataType {
DATE(ExprCoreType.DATE),
TIME(ExprCoreType.TIME),
TIMESTAMP(ExprCoreType.TIMESTAMP),
ARRAY(ExprCoreType.ARRAY),
INTERVAL(ExprCoreType.INTERVAL);

@Getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
* Params include the field expression and/or wildcard field expression,
* nested field expression (@field).
* And the values that the field is mapped to (@valueList).
*
* @deprecated use function ("in") instead
*/
@Deprecated
@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public static ExprValue tupleValue(Map<String, Object> map) {
/**
* {@link ExprCollectionValue} constructor.
*/
public static ExprValue collectionValue(List<Object> list) {
public static <T> ExprValue collectionValue(List<T> list) {
List<ExprValue> valueList = new ArrayList<>();
list.forEach(o -> valueList.add(fromObjectValue(o)));
return new ExprCollectionValue(valueList);
Expand Down Expand Up @@ -129,6 +129,10 @@ public static ExprValue fromObjectValue(Object o) {
return stringValue((String) o);
} else if (o instanceof Float) {
return floatValue((Float) o);
} else if (o instanceof ExprValue) {
// since there is no primitive in Java for differentiating TIMESTAMP DATETIME and DATE
// we can allow passing a ExprValue that already contains this information
return (ExprValue) o;
} else {
throw new ExpressionEvaluationException("unsupported object " + o.getClass());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionName;
import com.amazon.opendistroforelasticsearch.sql.expression.function.BuiltinFunctionRepository;
import com.amazon.opendistroforelasticsearch.sql.expression.window.ranking.RankingWindowFunction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
Expand Down Expand Up @@ -252,7 +254,7 @@ public FunctionExpression subtract(Expression... expressions) {
public FunctionExpression multiply(Expression... expressions) {
return function(BuiltinFunctionName.MULTIPLY, expressions);
}

public FunctionExpression adddate(Expression... expressions) {
return function(BuiltinFunctionName.ADDDATE, expressions);
}
Expand Down Expand Up @@ -364,7 +366,7 @@ public FunctionExpression module(Expression... expressions) {
public FunctionExpression substr(Expression... expressions) {
return function(BuiltinFunctionName.SUBSTR, expressions);
}

public FunctionExpression substring(Expression... expressions) {
return function(BuiltinFunctionName.SUBSTR, expressions);
}
Expand Down Expand Up @@ -588,4 +590,19 @@ public FunctionExpression castTimestamp(Expression value) {
return (FunctionExpression) repository
.compile(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), Arrays.asList(value));
}

/**
* Check that a field is contained in a set of values.
*/
public FunctionExpression in(Expression field, Expression... expressions) {
List<Expression> where = new ArrayList<>();
where.add(field);
where.addAll(Arrays.asList(expressions));

return function(BuiltinFunctionName.IN, where.toArray(new Expression[0]));
}

public FunctionExpression not_in(Expression field, Expression... expressions) {
return not(in(field, expressions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public enum BuiltinFunctionName {
GTE(FunctionName.of(">=")),
LIKE(FunctionName.of("like")),
NOT_LIKE(FunctionName.of("not like")),
IN(FunctionName.of("in")),

/**
* Aggregation Function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_MISSING;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_NULL;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.ARRAY;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DOUBLE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.FLOAT;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.LONG;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
Expand Down Expand Up @@ -63,6 +70,7 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(like());
repository.register(notLike());
repository.register(regexp());
repository.register(in());
}

/**
Expand Down Expand Up @@ -262,6 +270,26 @@ private static FunctionResolver notLike() {
STRING));
}

private static FunctionResolver in() {
return FunctionDSL.define(BuiltinFunctionName.IN.getName(),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, INTEGER, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, STRING, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, LONG, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, FLOAT, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, DOUBLE, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, DATE, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, DATETIME, ARRAY),
FunctionDSL.impl(FunctionDSL.nullMissingHandling(OperatorUtils::in),
BOOLEAN, TIMESTAMP, ARRAY));
}

private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2,
Table<ExprValue, ExprValue, ExprValue> table) {
if (table.contains(arg1, arg2)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

package com.amazon.opendistroforelasticsearch.sql.utils;

import com.amazon.opendistroforelasticsearch.sql.data.model.AbstractExprNumberValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprStringValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import java.util.regex.Pattern;
import lombok.experimental.UtilityClass;
Expand Down Expand Up @@ -99,4 +101,24 @@ private static String patternToRegex(String patternString) {
regex.append('$');
return regex.toString();
}


/**
* IN (..., ...) operator util.
* Expression { expr IN (collection of values..) } is to judge
* if expr is contained in a given collection.
*/
public static ExprBooleanValue in(ExprValue expr, ExprValue setOfValues) {
return ExprBooleanValue.of(isIn(expr, setOfValues));
}

private static boolean isIn(ExprValue expr, ExprValue setOfValues) {
if (expr instanceof AbstractExprNumberValue) {
return setOfValues.collectionValue().contains(expr.value());
} else if (expr instanceof ExprStringValue) {
return setOfValues.collectionValue().contains(expr.stringValue());
} else {
return setOfValues.collectionValue().contains(expr);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,24 @@ public void invalidConvertExprValue(ExprValue value, Function<ExprValue, Object>
assertThat(exception.getMessage(), Matchers.containsString("invalid"));
}

// disabling test because in case of expr collections, we could pass ExprValues
// @Test
// public void unSupportedObject() {
// Exception exception = assertThrows(ExpressionEvaluationException.class,
// () -> ExprValueUtils.fromObjectValue(integerValue(1)));
// assertEquals(
// "unsupported object "
// + "class com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue",
// exception.getMessage());
// }

@Test
public void unSupportedObject() {
Exception exception = assertThrows(ExpressionEvaluationException.class,
() -> ExprValueUtils.fromObjectValue(integerValue(1)));
() -> ExprValueUtils.fromObjectValue(new Object()));
assertEquals(
"unsupported object "
+ "class com.amazon.opendistroforelasticsearch.sql.data.model.ExprIntegerValue",
+ "class java.lang.Object",
exception.getMessage());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.LITERAL_TRUE;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.booleanValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.fromObjectValue;
import static com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils.missingValue;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.BOOLEAN;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATE;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.DATETIME;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.INTEGER;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.STRING;
import static com.amazon.opendistroforelasticsearch.sql.data.type.ExprCoreType.TIMESTAMP;
import static com.amazon.opendistroforelasticsearch.sql.utils.ComparisonUtil.compare;
import static com.amazon.opendistroforelasticsearch.sql.utils.OperatorUtils.matches;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprBooleanValue;
Expand All @@ -49,14 +52,15 @@
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprTupleValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValueUtils;
import com.amazon.opendistroforelasticsearch.sql.exception.ExpressionEvaluationException;
import com.amazon.opendistroforelasticsearch.sql.expression.DSL;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
import com.amazon.opendistroforelasticsearch.sql.expression.ExpressionTestBase;
import com.amazon.opendistroforelasticsearch.sql.expression.FunctionExpression;
import com.amazon.opendistroforelasticsearch.sql.utils.OperatorUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.sun.org.apache.xpath.internal.Arg;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
Expand Down Expand Up @@ -832,4 +836,66 @@ public void compare_int_long() {
FunctionExpression equal = dsl.equal(DSL.literal(1), DSL.literal(1L));
assertTrue(equal.valueOf(valueEnv()).booleanValue());
}

private static Stream<Arguments> testInArguments() {
List<List> arguments =
Arrays.asList(Arrays.asList(1, Arrays.asList(0, 2, 1, 3)),
Arrays.asList(1, Arrays.asList(2, 0)), Arrays.asList(1L, Arrays.asList(1L, 2L, 3L)),
Arrays.asList(2L, Arrays.asList(1L, 2L)), Arrays.asList(3F, Arrays.asList(1F, 2F)),
Arrays.asList(0F, Arrays.asList(1F, 2F)), Arrays.asList(1D, Arrays.asList(1D, 1D)),
Arrays.asList(1D, Arrays.asList(2D, 2D)),
Arrays.asList("b", Arrays.asList("a", "c")),
Arrays.asList("b", Arrays.asList("c", "a")),
Arrays.asList("a", Arrays.asList("a", "b")),
Arrays.asList("b", Arrays.asList("a", "b")),
Arrays.asList("c", Arrays.asList("a", "b")),
Arrays.asList("a", Arrays.asList("b", "c")),
Arrays.asList("a", Arrays.asList("a", "a")),
Arrays.asList("b", Arrays.asList("a", "a")));

Stream.Builder<Arguments> builder = Stream.builder();
for (List<Object> argGroup : arguments) {
builder.add(Arguments.of(fromObjectValue(argGroup.get(0)), fromObjectValue(argGroup.get(1))));
}
builder
.add(Arguments.of(fromObjectValue("2021-01-02", DATE),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01", DATE),
fromObjectValue("2021-01-03", DATE)))))
.add(Arguments.of(fromObjectValue("2021-01-02", DATE),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01", DATE),
fromObjectValue("2021-01-03", DATE)))))
.add(Arguments.of(fromObjectValue("2021-01-01 03:00:00", DATETIME),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01 01:00:00", DATETIME),
fromObjectValue("3021-01-01 02:00:00", DATETIME)))))
.add(Arguments.of(fromObjectValue("2021-01-01 01:00:00", TIMESTAMP),
fromObjectValue(Arrays.asList(fromObjectValue("2021-01-01 01:00:00", TIMESTAMP),
fromObjectValue("3021-01-01 01:00:00", TIMESTAMP)))));
return builder.build();
}

@ParameterizedTest(name = "in({0}, ({1}))")
@MethodSource("testInArguments")
public void in(ExprValue field, ExprValue arrayOfArgs) {
FunctionExpression in = dsl.in(
DSL.literal(field), DSL.literal(arrayOfArgs));
assertEquals(BOOLEAN, in.type());
assertEquals(OperatorUtils.in(field, arrayOfArgs), in.valueOf(valueEnv()));
}

@ParameterizedTest(name = "not in({0}, ({1}))")
@MethodSource("testInArguments")
public void not_in(ExprValue field, ExprValue arrayOfArgs) {
FunctionExpression notIn = dsl.not_in(
DSL.literal(field), DSL.literal(arrayOfArgs));
assertEquals(BOOLEAN, notIn.type());
assertEquals(!OperatorUtils.in(field, arrayOfArgs).booleanValue(),
notIn.valueOf(valueEnv()).booleanValue());
}

@Test
public void in_not_an_array() {
assertThrows(ExpressionEvaluationException.class, () ->
dsl.in(DSL.literal(1), DSL.literal("1")));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.RangeQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.RangeQuery.Comparison;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.TermQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.TermsQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene.WildcardQuery;
import com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.serialization.ExpressionSerializer;
import com.amazon.opendistroforelasticsearch.sql.expression.Expression;
Expand Down Expand Up @@ -63,6 +64,7 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor<QueryBuilder, Obje
.put(BuiltinFunctionName.LTE.getName(), new RangeQuery(Comparison.LTE))
.put(BuiltinFunctionName.GTE.getName(), new RangeQuery(Comparison.GTE))
.put(BuiltinFunctionName.LIKE.getName(), new WildcardQuery())
.put(BuiltinFunctionName.IN.getName(), new TermsQuery())
.build();

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.opendistroforelasticsearch.sql.elasticsearch.storage.script.filter.lucene;

import com.amazon.opendistroforelasticsearch.sql.data.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.data.type.ExprType;
import java.util.stream.Collectors;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;


/**
* Lucene query that build terms query for equality comparison.
*/
public class TermsQuery extends LuceneQuery {

@Override
protected QueryBuilder doBuild(String fieldName, ExprType fieldType, ExprValue literal) {
fieldName = convertTextToKeyword(fieldName, fieldType);
return QueryBuilders.termsQuery(fieldName,
literal.collectionValue().stream().map(ExprValue::value)
.collect(Collectors.toList()));
}

}
Loading

0 comments on commit 4a5f712

Please sign in to comment.