Skip to content

Commit

Permalink
Update test-cases
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Feb 13, 2025
1 parent b87db9f commit 39bf113
Showing 1 changed file with 88 additions and 235 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,20 @@
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.data.model.ExprNullValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
Expand All @@ -43,6 +49,47 @@
public class TrendlineOperatorTest extends PhysicalPlanTestBase {
@Mock private PhysicalPlan inputPlan;

static Stream<Arguments> supportedDataTypes() {
return Stream.of(SMA, WMA)
.flatMap(
trendlineType ->
Stream.of(
Arguments.of(trendlineType, ExprCoreType.SHORT),
Arguments.of(trendlineType, ExprCoreType.INTEGER),
Arguments.of(trendlineType, ExprCoreType.LONG),
Arguments.of(trendlineType, ExprCoreType.FLOAT),
Arguments.of(trendlineType, ExprCoreType.DOUBLE)));
}

static Stream<Arguments> invalidArguments() {
return Stream.of(SMA, WMA)
.flatMap(
trendlineType ->
Stream.of(
// WMA
Arguments.of(
2,
AstDSL.field("distance"),
"distance_alias",
trendlineType,
ExprCoreType.ARRAY,
"DateType - Array"),
Arguments.of(
-100,
AstDSL.field("distance"),
"distance_alias",
trendlineType,
ExprCoreType.INTEGER,
"DataPoints - Negative"),
Arguments.of(
0,
AstDSL.field("distance"),
"distance_alias",
trendlineType,
ExprCoreType.INTEGER,
"DataPoints - zero")));
}

@Test
public void calculates_simple_moving_average_one_field_one_sample() {
mockPlanWithData(List.of(tupleValue(ImmutableMap.of("distance", 100, "time", 10))));
Expand Down Expand Up @@ -112,84 +159,6 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows()
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))));
}

@Test
public void calculates_simple_moving_average_data_type_support_short() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA),
ExprCoreType.SHORT)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))));
}

@Test
public void calculates_simple_moving_average_data_type_support_long() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA),
ExprCoreType.SHORT)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))));
}

@Test
public void calculates_simple_moving_average_data_type_support_float() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA),
ExprCoreType.FLOAT)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))));
}

@Test
public void calculates_simple_moving_average_multiple_computations() {
mockPlanWithData(
Expand Down Expand Up @@ -301,20 +270,6 @@ public void use_null_value() {
tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100))));
}

@Test
public void use_illegal_core_type() {
assertThrows(
SemanticCheckException.class,
() -> {
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA),
ExprCoreType.ARRAY)));
});
}

@Test
public void calculates_simple_moving_average_date() {
mockPlanWithData(
Expand Down Expand Up @@ -503,122 +458,6 @@ public void calculates_weighted_moving_average_one_field_two_samples_three_rows(
"distance", 200, "time", 10, "distance_alias", 199.99999999999997))));
}

@Test
public void calculates_weighted_moving_average_data_type_support_short() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.SHORT)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
tupleValue(
ImmutableMap.of(
"distance", 200, "time", 10, "distance_alias", 199.99999999999997))));
}

@Test
public void calculates_weighted_moving_average_data_type_support_integer() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.INTEGER)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
tupleValue(
ImmutableMap.of(
"distance", 200, "time", 10, "distance_alias", 199.99999999999997))));
}

@Test
public void calculates_weighted_moving_average_data_type_support_long() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.LONG)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
tupleValue(
ImmutableMap.of(
"distance", 200, "time", 10, "distance_alias", 199.99999999999997))));
}

@Test
public void calculates_weighted_moving_average_data_type_support_float() {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.FLOAT)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
tupleValue(
ImmutableMap.of(
"distance", 200, "time", 10, "distance_alias", 199.99999999999997))));
}

@Test
public void calculates_weighted_moving_average_multiple_computations() {
mockPlanWithData(
Expand Down Expand Up @@ -807,43 +646,57 @@ public void calculates_weighted_moving_average_timestamp() {
Instant.EPOCH.plusMillis(1333)))));
}

@Test
public void use_illegal_core_type_wma() {
assertThrows(
SemanticCheckException.class,
() ->
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.ARRAY))));
}
@ParameterizedTest
@MethodSource("supportedDataTypes")
public void trendLine_dataType_support(
Trendline.TrendlineType trendlineType, ExprCoreType supportedType) {
mockPlanWithData(
List.of(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10)),
tupleValue(ImmutableMap.of("distance", 200, "time", 10))));

@Test
public void use_invalid_dataPoints_zero() {
assertThrows(
SemanticCheckException.class,
() ->
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(0, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.INTEGER))));
var plan =
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", WMA),
supportedType)));

List<ExprValue> result = execute(plan);
assertEquals(3, result.size());
assertThat(
String.format(
"Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()),
result,
containsInAnyOrder(
tupleValue(ImmutableMap.of("distance", 100, "time", 10)),
tupleValue(
ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 166.66666666666663)),
tupleValue(
ImmutableMap.of(
"distance", 200, "time", 10, "distance_alias", 199.99999999999997))));
}

@Test
public void use_invalid_dataPoints_negative() {
@ParameterizedTest
@MethodSource("invalidArguments")
public void use_invalid_configuration(
Integer dataPoints,
Field field,
String alias,
Trendline.TrendlineType trendlineType,
ExprCoreType dataType,
String errorMessage) {
assertThrows(
SemanticCheckException.class,
() ->
new TrendlineOperator(
inputPlan,
Collections.singletonList(
Pair.of(
AstDSL.computation(-100, AstDSL.field("distance"), "distance_alias", WMA),
ExprCoreType.INTEGER))));
AstDSL.computation(dataPoints, field, alias, trendlineType), dataType))),
"Unsupported arguments: " + errorMessage);
}

private void mockPlanWithData(List<ExprValue> inputs) {
Expand Down

0 comments on commit 39bf113

Please sign in to comment.