Skip to content

Commit

Permalink
ESQL: Speed type error testing (#119678)
Browse files Browse the repository at this point in the history
This shaves a few minutes off of the ESQL build:
```
14m 50s -> 12m 38s
```

It does so by moving the type error testing from parameterized tests to
a single, stand alone test per scalar that checks the errors for all
unsupported types. It gets the list from the parameterized tests the
same way as we were doing. But it's *fast*.
AND, this will let us test a huge number of combinations without nearly
as much overhead as we had before.

In the worse case, unary math functions, this doesn't save any time.
Maybe .1 second per function. For binary math functions it saves a
*little* time. About a second per function.

But for non-math, multivalued functions: wow. IpPrefix is ternary and
it's test goes from 56.8 seconds to 2.6 seconds! Here are a few examples.

|        name        | before |    after     | before| after |
| -----------------: | -----: | -----------: | ----: | ----: |
|                Sin |   2.6s |         2.5s |   400 |   291 |
|              ATan2 |  17.4s |        16.1s |  8270 |  5961 |
|           IpPrefix |  56.8s |  🎉 2.6s | 40650 |   191 |
|             Equals |  69.9s |        50.6s | 30130 | 28131 |
|          NotEquals |  67.1s |        46.8s | 30100 | 28101 |
|        GreaterThan |  63.7s |        57.8s | 29940 | 27791 |
| GreaterThanOrEqual |  61.1s |        61.6s | 29940 | 27791 |
|           LessThan |  63.7s |        61.3s | 29940 | 27791 |
|    LessThanOrEqual |  61.1s |        59.8s | 29940 | 27791 |
|               Case | 115.3s | 🎉 45.1s | 63756 | 13236 |
|           DateDiff |   3.4s |         4.0s?|   507 |   271 |
|        DateExtract |  12.1s |         3.4s |  3406 |   156 |
|         DateFormat |   8.1s |         2.4s |  2849 |   100 |
|          DateParse |  10.6s |         2.8s |  2992 |   276 |
|          DateTrunc |  10.9s |         3.4s |  3320 |   790 |
|         ByteLength |   5.7s |         4.0s |   520 |   391 |
|           EndsWith |  13.7s |         7.2s |  3880 |  1411 |
|               Hash |  30.7s |        17.4s |  3980 |  1511 |
|              LTrim |  27.1s |        29.0s?|  2840 |  2711|
|             Locate |  85.3s | 🎉 10.3s | 44310 |  1461 |
|            Replace |  96.5s | 🎉 10.1s | 42010 |  1711 |
|              RTrim |  15.6s |        20.0s?|  2840 |  2711 |
|              Split |   6.6s |         4.0s |  3360 |   397 |
|         StartsWith |   5.5s |  🎉 0.7s |  2800 |   330 |
|          Substring | 115.2s |  🎉 2.7s | 85386 |   483 |
|               Trim |  17.4s |        17.8s |  2840 |  2710 |

Gradle Enterprise is also not happy with the raw *number* of tests ESQL
runs. So lowering the overall number is important. See the table above.
This strategy is *super* effective for that. It takes us
```
769459 -> 470429
```
  • Loading branch information
nik9000 authored Jan 8, 2025
1 parent 28ce53f commit f099658
Show file tree
Hide file tree
Showing 73 changed files with 1,673 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ protected static List<TestCaseSupplier> anyNullIsNull(
ExpectedType expectedType,
ExpectedEvaluatorToString evaluatorToString
) {
typesRequired(testCaseSuppliers);
List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
suppliers.addAll(testCaseSuppliers);

Expand Down Expand Up @@ -274,7 +273,7 @@ protected static List<TestCaseSupplier> anyNullIsNull(
}

@FunctionalInterface
protected interface PositionalErrorMessageSupplier {
public interface PositionalErrorMessageSupplier {
/**
* This interface defines functions to supply error messages for incorrect types in specific positions. Functions which have
* the same type requirements for all positions can simplify this with a lambda returning a string constant.
Expand All @@ -291,7 +290,9 @@ protected interface PositionalErrorMessageSupplier {
/**
* Adds test cases containing unsupported parameter types that assert
* that they throw type errors.
* @deprecated make a subclass of {@link ErrorsForCasesWithoutExamplesTestCase} instead
*/
@Deprecated
protected static List<TestCaseSupplier> errorsForCasesWithoutExamples(
List<TestCaseSupplier> testCaseSuppliers,
PositionalErrorMessageSupplier positionalErrorMessageSupplier
Expand Down Expand Up @@ -331,11 +332,14 @@ protected interface TypeErrorMessageSupplier {
String apply(boolean includeOrdinal, List<Set<DataType>> validPerPosition, List<DataType> types);
}

/**
* @deprecated make a subclass of {@link ErrorsForCasesWithoutExamplesTestCase} instead
*/
@Deprecated
protected static List<TestCaseSupplier> errorsForCasesWithoutExamples(
List<TestCaseSupplier> testCaseSuppliers,
TypeErrorMessageSupplier typeErrorMessageSupplier
) {
typesRequired(testCaseSuppliers);
List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
suppliers.addAll(testCaseSuppliers);

Expand All @@ -346,7 +350,7 @@ protected static List<TestCaseSupplier> errorsForCasesWithoutExamples(
.map(s -> s.types().size())
.collect(Collectors.toSet())
.stream()
.flatMap(count -> allPermutations(count))
.flatMap(AbstractFunctionTestCase::allPermutations)
.filter(types -> valid.contains(types) == false)
/*
* Skip any cases with more than one null. Our tests don't generate
Expand All @@ -366,10 +370,6 @@ private static List<DataType> append(List<DataType> orig, DataType extra) {
return longer;
}

protected static Stream<DataType> representable() {
return DataType.types().stream().filter(DataType::isRepresentable);
}

protected static TestCaseSupplier typeErrorSupplier(
boolean includeOrdinal,
List<Set<DataType>> validPerPosition,
Expand Down Expand Up @@ -398,7 +398,7 @@ protected static TestCaseSupplier typeErrorSupplier(
);
}

private static List<Set<DataType>> validPerPosition(Set<List<DataType>> valid) {
static List<Set<DataType>> validPerPosition(Set<List<DataType>> valid) {
int max = valid.stream().mapToInt(List::size).max().getAsInt();
List<Set<DataType>> result = new ArrayList<>(max);
for (int i = 0; i < max; i++) {
Expand Down Expand Up @@ -1327,17 +1327,6 @@ public void allMemoryReleased() {
}
}

/**
* Validate that we know the types for all the test cases already created
* @param suppliers - list of suppliers before adding in the illegal type combinations
*/
protected static void typesRequired(List<TestCaseSupplier> suppliers) {
String bad = suppliers.stream().filter(s -> s.types() == null).map(s -> s.name()).collect(Collectors.joining("\n"));
if (bad.equals("") == false) {
throw new IllegalArgumentException("types required but not found for these tests:\n" + bad);
}
}

/**
* Returns true if the current test case is for an aggregation function.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
* </p>
*
* @param entirelyNullPreservesType See {@link #anyNullIsNull(boolean, List)}
* @deprecated use {@link #parameterSuppliersFromTypedDataWithDefaultChecksNoErrors}
* and make a subclass of {@link ErrorsForCasesWithoutExamplesTestCase}.
* It's a <strong>long</strong> faster.
*/
@Deprecated
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecks(
boolean entirelyNullPreservesType,
List<TestCaseSupplier> suppliers,
Expand All @@ -72,6 +76,23 @@ protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultCh
);
}

/**
* Converts a list of test cases into a list of parameter suppliers.
* Also, adds a default set of extra test cases.
* <p>
* Use if possible, as this method may get updated with new checks in the future.
* </p>
*
* @param entirelyNullPreservesType See {@link #anyNullIsNull(boolean, List)}
*/
protected static Iterable<Object[]> parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(
// TODO remove after removing parameterSuppliersFromTypedDataWithDefaultChecks rename this to that.
boolean entirelyNullPreservesType,
List<TestCaseSupplier> suppliers
) {
return parameterSuppliersFromTypedData(anyNullIsNull(entirelyNullPreservesType, randomizeBytesRefsOffset(suppliers)));
}

/**
* Converts a list of test cases into a list of parameter suppliers.
* Also, adds a default set of extra test cases.
Expand Down Expand Up @@ -364,43 +385,10 @@ public void testFold() {
}
}

public static String errorMessageStringForBinaryOperators(
boolean includeOrdinal,
List<Set<DataType>> validPerPosition,
List<DataType> types,
PositionalErrorMessageSupplier positionalErrorMessageSupplier
) {
try {
return typeErrorMessage(includeOrdinal, validPerPosition, types, positionalErrorMessageSupplier);
} catch (IllegalStateException e) {
// This means all the positional args were okay, so the expected error is from the combination
if (types.get(0).equals(DataType.UNSIGNED_LONG)) {
return "first argument of [] is [unsigned_long] and second is ["
+ types.get(1).typeName()
+ "]. [unsigned_long] can only be operated on together with another [unsigned_long]";

}
if (types.get(1).equals(DataType.UNSIGNED_LONG)) {
return "first argument of [] is ["
+ types.get(0).typeName()
+ "] and second is [unsigned_long]. [unsigned_long] can only be operated on together with another [unsigned_long]";
}
return "first argument of [] is ["
+ (types.get(0).isNumeric() ? "numeric" : types.get(0).typeName())
+ "] so second argument must also be ["
+ (types.get(0).isNumeric() ? "numeric" : types.get(0).typeName())
+ "] but was ["
+ types.get(1).typeName()
+ "]";

}
}

/**
* Adds test cases containing unsupported parameter types that immediately fail.
*/
protected static List<TestCaseSupplier> failureForCasesWithoutExamples(List<TestCaseSupplier> testCaseSuppliers) {
typesRequired(testCaseSuppliers);
List<TestCaseSupplier> suppliers = new ArrayList<>(testCaseSuppliers.size());
suppliers.addAll(testCaseSuppliers);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.hamcrest.Matcher;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
import static org.hamcrest.Matchers.greaterThan;

public abstract class ErrorsForCasesWithoutExamplesTestCase extends ESTestCase {
protected abstract List<TestCaseSupplier> cases();

/**
* Build the expression being tested, for the given source and list of arguments. Test classes need to implement this
* to have something to test.
*
* @param source the source
* @param args arg list from the test case, should match the length expected
* @return an expression for evaluating the function being tested on the given arguments
*/
protected abstract Expression build(Source source, List<Expression> args);

protected abstract Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature);

protected final List<TestCaseSupplier> paramsToSuppliers(Iterable<Object[]> cases) {
List<TestCaseSupplier> result = new ArrayList<>();
for (Object[] c : cases) {
if (c.length != 1) {
throw new IllegalArgumentException("weird layout for test cases");
}
TestCaseSupplier supplier = (TestCaseSupplier) c[0];
result.add(supplier);
}
return result;
}

public final void test() {
int checked = 0;
List<TestCaseSupplier> cases = cases();
Set<List<DataType>> valid = cases.stream().map(TestCaseSupplier::types).collect(Collectors.toSet());
List<Set<DataType>> validPerPosition = AbstractFunctionTestCase.validPerPosition(valid);
Iterable<List<DataType>> missingSignatures = missingSignatures(cases, valid)::iterator;
for (List<DataType> signature : missingSignatures) {
logger.debug("checking {}", signature);
List<Expression> args = new ArrayList<>(signature.size());
for (DataType type : signature) {
args.add(randomLiteral(type));
}
Expression expression = build(Source.synthetic(sourceForSignature(signature)), args);
assertTrue("expected unresolved " + expression, expression.typeResolved().unresolved());
assertThat(expression.typeResolved().message(), expectedTypeErrorMatcher(validPerPosition, signature));
checked++;
}
logger.info("checked {} signatures", checked);
assertThat("didn't check any signatures", checked, greaterThan(0));
}

private Stream<List<DataType>> missingSignatures(List<TestCaseSupplier> cases, Set<List<DataType>> valid) {
return cases.stream()
.map(s -> s.types().size())
.collect(Collectors.toSet())
.stream()
.flatMap(AbstractFunctionTestCase::allPermutations)
.filter(types -> valid.contains(types) == false)
/*
* Skip any cases with more than one null. Our tests don't generate
* the full combinatorial explosions of all nulls - just a single null.
* Hopefully <null>, <null> cases will function the same as <null>, <valid>
* cases.
*/
.filter(types -> types.stream().filter(t -> t == DataType.NULL).count() <= 1);
}

protected static String sourceForSignature(List<DataType> signature) {
StringBuilder source = new StringBuilder();
for (DataType type : signature) {
if (false == source.isEmpty()) {
source.append(", ");
}
source.append(type.typeName());
}
return source.toString();
}

/**
* Build the expected error message for an invalid type signature.
*/
protected static String typeErrorMessage(
boolean includeOrdinal,
List<Set<DataType>> validPerPosition,
List<DataType> signature,
AbstractFunctionTestCase.PositionalErrorMessageSupplier expectedTypeSupplier
) {
int badArgPosition = -1;
for (int i = 0; i < signature.size(); i++) {
if (validPerPosition.get(i).contains(signature.get(i)) == false) {
badArgPosition = i;
break;
}
}
if (badArgPosition == -1) {
throw new IllegalStateException(
"Can't generate error message for these types, you probably need a custom error message function"
);
}
String ordinal = includeOrdinal ? TypeResolutions.ParamOrdinal.fromIndex(badArgPosition).name().toLowerCase(Locale.ROOT) + " " : "";
String source = sourceForSignature(signature);
String expectedTypeString = expectedTypeSupplier.apply(validPerPosition.get(badArgPosition), badArgPosition);
String name = signature.get(badArgPosition).typeName();
return ordinal + "argument of [" + source + "] must be [" + expectedTypeString + "], found value [] type [" + name + "]";
}

protected static String errorMessageStringForBinaryOperators(
List<Set<DataType>> validPerPosition,
List<DataType> signature,
AbstractFunctionTestCase.PositionalErrorMessageSupplier positionalErrorMessageSupplier
) {
try {
return typeErrorMessage(true, validPerPosition, signature, positionalErrorMessageSupplier);
} catch (IllegalStateException e) {
String source = sourceForSignature(signature);
// This means all the positional args were okay, so the expected error is from the combination
if (signature.get(0).equals(DataType.UNSIGNED_LONG)) {
return "first argument of ["
+ source
+ "] is [unsigned_long] and second is ["
+ signature.get(1).typeName()
+ "]. [unsigned_long] can only be operated on together with another [unsigned_long]";

}
if (signature.get(1).equals(DataType.UNSIGNED_LONG)) {
return "first argument of ["
+ source
+ "] is ["
+ signature.get(0).typeName()
+ "] and second is [unsigned_long]. [unsigned_long] can only be operated on together with another [unsigned_long]";
}
return "first argument of ["
+ source
+ "] is ["
+ (signature.get(0).isNumeric() ? "numeric" : signature.get(0).typeName())
+ "] so second argument must also be ["
+ (signature.get(0).isNumeric() ? "numeric" : signature.get(0).typeName())
+ "] but was ["
+ signature.get(1).typeName()
+ "]";

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.expression.function.grouping;

import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.ErrorsForCasesWithoutExamplesTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.hamcrest.Matcher;

import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.equalTo;

public class CategorizeErrorTests extends ErrorsForCasesWithoutExamplesTestCase {
@Override
protected List<TestCaseSupplier> cases() {
return paramsToSuppliers(CategorizeTests.parameters());
}

@Override
protected Expression build(Source source, List<Expression> args) {
return new Categorize(source, args.get(0));
}

@Override
protected Matcher<String> expectedTypeErrorMatcher(List<Set<DataType>> validPerPosition, List<DataType> signature) {
return equalTo(typeErrorMessage(false, validPerPosition, signature, (v, p) -> "string"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static Iterable<Object[]> parameters() {
)
);
}
return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "string");
return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
}

@Override
Expand Down
Loading

0 comments on commit f099658

Please sign in to comment.