+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#dateCases}.
+ *
+ * Helper function for if you want to specify your min and max range as dates instead of longs.
+ */
+ public static List dateCases(Instant min, Instant max) {
+ return dateCases(min.toEpochMilli(), max.toEpochMilli());
+ }
+
/**
* Generate cases for {@link DataType#DATETIME}.
*
@@ -1045,6 +1056,19 @@ public static List dateCases(long min, long max) {
return cases;
}
+ /**
+ *
+ * @return randomized valid date formats
+ */
+ public static List dateFormatCases() {
+ return List.of(
+ new TypedDataSupplier("", () -> new BytesRef(ESTestCase.randomDateFormatterPattern()), DataType.KEYWORD),
+ new TypedDataSupplier("", () -> new BytesRef(ESTestCase.randomDateFormatterPattern()), DataType.TEXT),
+ new TypedDataSupplier("", () -> new BytesRef("yyyy"), DataType.KEYWORD),
+ new TypedDataSupplier("", () -> new BytesRef("yyyy"), DataType.TEXT)
+ );
+ }
+
/**
* Generate cases for {@link DataType#DATE_NANOS}.
*
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatErrorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatErrorTests.java
index 985f1144fbcf2..a5e6514b3e02c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatErrorTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatErrorTests.java
@@ -28,11 +28,22 @@ protected List cases() {
@Override
protected Expression build(Source source, List args) {
- return new DateFormat(source, args.get(0), args.get(1), EsqlTestUtils.TEST_CFG);
+ return new DateFormat(source, args.get(0), args.size() == 2 ? args.get(1) : null, EsqlTestUtils.TEST_CFG);
}
@Override
protected Matcher expectedTypeErrorMatcher(List> validPerPosition, List signature) {
+ // Single argument version
+ String source = sourceForSignature(signature);
+ String name = signature.get(0).typeName();
+ if (signature.size() == 1) {
+ return equalTo("first argument of [" + source + "] must be [datetime], found value [] type [" + name + "]");
+ }
+ // Two argument version
+ // Handle the weird case where we're calling the two argument version with the date first instead of the format.
+ if (signature.get(0).isDate()) {
+ return equalTo("first argument of [" + source + "] must be [string], found value [] type [" + name + "]");
+ }
return equalTo(typeErrorMessage(true, validPerPosition, signature, (v, p) -> switch (p) {
case 0 -> "string";
case 1 -> "datetime";
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatTests.java
index 8dfdd1ba486c7..3dd1f3e629da4 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormatTests.java
@@ -11,18 +11,21 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.common.lucene.BytesRefs;
+import org.elasticsearch.common.time.DateFormatter;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.elasticsearch.xpack.esql.expression.function.scalar.AbstractConfigurationFunctionTestCase;
import org.elasticsearch.xpack.esql.session.Configuration;
+import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter;
+import java.time.Instant;
+import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
-import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.matchesPattern;
public class DateFormatTests extends AbstractConfigurationFunctionTestCase {
public DateFormatTests(@Name("TestCase") Supplier testCaseSupplier) {
@@ -31,39 +34,35 @@ public DateFormatTests(@Name("TestCase") Supplier tes
@ParametersFactory
public static Iterable parameters() {
- return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(
- true,
- List.of(
- new TestCaseSupplier(
- List.of(DataType.KEYWORD, DataType.DATETIME),
- () -> new TestCaseSupplier.TestCase(
- List.of(
- new TestCaseSupplier.TypedData(new BytesRef("yyyy"), DataType.KEYWORD, "formatter"),
- new TestCaseSupplier.TypedData(1687944333000L, DataType.DATETIME, "val")
- ),
- "DateFormatEvaluator[val=Attribute[channel=1], formatter=Attribute[channel=0], locale=en_US]",
- DataType.KEYWORD,
- equalTo(BytesRefs.toBytesRef("2023"))
- )
+ List suppliers = new ArrayList<>();
+ // Formatter supplied cases
+ suppliers.addAll(
+ TestCaseSupplier.forBinaryNotCasting(
+ (format, value) -> new BytesRef(
+ DateFormatter.forPattern(((BytesRef) format).utf8ToString()).formatMillis(((Instant) value).toEpochMilli())
),
- new TestCaseSupplier(
- List.of(DataType.TEXT, DataType.DATETIME),
- () -> new TestCaseSupplier.TestCase(
- List.of(
- new TestCaseSupplier.TypedData(new BytesRef("yyyy"), DataType.TEXT, "formatter"),
- new TestCaseSupplier.TypedData(1687944333000L, DataType.DATETIME, "val")
- ),
- "DateFormatEvaluator[val=Attribute[channel=1], formatter=Attribute[channel=0], locale=en_US]",
- DataType.KEYWORD,
- equalTo(BytesRefs.toBytesRef("2023"))
- )
- )
+ DataType.KEYWORD,
+ TestCaseSupplier.dateFormatCases(),
+ TestCaseSupplier.dateCases(Instant.parse("1900-01-01T00:00:00.00Z"), Instant.parse("9999-12-31T00:00:00.00Z")),
+ matchesPattern("DateFormatEvaluator\\[val=Attribute\\[channel=1], formatter=Attribute\\[(channel=0|\\w+)], locale=en_US]"),
+ (lhs, rhs) -> List.of(),
+ false
)
);
+ // Default formatter cases
+ TestCaseSupplier.unary(
+ suppliers,
+ "DateFormatConstantEvaluator[val=Attribute[channel=0], formatter=format[strict_date_optional_time] locale[]]",
+ TestCaseSupplier.dateCases(Instant.parse("1900-01-01T00:00:00.00Z"), Instant.parse("9999-12-31T00:00:00.00Z")),
+ DataType.KEYWORD,
+ (value) -> new BytesRef(EsqlDataTypeConverter.DEFAULT_DATE_TIME_FORMATTER.formatMillis(((Instant) value).toEpochMilli())),
+ List.of()
+ );
+ return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
}
@Override
protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) {
- return new DateFormat(source, args.get(0), args.get(1), configuration);
+ return new DateFormat(source, args.get(0), args.size() == 2 ? args.get(1) : null, configuration);
}
}
From 838a41a8391f1d579a1c8e77c8630a42cddcd087 Mon Sep 17 00:00:00 2001
From: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com>
Date: Mon, 13 Jan 2025 09:48:23 -0500
Subject: [PATCH 09/44] [ML] Adding docs for the unified inference API
(#118696)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Including examples
* Using js instead of json
* Adding unified docs to main page
* Adding missing description text
* Refactoring to remove unified route
* Addign back references to the _unified route
* Update docs/reference/inference/chat-completion-inference.asciidoc
Co-authored-by: István Zoltán Szabó
* Address feedback
---------
Co-authored-by: István Zoltán Szabó
---
.../chat-completion-inference.asciidoc | 417 ++++++++++++++++++
.../inference/inference-apis.asciidoc | 10 +-
.../inference/inference-shared.asciidoc | 46 +-
.../inference/put-inference.asciidoc | 8 +-
.../inference/service-openai.asciidoc | 12 +-
.../inference/stream-inference.asciidoc | 6 +-
6 files changed, 487 insertions(+), 12 deletions(-)
create mode 100644 docs/reference/inference/chat-completion-inference.asciidoc
diff --git a/docs/reference/inference/chat-completion-inference.asciidoc b/docs/reference/inference/chat-completion-inference.asciidoc
new file mode 100644
index 0000000000000..83a8f94634f2f
--- /dev/null
+++ b/docs/reference/inference/chat-completion-inference.asciidoc
@@ -0,0 +1,417 @@
+[role="xpack"]
+[[chat-completion-inference-api]]
+=== Chat completion inference API
+
+Streams a chat completion response.
+
+IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI, Anthropic, Watsonx.ai, or Hugging Face.
+For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models.
+However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>.
+
+
+[discrete]
+[[chat-completion-inference-api-request]]
+==== {api-request-title}
+
+`POST /_inference//_unified`
+
+`POST /_inference/chat_completion//_unified`
+
+
+[discrete]
+[[chat-completion-inference-api-prereqs]]
+==== {api-prereq-title}
+
+* Requires the `monitor_inference` <>
+(the built-in `inference_admin` and `inference_user` roles grant this privilege)
+* You must use a client that supports streaming.
+
+
+[discrete]
+[[chat-completion-inference-api-desc]]
+==== {api-description-title}
+
+The chat completion {infer} API enables real-time responses for chat completion tasks by delivering answers incrementally, reducing response times during computation.
+It only works with the `chat_completion` task type for `openai` and `elastic` {infer} services.
+
+[NOTE]
+====
+The `chat_completion` task type is only available within the _unified API and only supports streaming.
+====
+
+[discrete]
+[[chat-completion-inference-api-path-params]]
+==== {api-path-parms-title}
+
+``::
+(Required, string)
+The unique identifier of the {infer} endpoint.
+
+
+``::
+(Optional, string)
+The type of {infer} task that the model performs. If included, this must be set to the value `chat_completion`.
+
+
+[discrete]
+[[chat-completion-inference-api-request-body]]
+==== {api-request-body-title}
+
+`messages`::
+(Required, array of objects) A list of objects representing the conversation.
+Requests should generally only add new messages from the user (role `user`). The other message roles (`assistant`, `system`, or `tool`) should generally only be copied from the response to a previous completion request, such that the messages array is built up throughout a conversation.
++
+.Assistant message
+[%collapsible%closed]
+=====
+`content`::
+(Required unless `tool_calls` is specified, string or array of objects)
+The contents of the message.
++
+include::inference-shared.asciidoc[tag=chat-completion-schema-content-with-examples]
++
+`role`::
+(Required, string)
+The role of the message author. This should be set to `assistant` for this type of message.
++
+`tool_calls`::
+(Optional, array of objects)
+The tool calls generated by the model.
++
+.Examples
+[%collapsible%closed]
+======
+[source,js]
+------------------------------------------------------------
+{
+ "tool_calls": [
+ {
+ "id": "call_KcAjWtAww20AihPHphUh46Gd",
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "arguments": "{\"location\":\"Boston, MA\"}"
+ }
+ }
+ ]
+}
+------------------------------------------------------------
+// NOTCONSOLE
+======
++
+`id`:::
+(Required, string)
+The identifier of the tool call.
++
+`type`:::
+(Required, string)
+The type of tool call. This must be set to the value `function`.
++
+`function`:::
+(Required, object)
+The function that the model called.
++
+`name`::::
+(Required, string)
+The name of the function to call.
++
+`arguments`::::
+(Required, string)
+The arguments to call the function with in JSON format.
+=====
++
+.System message
+[%collapsible%closed]
+=====
+`content`:::
+(Required, string or array of objects)
+The contents of the message.
++
+include::inference-shared.asciidoc[tag=chat-completion-schema-content-with-examples]
++
+`role`:::
+(Required, string)
+The role of the message author. This should be set to `system` for this type of message.
+=====
++
+.Tool message
+[%collapsible%closed]
+=====
+`content`::
+(Required, string or array of objects)
+The contents of the message.
++
+include::inference-shared.asciidoc[tag=chat-completion-schema-content-with-examples]
++
+`role`::
+(Required, string)
+The role of the message author. This should be set to `tool` for this type of message.
++
+`tool_call_id`::
+(Required, string)
+The tool call that this message is responding to.
+=====
++
+.User message
+[%collapsible%closed]
+=====
+`content`::
+(Required, string or array of objects)
+The contents of the message.
++
+include::inference-shared.asciidoc[tag=chat-completion-schema-content-with-examples]
++
+`role`::
+(Required, string)
+The role of the message author. This should be set to `user` for this type of message.
+=====
+
+`model`::
+(Optional, string)
+The ID of the model to use. By default, the model ID is set to the value included when creating the inference endpoint.
+
+`max_completion_tokens`::
+(Optional, integer)
+The upper bound limit for the number of tokens that can be generated for a completion request.
+
+`stop`::
+(Optional, array of strings)
+A sequence of strings to control when the model should stop generating additional tokens.
+
+`temperature`::
+(Optional, float)
+The sampling temperature to use.
+
+`tools`::
+(Optional, array of objects)
+A list of tools that the model can call.
++
+.Structure
+[%collapsible%closed]
+=====
+`type`::
+(Required, string)
+The type of tool, must be set to the value `function`.
++
+`function`::
+(Required, object)
+The function definition.
++
+`description`:::
+(Optional, string)
+A description of what the function does. This is used by the model to choose when and how to call the function.
++
+`name`:::
+(Required, string)
+The name of the function.
++
+`parameters`:::
+(Optional, object)
+The parameters the functional accepts. This should be formatted as a JSON object.
++
+`strict`:::
+(Optional, boolean)
+Whether to enable schema adherence when generating the function call.
+=====
++
+.Examples
+[%collapsible%closed]
+======
+[source,js]
+------------------------------------------------------------
+{
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_price_of_item",
+ "description": "Get the current price of an item",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "item": {
+ "id": "12345"
+ },
+ "unit": {
+ "type": "currency"
+ }
+ }
+ }
+ }
+ }
+ ]
+}
+------------------------------------------------------------
+// NOTCONSOLE
+======
+
+`tool_choice`::
+(Optional, string or object)
+Controls which tool is called by the model.
++
+String representation:::
+One of `auto`, `none`, or `requrired`. `auto` allows the model to choose between calling tools and generating a message. `none` causes the model to not call any tools. `required` forces the model to call one or more tools.
++
+Object representation:::
++
+.Structure
+[%collapsible%closed]
+=====
+`type`::
+(Required, string)
+The type of the tool. This must be set to the value `function`.
++
+`function`::
+(Required, object)
++
+`name`:::
+(Required, string)
+The name of the function to call.
+=====
++
+.Examples
+[%collapsible%closed]
+=====
+[source,js]
+------------------------------------------------------------
+{
+ "tool_choice": {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather"
+ }
+ }
+}
+------------------------------------------------------------
+// NOTCONSOLE
+=====
+
+`top_p`::
+(Optional, float)
+Nucleus sampling, an alternative to sampling with temperature.
+
+[discrete]
+[[chat-completion-inference-api-example]]
+==== {api-examples-title}
+
+The following example performs a chat completion on the example question with streaming.
+
+
+[source,console]
+------------------------------------------------------------
+POST _inference/chat_completion/openai-completion/_stream
+{
+ "model": "gpt-4o",
+ "messages": [
+ {
+ "role": "user",
+ "content": "What is Elastic?"
+ }
+ ]
+}
+------------------------------------------------------------
+// TEST[skip:TBD]
+
+The following example performs a chat completion using an Assistant message with `tool_calls`.
+
+[source,console]
+------------------------------------------------------------
+POST _inference/chat_completion/openai-completion/_stream
+{
+ "messages": [
+ {
+ "role": "assistant",
+ "content": "Let's find out what the weather is",
+ "tool_calls": [ <1>
+ {
+ "id": "call_KcAjWtAww20AihPHphUh46Gd",
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "arguments": "{\"location\":\"Boston, MA\"}"
+ }
+ }
+ ]
+ },
+ { <2>
+ "role": "tool",
+ "content": "The weather is cold",
+ "tool_call_id": "call_KcAjWtAww20AihPHphUh46Gd"
+ }
+ ]
+}
+------------------------------------------------------------
+// TEST[skip:TBD]
+
+<1> Each tool call needs a corresponding Tool message.
+<2> The corresponding Tool message.
+
+The following example performs a chat completion using a User message with `tools` and `tool_choice`.
+
+[source,console]
+------------------------------------------------------------
+POST _inference/chat_completion/openai-completion/_stream
+{
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "What's the price of a scarf?"
+ }
+ ]
+ }
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_price",
+ "description": "Get the current price of a item",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "item": {
+ "id": "123"
+ }
+ }
+ }
+ }
+ }
+ ],
+ "tool_choice": {
+ "type": "function",
+ "function": {
+ "name": "get_current_price"
+ }
+ }
+}
+------------------------------------------------------------
+// TEST[skip:TBD]
+
+The API returns the following response when a request is made to the OpenAI service:
+
+
+[source,txt]
+------------------------------------------------------------
+event: message
+data: {"chat_completion":{"id":"chatcmpl-Ae0TWsy2VPnSfBbv5UztnSdYUMFP3","choices":[{"delta":{"content":"","role":"assistant"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}}
+
+event: message
+data: {"chat_completion":{"id":"chatcmpl-Ae0TWsy2VPnSfBbv5UztnSdYUMFP3","choices":[{"delta":{"content":Elastic"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}}
+
+event: message
+data: {"chat_completion":{"id":"chatcmpl-Ae0TWsy2VPnSfBbv5UztnSdYUMFP3","choices":[{"delta":{"content":" is"},"index":0}],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk"}}
+
+(...)
+
+event: message
+data: {"chat_completion":{"id":"chatcmpl-Ae0TWsy2VPnSfBbv5UztnSdYUMFP3","choices":[],"model":"gpt-4o-2024-08-06","object":"chat.completion.chunk","usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}} <1>
+
+event: message
+data: [DONE]
+------------------------------------------------------------
+// NOTCONSOLE
+
+<1> The last object message of the stream contains the token usage information.
diff --git a/docs/reference/inference/inference-apis.asciidoc b/docs/reference/inference/inference-apis.asciidoc
index ca273afc478ea..4f27409973ca2 100644
--- a/docs/reference/inference/inference-apis.asciidoc
+++ b/docs/reference/inference/inference-apis.asciidoc
@@ -26,6 +26,7 @@ the following APIs to manage {infer} models and perform {infer}:
* <>
* <>
* <>
+* <>
* <>
[[inference-landscape]]
@@ -34,9 +35,9 @@ image::images/inference-landscape.jpg[A representation of the Elastic inference
An {infer} endpoint enables you to use the corresponding {ml} model without
manual deployment and apply it to your data at ingestion time through
-<>.
+<>.
-Choose a model from your provider or use ELSER – a retrieval model trained by
+Choose a model from your provider or use ELSER – a retrieval model trained by
Elastic –, then create an {infer} endpoint by the <>.
Now use <> to perform
<> on your data.
@@ -67,7 +68,7 @@ The following list contains the default {infer} endpoints listed by `inference_i
Use the `inference_id` of the endpoint in a <> field definition or when creating an <>.
The API call will automatically download and deploy the model which might take a couple of minutes.
Default {infer} enpoints have {ml-docs}/ml-nlp-auto-scale.html#nlp-model-adaptive-allocations[adaptive allocations] enabled.
-For these models, the minimum number of allocations is `0`.
+For these models, the minimum number of allocations is `0`.
If there is no {infer} activity that uses the endpoint, the number of allocations will scale down to `0` automatically after 15 minutes.
@@ -84,7 +85,7 @@ Returning a long document in search results is less useful than providing the mo
Each chunk will include the text subpassage and the corresponding embedding generated from it.
By default, documents are split into sentences and grouped in sections up to 250 words with 1 sentence overlap so that each chunk shares a sentence with the previous chunk.
-Overlapping ensures continuity and prevents vital contextual information in the input text from being lost by a hard break.
+Overlapping ensures continuity and prevents vital contextual information in the input text from being lost by a hard break.
{es} uses the https://unicode-org.github.io/icu-docs/[ICU4J] library to detect word and sentence boundaries for chunking.
https://unicode-org.github.io/icu/userguide/boundaryanalysis/#word-boundary[Word boundaries] are identified by following a series of rules, not just the presence of a whitespace character.
@@ -135,6 +136,7 @@ PUT _inference/sparse_embedding/small_chunk_size
include::delete-inference.asciidoc[]
include::get-inference.asciidoc[]
include::post-inference.asciidoc[]
+include::chat-completion-inference.asciidoc[]
include::put-inference.asciidoc[]
include::stream-inference.asciidoc[]
include::update-inference.asciidoc[]
diff --git a/docs/reference/inference/inference-shared.asciidoc b/docs/reference/inference/inference-shared.asciidoc
index da497c6581e5d..b133c54082810 100644
--- a/docs/reference/inference/inference-shared.asciidoc
+++ b/docs/reference/inference/inference-shared.asciidoc
@@ -41,7 +41,7 @@ end::chunking-settings[]
tag::chunking-settings-max-chunking-size[]
Specifies the maximum size of a chunk in words.
Defaults to `250`.
-This value cannot be higher than `300` or lower than `20` (for `sentence` strategy) or `10` (for `word` strategy).
+This value cannot be higher than `300` or lower than `20` (for `sentence` strategy) or `10` (for `word` strategy).
end::chunking-settings-max-chunking-size[]
tag::chunking-settings-overlap[]
@@ -63,4 +63,48 @@ Specifies the chunking strategy.
It could be either `sentence` or `word`.
end::chunking-settings-strategy[]
+tag::chat-completion-schema-content-with-examples[]
+.Examples
+[%collapsible%closed]
+======
+String example
+[source,js]
+------------------------------------------------------------
+{
+ "content": "Some string"
+}
+------------------------------------------------------------
+// NOTCONSOLE
+
+Object example
+[source,js]
+------------------------------------------------------------
+{
+ "content": [
+ {
+ "text": "Some text",
+ "type": "text"
+ }
+ ]
+}
+------------------------------------------------------------
+// NOTCONSOLE
+======
+
+String representation:::
+(Required, string)
+The text content.
++
+Object representation:::
+`text`::::
+(Required, string)
+The text content.
++
+`type`::::
+(Required, string)
+This must be set to the value `text`.
+end::chat-completion-schema-content-with-examples[]
+tag::chat-completion-docs[]
+For more information on how to use the `chat_completion` task type, please refer to the <>.
+end::chat-completion-docs[]
diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc
index f0c15323863d7..da07d1d3e7d84 100644
--- a/docs/reference/inference/put-inference.asciidoc
+++ b/docs/reference/inference/put-inference.asciidoc
@@ -42,7 +42,7 @@ include::inference-shared.asciidoc[tag=inference-id]
include::inference-shared.asciidoc[tag=task-type]
+
--
-Refer to the service list in the <> for the available task types.
+Refer to the service list in the <> for the available task types.
--
@@ -61,7 +61,7 @@ The create {infer} API enables you to create an {infer} endpoint and configure a
The following services are available through the {infer} API.
-You can find the available task types next to the service name.
+You can find the available task types next to the service name.
Click the links to review the configuration details of the services:
* <> (`completion`, `rerank`, `sparse_embedding`, `text_embedding`)
@@ -73,10 +73,10 @@ Click the links to review the configuration details of the services:
* <> (`rerank`, `sparse_embedding`, `text_embedding` - this service is for built-in models and models uploaded through Eland)
* <> (`sparse_embedding`)
* <> (`completion`, `text_embedding`)
-* <> (`rerank`, `text_embedding`)
+* <> (`rerank`, `text_embedding`)
* <> (`text_embedding`)
* <> (`text_embedding`)
-* <> (`completion`, `text_embedding`)
+* <> (`chat_completion`, `completion`, `text_embedding`)
* <> (`text_embedding`)
* <> (`text_embedding`, `rerank`)
diff --git a/docs/reference/inference/service-openai.asciidoc b/docs/reference/inference/service-openai.asciidoc
index e4be7f18e09dd..590f280b1c494 100644
--- a/docs/reference/inference/service-openai.asciidoc
+++ b/docs/reference/inference/service-openai.asciidoc
@@ -31,10 +31,18 @@ include::inference-shared.asciidoc[tag=task-type]
--
Available task types:
+* `chat_completion`,
* `completion`,
* `text_embedding`.
--
+[NOTE]
+====
+The `chat_completion` task type only supports streaming and only through the `_unified` API.
+
+include::inference-shared.asciidoc[tag=chat-completion-docs]
+====
+
[discrete]
[[infer-service-openai-api-request-body]]
==== {api-request-body-title}
@@ -61,7 +69,7 @@ include::inference-shared.asciidoc[tag=chunking-settings-strategy]
`service`::
(Required, string)
-The type of service supported for the specified task type. In this case,
+The type of service supported for the specified task type. In this case,
`openai`.
`service_settings`::
@@ -176,4 +184,4 @@ PUT _inference/completion/openai-completion
}
}
------------------------------------------------------------
-// TEST[skip:TBD]
\ No newline at end of file
+// TEST[skip:TBD]
diff --git a/docs/reference/inference/stream-inference.asciidoc b/docs/reference/inference/stream-inference.asciidoc
index 42abb589f9afd..4a3ce31909712 100644
--- a/docs/reference/inference/stream-inference.asciidoc
+++ b/docs/reference/inference/stream-inference.asciidoc
@@ -38,8 +38,12 @@ However, if you do not plan to use the {infer} APIs to use these models or if yo
==== {api-description-title}
The stream {infer} API enables real-time responses for completion tasks by delivering answers incrementally, reducing response times during computation.
-It only works with the `completion` task type.
+It only works with the `completion` and `chat_completion` task types.
+[NOTE]
+====
+include::inference-shared.asciidoc[tag=chat-completion-docs]
+====
[discrete]
[[stream-inference-api-path-params]]
From 67e1bd46a6d45e0d1e2f7d6f22486413f42bf48e Mon Sep 17 00:00:00 2001
From: Jan Kuipers <148754765+jan-elastic@users.noreply.github.com>
Date: Mon, 13 Jan 2025 15:50:56 +0100
Subject: [PATCH 10/44] Clean up lingering tasks after DatafeedJobsIT.
(#120040)
---
.../elasticsearch/xpack/ml/integration/DatafeedJobsIT.java | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedJobsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedJobsIT.java
index 5287d149fae3d..367c1cee8b0ee 100644
--- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedJobsIT.java
+++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedJobsIT.java
@@ -11,6 +11,8 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
+import org.elasticsearch.action.admin.cluster.node.tasks.cancel.TransportCancelTasksAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.ReferenceDocs;
@@ -80,6 +82,10 @@ public class DatafeedJobsIT extends MlNativeAutodetectIntegTestCase {
public void cleanup() {
updateClusterSettings(Settings.builder().putNull("logger.org.elasticsearch.xpack.ml.datafeed"));
cleanUp();
+ // Race conditions between closing and killing tasks in these tests,
+ // sometimes result in lingering persistent tasks (such as "_close"),
+ // which cause subsequent tests to fail.
+ client().execute(TransportCancelTasksAction.TYPE, new CancelTasksRequest());
}
public void testLookbackOnly() throws Exception {
From c990377c955ffeca7d08235411a0c468bba1ac36 Mon Sep 17 00:00:00 2001
From: Nik Everett
Date: Mon, 13 Jan 2025 10:04:27 -0500
Subject: [PATCH 11/44] ESQL: Limit memory usage of `fold` (#118602)
`fold` can be surprisingly heavy! The maximally efficient/paranoid thing
would be to fold each expression one time, in the constant folding rule,
and then store the result as a `Literal`. But this PR doesn't do that
because it's a big change. Instead, it creates the infrastructure for
tracking memory usage for folding as plugs it into as many places as
possible. That's not perfect, but it's better.
This infrastructure limit the allocations of fold similar to the
`CircuitBreaker` infrastructure we use for values, but it's different
in a critical way: you don't manually free any of the values. This is
important because the plan itself isn't `Releasable`, which is required
when using a real CircuitBreaker. We could have tried to make the plan
releasable, but that'd be a huge change.
Right now there's a single limit of 5% of heap per query. We create the
limit at the start of query planning and use it throughout planning.
There are about 40 places that don't yet use it. We should get them
plugged in as quick as we can manage. After that, we should look to the
maximally efficient/paranoid thing that I mentioned about waiting for
constant folding. That's an even bigger change, one I'm not equipped
to make on my own.
---
.../compute/operator/EvalBenchmark.java | 25 ++-
docs/changelog/118602.yaml | 5 +
.../functions/kibana/definition/bucket.json | 2 +-
.../kibana/definition/match_operator.json | 2 +-
.../functions/kibana/docs/match_operator.md | 9 +-
.../common/settings/Setting.java | 2 +-
.../xpack/esql/heap_attack/HeapAttackIT.java | 35 +++-
.../esql/core/expression/Expression.java | 12 +-
.../esql/core/expression/Expressions.java | 6 +-
.../esql/core/expression/FoldContext.java | 178 ++++++++++++++++++
.../xpack/esql/core/expression/Foldables.java | 4 +-
.../xpack/esql/core/expression/Literal.java | 6 +-
.../esql/core/expression/TypeResolutions.java | 10 +-
.../function/scalar/UnaryScalarFunction.java | 3 +-
.../expression/predicate/BinaryPredicate.java | 5 +-
.../esql/core/expression/predicate/Range.java | 22 +--
.../expression/predicate/logical/Not.java | 5 +-
.../expression/predicate/nulls/IsNotNull.java | 5 +-
.../expression/predicate/nulls/IsNull.java | 5 +-
.../predicate/operator/arithmetic/Neg.java | 5 +-
.../predicate/regex/RegexMatch.java | 3 +-
.../core/expression/FoldContextTests.java | 97 ++++++++++
.../core/expression/predicate/RangeTests.java | 8 +-
.../core/optimizer/OptimizerRulesTests.java | 5 +-
.../xpack/esql/EsqlTestUtils.java | 6 +
.../xpack/esql/analysis/Analyzer.java | 15 +-
.../xpack/esql/evaluator/EvalMapper.java | 39 ++--
.../evaluator/mapper/EvaluatorMapper.java | 104 ++++++++--
.../evaluator/mapper/ExpressionMapper.java | 3 +-
.../xpack/esql/execution/PlanExecutor.java | 4 +-
.../function/aggregate/AggregateFunction.java | 4 +-
.../function/aggregate/CountDistinct.java | 5 +-
.../function/aggregate/Percentile.java | 3 +-
.../expression/function/aggregate/Rate.java | 3 +-
.../expression/function/aggregate/Top.java | 5 +-
.../function/aggregate/WeightedAvg.java | 13 +-
.../function/fulltext/FullTextFunction.java | 3 +-
.../expression/function/fulltext/Match.java | 3 +-
.../expression/function/grouping/Bucket.java | 21 ++-
.../function/grouping/GroupingFunction.java | 5 +-
.../function/scalar/EsqlScalarFunction.java | 5 +-
.../function/scalar/conditional/Case.java | 8 +-
.../convert/FoldablesConvertFunction.java | 5 +-
.../function/scalar/date/DateDiff.java | 2 +-
.../function/scalar/date/DateExtract.java | 9 +-
.../function/scalar/date/DateFormat.java | 2 +-
.../function/scalar/date/DateParse.java | 2 +-
.../function/scalar/date/DateTrunc.java | 2 +-
.../expression/function/scalar/date/Now.java | 3 +-
.../expression/function/scalar/math/E.java | 3 +-
.../expression/function/scalar/math/Pi.java | 3 +-
.../expression/function/scalar/math/Tau.java | 3 +-
.../function/scalar/multivalue/MvConcat.java | 5 +-
.../multivalue/MvPSeriesWeightedSum.java | 2 +-
.../function/scalar/multivalue/MvSlice.java | 4 +-
.../function/scalar/multivalue/MvSort.java | 16 +-
.../scalar/spatial/SpatialContains.java | 7 +-
.../scalar/spatial/SpatialDisjoint.java | 7 +-
.../spatial/SpatialEvaluatorFactory.java | 12 +-
.../scalar/spatial/SpatialIntersects.java | 7 +-
.../scalar/spatial/SpatialRelatesUtils.java | 22 ++-
.../scalar/spatial/SpatialWithin.java | 7 +-
.../function/scalar/spatial/StDistance.java | 11 +-
.../function/scalar/string/Hash.java | 2 +-
.../function/scalar/string/RLike.java | 5 +-
.../function/scalar/string/Repeat.java | 2 +-
.../function/scalar/string/Replace.java | 2 +-
.../function/scalar/string/Space.java | 2 +-
.../function/scalar/string/Split.java | 7 +-
.../function/scalar/string/WildcardLike.java | 5 +-
.../function/scalar/util/Delay.java | 9 +-
.../DateTimeArithmeticOperation.java | 29 ++-
.../arithmetic/EsqlArithmeticOperation.java | 5 +-
.../predicate/operator/arithmetic/Neg.java | 9 +-
.../comparison/EsqlBinaryComparison.java | 5 +-
.../predicate/operator/comparison/In.java | 5 +-
.../comparison/InsensitiveEquals.java | 7 +-
.../comparison/InsensitiveEqualsMapper.java | 14 +-
.../LocalLogicalOptimizerContext.java | 5 +-
.../LocalPhysicalOptimizerContext.java | 3 +-
.../optimizer/LogicalOptimizerContext.java | 15 +-
.../BooleanFunctionEqualsElimination.java | 3 +-
.../rules/logical/BooleanSimplification.java | 3 +-
.../logical/CombineBinaryComparisons.java | 34 ++--
.../rules/logical/CombineDisjunctions.java | 7 +-
.../rules/logical/ConstantFolding.java | 5 +-
.../rules/logical/ConvertStringToByteRef.java | 4 +-
.../optimizer/rules/logical/FoldNull.java | 3 +-
.../rules/logical/LiteralsOnTheRight.java | 3 +-
.../rules/logical/OptimizerRules.java | 19 +-
.../rules/logical/PartiallyFoldCase.java | 5 +-
.../rules/logical/PropagateEmptyRelation.java | 25 ++-
.../rules/logical/PropagateEquals.java | 38 ++--
.../rules/logical/PropagateEvalFoldables.java | 9 +-
.../rules/logical/PropagateNullable.java | 3 +-
.../logical/PushDownAndCombineLimits.java | 19 +-
.../rules/logical/ReplaceRegexMatch.java | 3 +-
.../logical/ReplaceRowAsLocalRelation.java | 10 +-
.../ReplaceStatsFilteredAggWithEval.java | 2 +-
...laceStringCasingWithInsensitiveEquals.java | 22 ++-
.../SimplifyComparisonsArithmetics.java | 38 ++--
.../rules/logical/SkipQueryOnLimitZero.java | 11 +-
.../logical/SplitInWithFoldableValue.java | 5 +-
.../logical/SubstituteFilteredExpression.java | 3 +-
.../logical/SubstituteSpatialSurrogates.java | 3 +-
.../local/LocalPropagateEmptyRelation.java | 12 +-
.../local/EnableSpatialDistancePushdown.java | 41 ++--
.../physical/local/PushTopNToSource.java | 17 +-
.../xpack/esql/parser/ExpressionBuilder.java | 13 +-
.../xpack/esql/parser/LogicalPlanBuilder.java | 7 +-
.../xpack/esql/plan/logical/Enrich.java | 3 +-
.../AbstractPhysicalOperationProviders.java | 12 +-
.../planner/EsPhysicalOperationProviders.java | 9 +-
.../planner/EsqlExpressionTranslators.java | 17 +-
.../esql/planner/LocalExecutionPlanner.java | 23 +--
.../xpack/esql/planner/PlannerUtils.java | 16 +-
.../xpack/esql/planner/TypeConverter.java | 36 ++--
.../esql/planner/mapper/MapperUtils.java | 3 +-
.../xpack/esql/plugin/ComputeService.java | 32 +++-
.../xpack/esql/plugin/QueryPragmas.java | 15 ++
.../esql/plugin/TransportEsqlQueryAction.java | 4 +
.../xpack/esql/session/Configuration.java | 8 +
.../xpack/esql/session/EsqlSession.java | 8 +-
.../esql/type/EsqlDataTypeConverter.java | 5 +-
.../elasticsearch/xpack/esql/CsvTests.java | 35 ++--
.../xpack/esql/analysis/AnalyzerTests.java | 22 +--
.../mapper/EvaluatorMapperTests.java | 43 +++++
.../function/AbstractAggregationTestCase.java | 8 +-
.../function/AbstractFunctionTestCase.java | 23 ++-
.../AbstractScalarFunctionTestCase.java | 10 +-
.../function/CheckLicenseTests.java | 5 +-
.../expression/function/TestCaseSupplier.java | 6 +-
.../scalar/conditional/CaseExtraTests.java | 61 +++---
.../scalar/conditional/CaseTests.java | 9 +-
.../scalar/date/DateExtractTests.java | 3 +-
.../function/scalar/nulls/CoalesceTests.java | 33 ++--
.../function/scalar/string/RLikeTests.java | 5 +-
.../function/scalar/string/ToLowerTests.java | 3 +-
.../function/scalar/string/ToUpperTests.java | 3 +-
.../scalar/string/WildcardLikeTests.java | 5 +-
.../operator/arithmetic/NegTests.java | 3 +-
.../operator/comparison/InTests.java | 13 +-
.../comparison/InsensitiveEqualsTests.java | 57 +++---
.../LocalLogicalPlanOptimizerTests.java | 16 +-
.../LocalPhysicalPlanOptimizerTests.java | 40 ++--
.../optimizer/LogicalPlanOptimizerTests.java | 176 ++++++++---------
.../optimizer/PhysicalPlanOptimizerTests.java | 54 +++---
.../esql/optimizer/TestPlannerOptimizer.java | 12 +-
.../rules/LogicalOptimizerContextTests.java | 62 ++++++
...BooleanFunctionEqualsEliminationTests.java | 13 +-
.../logical/BooleanSimplificationTests.java | 38 ++--
.../CombineBinaryComparisonsTests.java | 76 +++-----
.../logical/CombineDisjunctionsTests.java | 30 +--
.../rules/logical/ConstantFoldingTests.java | 79 ++++----
.../rules/logical/FoldNullTests.java | 122 ++++++------
.../logical/LiteralsOnTheRightTests.java | 3 +-
.../rules/logical/PropagateEqualsTests.java | 77 +++-----
.../rules/logical/PropagateNullableTests.java | 27 ++-
.../rules/logical/ReplaceRegexMatchTests.java | 18 +-
.../physical/local/PushTopNToSourceTests.java | 3 +-
.../xpack/esql/parser/ExpressionTests.java | 11 +-
.../esql/parser/StatementParserTests.java | 39 ++--
.../xpack/esql/plan/QueryPlanTests.java | 23 +--
.../xpack/esql/planner/EvalMapperTests.java | 3 +-
.../xpack/esql/planner/FilterTests.java | 4 +-
.../planner/LocalExecutionPlannerTests.java | 6 +-
.../TestPhysicalOperationProviders.java | 9 +-
.../esql/plugin/ClusterRequestTests.java | 5 +-
.../DataNodeRequestSerializationTests.java | 3 +-
.../esql/stats/PlanExecutorMetricsTests.java | 3 +
170 files changed, 1858 insertions(+), 972 deletions(-)
create mode 100644 docs/changelog/118602.yaml
create mode 100644 x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java
create mode 100644 x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java
create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java
create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java
diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java
index 9aab4a3e3210f..d3259b9604717 100644
--- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java
@@ -27,6 +27,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -71,12 +72,11 @@ public class EvalBenchmark {
BigArrays.NON_RECYCLING_INSTANCE
);
+ private static final FoldContext FOLD_CONTEXT = FoldContext.small();
+
private static final int BLOCK_LENGTH = 8 * 1024;
- static final DriverContext driverContext = new DriverContext(
- BigArrays.NON_RECYCLING_INSTANCE,
- BlockFactory.getInstance(new NoopCircuitBreaker("noop"), BigArrays.NON_RECYCLING_INSTANCE)
- );
+ static final DriverContext driverContext = new DriverContext(BigArrays.NON_RECYCLING_INSTANCE, blockFactory);
static {
// Smoke test all the expected values and force loading subclasses more like prod
@@ -114,11 +114,12 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
return switch (operation) {
case "abs" -> {
FieldAttribute longField = longField();
- yield EvalMapper.toEvaluator(new Abs(Source.EMPTY, longField), layout(longField)).get(driverContext);
+ yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Abs(Source.EMPTY, longField), layout(longField)).get(driverContext);
}
case "add" -> {
FieldAttribute longField = longField();
yield EvalMapper.toEvaluator(
+ FOLD_CONTEXT,
new Add(Source.EMPTY, longField, new Literal(Source.EMPTY, 1L, DataType.LONG)),
layout(longField)
).get(driverContext);
@@ -126,6 +127,7 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
case "add_double" -> {
FieldAttribute doubleField = doubleField();
yield EvalMapper.toEvaluator(
+ FOLD_CONTEXT,
new Add(Source.EMPTY, doubleField, new Literal(Source.EMPTY, 1D, DataType.DOUBLE)),
layout(doubleField)
).get(driverContext);
@@ -140,7 +142,8 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
lhs = new Add(Source.EMPTY, lhs, new Literal(Source.EMPTY, 1L, DataType.LONG));
rhs = new Add(Source.EMPTY, rhs, new Literal(Source.EMPTY, 1L, DataType.LONG));
}
- yield EvalMapper.toEvaluator(new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2)).get(driverContext);
+ yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2))
+ .get(driverContext);
}
case "date_trunc" -> {
FieldAttribute timestamp = new FieldAttribute(
@@ -149,6 +152,7 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
new EsField("timestamp", DataType.DATETIME, Map.of(), true)
);
yield EvalMapper.toEvaluator(
+ FOLD_CONTEXT,
new DateTrunc(Source.EMPTY, new Literal(Source.EMPTY, Duration.ofHours(24), DataType.TIME_DURATION), timestamp),
layout(timestamp)
).get(driverContext);
@@ -156,6 +160,7 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
case "equal_to_const" -> {
FieldAttribute longField = longField();
yield EvalMapper.toEvaluator(
+ FOLD_CONTEXT,
new Equals(Source.EMPTY, longField, new Literal(Source.EMPTY, 100_000L, DataType.LONG)),
layout(longField)
).get(driverContext);
@@ -163,21 +168,21 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) {
case "long_equal_to_long" -> {
FieldAttribute lhs = longField();
FieldAttribute rhs = longField();
- yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext);
+ yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext);
}
case "long_equal_to_int" -> {
FieldAttribute lhs = longField();
FieldAttribute rhs = intField();
- yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext);
+ yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext);
}
case "mv_min", "mv_min_ascending" -> {
FieldAttribute longField = longField();
- yield EvalMapper.toEvaluator(new MvMin(Source.EMPTY, longField), layout(longField)).get(driverContext);
+ yield EvalMapper.toEvaluator(FOLD_CONTEXT, new MvMin(Source.EMPTY, longField), layout(longField)).get(driverContext);
}
case "rlike" -> {
FieldAttribute keywordField = keywordField();
RLike rlike = new RLike(Source.EMPTY, keywordField, new RLikePattern(".ar"));
- yield EvalMapper.toEvaluator(rlike, layout(keywordField)).get(driverContext);
+ yield EvalMapper.toEvaluator(FOLD_CONTEXT, rlike, layout(keywordField)).get(driverContext);
}
default -> throw new UnsupportedOperationException();
};
diff --git a/docs/changelog/118602.yaml b/docs/changelog/118602.yaml
new file mode 100644
index 0000000000000..a75c5dcf11da3
--- /dev/null
+++ b/docs/changelog/118602.yaml
@@ -0,0 +1,5 @@
+pr: 118602
+summary: Limit memory usage of `fold`
+area: ES|QL
+type: bug
+issues: []
diff --git a/docs/reference/esql/functions/kibana/definition/bucket.json b/docs/reference/esql/functions/kibana/definition/bucket.json
index 3d96de05c8407..f9c7f2f27d6f9 100644
--- a/docs/reference/esql/functions/kibana/definition/bucket.json
+++ b/docs/reference/esql/functions/kibana/definition/bucket.json
@@ -1599,7 +1599,7 @@
"FROM sample_data \n| WHERE @timestamp >= NOW() - 1 day and @timestamp < NOW()\n| STATS COUNT(*) BY bucket = BUCKET(@timestamp, 25, NOW() - 1 day, NOW())",
"FROM employees\n| WHERE hire_date >= \"1985-01-01T00:00:00Z\" AND hire_date < \"1986-01-01T00:00:00Z\"\n| STATS AVG(salary) BY bucket = BUCKET(hire_date, 20, \"1985-01-01T00:00:00Z\", \"1986-01-01T00:00:00Z\")\n| SORT bucket",
"FROM employees\n| STATS s1 = b1 + 1, s2 = BUCKET(salary / 1000 + 999, 50.) + 2 BY b1 = BUCKET(salary / 100 + 99, 50.), b2 = BUCKET(salary / 1000 + 999, 50.)\n| SORT b1, b2\n| KEEP s1, b1, s2, b2",
- "FROM employees \n| STATS dates = VALUES(birth_date) BY b = BUCKET(birth_date + 1 HOUR, 1 YEAR) - 1 HOUR\n| EVAL d_count = MV_COUNT(dates)\n| SORT d_count\n| LIMIT 3"
+ "FROM employees\n| STATS dates = MV_SORT(VALUES(birth_date)) BY b = BUCKET(birth_date + 1 HOUR, 1 YEAR) - 1 HOUR\n| EVAL d_count = MV_COUNT(dates)\n| SORT d_count, b\n| LIMIT 3"
],
"preview" : false,
"snapshot_only" : false
diff --git a/docs/reference/esql/functions/kibana/definition/match_operator.json b/docs/reference/esql/functions/kibana/definition/match_operator.json
index 44233bbddb653..c8cbf1cf9d966 100644
--- a/docs/reference/esql/functions/kibana/definition/match_operator.json
+++ b/docs/reference/esql/functions/kibana/definition/match_operator.json
@@ -2,7 +2,7 @@
"comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.",
"type" : "operator",
"name" : "match_operator",
- "description" : "Performs a <> on the specified field. Returns true if the provided query matches the row.",
+ "description" : "Use `MATCH` to perform a <> on the specified field.\nUsing `MATCH` is equivalent to using the `match` query in the Elasticsearch Query DSL.\n\nMatch can be used on text fields, as well as other field types like boolean, dates, and numeric types.\n\nFor a simplified syntax, you can use the <> `:` operator instead of `MATCH`.\n\n`MATCH` returns true if the provided query matches the row.",
"signatures" : [
{
"params" : [
diff --git a/docs/reference/esql/functions/kibana/docs/match_operator.md b/docs/reference/esql/functions/kibana/docs/match_operator.md
index b0b6196798087..7681c2d1ce231 100644
--- a/docs/reference/esql/functions/kibana/docs/match_operator.md
+++ b/docs/reference/esql/functions/kibana/docs/match_operator.md
@@ -3,7 +3,14 @@ This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../READ
-->
### MATCH_OPERATOR
-Performs a <> on the specified field. Returns true if the provided query matches the row.
+Use `MATCH` to perform a <> on the specified field.
+Using `MATCH` is equivalent to using the `match` query in the Elasticsearch Query DSL.
+
+Match can be used on text fields, as well as other field types like boolean, dates, and numeric types.
+
+For a simplified syntax, you can use the <> `:` operator instead of `MATCH`.
+
+`MATCH` returns true if the provided query matches the row.
```
FROM books
diff --git a/server/src/main/java/org/elasticsearch/common/settings/Setting.java b/server/src/main/java/org/elasticsearch/common/settings/Setting.java
index aec9c108d898d..16c6844f46402 100644
--- a/server/src/main/java/org/elasticsearch/common/settings/Setting.java
+++ b/server/src/main/java/org/elasticsearch/common/settings/Setting.java
@@ -1727,7 +1727,7 @@ public static > Setting enumSetting(
*
* @param key the key for the setting
* @param defaultValue the default value for this setting
- * @param properties properties properties for this setting like scope, filtering...
+ * @param properties properties for this setting like scope, filtering...
* @return the setting object
*/
public static Setting memorySizeSetting(String key, ByteSizeValue defaultValue, Property... properties) {
diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java
index ace3db377664c..958132b3e4076 100644
--- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java
+++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java
@@ -194,6 +194,16 @@ private void assertCircuitBreaks(ThrowingRunnable r) throws IOException {
);
}
+ private void assertFoldCircuitBreaks(ThrowingRunnable r) throws IOException {
+ ResponseException e = expectThrows(ResponseException.class, r);
+ Map, ?> map = responseAsMap(e.getResponse());
+ logger.info("expected fold circuit breaking {}", map);
+ assertMap(
+ map,
+ matchesMap().entry("status", 400).entry("error", matchesMap().extraOk().entry("type", "fold_too_much_memory_exception"))
+ );
+ }
+
private void assertParseFailure(ThrowingRunnable r) throws IOException {
ResponseException e = expectThrows(ResponseException.class, r);
Map, ?> map = responseAsMap(e.getResponse());
@@ -325,11 +335,23 @@ public void testManyConcatFromRow() throws IOException {
assertManyStrings(resp, strings);
}
+ /**
+ * Hits a circuit breaker by building many moderately long strings.
+ */
+ public void testHugeManyConcatFromRow() throws IOException {
+ assertFoldCircuitBreaks(
+ () -> manyConcat(
+ "ROW a=9999999999999, b=99999999999999999, c=99999999999999999, d=99999999999999999, e=99999999999999999",
+ 5000
+ )
+ );
+ }
+
/**
* Fails to parse a huge huge query.
*/
public void testHugeHugeManyConcatFromRow() throws IOException {
- assertParseFailure(() -> manyConcat("ROW a=9999, b=9999, c=9999, d=9999, e=9999", 50000));
+ assertParseFailure(() -> manyConcat("ROW a=9999, b=9999, c=9999, d=9999, e=9999", 6000));
}
/**
@@ -387,13 +409,20 @@ public void testHugeManyRepeat() throws IOException {
* Returns many moderately long strings.
*/
public void testManyRepeatFromRow() throws IOException {
- int strings = 10000;
+ int strings = 300;
Response resp = manyRepeat("ROW a = 99", strings);
assertManyStrings(resp, strings);
}
/**
- * Fails to parse a huge huge query.
+ * Hits a circuit breaker by building many moderately long strings.
+ */
+ public void testHugeManyRepeatFromRow() throws IOException {
+ assertFoldCircuitBreaks(() -> manyRepeat("ROW a = 99", 400));
+ }
+
+ /**
+ * Fails to parse a huge, huge query.
*/
public void testHugeHugeManyRepeatFromRow() throws IOException {
assertParseFailure(() -> manyRepeat("ROW a = 99", 100000));
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java
index 00765a8c0528c..b254612a700df 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java
@@ -78,12 +78,20 @@ public Expression(Source source, List children) {
super(source, children);
}
- // whether the expression can be evaluated statically (folded) or not
+ /**
+ * Whether the expression can be evaluated statically, aka "folded", or not.
+ */
public boolean foldable() {
return false;
}
- public Object fold() {
+ /**
+ * Evaluate this expression statically to a constant. It is an error to call
+ * this if {@link #foldable} returns false.
+ */
+ public Object fold(FoldContext ctx) {
+ // TODO After removing FoldContext.unbounded from non-test code examine all calls
+ // for places we should use instanceof Literal instead
throw new QlIllegalArgumentException("Should not fold expression");
}
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java
index 4e4338aad3704..739333ded0fde 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java
@@ -107,10 +107,10 @@ public static boolean foldable(List extends Expression> exps) {
return true;
}
- public static List fold(List extends Expression> exps) {
+ public static List fold(FoldContext ctx, List extends Expression> exps) {
List folded = new ArrayList<>(exps.size());
for (Expression exp : exps) {
- folded.add(exp.fold());
+ folded.add(exp.fold(ctx));
}
return folded;
@@ -135,7 +135,7 @@ public static String name(Expression e) {
/**
* Is this {@linkplain Expression} guaranteed to have
* only the {@code null} value. {@linkplain Expression}s that
- * {@link Expression#fold()} to {@code null} may
+ * {@link Expression#fold} to {@code null} may
* return {@code false} here, but should eventually be folded
* into a {@link Literal} containing {@code null} which will return
* {@code true} from here.
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java
new file mode 100644
index 0000000000000..25da44c5fd226
--- /dev/null
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.core.expression;
+
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.breaker.CircuitBreakingException;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.unit.MemorySizeValue;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.xpack.esql.core.QlClientException;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+
+import java.util.Objects;
+
+/**
+ * Context passed to {@link Expression#fold}. This is not thread safe.
+ */
+public class FoldContext {
+ private static final long SMALL = MemorySizeValue.parseBytesSizeValueOrHeapRatio("5%", "small").getBytes();
+
+ /**
+ * {@link Expression#fold} using less than 5% of heap. Fine in tests but otherwise
+ * calling this is a signal that you either, shouldn't be calling {@link Expression#fold}
+ * at all, or should pass in a shared {@link FoldContext} made by {@code Configuration}.
+ */
+ public static FoldContext small() {
+ return new FoldContext(SMALL);
+ }
+
+ private final long initialAllowedBytes;
+ private long allowedBytes;
+
+ public FoldContext(long allowedBytes) {
+ this.initialAllowedBytes = allowedBytes;
+ this.allowedBytes = allowedBytes;
+ }
+
+ /**
+ * The maximum allowed bytes. {@link #allowedBytes()} will be the same as this
+ * for an unused context.
+ */
+ public long initialAllowedBytes() {
+ return initialAllowedBytes;
+ }
+
+ /**
+ * The remaining allowed bytes.
+ */
+ long allowedBytes() {
+ return allowedBytes;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o == null || getClass() != o.getClass()) return false;
+ FoldContext that = (FoldContext) o;
+ return initialAllowedBytes == that.initialAllowedBytes && allowedBytes == that.allowedBytes;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(initialAllowedBytes, allowedBytes);
+ }
+
+ @Override
+ public String toString() {
+ return "FoldContext[" + allowedBytes + '/' + initialAllowedBytes + ']';
+ }
+
+ /**
+ * Track an allocation. Best to call this before allocating
+ * if possible, but after is ok if the allocation is small.
+ *
+ * Note that, unlike {@link CircuitBreaker}, you don't have
+ * to free this allocation later. This is important because the query plan
+ * doesn't implement {@link Releasable} so it can't free
+ * consistently. But when you have to allocate big chunks of memory during
+ * folding and know that you are returning the memory it is kindest to
+ * call this with a negative number, effectively giving those bytes back.
+ *
+ */
+ public void trackAllocation(Source source, long bytes) {
+ allowedBytes -= bytes;
+ assert allowedBytes <= initialAllowedBytes : "returned more bytes than it used";
+ if (allowedBytes < 0) {
+ throw new FoldTooMuchMemoryException(source, bytes, initialAllowedBytes);
+ }
+ }
+
+ /**
+ * Adapt this into a {@link CircuitBreaker} suitable for building bounded local
+ * DriverContext. This is absolutely an abuse of the {@link CircuitBreaker} contract
+ * and only methods used by BlockFactory are implemented. And this'll throw a
+ * {@link FoldTooMuchMemoryException} instead of the standard {@link CircuitBreakingException}.
+ * This works for the common folding implementation though.
+ */
+ public CircuitBreaker circuitBreakerView(Source source) {
+ return new CircuitBreaker() {
+ @Override
+ public void circuitBreak(String fieldName, long bytesNeeded) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException {
+ trackAllocation(source, bytes);
+ }
+
+ @Override
+ public void addWithoutBreaking(long bytes) {
+ assert bytes <= 0 : "we only expect this to be used for deallocation";
+ allowedBytes -= bytes;
+ assert allowedBytes <= initialAllowedBytes : "returned more bytes than it used";
+ }
+
+ @Override
+ public long getUsed() {
+ /*
+ * This isn't expected to be used by we can implement it so we may as
+ * well. Maybe it'll be useful for debugging one day.
+ */
+ return initialAllowedBytes - allowedBytes;
+ }
+
+ @Override
+ public long getLimit() {
+ /*
+ * This isn't expected to be used by we can implement it so we may as
+ * well. Maybe it'll be useful for debugging one day.
+ */
+ return initialAllowedBytes;
+ }
+
+ @Override
+ public double getOverhead() {
+ return 1.0;
+ }
+
+ @Override
+ public long getTrippedCount() {
+ return 0;
+ }
+
+ @Override
+ public String getName() {
+ return REQUEST;
+ }
+
+ @Override
+ public Durability getDurability() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setLimitAndOverhead(long limit, double overhead) {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ public static class FoldTooMuchMemoryException extends QlClientException {
+ protected FoldTooMuchMemoryException(Source source, long bytesForExpression, long initialAllowedBytes) {
+ super(
+ "line {}:{}: Folding query used more than {}. The expression that pushed past the limit is [{}] which needed {}.",
+ source.source().getLineNumber(),
+ source.source().getColumnNumber(),
+ ByteSizeValue.ofBytes(initialAllowedBytes),
+ source.text(),
+ ByteSizeValue.ofBytes(bytesForExpression)
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java
index 601758bca5918..233113c3fe1b8 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java
@@ -10,9 +10,9 @@
public abstract class Foldables {
- public static Object valueOf(Expression e) {
+ public static Object valueOf(FoldContext ctx, Expression e) {
if (e.foldable()) {
- return e.fold();
+ return e.fold(ctx);
}
throw new QlIllegalArgumentException("Cannot determine value for {}", e);
}
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java
index 53f559c5c82fe..afe616489d81d 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java
@@ -98,7 +98,7 @@ public boolean resolved() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
return value;
}
@@ -138,7 +138,7 @@ public String nodeString() {
* Utility method for creating a literal out of a foldable expression.
* Throws an exception if the expression is not foldable.
*/
- public static Literal of(Expression foldable) {
+ public static Literal of(FoldContext ctx, Expression foldable) {
if (foldable.foldable() == false) {
throw new QlIllegalArgumentException("Foldable expression required for Literal creation; received unfoldable " + foldable);
}
@@ -147,7 +147,7 @@ public static Literal of(Expression foldable) {
return (Literal) foldable;
}
- return new Literal(foldable.source(), foldable.fold(), foldable.dataType());
+ return new Literal(foldable.source(), foldable.fold(ctx), foldable.dataType());
}
public static Literal of(Expression source, Object value) {
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java
index b817ec17c7bda..842f3c0ddadd7 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java
@@ -133,6 +133,14 @@ public static TypeResolution isFoldable(Expression e, String operationName, Para
return TypeResolution.TYPE_RESOLVED;
}
+ /**
+ * Is this {@link Expression#foldable()} and not {@code null}.
+ *
+ * @deprecated instead of calling this, check for a {@link Literal} containing
+ * {@code null}. Foldable expressions will be folded by other rules,
+ * eventually, to a {@link Literal}.
+ */
+ @Deprecated
public static TypeResolution isNotNullAndFoldable(Expression e, String operationName, ParamOrdinal paramOrd) {
TypeResolution resolution = isFoldable(e, operationName, paramOrd);
@@ -140,7 +148,7 @@ public static TypeResolution isNotNullAndFoldable(Expression e, String operation
return resolution;
}
- if (e.dataType() == DataType.NULL || e.fold() == null) {
+ if (e.dataType() == DataType.NULL || e.fold(FoldContext.small()) == null) {
resolution = new TypeResolution(
format(
null,
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java
index 8704a42ed33e2..36517b1be9ce7 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java
@@ -9,6 +9,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;
@@ -53,5 +54,5 @@ public boolean foldable() {
}
@Override
- public abstract Object fold();
+ public abstract Object fold(FoldContext ctx);
}
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java
index be5caedacd50a..bf5549b31e5fa 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.esql.core.expression.predicate;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -29,8 +30,8 @@ protected BinaryPredicate(Source source, Expression left, Expression right, F fu
@SuppressWarnings("unchecked")
@Override
- public R fold() {
- return function().apply((T) left().fold(), (U) right().fold());
+ public R fold(FoldContext ctx) {
+ return function().apply((T) left().fold(ctx), (U) right().fold(ctx));
}
@Override
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java
index 5de09f40437c7..a4e4685f764e8 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java
@@ -8,6 +8,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
@@ -99,23 +100,24 @@ public boolean foldable() {
}
// We cannot fold the bounds here; but if they're already literals, we can check if the range is always empty.
- if (lower() instanceof Literal && upper() instanceof Literal) {
- return areBoundariesInvalid();
+ if (lower() instanceof Literal l && upper() instanceof Literal u) {
+ return areBoundariesInvalid(l.value(), u.value());
}
}
-
return false;
}
@Override
- public Object fold() {
- if (areBoundariesInvalid()) {
+ public Object fold(FoldContext ctx) {
+ Object lowerValue = lower.fold(ctx);
+ Object upperValue = upper.fold(ctx);
+ if (areBoundariesInvalid(lowerValue, upperValue)) {
return Boolean.FALSE;
}
- Object val = value.fold();
- Integer lowerCompare = BinaryComparison.compare(lower.fold(), val);
- Integer upperCompare = BinaryComparison.compare(val, upper().fold());
+ Object val = value.fold(ctx);
+ Integer lowerCompare = BinaryComparison.compare(lower.fold(ctx), val);
+ Integer upperCompare = BinaryComparison.compare(val, upper().fold(ctx));
boolean lowerComparsion = lowerCompare == null ? false : (includeLower ? lowerCompare <= 0 : lowerCompare < 0);
boolean upperComparsion = upperCompare == null ? false : (includeUpper ? upperCompare <= 0 : upperCompare < 0);
return lowerComparsion && upperComparsion;
@@ -125,9 +127,7 @@ public Object fold() {
* Check whether the boundaries are invalid ( upper < lower) or not.
* If they are, the value does not have to be evaluated.
*/
- protected boolean areBoundariesInvalid() {
- Object lowerValue = lower.fold();
- Object upperValue = upper.fold();
+ protected boolean areBoundariesInvalid(Object lowerValue, Object upperValue) {
if (DataType.isDateTime(value.dataType()) || DataType.isDateTime(lower.dataType()) || DataType.isDateTime(upper.dataType())) {
try {
if (upperValue instanceof String upperString) {
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java
index c4983b49a6bc8..218f61856accc 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java
@@ -10,6 +10,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
@@ -56,8 +57,8 @@ protected TypeResolution resolveType() {
}
@Override
- public Object fold() {
- return apply(field().fold());
+ public Object fold(FoldContext ctx) {
+ return apply(field().fold(ctx));
}
private static Boolean apply(Object input) {
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java
index 9879a1f5ffc29..f5542ff7c3de5 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java
@@ -9,6 +9,7 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable;
@@ -49,8 +50,8 @@ protected IsNotNull replaceChild(Expression newChild) {
}
@Override
- public Object fold() {
- return field().fold() != null && DataType.isNull(field().dataType()) == false;
+ public Object fold(FoldContext ctx) {
+ return DataType.isNull(field().dataType()) == false && field().fold(ctx) != null;
}
@Override
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java
index d88945045b03e..bb85791a9f85e 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java
@@ -9,6 +9,7 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable;
@@ -45,8 +46,8 @@ protected IsNull replaceChild(Expression newChild) {
}
@Override
- public Object fold() {
- return field().fold() == null || DataType.isNull(field().dataType());
+ public Object fold(FoldContext ctx) {
+ return DataType.isNull(field().dataType()) || field().fold(ctx) == null;
}
@Override
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java
index 9a8a14f320cd6..b0e79704f5fda 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java
@@ -8,6 +8,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -53,8 +54,8 @@ protected TypeResolution resolveType() {
}
@Override
- public Object fold() {
- return Arithmetics.negate((Number) field().fold());
+ public Object fold(FoldContext ctx) {
+ return Arithmetics.negate((Number) field().fold(ctx));
}
@Override
diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java
index 0f9116ade5a31..a4a0a6217161e 100644
--- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java
+++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.core.expression.predicate.regex;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -62,7 +63,7 @@ public boolean foldable() {
}
@Override
- public Boolean fold() {
+ public Boolean fold(FoldContext ctx) {
throw new UnsupportedOperationException();
}
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java
new file mode 100644
index 0000000000000..2080f4007777c
--- /dev/null
+++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.core.expression;
+
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.EqualsHashCodeTestUtils;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class FoldContextTests extends ESTestCase {
+ public void testEq() {
+ EqualsHashCodeTestUtils.checkEqualsAndHashCode(randomFoldContext(), this::copy, this::mutate);
+ }
+
+ private FoldContext randomFoldContext() {
+ FoldContext ctx = new FoldContext(randomNonNegativeLong());
+ if (randomBoolean()) {
+ ctx.trackAllocation(Source.EMPTY, randomLongBetween(0, ctx.initialAllowedBytes()));
+ }
+ return ctx;
+ }
+
+ private FoldContext copy(FoldContext ctx) {
+ FoldContext copy = new FoldContext(ctx.initialAllowedBytes());
+ copy.trackAllocation(Source.EMPTY, ctx.initialAllowedBytes() - ctx.allowedBytes());
+ return copy;
+ }
+
+ private FoldContext mutate(FoldContext ctx) {
+ if (randomBoolean()) {
+ FoldContext differentInitial = new FoldContext(ctx.initialAllowedBytes() + 1);
+ differentInitial.trackAllocation(Source.EMPTY, differentInitial.initialAllowedBytes() - ctx.allowedBytes());
+ assertThat(differentInitial.allowedBytes(), equalTo(ctx.allowedBytes()));
+ return differentInitial;
+ } else {
+ FoldContext differentAllowed = new FoldContext(ctx.initialAllowedBytes());
+ long allowed = randomValueOtherThan(ctx.allowedBytes(), () -> randomLongBetween(0, ctx.initialAllowedBytes()));
+ differentAllowed.trackAllocation(Source.EMPTY, ctx.initialAllowedBytes() - allowed);
+ assertThat(differentAllowed.allowedBytes(), equalTo(allowed));
+ return differentAllowed;
+ }
+ }
+
+ public void testTrackAllocation() {
+ FoldContext ctx = new FoldContext(10);
+ ctx.trackAllocation(Source.synthetic("shouldn't break"), 10);
+ Exception e = expectThrows(
+ FoldContext.FoldTooMuchMemoryException.class,
+ () -> ctx.trackAllocation(Source.synthetic("should break"), 1)
+ );
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "line -1:-1: Folding query used more than 10b. "
+ + "The expression that pushed past the limit is [should break] which needed 1b."
+ )
+ );
+ }
+
+ public void testCircuitBreakerViewBreaking() {
+ FoldContext ctx = new FoldContext(10);
+ ctx.circuitBreakerView(Source.synthetic("shouldn't break")).addEstimateBytesAndMaybeBreak(10, "test");
+ Exception e = expectThrows(
+ FoldContext.FoldTooMuchMemoryException.class,
+ () -> ctx.circuitBreakerView(Source.synthetic("should break")).addEstimateBytesAndMaybeBreak(1, "test")
+ );
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "line -1:-1: Folding query used more than 10b. "
+ + "The expression that pushed past the limit is [should break] which needed 1b."
+ )
+ );
+ }
+
+ public void testCircuitBreakerViewWithoutBreaking() {
+ FoldContext ctx = new FoldContext(10);
+ CircuitBreaker view = ctx.circuitBreakerView(Source.synthetic("shouldn't break"));
+ view.addEstimateBytesAndMaybeBreak(10, "test");
+ view.addWithoutBreaking(-1);
+ assertThat(view.getUsed(), equalTo(9L));
+ }
+
+ public void testToString() {
+ // Random looking numbers are indeed random. Just so we have consistent numbers to assert on in toString.
+ FoldContext ctx = new FoldContext(123);
+ ctx.trackAllocation(Source.EMPTY, 22);
+ assertThat(ctx.toString(), equalTo("FoldContext[101/123]"));
+ }
+}
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java
index ed4c6282368ca..cd15ed5a94cfc 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java
+++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.core.expression.predicate;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.DateUtils;
@@ -211,7 +212,11 @@ public void testAreBoundariesInvalid() {
(Boolean) test[7],
ZoneId.systemDefault()
);
- assertEquals("failed on test " + i + ": " + Arrays.toString(test), test[8], range.areBoundariesInvalid());
+ assertEquals(
+ "failed on test " + i + ": " + Arrays.toString(test),
+ test[8],
+ range.areBoundariesInvalid(range.lower().fold(FoldContext.small()), range.upper().fold(FoldContext.small()))
+ );
}
}
@@ -226,5 +231,4 @@ private static DataType randomNumericType() {
private static DataType randomTextType() {
return randomFrom(KEYWORD, TEXT);
}
-
}
diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java
index d323174d2d3d9..91b0564a5b404 100644
--- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java
+++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java
@@ -10,6 +10,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.predicate.Range;
@@ -104,7 +105,7 @@ public void testFoldExcludingRangeToFalse() {
Range r = rangeOf(fa, SIX, false, FIVE, true);
assertTrue(r.foldable());
- assertEquals(Boolean.FALSE, r.fold());
+ assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
}
// 6 < a <= 5.5 -> FALSE
@@ -113,7 +114,7 @@ public void testFoldExcludingRangeWithDifferentTypesToFalse() {
Range r = rangeOf(fa, SIX, false, L(5.5d), true);
assertTrue(r.foldable());
- assertEquals(Boolean.FALSE, r.fold());
+ assertEquals(Boolean.FALSE, r.fold(FoldContext.small()));
}
// Conjunction
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java
index 66fd7d3ee5eb5..7e25fb29fdb78 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java
@@ -41,6 +41,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.expression.predicate.Range;
@@ -61,6 +62,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
import org.elasticsearch.xpack.esql.index.EsIndex;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.parser.QueryParam;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
@@ -350,6 +352,10 @@ public String toString() {
public static final Configuration TEST_CFG = configuration(new QueryPragmas(Settings.EMPTY));
+ public static LogicalOptimizerContext unboundLogicalOptimizerContext() {
+ return new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small());
+ }
+
public static final Verifier TEST_VERIFIER = new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L));
public static final QueryBuilderResolver MOCK_QUERY_BUILDER_RESOLVER = new MockQueryBuilderResolver();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
index 3d1bfdfd0ef42..a11b511cb83b7 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
@@ -25,6 +25,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@@ -325,7 +326,7 @@ protected LogicalPlan rule(Enrich plan, AnalyzerContext context) {
// the policy does not exist
return plan;
}
- final String policyName = (String) plan.policyName().fold();
+ final String policyName = (String) plan.policyName().fold(FoldContext.small() /* TODO remove me */);
final var resolved = context.enrichResolution().getResolvedPolicy(policyName, plan.mode());
if (resolved != null) {
var policy = new EnrichPolicy(resolved.matchType(), null, List.of(), resolved.matchField(), resolved.enrichFields());
@@ -1279,16 +1280,16 @@ private static boolean supportsStringImplicitCasting(DataType type) {
private static UnresolvedAttribute unresolvedAttribute(Expression value, String type, Exception e) {
String message = format(
"Cannot convert string [{}] to [{}], error [{}]",
- value.fold(),
+ value.fold(FoldContext.small() /* TODO remove me */),
type,
(e instanceof ParsingException pe) ? pe.getErrorMessage() : e.getMessage()
);
- return new UnresolvedAttribute(value.source(), String.valueOf(value.fold()), message);
+ return new UnresolvedAttribute(value.source(), String.valueOf(value.fold(FoldContext.small() /* TODO remove me */)), message);
}
private static Expression castStringLiteralToTemporalAmount(Expression from) {
try {
- TemporalAmount result = maybeParseTemporalAmount(from.fold().toString().strip());
+ TemporalAmount result = maybeParseTemporalAmount(from.fold(FoldContext.small() /* TODO remove me */).toString().strip());
if (result == null) {
return from;
}
@@ -1304,7 +1305,11 @@ private static Expression castStringLiteral(Expression from, DataType target) {
try {
return isTemporalAmount(target)
? castStringLiteralToTemporalAmount(from)
- : new Literal(from.source(), EsqlDataTypeConverter.convert(from.fold(), target), target);
+ : new Literal(
+ from.source(),
+ EsqlDataTypeConverter.convert(from.fold(FoldContext.small() /* TODO remove me */), target),
+ target
+ );
} catch (Exception e) {
return unresolvedAttribute(from, target.toString(), e);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java
index 9a2e9398f52fd..b9c2b92ea72dd 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java
@@ -23,6 +23,7 @@
import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
@@ -50,13 +51,23 @@ public final class EvalMapper {
private EvalMapper() {}
@SuppressWarnings({ "rawtypes", "unchecked" })
- public static ExpressionEvaluator.Factory toEvaluator(Expression exp, Layout layout) {
+ public static ExpressionEvaluator.Factory toEvaluator(FoldContext foldCtx, Expression exp, Layout layout) {
if (exp instanceof EvaluatorMapper m) {
- return m.toEvaluator(e -> toEvaluator(e, layout));
+ return m.toEvaluator(new EvaluatorMapper.ToEvaluator() {
+ @Override
+ public ExpressionEvaluator.Factory apply(Expression expression) {
+ return toEvaluator(foldCtx, expression, layout);
+ }
+
+ @Override
+ public FoldContext foldCtx() {
+ return foldCtx;
+ }
+ });
}
for (ExpressionMapper em : MAPPERS) {
if (em.typeToken.isInstance(exp)) {
- return em.map(exp, layout);
+ return em.map(foldCtx, exp, layout);
}
}
throw new QlIllegalArgumentException("Unsupported expression [{}]", exp);
@@ -64,9 +75,9 @@ public static ExpressionEvaluator.Factory toEvaluator(Expression exp, Layout lay
static class BooleanLogic extends ExpressionMapper {
@Override
- public ExpressionEvaluator.Factory map(BinaryLogic bc, Layout layout) {
- var leftEval = toEvaluator(bc.left(), layout);
- var rightEval = toEvaluator(bc.right(), layout);
+ public ExpressionEvaluator.Factory map(FoldContext foldCtx, BinaryLogic bc, Layout layout) {
+ var leftEval = toEvaluator(foldCtx, bc.left(), layout);
+ var rightEval = toEvaluator(foldCtx, bc.right(), layout);
/**
* Evaluator for the three-valued boolean expressions .
* We can't generate these with the {@link Evaluator} annotation because that
@@ -142,8 +153,8 @@ public void close() {
static class Nots extends ExpressionMapper {
@Override
- public ExpressionEvaluator.Factory map(Not not, Layout layout) {
- var expEval = toEvaluator(not.field(), layout);
+ public ExpressionEvaluator.Factory map(FoldContext foldCtx, Not not, Layout layout) {
+ var expEval = toEvaluator(foldCtx, not.field(), layout);
return dvrCtx -> new org.elasticsearch.xpack.esql.evaluator.predicate.operator.logical.NotEvaluator(
not.source(),
expEval.get(dvrCtx),
@@ -154,7 +165,7 @@ public ExpressionEvaluator.Factory map(Not not, Layout layout) {
static class Attributes extends ExpressionMapper {
@Override
- public ExpressionEvaluator.Factory map(Attribute attr, Layout layout) {
+ public ExpressionEvaluator.Factory map(FoldContext foldCtx, Attribute attr, Layout layout) {
record Attribute(int channel) implements ExpressionEvaluator {
@Override
public Block eval(Page page) {
@@ -189,7 +200,7 @@ public boolean eagerEvalSafeInLazy() {
static class Literals extends ExpressionMapper {
@Override
- public ExpressionEvaluator.Factory map(Literal lit, Layout layout) {
+ public ExpressionEvaluator.Factory map(FoldContext foldCtx, Literal lit, Layout layout) {
record LiteralsEvaluator(DriverContext context, Literal lit) implements ExpressionEvaluator {
@Override
public Block eval(Page page) {
@@ -246,8 +257,8 @@ private static Block block(Literal lit, BlockFactory blockFactory, int positions
static class IsNulls extends ExpressionMapper {
@Override
- public ExpressionEvaluator.Factory map(IsNull isNull, Layout layout) {
- var field = toEvaluator(isNull.field(), layout);
+ public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNull isNull, Layout layout) {
+ var field = toEvaluator(foldCtx, isNull.field(), layout);
return new IsNullEvaluatorFactory(field);
}
@@ -294,8 +305,8 @@ public String toString() {
static class IsNotNulls extends ExpressionMapper {
@Override
- public ExpressionEvaluator.Factory map(IsNotNull isNotNull, Layout layout) {
- return new IsNotNullEvaluatorFactory(toEvaluator(isNotNull.field(), layout));
+ public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNotNull isNotNull, Layout layout) {
+ return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout));
}
record IsNotNullEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory field) implements ExpressionEvaluator.Factory {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java
index d8692faef5290..5a8b3d32e7db0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java
@@ -7,11 +7,20 @@
package org.elasticsearch.xpack.esql.evaluator.mapper;
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.indices.breaker.AllCircuitBreakerStats;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
+import org.elasticsearch.indices.breaker.CircuitBreakerStats;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
import org.elasticsearch.xpack.esql.planner.Layout;
import static org.elasticsearch.compute.data.BlockUtils.fromArrayRow;
@@ -23,9 +32,12 @@
public interface EvaluatorMapper {
interface ToEvaluator {
ExpressionEvaluator.Factory apply(Expression expression);
+
+ FoldContext foldCtx();
}
/**
+ * Convert this into an {@link ExpressionEvaluator}.
*
* Note for implementors:
* If you are implementing this function, you should call the passed-in
@@ -35,8 +47,8 @@ interface ToEvaluator {
*
* Note for Callers:
* If you are attempting to call this method, and you have an
- * {@link Expression} and a {@link org.elasticsearch.xpack.esql.planner.Layout},
- * you likely want to call {@link org.elasticsearch.xpack.esql.evaluator.EvalMapper#toEvaluator(Expression, Layout)}
+ * {@link Expression} and a {@link Layout},
+ * you likely want to call {@link EvalMapper#toEvaluator}
* instead. On the other hand, if you already have something that
* looks like the parameter for this method, you should call this method
* with that function.
@@ -56,19 +68,89 @@ interface ToEvaluator {
/**
* Fold using {@link #toEvaluator} so you don't need a "by hand"
- * implementation of fold. The evaluator that it makes is "funny"
- * in that it'll always call {@link Expression#fold}, but that's
- * good enough.
+ * implementation of {@link Expression#fold}.
*/
- default Object fold() {
- return toJavaObject(toEvaluator(e -> driverContext -> new ExpressionEvaluator() {
+ default Object fold(Source source, FoldContext ctx) {
+ /*
+ * OK! So! We're going to build a bunch of *stuff* that so that we can
+ * call toEvaluator and use it without standing up an entire compute
+ * engine.
+ *
+ * Step 1 is creation of a `toEvaluator` which we'll soon use to turn
+ * the *children* of this Expression into ExpressionEvaluators. They
+ * have to be foldable or else we wouldn't have ended up here. So!
+ * We just call `fold` on them and turn the result of that into a
+ * Block.
+ *
+ * If the tree of expressions is pretty deep that `fold` call will
+ * likely end up being implemented by calling this method for the
+ * child. That's fine. Recursion is how you process trees.
+ */
+ ToEvaluator foldChildren = new ToEvaluator() {
@Override
- public Block eval(Page page) {
- return fromArrayRow(driverContext.blockFactory(), e.fold())[0];
+ public ExpressionEvaluator.Factory apply(Expression expression) {
+ return driverContext -> new ExpressionEvaluator() {
+ @Override
+ public Block eval(Page page) {
+ return fromArrayRow(driverContext.blockFactory(), expression.fold(ctx))[0];
+ }
+
+ @Override
+ public void close() {}
+ };
}
@Override
- public void close() {}
- }).get(DriverContext.getLocalDriver()).eval(new Page(1)), 0);
+ public FoldContext foldCtx() {
+ return ctx;
+ }
+ };
+
+ /*
+ * Step 2 is to create a DriverContext that we can pass to the above.
+ * This DriverContext is mostly about delegating to the FoldContext.
+ * That'll cause us to break if we attempt to allocate a huge amount
+ * of memory. Neat.
+ *
+ * Specifically, we make a CircuitBreaker view of the FoldContext, then
+ * we wrap it in a CircuitBreakerService so we can feed it to a BigArray
+ * so we can feed *that* into a DriverContext. It's a bit hacky, but
+ * that's what's going on here.
+ */
+ CircuitBreaker breaker = ctx.circuitBreakerView(source);
+ BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() {
+ @Override
+ public CircuitBreaker getBreaker(String name) {
+ if (name.equals(CircuitBreaker.REQUEST) == false) {
+ throw new UnsupportedOperationException();
+ }
+ return breaker;
+ }
+
+ @Override
+ public AllCircuitBreakerStats stats() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public CircuitBreakerStats stats(String name) {
+ throw new UnsupportedOperationException();
+ }
+ }, CircuitBreaker.REQUEST).withCircuitBreaking();
+ DriverContext driverCtx = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
+
+ /*
+ * Finally we can call toEvaluator on ourselves! It'll fold our children,
+ * convert the result into Blocks, and then we'll run that with the memory
+ * breaking DriverContext.
+ *
+ * Then, finally finally, we turn the result into a java object to be compatible
+ * with the signature of `fold`.
+ */
+ Block block = toEvaluator(foldChildren).get(driverCtx).eval(new Page(1));
+ if (block.getPositionCount() != 1) {
+ throw new IllegalStateException("generated odd block from fold [" + block + "]");
+ }
+ return toJavaObject(block, 0);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java
index 5cd830058573f..5a76080e7995c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java
@@ -9,6 +9,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
import org.elasticsearch.xpack.esql.planner.Layout;
@@ -19,5 +20,5 @@ public ExpressionMapper() {
typeToken = ReflectionUtils.detectSuperTypeForRuleLike(getClass());
}
- public abstract ExpressionEvaluator.Factory map(E expression, Layout layout);
+ public abstract ExpressionEvaluator.Factory map(FoldContext foldCtx, E expression, Layout layout);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
index 974f029eab2ef..94913581f696d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java
@@ -15,6 +15,7 @@
import org.elasticsearch.xpack.esql.action.EsqlQueryRequest;
import org.elasticsearch.xpack.esql.analysis.PreAnalyzer;
import org.elasticsearch.xpack.esql.analysis.Verifier;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
@@ -56,6 +57,7 @@ public void esql(
EsqlQueryRequest request,
String sessionId,
Configuration cfg,
+ FoldContext foldContext,
EnrichPolicyResolver enrichPolicyResolver,
EsqlExecutionInfo executionInfo,
IndicesExpressionGrouper indicesExpressionGrouper,
@@ -71,7 +73,7 @@ public void esql(
enrichPolicyResolver,
preAnalyzer,
functionRegistry,
- new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg)),
+ new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
mapper,
verifier,
planningMetrics,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java
index 011fcaccf7fe4..8aa7f697489c6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
@@ -92,7 +93,8 @@ public List extends Expression> parameters() {
}
public boolean hasFilter() {
- return filter != null && (filter.foldable() == false || Boolean.TRUE.equals(filter.fold()) == false);
+ return filter != null
+ && (filter.foldable() == false || Boolean.TRUE.equals(filter.fold(FoldContext.small() /* TODO remove me */)) == false);
}
public Expression filter() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java
index 7436db9e00dd2..3170ae8f132c2 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java
@@ -19,6 +19,7 @@
import org.elasticsearch.compute.aggregation.CountDistinctLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -210,7 +211,9 @@ protected TypeResolution resolveType() {
@Override
public AggregatorFunctionSupplier supplier(List inputChannels) {
DataType type = field().dataType();
- int precision = this.precision == null ? DEFAULT_PRECISION : ((Number) this.precision.fold()).intValue();
+ int precision = this.precision == null
+ ? DEFAULT_PRECISION
+ : ((Number) this.precision.fold(FoldContext.small() /* TODO remove me */)).intValue();
if (SUPPLIERS.containsKey(type) == false) {
// If the type checking did its job, this should never happen
throw EsqlIllegalArgumentException.illegalDataType(type);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java
index 0d57267da1e29..8c943c991d501 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java
@@ -16,6 +16,7 @@
import org.elasticsearch.compute.aggregation.PercentileIntAggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.PercentileLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -170,7 +171,7 @@ protected AggregatorFunctionSupplier doubleSupplier(List inputChannels)
}
private int percentileValue() {
- return ((Number) percentile.fold()).intValue();
+ return ((Number) percentile.fold(FoldContext.small() /* TODO remove me */)).intValue();
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java
index 87ac9b77a6826..85ae65b6c5dc3 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java
@@ -18,6 +18,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
@@ -156,7 +157,7 @@ long unitInMillis() {
}
final Object foldValue;
try {
- foldValue = unit.fold();
+ foldValue = unit.fold(FoldContext.small() /* TODO remove me */);
} catch (Exception e) {
throw new IllegalArgumentException("function [" + sourceText() + "] has invalid unit [" + unit.sourceText() + "]");
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
index 40777b4d78dc2..9be8c94266ee8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
@@ -21,6 +21,7 @@
import org.elasticsearch.compute.aggregation.TopLongAggregatorFunctionSupplier;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -115,11 +116,11 @@ Expression orderField() {
}
private int limitValue() {
- return (int) limitField().fold();
+ return (int) limitField().fold(FoldContext.small() /* TODO remove me */);
}
private String orderRawValue() {
- return BytesRefs.toString(orderField().fold());
+ return BytesRefs.toString(orderField().fold(FoldContext.small() /* TODO remove me */));
}
private boolean orderValue() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java
index 56d034a2eae1d..bab65653ba576 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java
@@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -114,9 +115,15 @@ protected Expression.TypeResolution resolveType() {
return resolution;
}
- if (weight.dataType() == DataType.NULL
- || (weight.foldable() && (weight.fold() == null || weight.fold().equals(0) || weight.fold().equals(0.0)))) {
- return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), weight.foldable() ? weight.fold() : null));
+ if (weight.dataType() == DataType.NULL) {
+ return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), null));
+ }
+ if (weight.foldable() == false) {
+ return TypeResolution.TYPE_RESOLVED;
+ }
+ Object weightVal = weight.fold(FoldContext.small()/* TODO remove me*/);
+ if (weightVal == null || weightVal.equals(0) || weightVal.equals(0.0)) {
+ return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), weightVal));
}
return TypeResolution.TYPE_RESOLVED;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
index 07c4bb282ba71..4da7c01139c24 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.TranslationAware;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
@@ -111,7 +112,7 @@ public Expression query() {
* @return query expression as an object
*/
public Object queryAsObject() {
- Object queryAsObject = query().fold();
+ Object queryAsObject = query().fold(FoldContext.small() /* TODO remove me */);
if (queryAsObject instanceof BytesRef bytesRef) {
return bytesRef.utf8ToString();
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java
index ce23860dbdba7..7552b100119f0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java
@@ -18,6 +18,7 @@
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.planner.ExpressionTranslator;
import org.elasticsearch.xpack.esql.core.querydsl.query.QueryStringQuery;
@@ -222,7 +223,7 @@ public void postLogicalOptimizationVerification(Failures failures) {
@Override
public Object queryAsObject() {
- Object queryAsObject = query().fold();
+ Object queryAsObject = query().fold(FoldContext.small() /* TODO remove me */);
// Convert BytesRef to string for string-based values
if (queryAsObject instanceof BytesRef bytesRef) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java
index 113989323eff2..7a3e080f5c830 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java
@@ -18,6 +18,7 @@
import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Foldables;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
@@ -255,25 +256,25 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) {
Rounding.Prepared preparedRounding;
if (buckets.dataType().isWholeNumber()) {
- int b = ((Number) buckets.fold()).intValue();
- long f = foldToLong(from);
- long t = foldToLong(to);
+ int b = ((Number) buckets.fold(toEvaluator.foldCtx())).intValue();
+ long f = foldToLong(toEvaluator.foldCtx(), from);
+ long t = foldToLong(toEvaluator.foldCtx(), to);
preparedRounding = new DateRoundingPicker(b, f, t).pickRounding().prepareForUnknown();
} else {
assert DataType.isTemporalAmount(buckets.dataType()) : "Unexpected span data type [" + buckets.dataType() + "]";
- preparedRounding = DateTrunc.createRounding(buckets.fold(), DEFAULT_TZ);
+ preparedRounding = DateTrunc.createRounding(buckets.fold(toEvaluator.foldCtx()), DEFAULT_TZ);
}
return DateTrunc.evaluator(field.dataType(), source(), toEvaluator.apply(field), preparedRounding);
}
if (field.dataType().isNumeric()) {
double roundTo;
if (from != null) {
- int b = ((Number) buckets.fold()).intValue();
- double f = ((Number) from.fold()).doubleValue();
- double t = ((Number) to.fold()).doubleValue();
+ int b = ((Number) buckets.fold(toEvaluator.foldCtx())).intValue();
+ double f = ((Number) from.fold(toEvaluator.foldCtx())).doubleValue();
+ double t = ((Number) to.fold(toEvaluator.foldCtx())).doubleValue();
roundTo = pickRounding(b, f, t);
} else {
- roundTo = ((Number) buckets.fold()).doubleValue();
+ roundTo = ((Number) buckets.fold(toEvaluator.foldCtx())).doubleValue();
}
Literal rounding = new Literal(source(), roundTo, DataType.DOUBLE);
@@ -416,8 +417,8 @@ public void postLogicalOptimizationVerification(Failures failures) {
.add(to != null ? isFoldable(to, operation, FOURTH) : null);
}
- private long foldToLong(Expression e) {
- Object value = Foldables.valueOf(e);
+ private long foldToLong(FoldContext ctx, Expression e) {
+ Object value = Foldables.valueOf(ctx, e);
return DataType.isDateTime(e.dataType()) ? ((Number) value).longValue() : dateTimeToLong(((BytesRef) value).utf8ToString());
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java
index 0fee65d32ca98..fd025e5e67a7c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java
@@ -10,6 +10,7 @@
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
@@ -28,8 +29,8 @@ protected GroupingFunction(Source source, List fields) {
}
@Override
- public Object fold() {
- return EvaluatorMapper.super.fold();
+ public Object fold(FoldContext ctx) {
+ return EvaluatorMapper.super.fold(source(), ctx);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java
index 404ce7e3900c9..85d15f82f458a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.expression.function.scalar;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
@@ -34,7 +35,7 @@ protected EsqlScalarFunction(Source source, List fields) {
}
@Override
- public Object fold() {
- return EvaluatorMapper.super.fold();
+ public Object fold(FoldContext ctx) {
+ return EvaluatorMapper.super.fold(source(), ctx);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java
index 824f02ca7ccbb..236e625f7abe1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java
@@ -23,6 +23,7 @@
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
@@ -227,7 +228,7 @@ public boolean foldable() {
if (condition.condition.foldable() == false) {
return false;
}
- if (Boolean.TRUE.equals(condition.condition.fold())) {
+ if (Boolean.TRUE.equals(condition.condition.fold(FoldContext.small() /* TODO remove me - use literal true?*/))) {
/*
* `fold` can make four things here:
* 1. `TRUE`
@@ -264,7 +265,8 @@ public boolean foldable() {
* And those two combine so {@code EVAL c=CASE(false, foo, b, bar, true, bort, el)} becomes
* {@code EVAL c=CASE(b, bar, bort)}.
*/
- public Expression partiallyFold() {
+ public Expression partiallyFold(FoldContext ctx) {
+ // TODO don't throw away the results of any `fold`. That might mean looking for literal TRUE on the conditions.
List newChildren = new ArrayList<>(children().size());
boolean modified = false;
for (Condition condition : conditions) {
@@ -274,7 +276,7 @@ public Expression partiallyFold() {
continue;
}
modified = true;
- if (Boolean.TRUE.equals(condition.condition.fold())) {
+ if (Boolean.TRUE.equals(condition.condition.fold(ctx))) {
/*
* `fold` can make four things here:
* 1. `TRUE`
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java
index 842e899ebdac6..57f362f86ff4c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java
@@ -11,6 +11,7 @@
import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -65,8 +66,8 @@ protected final Map factories() {
}
@Override
- public final Object fold() {
- return foldToTemporalAmount(field(), sourceText(), dataType());
+ public final Object fold(FoldContext ctx) {
+ return foldToTemporalAmount(ctx, field(), sourceText(), dataType());
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java
index f6a23a5d5962e..b588832aba4cb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java
@@ -232,7 +232,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (unit.foldable()) {
try {
- Part datePartField = Part.resolve(((BytesRef) unit.fold()).utf8ToString());
+ Part datePartField = Part.resolve(((BytesRef) unit.fold(toEvaluator.foldCtx())).utf8ToString());
return new DateDiffConstantEvaluator.Factory(source(), datePartField, startTimestampEvaluator, endTimestampEvaluator);
} catch (IllegalArgumentException e) {
throw new InvalidArgumentException("invalid unit format for [{}]: {}", sourceText(), e.getMessage());
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java
index 501dfd431f106..7fc5d82441802 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java
@@ -16,6 +16,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -110,9 +111,9 @@ public String getWriteableName() {
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
var fieldEvaluator = toEvaluator.apply(children().get(1));
if (children().get(0).foldable()) {
- ChronoField chrono = chronoField();
+ ChronoField chrono = chronoField(toEvaluator.foldCtx());
if (chrono == null) {
- BytesRef field = (BytesRef) children().get(0).fold();
+ BytesRef field = (BytesRef) children().get(0).fold(toEvaluator.foldCtx());
throw new InvalidArgumentException("invalid date field for [{}]: {}", sourceText(), field.utf8ToString());
}
return new DateExtractConstantEvaluator.Factory(source(), fieldEvaluator, chrono, configuration().zoneId());
@@ -121,14 +122,14 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
return new DateExtractEvaluator.Factory(source(), fieldEvaluator, chronoEvaluator, configuration().zoneId());
}
- private ChronoField chronoField() {
+ private ChronoField chronoField(FoldContext ctx) {
// chronoField's never checked (the return is). The foldability test is done twice and type is checked in resolveType() already.
// TODO: move the slimmed down code here to toEvaluator?
if (chronoField == null) {
Expression field = children().get(0);
try {
if (field.foldable() && DataType.isString(field.dataType())) {
- chronoField = (ChronoField) STRING_TO_CHRONO_FIELD.convert(field.fold());
+ chronoField = (ChronoField) STRING_TO_CHRONO_FIELD.convert(field.fold(ctx));
}
} catch (Exception e) {
return null;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java
index 920a3bb1f4a13..29648d55cadd8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java
@@ -147,7 +147,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
throw new IllegalArgumentException("unsupported data type for format [" + format.dataType() + "]");
}
if (format.foldable()) {
- DateFormatter formatter = toFormatter(format.fold(), configuration().locale());
+ DateFormatter formatter = toFormatter(format.fold(toEvaluator.foldCtx()), configuration().locale());
return new DateFormatConstantEvaluator.Factory(source(), fieldEvaluator, formatter);
}
var formatEvaluator = toEvaluator.apply(format);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java
index e09fabab98d0f..7c38b54ed232b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java
@@ -143,7 +143,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
}
if (format.foldable()) {
try {
- DateFormatter formatter = toFormatter(format.fold());
+ DateFormatter formatter = toFormatter(format.fold(toEvaluator.foldCtx()));
return new DateParseConstantEvaluator.Factory(source(), fieldEvaluator, formatter);
} catch (IllegalArgumentException e) {
throw new InvalidArgumentException(e, "invalid date pattern for [{}]: {}", sourceText(), e.getMessage());
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java
index a35b67d7ac3fd..7983c38cc4288 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java
@@ -225,7 +225,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
}
Object foldedInterval;
try {
- foldedInterval = interval.fold();
+ foldedInterval = interval.fold(toEvaluator.foldCtx());
if (foldedInterval == null) {
throw new IllegalArgumentException("Interval cannot not be null");
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java
index d259fc6ae57ce..74c2da450995c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java
@@ -14,6 +14,7 @@
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -59,7 +60,7 @@ public String getWriteableName() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
return now;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java
index 757b67b47ce72..e1eceef7ed1f5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java
@@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
@@ -49,7 +50,7 @@ public String getWriteableName() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
return Math.E;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java
index 90a4f1f091e91..32b7a0ab88b4e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java
@@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
@@ -49,7 +50,7 @@ public String getWriteableName() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
return Math.PI;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java
index 17e5b027270d1..1a7669b7391e1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java
@@ -11,6 +11,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
@@ -51,7 +52,7 @@ public String getWriteableName() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
return TAU;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java
index 1996744a76567..26211258e6ca6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java
@@ -17,6 +17,7 @@
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
@@ -91,8 +92,8 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
}
@Override
- public Object fold() {
- return EvaluatorMapper.super.fold();
+ public Object fold(FoldContext ctx) {
+ return EvaluatorMapper.super.fold(source(), ctx);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java
index 4dd447f938880..d5093964145b7 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java
@@ -115,7 +115,7 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
source(),
toEvaluator.apply(field),
ctx -> new CompensatedSum(),
- (Double) p.fold()
+ (Double) p.fold(toEvaluator.foldCtx())
);
case NULL -> EvalOperator.CONSTANT_NULL_FACTORY;
default -> throw EsqlIllegalArgumentException.illegalDataType(field.dataType());
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java
index f4f9679dc3704..4a04524d1b23d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java
@@ -187,8 +187,8 @@ public boolean foldable() {
@Override
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (start.foldable() && end.foldable()) {
- int startOffset = stringToInt(String.valueOf(start.fold()));
- int endOffset = stringToInt(String.valueOf(end.fold()));
+ int startOffset = stringToInt(String.valueOf(start.fold(toEvaluator.foldCtx())));
+ int endOffset = stringToInt(String.valueOf(end.fold(toEvaluator.foldCtx())));
checkStartEnd(startOffset, endOffset);
}
return switch (PlannerUtils.toElementType(field.dataType())) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java
index 86538c828ece7..b68718acfcd0a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java
@@ -33,6 +33,7 @@
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -155,12 +156,12 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
sourceText(),
ASC.value(),
DESC.value(),
- ((BytesRef) order.fold()).utf8ToString()
+ ((BytesRef) order.fold(toEvaluator.foldCtx())).utf8ToString()
)
);
}
if (order != null && order.foldable()) {
- ordering = ((BytesRef) order.fold()).utf8ToString().equalsIgnoreCase((String) ASC.value());
+ ordering = ((BytesRef) order.fold(toEvaluator.foldCtx())).utf8ToString().equalsIgnoreCase((String) ASC.value());
}
return switch (PlannerUtils.toElementType(field.dataType())) {
@@ -238,7 +239,14 @@ public void postLogicalOptimizationVerification(Failures failures) {
failures.add(isFoldable(order, operation, SECOND));
if (isValidOrder() == false) {
failures.add(
- Failure.fail(order, INVALID_ORDER_ERROR, sourceText(), ASC.value(), DESC.value(), ((BytesRef) order.fold()).utf8ToString())
+ Failure.fail(
+ order,
+ INVALID_ORDER_ERROR,
+ sourceText(),
+ ASC.value(),
+ DESC.value(),
+ ((BytesRef) order.fold(FoldContext.small() /* TODO remove me */)).utf8ToString()
+ )
);
}
}
@@ -246,7 +254,7 @@ public void postLogicalOptimizationVerification(Failures failures) {
private boolean isValidOrder() {
boolean isValidOrder = true;
if (order != null && order.foldable()) {
- Object obj = order.fold();
+ Object obj = order.fold(FoldContext.small() /* TODO remove me */);
String o = null;
if (obj instanceof BytesRef ob) {
o = ob.utf8ToString();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java
index 9189c6a7b8f70..b15d04aa792d9 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java
@@ -26,6 +26,7 @@
import org.elasticsearch.lucene.spatial.CoordinateEncoder;
import org.elasticsearch.lucene.spatial.GeometryDocValueReader;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -214,10 +215,10 @@ protected NodeInfo extends Expression> info() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
try {
- GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left());
- Geometry rightGeom = makeGeometryFromLiteral(right());
+ GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left());
+ Geometry rightGeom = makeGeometryFromLiteral(ctx, right());
Component2D[] components = asLuceneComponent2Ds(crsType(), rightGeom);
return (crsType() == SpatialCrsType.GEO)
? GEO.geometryRelatesGeometries(docValueReader, components)
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java
index ee78f50c4d6bd..3e16fa163fcd6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java
@@ -23,6 +23,7 @@
import org.elasticsearch.lucene.spatial.CoordinateEncoder;
import org.elasticsearch.lucene.spatial.GeometryDocValueReader;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -129,10 +130,10 @@ protected NodeInfo extends Expression> info() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
try {
- GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left());
- Component2D component2D = asLuceneComponent2D(crsType(), right());
+ GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left());
+ Component2D component2D = asLuceneComponent2D(ctx, crsType(), right());
return (crsType() == SpatialCrsType.GEO)
? GEO.geometryRelatesGeometry(docValueReader, component2D)
: CARTESIAN.geometryRelatesGeometry(docValueReader, component2D);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java
index 1a51af8dfeeb4..dcd53075cf69c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java
@@ -171,7 +171,11 @@ protected static class SpatialEvaluatorWithConstantFactory extends SpatialEvalua
@Override
public EvalOperator.ExpressionEvaluator.Factory get(SpatialSourceSupplier s, EvaluatorMapper.ToEvaluator toEvaluator) {
- return factoryCreator.apply(s.source(), toEvaluator.apply(s.left()), asLuceneComponent2D(s.crsType(), s.right()));
+ return factoryCreator.apply(
+ s.source(),
+ toEvaluator.apply(s.left()),
+ asLuceneComponent2D(toEvaluator.foldCtx(), s.crsType(), s.right())
+ );
}
}
@@ -197,7 +201,11 @@ protected static class SpatialEvaluatorWithConstantArrayFactory extends SpatialE
@Override
public EvalOperator.ExpressionEvaluator.Factory get(SpatialSourceSupplier s, EvaluatorMapper.ToEvaluator toEvaluator) {
- return factoryCreator.apply(s.source(), toEvaluator.apply(s.left()), asLuceneComponent2Ds(s.crsType(), s.right()));
+ return factoryCreator.apply(
+ s.source(),
+ toEvaluator.apply(s.left()),
+ asLuceneComponent2Ds(toEvaluator.foldCtx(), s.crsType(), s.right())
+ );
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java
index 8d54e5ee443c2..601550cd173bb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java
@@ -23,6 +23,7 @@
import org.elasticsearch.lucene.spatial.CoordinateEncoder;
import org.elasticsearch.lucene.spatial.GeometryDocValueReader;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -127,10 +128,10 @@ protected NodeInfo extends Expression> info() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
try {
- GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left());
- Component2D component2D = asLuceneComponent2D(crsType(), right());
+ GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left());
+ Component2D component2D = asLuceneComponent2D(ctx, crsType(), right());
return (crsType() == SpatialCrsType.GEO)
? GEO.geometryRelatesGeometry(docValueReader, component2D)
: CARTESIAN.geometryRelatesGeometry(docValueReader, component2D);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java
index 6ae99ea8165cd..1b06c6dfd3dd5 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java
@@ -29,6 +29,7 @@
import org.elasticsearch.lucene.spatial.GeometryDocValueReader;
import org.elasticsearch.lucene.spatial.GeometryDocValueWriter;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes;
@@ -42,8 +43,8 @@
public class SpatialRelatesUtils {
/** Converts a {@link Expression} into a {@link Component2D}. */
- static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) {
- return asLuceneComponent2D(crsType, makeGeometryFromLiteral(expression));
+ static Component2D asLuceneComponent2D(FoldContext ctx, BinarySpatialFunction.SpatialCrsType crsType, Expression expression) {
+ return asLuceneComponent2D(crsType, makeGeometryFromLiteral(ctx, expression));
}
/** Converts a {@link Geometry} into a {@link Component2D}. */
@@ -66,8 +67,8 @@ static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType type
* Converts a {@link Expression} at a given {@code position} into a {@link Component2D} array.
* The reason for generating an array instead of a single component is for multi-shape support with ST_CONTAINS.
*/
- static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) {
- return asLuceneComponent2Ds(crsType, makeGeometryFromLiteral(expression));
+ static Component2D[] asLuceneComponent2Ds(FoldContext ctx, BinarySpatialFunction.SpatialCrsType crsType, Expression expression) {
+ return asLuceneComponent2Ds(crsType, makeGeometryFromLiteral(ctx, expression));
}
/**
@@ -90,9 +91,12 @@ static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType t
}
/** Converts a {@link Expression} into a {@link GeometryDocValueReader} */
- static GeometryDocValueReader asGeometryDocValueReader(BinarySpatialFunction.SpatialCrsType crsType, Expression expression)
- throws IOException {
- Geometry geometry = makeGeometryFromLiteral(expression);
+ static GeometryDocValueReader asGeometryDocValueReader(
+ FoldContext ctx,
+ BinarySpatialFunction.SpatialCrsType crsType,
+ Expression expression
+ ) throws IOException {
+ Geometry geometry = makeGeometryFromLiteral(ctx, expression);
if (crsType == BinarySpatialFunction.SpatialCrsType.GEO) {
return asGeometryDocValueReader(
CoordinateEncoder.GEO,
@@ -167,8 +171,8 @@ private static Geometry asGeometry(BytesRefBlock valueBlock, int position) {
* This function is used in two places, when evaluating a spatial constant in the SpatialRelatesFunction, as well as when
* we do lucene-pushdown of spatial functions.
*/
- public static Geometry makeGeometryFromLiteral(Expression expr) {
- return makeGeometryFromLiteralValue(valueOf(expr), expr.dataType());
+ public static Geometry makeGeometryFromLiteral(FoldContext ctx, Expression expr) {
+ return makeGeometryFromLiteralValue(valueOf(ctx, expr), expr.dataType());
}
private static Geometry makeGeometryFromLiteralValue(Object value, DataType dataType) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java
index 2005709cd37e9..9fcece1ce65bc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java
@@ -23,6 +23,7 @@
import org.elasticsearch.lucene.spatial.CoordinateEncoder;
import org.elasticsearch.lucene.spatial.GeometryDocValueReader;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -129,10 +130,10 @@ protected NodeInfo extends Expression> info() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
try {
- GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left());
- Component2D component2D = asLuceneComponent2D(crsType(), right());
+ GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left());
+ Component2D component2D = asLuceneComponent2D(ctx, crsType(), right());
return (crsType() == SpatialCrsType.GEO)
? GEO.geometryRelatesGeometry(docValueReader, component2D)
: CARTESIAN.geometryRelatesGeometry(docValueReader, component2D);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java
index 3cf042a2db828..f0c25e3289cc1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java
@@ -23,6 +23,7 @@
import org.elasticsearch.lucene.spatial.CoordinateEncoder;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -280,18 +281,18 @@ protected NodeInfo extends Expression> info() {
}
@Override
- public Object fold() {
- var leftGeom = makeGeometryFromLiteral(left());
- var rightGeom = makeGeometryFromLiteral(right());
+ public Object fold(FoldContext ctx) {
+ var leftGeom = makeGeometryFromLiteral(ctx, left());
+ var rightGeom = makeGeometryFromLiteral(ctx, right());
return (crsType() == SpatialCrsType.GEO) ? GEO.distance(leftGeom, rightGeom) : CARTESIAN.distance(leftGeom, rightGeom);
}
@Override
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (right().foldable()) {
- return toEvaluator(toEvaluator, left(), makeGeometryFromLiteral(right()), leftDocValues);
+ return toEvaluator(toEvaluator, left(), makeGeometryFromLiteral(toEvaluator.foldCtx(), right()), leftDocValues);
} else if (left().foldable()) {
- return toEvaluator(toEvaluator, right(), makeGeometryFromLiteral(left()), rightDocValues);
+ return toEvaluator(toEvaluator, right(), makeGeometryFromLiteral(toEvaluator.foldCtx(), left()), rightDocValues);
} else {
EvalOperator.ExpressionEvaluator.Factory leftE = toEvaluator.apply(left());
EvalOperator.ExpressionEvaluator.Factory rightE = toEvaluator.apply(right());
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java
index 52d33c0fc9d3d..be0a7b2fe27b2 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java
@@ -146,7 +146,7 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
if (algorithm.foldable()) {
try {
// hash function is created here in order to validate the algorithm is valid before evaluator is created
- var hf = HashFunction.create((BytesRef) algorithm.fold());
+ var hf = HashFunction.create((BytesRef) algorithm.fold(toEvaluator.foldCtx()));
return new HashConstantEvaluator.Factory(
source(),
context -> new BreakingBytesRefBuilder(context.breaker(), "hash"),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java
index 996c90a8e40bc..fb0aac0c85b38 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java
@@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -94,8 +95,8 @@ protected TypeResolution resolveType() {
}
@Override
- public Boolean fold() {
- return (Boolean) EvaluatorMapper.super.fold();
+ public Boolean fold(FoldContext ctx) {
+ return (Boolean) EvaluatorMapper.super.fold(source(), ctx);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java
index e91f03de3dd7e..363991d1556f1 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java
@@ -151,7 +151,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
ExpressionEvaluator.Factory strExpr = toEvaluator.apply(str);
if (number.foldable()) {
- int num = (int) number.fold();
+ int num = (int) number.fold(toEvaluator.foldCtx());
if (num < 0) {
throw new IllegalArgumentException("Number parameter cannot be negative, found [" + number + "]");
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java
index 4fa191244cb42..4b963b794aef0 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java
@@ -152,7 +152,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (regex.foldable() && regex.dataType() == DataType.KEYWORD) {
Pattern regexPattern;
try {
- regexPattern = Pattern.compile(((BytesRef) regex.fold()).utf8ToString());
+ regexPattern = Pattern.compile(((BytesRef) regex.fold(toEvaluator.foldCtx())).utf8ToString());
} catch (PatternSyntaxException pse) {
// TODO this is not right (inconsistent). See also https://github.com/elastic/elasticsearch/issues/100038
// this should generate a header warning and return null (as do the rest of this functionality in evaluators),
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java
index 3b9a466966911..e46c0a730431d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java
@@ -113,7 +113,7 @@ protected NodeInfo extends Expression> info() {
@Override
public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (field.foldable()) {
- Object folded = field.fold();
+ Object folded = field.fold(toEvaluator.foldCtx());
if (folded instanceof Integer num) {
checkNumber(num);
return toEvaluator.apply(new Literal(source(), " ".repeat(num), KEYWORD));
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java
index 24762122f755b..d0c1035978ff3 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java
@@ -17,6 +17,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -102,8 +103,8 @@ public boolean foldable() {
}
@Override
- public Object fold() {
- return EvaluatorMapper.super.fold();
+ public Object fold(FoldContext ctx) {
+ return EvaluatorMapper.super.fold(source(), ctx);
}
@Evaluator(extraName = "SingleByte")
@@ -163,7 +164,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
if (right().foldable() == false) {
return new SplitVariableEvaluator.Factory(source(), str, toEvaluator.apply(right()), context -> new BytesRef());
}
- BytesRef delim = (BytesRef) right().fold();
+ BytesRef delim = (BytesRef) right().fold(toEvaluator.foldCtx());
checkDelimiter(delim);
return new SplitSingleByteEvaluator.Factory(source(), str, delim.bytes[delim.offset], context -> new BytesRef());
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java
index d2edb0f92e8f2..65455c708cc9b 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java
@@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -100,8 +101,8 @@ protected TypeResolution resolveType() {
}
@Override
- public Boolean fold() {
- return (Boolean) EvaluatorMapper.super.fold();
+ public Boolean fold(FoldContext ctx) {
+ return (Boolean) EvaluatorMapper.super.fold(source(), ctx);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java
index 1d03f09c86409..3b17133bf4974 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java
@@ -14,6 +14,7 @@
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -84,15 +85,15 @@ public boolean foldable() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
return null;
}
- private long msValue() {
+ private long msValue(FoldContext ctx) {
if (field().foldable() == false) {
throw new IllegalArgumentException("function [" + sourceText() + "] has invalid argument [" + field().sourceText() + "]");
}
- var ms = field().fold();
+ var ms = field().fold(ctx);
if (ms instanceof Duration duration) {
return duration.toMillis();
}
@@ -101,7 +102,7 @@ private long msValue() {
@Override
public ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
- return new DelayEvaluator.Factory(source(), msValue());
+ return new DelayEvaluator.Factory(source(), msValue(toEvaluator.foldCtx()));
}
@Evaluator
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java
index 8bb166fac60bb..424c080c905e3 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java
@@ -11,6 +11,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.ExceptionUtils;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -111,7 +112,7 @@ protected TypeResolution checkCompatibility() {
/**
* Override this to allow processing literals of type {@link DataType#DATE_PERIOD} when folding constants.
- * Used in {@link DateTimeArithmeticOperation#fold()}.
+ * Used in {@link DateTimeArithmeticOperation#fold}.
* @param left the left period
* @param right the right period
* @return the result of the evaluation
@@ -120,7 +121,7 @@ protected TypeResolution checkCompatibility() {
/**
* Override this to allow processing literals of type {@link DataType#TIME_DURATION} when folding constants.
- * Used in {@link DateTimeArithmeticOperation#fold()}.
+ * Used in {@link DateTimeArithmeticOperation#fold}.
* @param left the left duration
* @param right the right duration
* @return the result of the evaluation
@@ -128,13 +129,13 @@ protected TypeResolution checkCompatibility() {
abstract Duration fold(Duration left, Duration right);
@Override
- public final Object fold() {
+ public final Object fold(FoldContext ctx) {
DataType leftDataType = left().dataType();
DataType rightDataType = right().dataType();
if (leftDataType == DATE_PERIOD && rightDataType == DATE_PERIOD) {
// Both left and right expressions are temporal amounts; we can assume they are both foldable.
- var l = left().fold();
- var r = right().fold();
+ var l = left().fold(ctx);
+ var r = right().fold(ctx);
if (l instanceof Collection> || r instanceof Collection>) {
return null;
}
@@ -148,8 +149,8 @@ public final Object fold() {
}
if (leftDataType == TIME_DURATION && rightDataType == TIME_DURATION) {
// Both left and right expressions are temporal amounts; we can assume they are both foldable.
- Duration l = (Duration) left().fold();
- Duration r = (Duration) right().fold();
+ Duration l = (Duration) left().fold(ctx);
+ Duration r = (Duration) right().fold(ctx);
try {
return fold(l, r);
} catch (ArithmeticException e) {
@@ -161,7 +162,7 @@ public final Object fold() {
if (isNull(leftDataType) || isNull(rightDataType)) {
return null;
}
- return super.fold();
+ return super.fold(ctx);
}
@Override
@@ -178,7 +179,11 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
temporalAmountArgument = left();
}
- return millisEvaluator.apply(source(), toEvaluator.apply(datetimeArgument), (TemporalAmount) temporalAmountArgument.fold());
+ return millisEvaluator.apply(
+ source(),
+ toEvaluator.apply(datetimeArgument),
+ (TemporalAmount) temporalAmountArgument.fold(toEvaluator.foldCtx())
+ );
} else if (dataType() == DATE_NANOS) {
// One of the arguments has to be a date_nanos and the other a temporal amount.
Expression dateNanosArgument;
@@ -191,7 +196,11 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
temporalAmountArgument = left();
}
- return nanosEvaluator.apply(source(), toEvaluator.apply(dateNanosArgument), (TemporalAmount) temporalAmountArgument.fold());
+ return nanosEvaluator.apply(
+ source(),
+ toEvaluator.apply(dateNanosArgument),
+ (TemporalAmount) temporalAmountArgument.fold(toEvaluator.foldCtx())
+ );
} else {
return super.toEvaluator(toEvaluator);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java
index 74394d796855f..e3248665ad486 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java
@@ -12,6 +12,7 @@
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.BinaryArithmeticOperation;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -120,8 +121,8 @@ public interface BinaryEvaluator {
}
@Override
- public Object fold() {
- return EvaluatorMapper.super.fold();
+ public Object fold(FoldContext ctx) {
+ return EvaluatorMapper.super.fold(source(), ctx);
}
public DataType dataType() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java
index fb32282005f02..6663ccf0ef7b6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java
@@ -14,6 +14,7 @@
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.ExceptionUtils;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -87,12 +88,12 @@ else if (type == DataType.LONG) {
}
@Override
- public final Object fold() {
+ public final Object fold(FoldContext ctx) {
DataType dataType = field().dataType();
// For date periods and time durations, we need to treat folding differently. These types are unrepresentable, so there is no
// evaluator for them - but the default folding requires an evaluator.
if (dataType == DATE_PERIOD) {
- Period fieldValue = (Period) field().fold();
+ Period fieldValue = (Period) field().fold(ctx);
try {
return fieldValue.negated();
} catch (ArithmeticException e) {
@@ -102,7 +103,7 @@ public final Object fold() {
}
}
if (dataType == TIME_DURATION) {
- Duration fieldValue = (Duration) field().fold();
+ Duration fieldValue = (Duration) field().fold(ctx);
try {
return fieldValue.negated();
} catch (ArithmeticException e) {
@@ -111,7 +112,7 @@ public final Object fold() {
throw ExceptionUtils.math(source(), e);
}
}
- return super.fold();
+ return super.fold(ctx);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java
index 3e2a21664aa7e..e56c19b26a902 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java
@@ -13,6 +13,7 @@
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -204,8 +205,8 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
}
@Override
- public Boolean fold() {
- return (Boolean) EvaluatorMapper.super.fold();
+ public Boolean fold(FoldContext ctx) {
+ return (Boolean) EvaluatorMapper.super.fold(source(), ctx);
}
@Override
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java
index f596d589cdde2..2061c2626aa45 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java
@@ -17,6 +17,7 @@
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Comparisons;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -205,11 +206,11 @@ public boolean foldable() {
}
@Override
- public Object fold() {
+ public Object fold(FoldContext ctx) {
if (Expressions.isGuaranteedNull(value) || list.stream().allMatch(Expressions::isGuaranteedNull)) {
return null;
}
- return super.fold();
+ return super.fold(ctx);
}
protected boolean areCompatible(DataType left, DataType right) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java
index c731e44197f2e..01564644bf5c7 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java
@@ -16,6 +16,7 @@
import org.elasticsearch.compute.ann.Evaluator;
import org.elasticsearch.compute.ann.Fixed;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -77,9 +78,9 @@ public static Automaton automaton(BytesRef val) {
}
@Override
- public Boolean fold() {
- BytesRef leftVal = BytesRefs.toBytesRef(left().fold());
- BytesRef rightVal = BytesRefs.toBytesRef(right().fold());
+ public Boolean fold(FoldContext ctx) {
+ BytesRef leftVal = BytesRefs.toBytesRef(left().fold(ctx));
+ BytesRef rightVal = BytesRefs.toBytesRef(right().fold(ctx));
if (leftVal == null || rightVal == null) {
return null;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java
index f5704239993f9..7ea95c764f36c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java
@@ -14,6 +14,7 @@
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.evaluator.mapper.ExpressionMapper;
@@ -28,15 +29,15 @@ public class InsensitiveEqualsMapper extends ExpressionMapper
InsensitiveEqualsEvaluator.Factory::new;
@Override
- public final ExpressionEvaluator.Factory map(InsensitiveEquals bc, Layout layout) {
+ public final ExpressionEvaluator.Factory map(FoldContext foldCtx, InsensitiveEquals bc, Layout layout) {
DataType leftType = bc.left().dataType();
DataType rightType = bc.right().dataType();
- var leftEval = toEvaluator(bc.left(), layout);
- var rightEval = toEvaluator(bc.right(), layout);
+ var leftEval = toEvaluator(foldCtx, bc.left(), layout);
+ var rightEval = toEvaluator(foldCtx, bc.right(), layout);
if (DataType.isString(leftType)) {
if (bc.right().foldable() && DataType.isString(rightType)) {
- BytesRef rightVal = BytesRefs.toBytesRef(bc.right().fold());
+ BytesRef rightVal = BytesRefs.toBytesRef(bc.right().fold(FoldContext.small() /* TODO remove me */));
Automaton automaton = InsensitiveEquals.automaton(rightVal);
return dvrCtx -> new InsensitiveEqualsConstantEvaluator(
bc.source(),
@@ -51,13 +52,14 @@ public final ExpressionEvaluator.Factory map(InsensitiveEquals bc, Layout layout
}
public static ExpressionEvaluator.Factory castToEvaluator(
+ FoldContext foldCtx,
InsensitiveEquals op,
Layout layout,
DataType required,
TriFunction factory
) {
- var lhs = Cast.cast(op.source(), op.left().dataType(), required, toEvaluator(op.left(), layout));
- var rhs = Cast.cast(op.source(), op.right().dataType(), required, toEvaluator(op.right(), layout));
+ var lhs = Cast.cast(op.source(), op.left().dataType(), required, toEvaluator(foldCtx, op.left(), layout));
+ var rhs = Cast.cast(op.source(), op.right().dataType(), required, toEvaluator(foldCtx, op.right(), layout));
return factory.apply(op.source(), lhs, rhs);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java
index ef5cf50c76541..183008f900c5d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java
@@ -7,6 +7,7 @@
package org.elasticsearch.xpack.esql.optimizer;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.stats.SearchStats;
@@ -15,8 +16,8 @@
public final class LocalLogicalOptimizerContext extends LogicalOptimizerContext {
private final SearchStats searchStats;
- public LocalLogicalOptimizerContext(Configuration configuration, SearchStats searchStats) {
- super(configuration);
+ public LocalLogicalOptimizerContext(Configuration configuration, FoldContext foldCtx, SearchStats searchStats) {
+ super(configuration, foldCtx);
this.searchStats = searchStats;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java
index c11e1a4ec49e4..22e07b45310fb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java
@@ -7,7 +7,8 @@
package org.elasticsearch.xpack.esql.optimizer;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.stats.SearchStats;
-public record LocalPhysicalOptimizerContext(Configuration configuration, SearchStats searchStats) {}
+public record LocalPhysicalOptimizerContext(Configuration configuration, FoldContext foldCtx, SearchStats searchStats) {}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java
index 67148e67cbc19..da2d583674a90 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java
@@ -7,37 +7,44 @@
package org.elasticsearch.xpack.esql.optimizer;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.session.Configuration;
import java.util.Objects;
public class LogicalOptimizerContext {
private final Configuration configuration;
+ private final FoldContext foldCtx;
- public LogicalOptimizerContext(Configuration configuration) {
+ public LogicalOptimizerContext(Configuration configuration, FoldContext foldCtx) {
this.configuration = configuration;
+ this.foldCtx = foldCtx;
}
public Configuration configuration() {
return configuration;
}
+ public FoldContext foldCtx() {
+ return foldCtx;
+ }
+
@Override
public boolean equals(Object obj) {
if (obj == this) return true;
if (obj == null || obj.getClass() != this.getClass()) return false;
var that = (LogicalOptimizerContext) obj;
- return Objects.equals(this.configuration, that.configuration);
+ return this.configuration.equals(that.configuration) && this.foldCtx.equals(that.foldCtx);
}
@Override
public int hashCode() {
- return Objects.hash(configuration);
+ return Objects.hash(configuration, foldCtx);
}
@Override
public String toString() {
- return "LogicalOptimizerContext[" + "configuration=" + configuration + ']';
+ return "LogicalOptimizerContext[configuration=" + configuration + ", foldCtx=" + foldCtx + ']';
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java
index 3152f9b574767..5f463f2aa4c78 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java
@@ -13,6 +13,7 @@
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
@@ -28,7 +29,7 @@ public BooleanFunctionEqualsElimination() {
}
@Override
- public Expression rule(BinaryComparison bc) {
+ public Expression rule(BinaryComparison bc, LogicalOptimizerContext ctx) {
if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) {
// for expression "==" or "!=" TRUE/FALSE, return the expression itself or its negated variant
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java
index 73d1ea1fb6e8f..e1803872fd606 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java
@@ -15,6 +15,7 @@
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.util.List;
@@ -35,7 +36,7 @@ public BooleanSimplification() {
}
@Override
- public Expression rule(ScalarFunction e) {
+ public Expression rule(ScalarFunction e, LogicalOptimizerContext ctx) {
if (e instanceof And || e instanceof Or) {
return simplifyAndOr((BinaryPredicate, ?, ?, ?>) e);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java
index 3f47c74aaf814..1c290a7c4c4fd 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic;
@@ -20,6 +21,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.util.ArrayList;
import java.util.List;
@@ -31,17 +33,17 @@ public CombineBinaryComparisons() {
}
@Override
- public Expression rule(BinaryLogic e) {
+ public Expression rule(BinaryLogic e, LogicalOptimizerContext ctx) {
if (e instanceof And and) {
- return combine(and);
+ return combine(ctx.foldCtx(), and);
} else if (e instanceof Or or) {
- return combine(or);
+ return combine(ctx.foldCtx(), or);
}
return e;
}
// combine conjunction
- private static Expression combine(And and) {
+ private static Expression combine(FoldContext ctx, And and) {
List bcs = new ArrayList<>();
List exps = new ArrayList<>();
boolean changed = false;
@@ -58,13 +60,13 @@ private static Expression combine(And and) {
});
for (Expression ex : andExps) {
if (ex instanceof BinaryComparison bc && (ex instanceof Equals || ex instanceof NotEquals) == false) {
- if (bc.right().foldable() && (findExistingComparison(bc, bcs, true))) {
+ if (bc.right().foldable() && (findExistingComparison(ctx, bc, bcs, true))) {
changed = true;
} else {
bcs.add(bc);
}
} else if (ex instanceof NotEquals neq) {
- if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(neq, bcs)) {
+ if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(ctx, neq, bcs)) {
// the non-equality can simply be dropped: either superfluous or has been merged with an updated range/inequality
changed = true;
} else { // not foldable OR not overlapping
@@ -78,13 +80,13 @@ private static Expression combine(And and) {
}
// combine disjunction
- private static Expression combine(Or or) {
+ private static Expression combine(FoldContext ctx, Or or) {
List bcs = new ArrayList<>();
List exps = new ArrayList<>();
boolean changed = false;
for (Expression ex : Predicates.splitOr(or)) {
if (ex instanceof BinaryComparison bc) {
- if (bc.right().foldable() && findExistingComparison(bc, bcs, false)) {
+ if (bc.right().foldable() && findExistingComparison(ctx, bc, bcs, false)) {
changed = true;
} else {
bcs.add(bc);
@@ -100,8 +102,8 @@ private static Expression combine(Or or) {
* Find commonalities between the given comparison in the given list.
* The method can be applied both for conjunctive (AND) or disjunctive purposes (OR).
*/
- private static boolean findExistingComparison(BinaryComparison main, List bcs, boolean conjunctive) {
- Object value = main.right().fold();
+ private static boolean findExistingComparison(FoldContext ctx, BinaryComparison main, List bcs, boolean conjunctive) {
+ Object value = main.right().fold(ctx);
// NB: the loop modifies the list (hence why the int is used)
for (int i = 0; i < bcs.size(); i++) {
BinaryComparison other = bcs.get(i);
@@ -113,7 +115,7 @@ private static boolean findExistingComparison(BinaryComparison main, List bcs) {
- Object neqVal = notEquals.right().fold();
+ private static boolean notEqualsIsRemovableFromConjunction(FoldContext ctx, NotEquals notEquals, List bcs) {
+ Object neqVal = notEquals.right().fold(ctx);
Integer comp;
// check on "condition-overlapping" inequalities:
@@ -183,7 +185,7 @@ private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals,
BinaryComparison bc = bcs.get(i);
if (notEquals.left().semanticEquals(bc.left())) {
if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
- comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null;
+ comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold(ctx)) : null;
if (comp != null) {
if (comp >= 0) {
if (comp == 0 && bc instanceof LessThanOrEqual) { // a != 2 AND a <= 2 -> a < 2
@@ -193,7 +195,7 @@ private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals,
} // else: comp < 0 : a != 2 AND a <= 3 -> nop
} // else: non-comparable, nop
} else if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
- comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null;
+ comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold(ctx)) : null;
if (comp != null) {
if (comp <= 0) {
if (comp == 0 && bc instanceof GreaterThanOrEqual) { // a != 2 AND a >= 2 -> a > 2
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java
index 5cb377de47efc..e1cda9cb149d4 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java
@@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.time.ZoneId;
import java.util.ArrayList;
@@ -61,7 +62,7 @@ protected static CIDRMatch createCIDRMatch(Expression k, List v) {
}
@Override
- public Expression rule(Or or) {
+ public Expression rule(Or or, LogicalOptimizerContext ctx) {
Expression e = or;
// look only at equals, In and CIDRMatch
List exps = splitOr(e);
@@ -78,7 +79,7 @@ public Expression rule(Or or) {
if (eq.right().foldable()) {
ins.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right());
if (eq.left().dataType() == DataType.IP) {
- Object value = eq.right().fold();
+ Object value = eq.right().fold(ctx.foldCtx());
// ImplicitCasting and ConstantFolding(includes explicit casting) are applied before CombineDisjunctions.
// They fold the input IP string to an internal IP format. These happen to Equals and IN, but not for CIDRMatch,
// as CIDRMatch takes strings as input, ImplicitCasting does not apply to it, and the first input to CIDRMatch is a
@@ -101,7 +102,7 @@ public Expression rule(Or or) {
if (in.value().dataType() == DataType.IP) {
List values = new ArrayList<>(in.list().size());
for (Expression i : in.list()) {
- Object value = i.fold();
+ Object value = i.fold(ctx.foldCtx());
// Same as Equals.
if (value instanceof BytesRef bytesRef) {
value = ipToString(bytesRef);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java
index 82fe2c6bddf50..27eec8de59020 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java
@@ -9,6 +9,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
public final class ConstantFolding extends OptimizerRules.OptimizerExpressionRule {
@@ -17,7 +18,7 @@ public ConstantFolding() {
}
@Override
- public Expression rule(Expression e) {
- return e.foldable() ? Literal.of(e) : e;
+ public Expression rule(Expression e, LogicalOptimizerContext ctx) {
+ return e.foldable() ? Literal.of(ctx.foldCtx(), e) : e;
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java
index 0604750883f14..b716d8f012d21 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java
@@ -10,6 +10,7 @@
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.util.ArrayList;
import java.util.List;
@@ -21,7 +22,8 @@ public ConvertStringToByteRef() {
}
@Override
- protected Expression rule(Literal lit) {
+ protected Expression rule(Literal lit, LogicalOptimizerContext ctx) {
+ // TODO we shouldn't be emitting String into Literals at all
Object value = lit.value();
if (value == null) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java
index 747864625e65c..cf4c7f19baafe 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java
@@ -15,6 +15,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
public class FoldNull extends OptimizerRules.OptimizerExpressionRule {
@@ -23,7 +24,7 @@ public FoldNull() {
}
@Override
- public Expression rule(Expression e) {
+ public Expression rule(Expression e, LogicalOptimizerContext ctx) {
Expression result = tryReplaceIsNullIsNotNull(e);
// convert an aggregate null filter into a false
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java
index d96c73d5ee4f1..6504e6042c33a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java
@@ -9,6 +9,7 @@
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
public final class LiteralsOnTheRight extends OptimizerRules.OptimizerExpressionRule> {
@@ -17,7 +18,7 @@ public LiteralsOnTheRight() {
}
@Override
- public BinaryOperator, ?, ?, ?> rule(BinaryOperator, ?, ?, ?> be) {
+ public BinaryOperator, ?, ?, ?> rule(BinaryOperator, ?, ?, ?> be, LogicalOptimizerContext ctx) {
return be.left() instanceof Literal && (be.right() instanceof Literal) == false ? be.swapLeftAndRight() : be;
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java
index 2a0b2a6af36aa..169ac2ac8c0fe 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java
@@ -8,6 +8,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.util.ReflectionUtils;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
import org.elasticsearch.xpack.esql.rule.Rule;
@@ -36,7 +37,10 @@ public final LogicalPlan apply(LogicalPlan plan) {
protected abstract LogicalPlan rule(SubPlan plan);
}
- public abstract static class OptimizerExpressionRule extends Rule {
+ public abstract static class OptimizerExpressionRule extends ParameterizedRule<
+ LogicalPlan,
+ LogicalPlan,
+ LogicalOptimizerContext> {
private final TransformDirection direction;
// overriding type token which returns the correct class but does an uncheck cast to LogicalPlan due to its generic bound
@@ -49,17 +53,13 @@ public OptimizerExpressionRule(TransformDirection direction) {
}
@Override
- public final LogicalPlan apply(LogicalPlan plan) {
+ public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
return direction == TransformDirection.DOWN
- ? plan.transformExpressionsDown(expressionTypeToken, this::rule)
- : plan.transformExpressionsUp(expressionTypeToken, this::rule);
+ ? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx))
+ : plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx));
}
- protected LogicalPlan rule(LogicalPlan plan) {
- return plan;
- }
-
- protected abstract Expression rule(E e);
+ protected abstract Expression rule(E e, LogicalOptimizerContext ctx);
public Class expressionToken() {
return expressionTypeToken;
@@ -82,6 +82,7 @@ protected ParameterizedOptimizerRule(TransformDirection direction) {
this.direction = direction;
}
+ @Override
public final LogicalPlan apply(LogicalPlan plan, P context) {
return direction == TransformDirection.DOWN
? plan.transformDown(typeToken(), t -> rule(t, context))
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java
index 118e4fc170520..0111c7cdd806a 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java
@@ -9,6 +9,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN;
@@ -28,7 +29,7 @@ public PartiallyFoldCase() {
}
@Override
- protected Expression rule(Case c) {
- return c.partiallyFold();
+ protected Expression rule(Case c, LogicalOptimizerContext ctx) {
+ return c.partiallyFold(ctx.foldCtx());
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java
index 8437b79454884..b6f185c856693 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java
@@ -12,9 +12,11 @@
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Alias;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
@@ -26,15 +28,18 @@
import java.util.List;
@SuppressWarnings("removal")
-public class PropagateEmptyRelation extends OptimizerRules.OptimizerRule {
+public class PropagateEmptyRelation extends OptimizerRules.ParameterizedOptimizerRule {
+ public PropagateEmptyRelation() {
+ super(OptimizerRules.TransformDirection.DOWN);
+ }
@Override
- protected LogicalPlan rule(UnaryPlan plan) {
+ protected LogicalPlan rule(UnaryPlan plan, LogicalOptimizerContext ctx) {
LogicalPlan p = plan;
if (plan.child() instanceof LocalRelation local && local.supplier() == LocalSupplier.EMPTY) {
// only care about non-grouped aggs might return something (count)
if (plan instanceof Aggregate agg && agg.groupings().isEmpty()) {
- List emptyBlocks = aggsFromEmpty(agg.aggregates());
+ List emptyBlocks = aggsFromEmpty(ctx.foldCtx(), agg.aggregates());
p = replacePlanByRelation(plan, LocalSupplier.of(emptyBlocks.toArray(Block[]::new)));
} else {
p = PruneEmptyPlans.skipPlan(plan);
@@ -43,14 +48,14 @@ protected LogicalPlan rule(UnaryPlan plan) {
return p;
}
- private List aggsFromEmpty(List extends NamedExpression> aggs) {
+ private List aggsFromEmpty(FoldContext foldCtx, List extends NamedExpression> aggs) {
List blocks = new ArrayList<>();
var blockFactory = PlannerUtils.NON_BREAKING_BLOCK_FACTORY;
int i = 0;
for (var agg : aggs) {
// there needs to be an alias
if (Alias.unwrap(agg) instanceof AggregateFunction aggFunc) {
- aggOutput(agg, aggFunc, blockFactory, blocks);
+ aggOutput(foldCtx, agg, aggFunc, blockFactory, blocks);
} else {
throw new EsqlIllegalArgumentException("Did not expect a non-aliased aggregation {}", agg);
}
@@ -61,9 +66,15 @@ private List aggsFromEmpty(List extends NamedExpression> aggs) {
/**
* The folded aggregation output - this variant is for the coordinator/final.
*/
- protected void aggOutput(NamedExpression agg, AggregateFunction aggFunc, BlockFactory blockFactory, List blocks) {
+ protected void aggOutput(
+ FoldContext foldCtx,
+ NamedExpression agg,
+ AggregateFunction aggFunc,
+ BlockFactory blockFactory,
+ List blocks
+ ) {
// look for count(literal) with literal != null
- Object value = aggFunc instanceof Count count && (count.foldable() == false || count.fold() != null) ? 0L : null;
+ Object value = aggFunc instanceof Count count && (count.foldable() == false || count.fold(foldCtx) != null) ? 0L : null;
var wrapper = BlockUtils.wrapperFor(blockFactory, PlannerUtils.toElementType(aggFunc.dataType()), 1);
wrapper.accept(value);
blocks.add(wrapper.builder().build());
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java
index 0bd98db1e1d7a..5a1677f2759e3 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java
@@ -23,6 +23,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.util.ArrayList;
import java.util.Iterator;
@@ -41,17 +42,18 @@ public PropagateEquals() {
super(OptimizerRules.TransformDirection.DOWN);
}
- public Expression rule(BinaryLogic e) {
+ @Override
+ public Expression rule(BinaryLogic e, LogicalOptimizerContext ctx) {
if (e instanceof And) {
- return propagate((And) e);
+ return propagate((And) e, ctx);
} else if (e instanceof Or) {
- return propagate((Or) e);
+ return propagate((Or) e, ctx);
}
return e;
}
// combine conjunction
- private static Expression propagate(And and) {
+ private static Expression propagate(And and, LogicalOptimizerContext ctx) {
List ranges = new ArrayList<>();
// Only equalities, not-equalities and inequalities with a foldable .right are extracted separately;
// the others go into the general 'exps'.
@@ -72,7 +74,7 @@ private static Expression propagate(And and) {
if (otherEq.right().foldable() && DataType.isDateTime(otherEq.left().dataType()) == false) {
for (BinaryComparison eq : equals) {
if (otherEq.left().semanticEquals(eq.left())) {
- Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold());
+ Integer comp = BinaryComparison.compare(eq.right().fold(ctx.foldCtx()), otherEq.right().fold(ctx.foldCtx()));
if (comp != null) {
// var cannot be equal to two different values at the same time
if (comp != 0) {
@@ -108,7 +110,7 @@ private static Expression propagate(And and) {
// check
for (BinaryComparison eq : equals) {
- Object eqValue = eq.right().fold();
+ Object eqValue = eq.right().fold(ctx.foldCtx());
for (Iterator iterator = ranges.iterator(); iterator.hasNext();) {
Range range = iterator.next();
@@ -116,7 +118,7 @@ private static Expression propagate(And and) {
if (range.value().semanticEquals(eq.left())) {
// if equals is outside the interval, evaluate the whole expression to FALSE
if (range.lower().foldable()) {
- Integer compare = BinaryComparison.compare(range.lower().fold(), eqValue);
+ Integer compare = BinaryComparison.compare(range.lower().fold(ctx.foldCtx()), eqValue);
if (compare != null && (
// eq outside the lower boundary
compare > 0 ||
@@ -126,7 +128,7 @@ private static Expression propagate(And and) {
}
}
if (range.upper().foldable()) {
- Integer compare = BinaryComparison.compare(range.upper().fold(), eqValue);
+ Integer compare = BinaryComparison.compare(range.upper().fold(ctx.foldCtx()), eqValue);
if (compare != null && (
// eq outside the upper boundary
compare < 0 ||
@@ -146,7 +148,7 @@ private static Expression propagate(And and) {
for (Iterator iter = notEquals.iterator(); iter.hasNext();) {
NotEquals neq = iter.next();
if (eq.left().semanticEquals(neq.left())) {
- Integer comp = BinaryComparison.compare(eqValue, neq.right().fold());
+ Integer comp = BinaryComparison.compare(eqValue, neq.right().fold(ctx.foldCtx()));
if (comp != null) {
if (comp == 0) { // clashing and conflicting: a = 1 AND a != 1
return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN);
@@ -162,7 +164,7 @@ private static Expression propagate(And and) {
for (Iterator iter = inequalities.iterator(); iter.hasNext();) {
BinaryComparison bc = iter.next();
if (eq.left().semanticEquals(bc.left())) {
- Integer compare = BinaryComparison.compare(eqValue, bc.right().fold());
+ Integer compare = BinaryComparison.compare(eqValue, bc.right().fold(ctx.foldCtx()));
if (compare != null) {
if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { // a = 2 AND a <= ?
if ((compare == 0 && bc instanceof LessThan) || // a = 2 AND a < 2
@@ -191,7 +193,7 @@ private static Expression propagate(And and) {
// a = 2 OR a < 3 -> a < 3; a = 2 OR a < 1 -> nop
// a = 2 OR 3 < a < 5 -> nop; a = 2 OR 1 < a < 3 -> 1 < a < 3; a = 2 OR 0 < a < 1 -> nop
// a = 2 OR a != 2 -> TRUE; a = 2 OR a = 5 -> nop; a = 2 OR a != 5 -> a != 5
- private static Expression propagate(Or or) {
+ private static Expression propagate(Or or, LogicalOptimizerContext ctx) {
List exps = new ArrayList<>();
List equals = new ArrayList<>(); // foldable right term Equals
List notEquals = new ArrayList<>(); // foldable right term NotEquals
@@ -230,13 +232,13 @@ private static Expression propagate(Or or) {
// evaluate the impact of each Equal over the different types of Expressions
for (Iterator iterEq = equals.iterator(); iterEq.hasNext();) {
Equals eq = iterEq.next();
- Object eqValue = eq.right().fold();
+ Object eqValue = eq.right().fold(ctx.foldCtx());
boolean removeEquals = false;
// Equals OR NotEquals
for (NotEquals neq : notEquals) {
if (eq.left().semanticEquals(neq.left())) { // a = 2 OR a != ? -> ...
- Integer comp = BinaryComparison.compare(eqValue, neq.right().fold());
+ Integer comp = BinaryComparison.compare(eqValue, neq.right().fold(ctx.foldCtx()));
if (comp != null) {
if (comp == 0) { // a = 2 OR a != 2 -> TRUE
return TRUE;
@@ -257,8 +259,12 @@ private static Expression propagate(Or or) {
for (int i = 0; i < ranges.size(); i++) { // might modify list, so use index loop
Range range = ranges.get(i);
if (eq.left().semanticEquals(range.value())) {
- Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare(eqValue, range.lower().fold()) : null;
- Integer upperComp = range.upper().foldable() ? BinaryComparison.compare(eqValue, range.upper().fold()) : null;
+ Integer lowerComp = range.lower().foldable()
+ ? BinaryComparison.compare(eqValue, range.lower().fold(ctx.foldCtx()))
+ : null;
+ Integer upperComp = range.upper().foldable()
+ ? BinaryComparison.compare(eqValue, range.upper().fold(ctx.foldCtx()))
+ : null;
if (lowerComp != null && lowerComp == 0) {
if (range.includeLower() == false) { // a = 2 OR 2 < a < ? -> 2 <= a < ?
@@ -312,7 +318,7 @@ private static Expression propagate(Or or) {
for (int i = 0; i < inequalities.size(); i++) {
BinaryComparison bc = inequalities.get(i);
if (eq.left().semanticEquals(bc.left())) {
- Integer comp = BinaryComparison.compare(eqValue, bc.right().fold());
+ Integer comp = BinaryComparison.compare(eqValue, bc.right().fold(ctx.foldCtx()));
if (comp != null) {
if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
if (comp < 0) { // a = 1 OR a > 2 -> nop
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java
index 73eaa9220fd84..66cdc992a91cb 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java
@@ -12,19 +12,20 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
-import org.elasticsearch.xpack.esql.rule.Rule;
+import org.elasticsearch.xpack.esql.rule.ParameterizedRule;
/**
* Replace any reference attribute with its source, if it does not affect the result.
* This avoids ulterior look-ups between attributes and its source across nodes.
*/
-public final class PropagateEvalFoldables extends Rule {
+public final class PropagateEvalFoldables extends ParameterizedRule {
@Override
- public LogicalPlan apply(LogicalPlan plan) {
+ public LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) {
var collectRefs = new AttributeMap();
java.util.function.Function replaceReference = r -> collectRefs.resolve(r, r);
@@ -39,7 +40,7 @@ public LogicalPlan apply(LogicalPlan plan) {
shouldCollect = c.foldable();
}
if (shouldCollect) {
- collectRefs.put(a.toAttribute(), Literal.of(c));
+ collectRefs.put(a.toAttribute(), Literal.of(ctx.foldCtx(), c));
}
});
if (collectRefs.isEmpty()) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java
index 738ca83b47e42..e3165180e331c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java
@@ -15,6 +15,7 @@
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.util.ArrayList;
import java.util.LinkedHashSet;
@@ -33,7 +34,7 @@ public PropagateNullable() {
}
@Override
- public Expression rule(And and) {
+ public Expression rule(And and, LogicalOptimizerContext ctx) {
List splits = Predicates.splitAnd(and);
Set nullExpressions = new LinkedHashSet<>();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java
index 1cacebdf27cd2..969a6bb713eca 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Enrich;
import org.elasticsearch.xpack.esql.plan.logical.Eval;
@@ -20,14 +21,18 @@
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
-public final class PushDownAndCombineLimits extends OptimizerRules.OptimizerRule {
+public final class PushDownAndCombineLimits extends OptimizerRules.ParameterizedOptimizerRule {
+
+ public PushDownAndCombineLimits() {
+ super(OptimizerRules.TransformDirection.DOWN);
+ }
@Override
- public LogicalPlan rule(Limit limit) {
+ public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
if (limit.child() instanceof Limit childLimit) {
var limitSource = limit.limit();
- var l1 = (int) limitSource.fold();
- var l2 = (int) childLimit.limit().fold();
+ var l1 = (int) limitSource.fold(ctx.foldCtx());
+ var l2 = (int) childLimit.limit().fold(ctx.foldCtx());
return new Limit(limit.source(), Literal.of(limitSource, Math.min(l1, l2)), childLimit.child());
} else if (limit.child() instanceof UnaryPlan unary) {
if (unary instanceof Eval || unary instanceof Project || unary instanceof RegexExtract || unary instanceof Enrich) {
@@ -41,7 +46,7 @@ public LogicalPlan rule(Limit limit) {
// we add an inner limit to MvExpand and just push down the existing limit, ie.
// | MV_EXPAND | LIMIT N -> | LIMIT N | MV_EXPAND (with limit N)
var limitSource = limit.limit();
- var limitVal = (int) limitSource.fold();
+ var limitVal = (int) limitSource.fold(ctx.foldCtx());
Integer mvxLimit = mvx.limit();
if (mvxLimit == null || mvxLimit > limitVal) {
mvx = new MvExpand(mvx.source(), mvx.child(), mvx.target(), mvx.expanded(), limitVal);
@@ -54,8 +59,8 @@ public LogicalPlan rule(Limit limit) {
else {
Limit descendantLimit = descendantLimit(unary);
if (descendantLimit != null) {
- var l1 = (int) limit.limit().fold();
- var l2 = (int) descendantLimit.limit().fold();
+ var l1 = (int) limit.limit().fold(ctx.foldCtx());
+ var l2 = (int) descendantLimit.limit().fold(ctx.foldCtx());
if (l2 <= l1) {
return new Limit(limit.source(), Literal.of(limit.limit(), l2), limit.child());
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java
index 1a8f8a164cc1b..7953b2b28eaaa 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java
@@ -14,6 +14,7 @@
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.StringPattern;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.parser.ParsingException;
public final class ReplaceRegexMatch extends OptimizerRules.OptimizerExpressionRule> {
@@ -23,7 +24,7 @@ public ReplaceRegexMatch() {
}
@Override
- public Expression rule(RegexMatch> regexMatch) {
+ public Expression rule(RegexMatch> regexMatch, LogicalOptimizerContext ctx) {
Expression e = regexMatch;
StringPattern pattern = regexMatch.pattern();
boolean matchesAll;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java
index eebeb1dc14f48..9e7b6ce80422d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Row;
import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation;
@@ -17,13 +18,16 @@
import java.util.ArrayList;
import java.util.List;
-public final class ReplaceRowAsLocalRelation extends OptimizerRules.OptimizerRule {
+public final class ReplaceRowAsLocalRelation extends OptimizerRules.ParameterizedOptimizerRule {
+ public ReplaceRowAsLocalRelation() {
+ super(OptimizerRules.TransformDirection.DOWN);
+ }
@Override
- protected LogicalPlan rule(Row row) {
+ protected LogicalPlan rule(Row row, LogicalOptimizerContext context) {
var fields = row.fields();
List values = new ArrayList<>(fields.size());
- fields.forEach(f -> values.add(f.child().fold()));
+ fields.forEach(f -> values.add(f.child().fold(context.foldCtx())));
var blocks = BlockUtils.fromListRow(PlannerUtils.NON_BREAKING_BLOCK_FACTORY, values);
return new LocalRelation(row.source(), row.output(), LocalSupplier.of(blocks));
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java
index 2cafcc2e07052..a7e56a5f25fc8 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java
@@ -49,7 +49,7 @@ protected LogicalPlan rule(Aggregate aggregate) {
&& alias.child() instanceof AggregateFunction aggFunction
&& aggFunction.hasFilter()
&& aggFunction.filter() instanceof Literal literal
- && Boolean.FALSE.equals(literal.fold())) {
+ && Boolean.FALSE.equals(literal.value())) {
Object value = aggFunction instanceof Count || aggFunction instanceof CountDistinct ? 0L : null;
Alias newAlias = alias.replaceChild(Literal.of(aggFunction, value));
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java
index 0fea7cf8ddc1f..053441bce5e1f 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java
@@ -18,6 +18,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InsensitiveEquals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
public class ReplaceStringCasingWithInsensitiveEquals extends OptimizerRules.OptimizerExpressionRule {
@@ -26,30 +27,35 @@ public ReplaceStringCasingWithInsensitiveEquals() {
}
@Override
- protected Expression rule(ScalarFunction sf) {
+ protected Expression rule(ScalarFunction sf, LogicalOptimizerContext ctx) {
Expression e = sf;
if (sf instanceof BinaryComparison bc) {
- e = rewriteBinaryComparison(sf, bc, false);
+ e = rewriteBinaryComparison(ctx, sf, bc, false);
} else if (sf instanceof Not not && not.field() instanceof BinaryComparison bc) {
- e = rewriteBinaryComparison(sf, bc, true);
+ e = rewriteBinaryComparison(ctx, sf, bc, true);
}
return e;
}
- private static Expression rewriteBinaryComparison(ScalarFunction sf, BinaryComparison bc, boolean negated) {
+ private static Expression rewriteBinaryComparison(
+ LogicalOptimizerContext ctx,
+ ScalarFunction sf,
+ BinaryComparison bc,
+ boolean negated
+ ) {
Expression e = sf;
if (bc.left() instanceof ChangeCase changeCase && bc.right().foldable()) {
if (bc instanceof Equals) {
- e = replaceChangeCase(bc, changeCase, negated);
+ e = replaceChangeCase(ctx, bc, changeCase, negated);
} else if (bc instanceof NotEquals) { // not actually used currently, `!=` is built as `NOT(==)` already
- e = replaceChangeCase(bc, changeCase, negated == false);
+ e = replaceChangeCase(ctx, bc, changeCase, negated == false);
}
}
return e;
}
- private static Expression replaceChangeCase(BinaryComparison bc, ChangeCase changeCase, boolean negated) {
- var foldedRight = BytesRefs.toString(bc.right().fold());
+ private static Expression replaceChangeCase(LogicalOptimizerContext ctx, BinaryComparison bc, ChangeCase changeCase, boolean negated) {
+ var foldedRight = BytesRefs.toString(bc.right().fold(ctx.foldCtx()));
var field = unwrapCase(changeCase.field());
var e = changeCase.caseType().matchesCase(foldedRight)
? new InsensitiveEquals(bc.source(), field, bc.right())
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java
index d3a9970896c16..60ff161651f2d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -15,6 +16,7 @@
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.BinaryComparisonInversible;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.time.DateTimeException;
import java.util.List;
@@ -41,20 +43,20 @@ public SimplifyComparisonsArithmetics(BiFunction ty
}
@Override
- protected Expression rule(BinaryComparison bc) {
+ protected Expression rule(BinaryComparison bc, LogicalOptimizerContext ctx) {
// optimize only once the expression has a literal on the right side of the binary comparison
if (bc.right() instanceof Literal) {
if (bc.left() instanceof ArithmeticOperation) {
- return simplifyBinaryComparison(bc);
+ return simplifyBinaryComparison(ctx.foldCtx(), bc);
}
if (bc.left() instanceof Neg) {
- return foldNegation(bc);
+ return foldNegation(ctx.foldCtx(), bc);
}
}
return bc;
}
- private Expression simplifyBinaryComparison(BinaryComparison comparison) {
+ private Expression simplifyBinaryComparison(FoldContext foldContext, BinaryComparison comparison) {
ArithmeticOperation operation = (ArithmeticOperation) comparison.left();
// Use symbol comp: SQL operations aren't available in this package (as dependencies)
String opSymbol = operation.symbol();
@@ -64,9 +66,9 @@ private Expression simplifyBinaryComparison(BinaryComparison comparison) {
}
OperationSimplifier simplification = null;
if (isMulOrDiv(opSymbol)) {
- simplification = new MulDivSimplifier(comparison);
+ simplification = new MulDivSimplifier(foldContext, comparison);
} else if (opSymbol.equals(ADD.symbol()) || opSymbol.equals(SUB.symbol())) {
- simplification = new AddSubSimplifier(comparison);
+ simplification = new AddSubSimplifier(foldContext, comparison);
}
return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply();
@@ -76,16 +78,16 @@ private static boolean isMulOrDiv(String opSymbol) {
return opSymbol.equals(MUL.symbol()) || opSymbol.equals(DIV.symbol());
}
- private static Expression foldNegation(BinaryComparison bc) {
+ private static Expression foldNegation(FoldContext ctx, BinaryComparison bc) {
Literal bcLiteral = (Literal) bc.right();
- Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral));
+ Expression literalNeg = tryFolding(ctx, new Neg(bcLiteral.source(), bcLiteral));
return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg));
}
- private static Expression tryFolding(Expression expression) {
+ private static Expression tryFolding(FoldContext ctx, Expression expression) {
if (expression.foldable()) {
try {
- expression = new Literal(expression.source(), expression.fold(), expression.dataType());
+ expression = new Literal(expression.source(), expression.fold(ctx), expression.dataType());
} catch (ArithmeticException | DateTimeException e) {
// null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped.
expression = null;
@@ -95,6 +97,7 @@ private static Expression tryFolding(Expression expression) {
}
private abstract static class OperationSimplifier {
+ final FoldContext foldContext;
final BinaryComparison comparison;
final Literal bcLiteral;
final ArithmeticOperation operation;
@@ -102,7 +105,8 @@ private abstract static class OperationSimplifier {
final Expression opRight;
final Literal opLiteral;
- OperationSimplifier(BinaryComparison comparison) {
+ OperationSimplifier(FoldContext foldContext, BinaryComparison comparison) {
+ this.foldContext = foldContext;
this.comparison = comparison;
operation = (ArithmeticOperation) comparison.left();
bcLiteral = (Literal) comparison.right();
@@ -151,7 +155,7 @@ final Expression apply() {
Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse()
.create(bcl.source(), bcl, opRight);
- bcRightExpression = tryFolding(bcRightExpression);
+ bcRightExpression = tryFolding(foldContext, bcRightExpression);
return bcRightExpression != null
? postProcess((BinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression)))
: comparison;
@@ -169,8 +173,8 @@ Expression postProcess(BinaryComparison binaryComparison) {
private static class AddSubSimplifier extends OperationSimplifier {
- AddSubSimplifier(BinaryComparison comparison) {
- super(comparison);
+ AddSubSimplifier(FoldContext foldContext, BinaryComparison comparison) {
+ super(foldContext, comparison);
}
@Override
@@ -182,7 +186,7 @@ boolean isOpUnsafe() {
if (operation.symbol().equals(SUB.symbol()) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX
// if next simplification step would fail on overflow anyways, skip the optimisation already
- return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null;
+ return tryFolding(foldContext, new Sub(EMPTY, opLeft, bcLiteral)) == null;
}
return false;
@@ -194,8 +198,8 @@ private static class MulDivSimplifier extends OperationSimplifier {
private final boolean isDiv; // and not MUL.
private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal)
- MulDivSimplifier(BinaryComparison comparison) {
- super(comparison);
+ MulDivSimplifier(FoldContext foldContext, BinaryComparison comparison) {
+ super(foldContext, comparison);
isDiv = operation.symbol().equals(DIV.symbol());
opRightSign = sign(opRight);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java
index 5d98d941bb207..c6d62dee0ba42 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java
@@ -7,14 +7,19 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.plan.logical.Limit;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
-public final class SkipQueryOnLimitZero extends OptimizerRules.OptimizerRule {
+public final class SkipQueryOnLimitZero extends OptimizerRules.ParameterizedOptimizerRule {
+ public SkipQueryOnLimitZero() {
+ super(OptimizerRules.TransformDirection.DOWN);
+ }
+
@Override
- protected LogicalPlan rule(Limit limit) {
+ protected LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
if (limit.limit().foldable()) {
- if (Integer.valueOf(0).equals((limit.limit().fold()))) {
+ if (Integer.valueOf(0).equals((limit.limit().fold(ctx.foldCtx())))) {
return PruneEmptyPlans.skipPlan(limit);
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java
index 9e9ae6a9a559d..870464feb4867 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java
@@ -11,6 +11,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import java.util.ArrayList;
import java.util.List;
@@ -25,7 +26,7 @@ public SplitInWithFoldableValue() {
}
@Override
- public Expression rule(In in) {
+ public Expression rule(In in, LogicalOptimizerContext ctx) {
if (in.value().foldable()) {
List foldables = new ArrayList<>(in.list().size());
List nonFoldables = new ArrayList<>(in.list().size());
@@ -36,7 +37,7 @@ public Expression rule(In in) {
nonFoldables.add(e);
}
});
- if (foldables.size() > 0 && nonFoldables.size() > 0) {
+ if (foldables.isEmpty() == false && nonFoldables.isEmpty() == false) {
In withFoldables = new In(in.source(), in.value(), foldables);
In withoutFoldables = new In(in.source(), in.value(), nonFoldables);
return new Or(in.source(), withFoldables, withoutFoldables);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java
index c8369d2b08a34..62a00b79d7333 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java
@@ -9,6 +9,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.OptimizerExpressionRule;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection;
@@ -21,7 +22,7 @@ public SubstituteFilteredExpression() {
}
@Override
- protected Expression rule(FilteredExpression filteredExpression) {
+ protected Expression rule(FilteredExpression filteredExpression, LogicalOptimizerContext ctx) {
return filteredExpression.surrogate();
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java
index 93512d80e1708..4b68ee941bc92 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java
@@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
/**
* Currently this works similarly to SurrogateExpression, leaving the logic inside the expressions,
@@ -23,7 +24,7 @@ public SubstituteSpatialSurrogates() {
}
@Override
- protected SpatialRelatesFunction rule(SpatialRelatesFunction function) {
+ protected SpatialRelatesFunction rule(SpatialRelatesFunction function, LogicalOptimizerContext ctx) {
return function.surrogate();
}
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java
index d29da1354ef3c..9259a50d5ff9e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java
@@ -11,6 +11,7 @@
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
@@ -25,19 +26,24 @@
* Local aggregation can only produce intermediate state that get wired into the global agg.
*/
public class LocalPropagateEmptyRelation extends PropagateEmptyRelation {
-
/**
* Local variant of the aggregation that returns the intermediate value.
*/
@Override
- protected void aggOutput(NamedExpression agg, AggregateFunction aggFunc, BlockFactory blockFactory, List blocks) {
+ protected void aggOutput(
+ FoldContext foldCtx,
+ NamedExpression agg,
+ AggregateFunction aggFunc,
+ BlockFactory blockFactory,
+ List blocks
+ ) {
List output = AbstractPhysicalOperationProviders.intermediateAttributes(List.of(agg), List.of());
for (Attribute o : output) {
DataType dataType = o.dataType();
// boolean right now is used for the internal #seen so always return true
var value = dataType == DataType.BOOLEAN ? true
// look for count(literal) with literal != null
- : aggFunc instanceof Count count && (count.foldable() == false || count.fold() != null) ? 0L
+ : aggFunc instanceof Count count && (count.foldable() == false || count.fold(foldCtx) != null) ? 0L
// otherwise nullify
: null;
var wrapper = BlockUtils.wrapperFor(blockFactory, PlannerUtils.toElementType(dataType), 1);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java
index dfb1dbc8bc8f3..afeab28745c65 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java
@@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@@ -76,22 +77,33 @@ public class EnableSpatialDistancePushdown extends PhysicalOptimizerRules.Parame
protected PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext ctx) {
PhysicalPlan plan = filterExec;
if (filterExec.child() instanceof EsQueryExec esQueryExec) {
- plan = rewrite(filterExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats()));
+ plan = rewrite(ctx.foldCtx(), filterExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats()));
} else if (filterExec.child() instanceof EvalExec evalExec && evalExec.child() instanceof EsQueryExec esQueryExec) {
- plan = rewriteBySplittingFilter(filterExec, evalExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats()));
+ plan = rewriteBySplittingFilter(
+ ctx.foldCtx(),
+ filterExec,
+ evalExec,
+ esQueryExec,
+ LucenePushdownPredicates.from(ctx.searchStats())
+ );
}
return plan;
}
- private FilterExec rewrite(FilterExec filterExec, EsQueryExec esQueryExec, LucenePushdownPredicates lucenePushdownPredicates) {
+ private FilterExec rewrite(
+ FoldContext ctx,
+ FilterExec filterExec,
+ EsQueryExec esQueryExec,
+ LucenePushdownPredicates lucenePushdownPredicates
+ ) {
// Find and rewrite any binary comparisons that involve a distance function and a literal
var rewritten = filterExec.condition().transformDown(EsqlBinaryComparison.class, comparison -> {
ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType());
if (comparison.left() instanceof StDistance dist && comparison.right().foldable()) {
- return rewriteComparison(comparison, dist, comparison.right(), comparisonType);
+ return rewriteComparison(ctx, comparison, dist, comparison.right(), comparisonType);
} else if (comparison.right() instanceof StDistance dist && comparison.left().foldable()) {
- return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType));
+ return rewriteComparison(ctx, comparison, dist, comparison.left(), ComparisonType.invert(comparisonType));
}
return comparison;
});
@@ -120,6 +132,7 @@ private FilterExec rewrite(FilterExec filterExec, EsQueryExec esQueryExec, Lucen
*
*/
private PhysicalPlan rewriteBySplittingFilter(
+ FoldContext ctx,
FilterExec filterExec,
EvalExec evalExec,
EsQueryExec esQueryExec,
@@ -142,7 +155,7 @@ private PhysicalPlan rewriteBySplittingFilter(
for (Expression exp : splitAnd(filterExec.condition())) {
Expression resExp = exp.transformUp(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r));
// Find and rewrite any binary comparisons that involve a distance function and a literal
- var rewritten = rewriteDistanceFilters(resExp, distances);
+ var rewritten = rewriteDistanceFilters(ctx, resExp, distances);
// If all pushable StDistance functions were found and re-written, we need to re-write the FILTER/EVAL combination
if (rewritten.equals(resExp) == false && canPushToSource(rewritten, lucenePushdownPredicates)) {
pushable.add(rewritten);
@@ -181,40 +194,42 @@ private Map getPushableDistances(List aliases, Lucene
return distances;
}
- private Expression rewriteDistanceFilters(Expression expr, Map distances) {
+ private Expression rewriteDistanceFilters(FoldContext ctx, Expression expr, Map distances) {
return expr.transformDown(EsqlBinaryComparison.class, comparison -> {
ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType());
if (comparison.left() instanceof ReferenceAttribute r && distances.containsKey(r.id()) && comparison.right().foldable()) {
StDistance dist = distances.get(r.id());
- return rewriteComparison(comparison, dist, comparison.right(), comparisonType);
+ return rewriteComparison(ctx, comparison, dist, comparison.right(), comparisonType);
} else if (comparison.right() instanceof ReferenceAttribute r
&& distances.containsKey(r.id())
&& comparison.left().foldable()) {
StDistance dist = distances.get(r.id());
- return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType));
+ return rewriteComparison(ctx, comparison, dist, comparison.left(), ComparisonType.invert(comparisonType));
}
return comparison;
});
}
private Expression rewriteComparison(
+ FoldContext ctx,
EsqlBinaryComparison comparison,
StDistance dist,
Expression literal,
ComparisonType comparisonType
) {
- Object value = literal.fold();
+ Object value = literal.fold(ctx);
if (value instanceof Number number) {
if (dist.right().foldable()) {
- return rewriteDistanceFilter(comparison, dist.left(), dist.right(), number, comparisonType);
+ return rewriteDistanceFilter(ctx, comparison, dist.left(), dist.right(), number, comparisonType);
} else if (dist.left().foldable()) {
- return rewriteDistanceFilter(comparison, dist.right(), dist.left(), number, comparisonType);
+ return rewriteDistanceFilter(ctx, comparison, dist.right(), dist.left(), number, comparisonType);
}
}
return comparison;
}
private Expression rewriteDistanceFilter(
+ FoldContext ctx,
EsqlBinaryComparison comparison,
Expression spatialExp,
Expression literalExp,
@@ -222,7 +237,7 @@ private Expression rewriteDistanceFilter(
ComparisonType comparisonType
) {
DataType shapeDataType = getShapeDataType(spatialExp);
- Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(literalExp);
+ Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(ctx, literalExp);
if (geometry instanceof Point point) {
double distance = number.doubleValue();
Source source = comparison.source();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java
index 2b531257e594a..24df3c1db234e 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java
@@ -14,6 +14,7 @@
import org.elasticsearch.xpack.esql.core.expression.AttributeMap;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@@ -61,7 +62,7 @@ public class PushTopNToSource extends PhysicalOptimizerRules.ParameterizedOptimi
@Override
protected PhysicalPlan rule(TopNExec topNExec, LocalPhysicalOptimizerContext ctx) {
- Pushable pushable = evaluatePushable(topNExec, LucenePushdownPredicates.from(ctx.searchStats()));
+ Pushable pushable = evaluatePushable(ctx.foldCtx(), topNExec, LucenePushdownPredicates.from(ctx.searchStats()));
return pushable.rewrite(topNExec);
}
@@ -95,18 +96,18 @@ private EsQueryExec.Sort sort() {
return new EsQueryExec.GeoDistanceSort(fieldAttribute.exactAttribute(), order.direction(), point.getLat(), point.getLon());
}
- private static PushableGeoDistance from(StDistance distance, Order order) {
+ private static PushableGeoDistance from(FoldContext ctx, StDistance distance, Order order) {
if (distance.left() instanceof Attribute attr && distance.right().foldable()) {
- return from(attr, distance.right(), order);
+ return from(ctx, attr, distance.right(), order);
} else if (distance.right() instanceof Attribute attr && distance.left().foldable()) {
- return from(attr, distance.left(), order);
+ return from(ctx, attr, distance.left(), order);
}
return null;
}
- private static PushableGeoDistance from(Attribute attr, Expression foldable, Order order) {
+ private static PushableGeoDistance from(FoldContext ctx, Attribute attr, Expression foldable, Order order) {
if (attr instanceof FieldAttribute fieldAttribute) {
- Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(foldable);
+ Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(ctx, foldable);
if (geometry instanceof Point point) {
return new PushableGeoDistance(fieldAttribute, order, point);
}
@@ -122,7 +123,7 @@ public PhysicalPlan rewrite(TopNExec topNExec) {
}
}
- private static Pushable evaluatePushable(TopNExec topNExec, LucenePushdownPredicates lucenePushdownPredicates) {
+ private static Pushable evaluatePushable(FoldContext ctx, TopNExec topNExec, LucenePushdownPredicates lucenePushdownPredicates) {
PhysicalPlan child = topNExec.child();
if (child instanceof EsQueryExec queryExec
&& queryExec.canPushSorts()
@@ -164,7 +165,7 @@ && canPushDownOrders(topNExec.order(), lucenePushdownPredicates)) {
if (distances.containsKey(resolvedAttribute.id())) {
StDistance distance = distances.get(resolvedAttribute.id());
StDistance d = (StDistance) distance.transformDown(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r));
- PushableGeoDistance pushableGeoDistance = PushableGeoDistance.from(d, order);
+ PushableGeoDistance pushableGeoDistance = PushableGeoDistance.from(ctx, d, order);
if (pushableGeoDistance != null) {
pushableSorts.add(pushableGeoDistance.sort());
} else {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
index 81d43bc68b79e..eb81446f9ddea 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
@@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
@@ -686,12 +687,20 @@ public Expression visitRegexBooleanExpression(EsqlBaseParser.RegexBooleanExpress
RegexMatch> result = switch (type) {
case EsqlBaseParser.LIKE -> {
try {
- yield new WildcardLike(source, left, new WildcardPattern(pattern.fold().toString()));
+ yield new WildcardLike(
+ source,
+ left,
+ new WildcardPattern(pattern.fold(FoldContext.small() /* TODO remove me */).toString())
+ );
} catch (InvalidArgumentException e) {
throw new ParsingException(source, "Invalid pattern for LIKE [{}]: [{}]", pattern, e.getMessage());
}
}
- case EsqlBaseParser.RLIKE -> new RLike(source, left, new RLikePattern(pattern.fold().toString()));
+ case EsqlBaseParser.RLIKE -> new RLike(
+ source,
+ left,
+ new RLikePattern(pattern.fold(FoldContext.small() /* TODO remove me */).toString())
+ );
default -> throw new ParsingException("Invalid predicate type for [{}]", source.text());
};
return ctx.NOT() == null ? result : new Not(source, result);
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
index 49d77bc36fb2e..4edd0470058db 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java
@@ -23,6 +23,7 @@
import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
@@ -157,7 +158,7 @@ public PlanFactory visitEvalCommand(EsqlBaseParser.EvalCommandContext ctx) {
public PlanFactory visitGrokCommand(EsqlBaseParser.GrokCommandContext ctx) {
return p -> {
Source source = source(ctx);
- String pattern = visitString(ctx.string()).fold().toString();
+ String pattern = visitString(ctx.string()).fold(FoldContext.small() /* TODO remove me */).toString();
Grok.Parser grokParser;
try {
grokParser = Grok.pattern(source, pattern);
@@ -188,7 +189,7 @@ private void validateGrokPattern(Source source, Grok.Parser grokParser, String p
@Override
public PlanFactory visitDissectCommand(EsqlBaseParser.DissectCommandContext ctx) {
return p -> {
- String pattern = visitString(ctx.string()).fold().toString();
+ String pattern = visitString(ctx.string()).fold(FoldContext.small() /* TODO remove me */).toString();
Map options = visitCommandOptions(ctx.commandOptions());
String appendSeparator = "";
for (Map.Entry item : options.entrySet()) {
@@ -243,7 +244,7 @@ public Map visitCommandOptions(EsqlBaseParser.CommandOptionsCont
}
Map result = new HashMap<>();
for (EsqlBaseParser.CommandOptionContext option : ctx.commandOption()) {
- result.put(visitIdentifier(option.identifier()), expression(option.constant()).fold());
+ result.put(visitIdentifier(option.identifier()), expression(option.constant()).fold(FoldContext.small() /* TODO remove me */));
}
return result;
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java
index 6755a7fa30af9..9b81060349815 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java
@@ -25,6 +25,7 @@
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
@@ -149,7 +150,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(policyName());
out.writeNamedWriteable(matchField());
if (out.getTransportVersion().before(TransportVersions.V_8_13_0)) {
- out.writeString(BytesRefs.toString(policyName().fold())); // old policy name
+ out.writeString(BytesRefs.toString(policyName().fold(FoldContext.small() /* TODO remove me */))); // old policy name
}
policy().writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
index 57ba1c8016feb..072bae21da2a3 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
@@ -25,6 +25,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
@@ -47,9 +48,11 @@
public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {
private final AggregateMapper aggregateMapper = new AggregateMapper();
+ private final FoldContext foldContext;
private final AnalysisRegistry analysisRegistry;
- AbstractPhysicalOperationProviders(AnalysisRegistry analysisRegistry) {
+ AbstractPhysicalOperationProviders(FoldContext foldContext, AnalysisRegistry analysisRegistry) {
+ this.foldContext = foldContext;
this.analysisRegistry = analysisRegistry;
}
@@ -251,6 +254,7 @@ public static List intermediateAttributes(List extends NamedExpress
private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, AggregatorMode mode) {}
private void aggregatesToFactory(
+
List extends NamedExpression> aggregates,
AggregatorMode mode,
Layout layout,
@@ -311,7 +315,11 @@ else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
// apply the filter only in the initial phase - as the rest of the data is already filtered
if (aggregateFunction.hasFilter() && mode.isInputPartial() == false) {
- EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(aggregateFunction.filter(), layout);
+ EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(
+ foldContext,
+ aggregateFunction.filter(),
+ layout
+ );
aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory);
}
consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode));
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
index b1fe0e7a7cf54..8b63a146f2e5d 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
@@ -47,6 +47,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
@@ -94,8 +95,8 @@ public interface ShardContext extends org.elasticsearch.compute.lucene.ShardCont
private final List shardContexts;
- public EsPhysicalOperationProviders(List shardContexts, AnalysisRegistry analysisRegistry) {
- super(analysisRegistry);
+ public EsPhysicalOperationProviders(FoldContext foldContext, List shardContexts, AnalysisRegistry analysisRegistry) {
+ super(foldContext, analysisRegistry);
this.shardContexts = shardContexts;
}
@@ -161,7 +162,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec,
List sorts = esQueryExec.sorts();
assert esQueryExec.estimatedRowSize() != null : "estimated row size not initialized";
int rowEstimatedSize = esQueryExec.estimatedRowSize();
- int limit = esQueryExec.limit() != null ? (Integer) esQueryExec.limit().fold() : NO_LIMIT;
+ int limit = esQueryExec.limit() != null ? (Integer) esQueryExec.limit().fold(context.foldCtx()) : NO_LIMIT;
boolean scoring = esQueryExec.attrs()
.stream()
.anyMatch(a -> a instanceof MetadataAttribute && a.name().equals(MetadataAttribute.SCORE));
@@ -217,7 +218,7 @@ public LuceneCountOperator.Factory countSource(LocalExecutionPlannerContext cont
querySupplier(queryBuilder),
context.queryPragmas().dataPartitioning(),
context.queryPragmas().taskConcurrency(),
- limit == null ? NO_LIMIT : (Integer) limit.fold()
+ limit == null ? NO_LIMIT : (Integer) limit.fold(context.foldCtx())
);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java
index a1765977ee9c2..c185bd5729879 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java
@@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.TranslationAware;
import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
@@ -144,7 +145,7 @@ public static void checkInsensitiveComparison(InsensitiveEquals bc) {
static Query translate(InsensitiveEquals bc) {
TypedAttribute attribute = checkIsPushableAttribute(bc.left());
Source source = bc.source();
- BytesRef value = BytesRefs.toBytesRef(valueOf(bc.right()));
+ BytesRef value = BytesRefs.toBytesRef(valueOf(FoldContext.small() /* TODO remove me */, bc.right()));
String name = pushableAttributeName(attribute);
return new TermQuery(source, name, value.utf8ToString(), true);
}
@@ -188,7 +189,7 @@ static Query translate(BinaryComparison bc, TranslatorHandler handler) {
TypedAttribute attribute = checkIsPushableAttribute(bc.left());
Source source = bc.source();
String name = handler.nameOf(attribute);
- Object result = bc.right().fold();
+ Object result = bc.right().fold(FoldContext.small() /* TODO remove me */);
Object value = result;
String format = null;
boolean isDateLiteralComparison = false;
@@ -269,7 +270,7 @@ private static Query translateOutOfRangeComparisons(BinaryComparison bc) {
return null;
}
Source source = bc.source();
- Object value = valueOf(bc.right());
+ Object value = valueOf(FoldContext.small() /* TODO remove me */, bc.right());
// Comparisons with multi-values always return null in ESQL.
if (value instanceof List>) {
@@ -369,7 +370,7 @@ public static Query doTranslate(ScalarFunction f, TranslatorHandler handler) {
if (f instanceof CIDRMatch cm) {
if (cm.ipField() instanceof FieldAttribute fa && Expressions.foldable(cm.matches())) {
String targetFieldName = handler.nameOf(fa.exactAttribute());
- Set set = new LinkedHashSet<>(Expressions.fold(cm.matches()));
+ Set set = new LinkedHashSet<>(Expressions.fold(FoldContext.small() /* TODO remove me */, cm.matches()));
Query query = new TermsQuery(f.source(), targetFieldName, set);
// CIDR_MATCH applies only to single values.
@@ -420,7 +421,7 @@ static Query translate(
String name = handler.nameOf(attribute);
try {
- Geometry shape = SpatialRelatesUtils.makeGeometryFromLiteral(constantExpression);
+ Geometry shape = SpatialRelatesUtils.makeGeometryFromLiteral(FoldContext.small() /* TODO remove me */, constantExpression);
return new SpatialRelatesQuery(bc.source(), name, bc.queryRelation(), shape, attribute.dataType());
} catch (IllegalArgumentException e) {
throw new QlIllegalArgumentException(e.getMessage(), e);
@@ -461,7 +462,7 @@ private static Query translate(In in, TranslatorHandler handler) {
queries.add(query);
}
} else {
- terms.add(valueOf(rhs));
+ terms.add(valueOf(FoldContext.small() /* TODO remove me */, rhs));
}
}
}
@@ -487,8 +488,8 @@ public static Query doTranslate(Range r, TranslatorHandler handler) {
}
private static RangeQuery translate(Range r, TranslatorHandler handler) {
- Object lower = valueOf(r.lower());
- Object upper = valueOf(r.upper());
+ Object lower = valueOf(FoldContext.small() /* TODO remove me */, r.lower());
+ Object upper = valueOf(FoldContext.small() /* TODO remove me */, r.upper());
String format = null;
DataType dataType = r.value().dataType();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
index af38551c1ad06..ecd0284c7cb57 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
@@ -22,7 +22,6 @@
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory;
-import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory;
import org.elasticsearch.compute.operator.LocalSourceOperator;
import org.elasticsearch.compute.operator.LocalSourceOperator.LocalSourceFactory;
@@ -55,6 +54,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
@@ -161,13 +161,14 @@ public LocalExecutionPlanner(
/**
* turn the given plan into a list of drivers to execute
*/
- public LocalExecutionPlan plan(PhysicalPlan localPhysicalPlan) {
+ public LocalExecutionPlan plan(FoldContext foldCtx, PhysicalPlan localPhysicalPlan) {
var context = new LocalExecutionPlannerContext(
new ArrayList<>(),
new Holder<>(DriverParallelism.SINGLE),
configuration.pragmas(),
bigArrays,
blockFactory,
+ foldCtx,
settings
);
@@ -397,7 +398,7 @@ private PhysicalOperation planEval(EvalExec eval, LocalExecutionPlannerContext c
PhysicalOperation source = plan(eval.child(), context);
for (Alias field : eval.fields()) {
- var evaluatorSupplier = EvalMapper.toEvaluator(field.child(), source.layout);
+ var evaluatorSupplier = EvalMapper.toEvaluator(context.foldCtx(), field.child(), source.layout);
Layout.Builder layout = source.layout.builder();
layout.append(field.toAttribute());
source = source.with(new EvalOperatorFactory(evaluatorSupplier), layout.build());
@@ -418,7 +419,7 @@ private PhysicalOperation planDissect(DissectExec dissect, LocalExecutionPlanner
source = source.with(
new StringExtractOperator.StringExtractOperatorFactory(
patternNames,
- EvalMapper.toEvaluator(expr, layout),
+ EvalMapper.toEvaluator(context.foldCtx(), expr, layout),
() -> (input) -> dissect.parser().parser().parse(input)
),
layout
@@ -450,7 +451,7 @@ private PhysicalOperation planGrok(GrokExec grok, LocalExecutionPlannerContext c
source = source.with(
new ColumnExtractOperator.Factory(
types,
- EvalMapper.toEvaluator(grok.inputExpression(), layout),
+ EvalMapper.toEvaluator(context.foldCtx(), grok.inputExpression(), layout),
() -> new GrokEvaluatorExtracter(grok.pattern().grok(), grok.pattern().pattern(), fieldToPos, fieldToType)
),
layout
@@ -599,10 +600,6 @@ private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlan
);
}
- private ExpressionEvaluator.Factory toEvaluator(Expression exp, Layout layout) {
- return EvalMapper.toEvaluator(exp, layout);
- }
-
private PhysicalOperation planLocal(LocalSourceExec localSourceExec, LocalExecutionPlannerContext context) {
Layout.Builder layout = new Layout.Builder();
layout.append(localSourceExec.output());
@@ -657,12 +654,15 @@ private PhysicalOperation planProject(ProjectExec project, LocalExecutionPlanner
private PhysicalOperation planFilter(FilterExec filter, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(filter.child(), context);
// TODO: should this be extracted into a separate eval block?
- return source.with(new FilterOperatorFactory(toEvaluator(filter.condition(), source.layout)), source.layout);
+ return source.with(
+ new FilterOperatorFactory(EvalMapper.toEvaluator(context.foldCtx(), filter.condition(), source.layout)),
+ source.layout
+ );
}
private PhysicalOperation planLimit(LimitExec limit, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(limit.child(), context);
- return source.with(new Factory((Integer) limit.limit().fold()), source.layout);
+ return source.with(new Factory((Integer) limit.limit().fold(context.foldCtx)), source.layout);
}
private PhysicalOperation planMvExpand(MvExpandExec mvExpandExec, LocalExecutionPlannerContext context) {
@@ -783,6 +783,7 @@ public record LocalExecutionPlannerContext(
QueryPragmas queryPragmas,
BigArrays bigArrays,
BlockFactory blockFactory,
+ FoldContext foldCtx,
Settings settings
) {
void addDriverFactory(DriverFactory driverFactory) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java
index 5325145a77ade..2f4368155069f 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java
@@ -20,6 +20,7 @@
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -159,13 +160,18 @@ private static , E extends T> void forEachUpWithChildren(
}
}
- public static PhysicalPlan localPlan(List searchContexts, Configuration configuration, PhysicalPlan plan) {
- return localPlan(configuration, plan, SearchContextStats.from(searchContexts));
+ public static PhysicalPlan localPlan(
+ List searchContexts,
+ Configuration configuration,
+ FoldContext foldCtx,
+ PhysicalPlan plan
+ ) {
+ return localPlan(configuration, foldCtx, plan, SearchContextStats.from(searchContexts));
}
- public static PhysicalPlan localPlan(Configuration configuration, PhysicalPlan plan, SearchStats searchStats) {
- final var logicalOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats));
- var physicalOptimizer = new LocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats));
+ public static PhysicalPlan localPlan(Configuration configuration, FoldContext foldCtx, PhysicalPlan plan, SearchStats searchStats) {
+ final var logicalOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, foldCtx, searchStats));
+ var physicalOptimizer = new LocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, foldCtx, searchStats));
return localPlan(plan, logicalOptimizer, physicalOptimizer);
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java
index 334875927eb96..4dea8a50b5c17 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java
@@ -14,6 +14,9 @@
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction;
class TypeConverter {
@@ -33,19 +36,26 @@ public static TypeConverter fromConvertFunction(AbstractConvertFunction convertF
BigArrays.NON_RECYCLING_INSTANCE
)
);
- return new TypeConverter(
- convertFunction.functionName(),
- convertFunction.toEvaluator(e -> driverContext -> new ExpressionEvaluator() {
- @Override
- public org.elasticsearch.compute.data.Block eval(Page page) {
- // This is a pass-through evaluator, since it sits directly on the source loading (no prior expressions)
- return page.getBlock(0);
- }
-
- @Override
- public void close() {}
- }).get(driverContext1)
- );
+ return new TypeConverter(convertFunction.functionName(), convertFunction.toEvaluator(new EvaluatorMapper.ToEvaluator() {
+ @Override
+ public ExpressionEvaluator.Factory apply(Expression expression) {
+ return driverContext -> new ExpressionEvaluator() {
+ @Override
+ public org.elasticsearch.compute.data.Block eval(Page page) {
+ // This is a pass-through evaluator, since it sits directly on the source loading (no prior expressions)
+ return page.getBlock(0);
+ }
+
+ @Override
+ public void close() {}
+ };
+ }
+
+ @Override
+ public FoldContext foldCtx() {
+ throw new IllegalStateException("not folding");
+ }
+ }).get(driverContext1));
}
public Block convert(Block block) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
index e881eabb38c43..b8f539ea307c9 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java
@@ -11,6 +11,7 @@
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -90,7 +91,7 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) {
enrich.mode(),
enrich.policy().getType(),
enrich.matchField(),
- BytesRefs.toString(enrich.policyName().fold()),
+ BytesRefs.toString(enrich.policyName().fold(FoldContext.small() /* TODO remove me */)),
enrich.policy().getMatchField(),
enrich.concreteIndices(),
enrich.enrichFields()
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
index 7223e6988bb19..a38236fe60954 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java
@@ -61,6 +61,7 @@
import org.elasticsearch.xpack.esql.action.EsqlQueryAction;
import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
@@ -140,6 +141,7 @@ public void execute(
CancellableTask rootTask,
PhysicalPlan physicalPlan,
Configuration configuration,
+ FoldContext foldContext,
EsqlExecutionInfo execInfo,
ActionListener listener
) {
@@ -174,6 +176,7 @@ public void execute(
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
List.of(),
configuration,
+ foldContext,
null,
null
);
@@ -226,6 +229,7 @@ public void execute(
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
List.of(),
configuration,
+ foldContext,
exchangeSource,
null
),
@@ -460,16 +464,16 @@ public SourceProvider createSourceProvider() {
context.exchangeSink(),
enrichLookupService,
lookupFromIndexService,
- new EsPhysicalOperationProviders(contexts, searchService.getIndicesService().getAnalysis())
+ new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis())
);
LOGGER.debug("Received physical plan:\n{}", plan);
- plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration, plan);
+ plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration, context.foldCtx(), plan);
// the planner will also set the driver parallelism in LocalExecutionPlanner.LocalExecutionPlan (used down below)
// it's doing this in the planning of EsQueryExec (the source of the data)
// see also EsPhysicalOperationProviders.sourcePhysicalOperation
- LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(plan);
+ LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(context.foldCtx(), plan);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe());
}
@@ -715,7 +719,15 @@ public void onFailure(Exception e) {
};
acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> {
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH, ESQL_WORKER_THREAD_POOL_NAME);
- var computeContext = new ComputeContext(sessionId, clusterAlias, searchContexts, configuration, null, exchangeSink);
+ var computeContext = new ComputeContext(
+ sessionId,
+ clusterAlias,
+ searchContexts,
+ configuration,
+ configuration.newFoldContext(),
+ null,
+ exchangeSink
+ );
runCompute(parentTask, computeContext, request.plan(), batchListener);
}, batchListener::onFailure));
}
@@ -766,6 +778,7 @@ private void runComputeOnDataNode(
request.clusterAlias(),
List.of(),
request.configuration(),
+ new FoldContext(request.pragmas().foldLimit().getBytes()),
exchangeSource,
externalSink
),
@@ -901,7 +914,15 @@ void runComputeOnRemoteCluster(
exchangeSink.addCompletionListener(computeListener.acquireAvoid());
runCompute(
parentTask,
- new ComputeContext(localSessionId, clusterAlias, List.of(), configuration, exchangeSource, exchangeSink),
+ new ComputeContext(
+ localSessionId,
+ clusterAlias,
+ List.of(),
+ configuration,
+ configuration.newFoldContext(),
+ exchangeSource,
+ exchangeSink
+ ),
coordinatorPlan,
computeListener.acquireCompute(clusterAlias)
);
@@ -925,6 +946,7 @@ record ComputeContext(
String clusterAlias,
List searchContexts,
Configuration configuration,
+ FoldContext foldCtx,
ExchangeSourceHandler exchangeSource,
ExchangeSinkHandler exchangeSink
) {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java
index 58e80e569ee5e..2443c3f2cda62 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java
@@ -12,12 +12,14 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.compute.lucene.DataPartitioning;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverStatus;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
import java.io.IOException;
import java.util.Objects;
@@ -57,6 +59,8 @@ public final class QueryPragmas implements Writeable {
public static final Setting NODE_LEVEL_REDUCTION = Setting.boolSetting("node_level_reduction", true);
+ public static final Setting FOLD_LIMIT = Setting.memorySizeSetting("fold_limit", "5%");
+
public static final QueryPragmas EMPTY = new QueryPragmas(Settings.EMPTY);
private final Settings settings;
@@ -134,6 +138,17 @@ public boolean nodeLevelReduction() {
return NODE_LEVEL_REDUCTION.get(settings);
}
+ /**
+ * The maximum amount of memory we can use for {@link Expression#fold} during planing. This
+ * defaults to 5% of memory available on the current node. If this method is called on the
+ * coordinating node, this is 5% of the coordinating node's memory. If it's called on a data
+ * node, it's 5% of the data node. That's an exciting inconsistency. But it's
+ * important. Bigger nodes have more space to do folding.
+ */
+ public ByteSizeValue foldLimit() {
+ return FOLD_LIMIT.get(settings);
+ }
+
public boolean isEmpty() {
return settings.isEmpty();
}
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
index b44e249e38006..84173eeecc060 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java
@@ -43,6 +43,7 @@
import org.elasticsearch.xpack.esql.action.EsqlQueryResponse;
import org.elasticsearch.xpack.esql.action.EsqlQueryTask;
import org.elasticsearch.xpack.esql.core.async.AsyncTaskManagementService;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
@@ -189,11 +190,13 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener computeService.execute(
sessionId,
(CancellableTask) task,
plan,
configuration,
+ foldCtx,
executionInfo,
resultListener
);
@@ -201,6 +204,7 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener new EnrichPolicyResolver.UnresolvedPolicy((String) e.policyName().fold(), e.mode()))
+ .map(
+ e -> new EnrichPolicyResolver.UnresolvedPolicy(
+ (String) e.policyName().fold(FoldContext.small() /* TODO remove me*/),
+ e.mode()
+ )
+ )
.collect(Collectors.toSet());
final List indices = preAnalysis.indices;
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java
index 0847f71b1fb01..95ee6ab337bd6 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java
@@ -20,6 +20,7 @@
import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.Converter;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -239,9 +240,9 @@ public static Converter converterFor(DataType from, DataType to) {
return null;
}
- public static TemporalAmount foldToTemporalAmount(Expression field, String sourceText, DataType expectedType) {
+ public static TemporalAmount foldToTemporalAmount(FoldContext ctx, Expression field, String sourceText, DataType expectedType) {
if (field.foldable()) {
- Object v = field.fold();
+ Object v = field.fold(ctx);
if (v instanceof BytesRef b) {
try {
return EsqlDataTypeConverter.parseTemporalAmount(b.utf8ToString(), expectedType);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index 7d4374934ab82..ee5073c05cab1 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -55,6 +55,7 @@
import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
import org.elasticsearch.xpack.esql.analysis.PreAnalyzer;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.core.type.InvalidMappedField;
@@ -482,15 +483,16 @@ private static CsvTestsDataLoader.MultiIndexTestDataset testDatasets(LogicalPlan
return new CsvTestsDataLoader.MultiIndexTestDataset(indexName, datasets);
}
- private static TestPhysicalOperationProviders testOperationProviders(CsvTestsDataLoader.MultiIndexTestDataset datasets)
- throws Exception {
- var indexResolution = loadIndexResolution(datasets);
+ private static TestPhysicalOperationProviders testOperationProviders(
+ FoldContext foldCtx,
+ CsvTestsDataLoader.MultiIndexTestDataset datasets
+ ) throws Exception {
var indexPages = new ArrayList();
for (CsvTestsDataLoader.TestDataset dataset : datasets.datasets()) {
var testData = loadPageFromCsv(CsvTests.class.getResource("/data/" + dataset.dataFileName()), dataset.typeMapping());
indexPages.add(new TestPhysicalOperationProviders.IndexPage(dataset.indexName(), testData.v1(), testData.v2()));
}
- return TestPhysicalOperationProviders.create(indexPages);
+ return TestPhysicalOperationProviders.create(foldCtx, indexPages);
}
private ActualResults executePlan(BigArrays bigArrays) throws Exception {
@@ -498,6 +500,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
var testDatasets = testDatasets(parsed);
LogicalPlan analyzed = analyzedPlan(parsed, testDatasets);
+ FoldContext foldCtx = FoldContext.small();
EsqlSession session = new EsqlSession(
getTestName(),
configuration,
@@ -505,21 +508,21 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
null,
null,
functionRegistry,
- new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)),
+ new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)),
mapper,
TEST_VERIFIER,
new PlanningMetrics(),
null,
EsqlTestUtils.MOCK_QUERY_BUILDER_RESOLVER
);
- TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(testDatasets);
+ TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(foldCtx, testDatasets);
PlainActionFuture listener = new PlainActionFuture<>();
session.executeOptimizedPlan(
new EsqlQueryRequest(),
new EsqlExecutionInfo(randomBoolean()),
- planRunner(bigArrays, physicalOperationProviders),
+ planRunner(bigArrays, foldCtx, physicalOperationProviders),
session.optimizedPlan(analyzed),
listener.delegateFailureAndWrap(
// Wrap so we can capture the warnings in the calling thread
@@ -579,12 +582,13 @@ private void assertWarnings(List warnings) {
testCase.assertWarnings(false).assertWarnings(normalized);
}
- PlanRunner planRunner(BigArrays bigArrays, TestPhysicalOperationProviders physicalOperationProviders) {
- return (physicalPlan, listener) -> executeSubPlan(bigArrays, physicalOperationProviders, physicalPlan, listener);
+ PlanRunner planRunner(BigArrays bigArrays, FoldContext foldCtx, TestPhysicalOperationProviders physicalOperationProviders) {
+ return (physicalPlan, listener) -> executeSubPlan(bigArrays, foldCtx, physicalOperationProviders, physicalPlan, listener);
}
void executeSubPlan(
BigArrays bigArrays,
+ FoldContext foldCtx,
TestPhysicalOperationProviders physicalOperationProviders,
PhysicalPlan physicalPlan,
ActionListener listener
@@ -630,12 +634,17 @@ void executeSubPlan(
// replace fragment inside the coordinator plan
List drivers = new ArrayList<>();
- LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(new OutputExec(coordinatorPlan, collectedPages::add));
+ LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(
+ foldCtx,
+ new OutputExec(coordinatorPlan, collectedPages::add)
+ );
drivers.addAll(coordinatorNodeExecutionPlan.createDrivers(getTestName()));
if (dataNodePlan != null) {
var searchStats = new DisabledSearchStats();
- var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats));
- var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats));
+ var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, foldCtx, searchStats));
+ var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(
+ new LocalPhysicalOptimizerContext(configuration, foldCtx, searchStats)
+ );
var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer);
exchangeSource.addRemoteSink(
@@ -646,7 +655,7 @@ void executeSubPlan(
throw new AssertionError("expected no failure", e);
})
);
- LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan);
+ LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(foldCtx, csvDataNodePhysicalPlan);
drivers.addAll(dataNodeExecutionPlan.createDrivers(getTestName()));
Randomness.shuffle(drivers);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index dc4120f357725..2df6e30e96081 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -1077,7 +1077,7 @@ public void testImplicitLimit() {
from test
""");
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(DEFAULT_LIMIT));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
as(limit.child(), EsRelation.class);
}
@@ -1085,7 +1085,7 @@ public void testImplicitMaxLimitAfterLimit() {
for (int i = -1; i <= 1; i++) {
var plan = analyze("from test | limit " + (MAX_LIMIT + i));
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(MAX_LIMIT));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
limit = as(limit.child(), Limit.class);
as(limit.child(), EsRelation.class);
}
@@ -1102,7 +1102,7 @@ public void testImplicitMaxLimitAfterLimitAndNonLimit() {
for (int i = -1; i <= 1; i++) {
var plan = analyze("from test | limit " + (MAX_LIMIT + i) + " | eval s = salary * 10 | where s > 0");
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(MAX_LIMIT));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
var filter = as(limit.child(), Filter.class);
var eval = as(filter.child(), Eval.class);
limit = as(eval.child(), Limit.class);
@@ -1114,7 +1114,7 @@ public void testImplicitDefaultLimitAfterLimitAndBreaker() {
for (var breaker : List.of("stats c = count(salary) by last_name", "sort salary")) {
var plan = analyze("from test | limit 100000 | " + breaker);
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(MAX_LIMIT));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT));
}
}
@@ -1122,7 +1122,7 @@ public void testImplicitDefaultLimitAfterBreakerAndNonBreakers() {
for (var breaker : List.of("stats c = count(salary) by last_name", "eval c = salary | sort c")) {
var plan = analyze("from test | " + breaker + " | eval cc = c * 10 | where cc > 0");
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(DEFAULT_LIMIT));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT));
}
}
@@ -1428,7 +1428,7 @@ public void testEmptyEsRelationOnLimitZeroWithCount() throws IOException {
var plan = analyzeWithEmptyFieldCapsResponse(query);
var limit = as(plan, Limit.class);
limit = as(limit.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(0));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(0));
var orderBy = as(limit.child(), OrderBy.class);
var agg = as(orderBy.child(), Aggregate.class);
assertEmptyEsRelation(agg.child());
@@ -1443,7 +1443,7 @@ public void testEmptyEsRelationOnConstantEvalAndKeep() throws IOException {
var plan = analyzeWithEmptyFieldCapsResponse(query);
var limit = as(plan, Limit.class);
limit = as(limit.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(2));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(2));
var project = as(limit.child(), EsqlProject.class);
var eval = as(project.child(), Eval.class);
assertEmptyEsRelation(eval.child());
@@ -1460,7 +1460,7 @@ public void testEmptyEsRelationOnConstantEvalAndStats() throws IOException {
var agg = as(limit.child(), Aggregate.class);
var eval = as(agg.child(), Eval.class);
limit = as(eval.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(10));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(10));
assertEmptyEsRelation(limit.child());
}
@@ -2054,10 +2054,10 @@ public void testLookup() {
}
LogicalPlan plan = analyze(query);
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(1000));
var lookup = as(limit.child(), Lookup.class);
- assertThat(lookup.tableName().fold(), equalTo("int_number_names"));
+ assertThat(as(lookup.tableName(), Literal.class).value(), equalTo("int_number_names"));
assertMap(lookup.matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("int{r}")));
assertThat(
lookup.localRelation().output().stream().map(Object::toString).toList(),
@@ -2336,7 +2336,7 @@ public void testCoalesceWithMixedNumericTypes() {
projection = as(projections.get(3), ReferenceAttribute.class);
assertEquals(projection.name(), "w");
assertEquals(projection.dataType(), DataType.DOUBLE);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(as(limit.limit(), Literal.class).value(), equalTo(1000));
}
public void testNamedParamsForIdentifiers() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java
new file mode 100644
index 0000000000000..828f9e061686b
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.evaluator.mapper;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
+import org.hamcrest.Matchers;
+
+public class EvaluatorMapperTests extends ESTestCase {
+ public void testFoldCompletesWithPlentyOfMemory() {
+ Add add = new Add(
+ Source.synthetic("shouldn't break"),
+ new Literal(Source.EMPTY, 1, DataType.INTEGER),
+ new Literal(Source.EMPTY, 3, DataType.INTEGER)
+ );
+ assertEquals(add.fold(new FoldContext(100)), 4);
+ }
+
+ public void testFoldBreaksWithLittleMemory() {
+ Add add = new Add(
+ Source.synthetic("should break"),
+ new Literal(Source.EMPTY, 1, DataType.INTEGER),
+ new Literal(Source.EMPTY, 3, DataType.INTEGER)
+ );
+ Exception e = expectThrows(FoldContext.FoldTooMuchMemoryException.class, () -> add.fold(new FoldContext(10)));
+ assertThat(
+ e.getMessage(),
+ Matchers.equalTo(
+ "line -1:-1: Folding query used more than 10b. "
+ + "The expression that pushed past the limit is [should break] which needed 32b."
+ )
+ );
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
index c086245d6fd61..87ea6315d4f3b 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
@@ -22,6 +22,7 @@
import org.elasticsearch.core.Releasables;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.NumericUtils;
@@ -40,6 +41,7 @@
import java.util.stream.IntStream;
import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
@@ -263,12 +265,12 @@ private void evaluate(Expression evaluableExpression) {
assertTrue(evaluableExpression.foldable());
if (testCase.foldingExceptionClass() != null) {
- Throwable t = expectThrows(testCase.foldingExceptionClass(), evaluableExpression::fold);
+ Throwable t = expectThrows(testCase.foldingExceptionClass(), () -> evaluableExpression.fold(FoldContext.small()));
assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage()));
return;
}
- Object result = evaluableExpression.fold();
+ Object result = evaluableExpression.fold(FoldContext.small());
// Decode unsigned longs into BigIntegers
if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) {
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
@@ -289,7 +291,7 @@ private void resolveExpression(Expression expression, Consumer onAgg
expression = resolveSurrogates(expression);
// As expressions may be composed of multiple functions, we need to fold nulls bottom-up
- expression = expression.transformUp(e -> new FoldNull().rule(e));
+ expression = expression.transformUp(e -> new FoldNull().rule(e, unboundLogicalOptimizerContext()));
assertThat(expression.dataType(), equalTo(testCase.expectedType()));
Expression.TypeResolution resolution = expression.typeResolved();
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
index 03b9dba298951..1cf087cf55ccd 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
@@ -34,6 +34,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull;
@@ -44,6 +45,7 @@
import org.elasticsearch.xpack.esql.core.util.NumericUtils;
import org.elasticsearch.xpack.esql.core.util.StringUtils;
import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
+import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
@@ -97,6 +99,7 @@
import static java.util.Map.entry;
import static org.elasticsearch.compute.data.BlockUtils.toJavaObject;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.endsWith;
@@ -530,14 +533,28 @@ protected final Expression buildLiteralExpression(TestCaseSupplier.TestCase test
return build(testCase.getSource(), testCase.getDataAsLiterals());
}
+ public static EvaluatorMapper.ToEvaluator toEvaluator() {
+ return new EvaluatorMapper.ToEvaluator() {
+ @Override
+ public ExpressionEvaluator.Factory apply(Expression expression) {
+ return evaluator(expression);
+ }
+
+ @Override
+ public FoldContext foldCtx() {
+ return FoldContext.small();
+ }
+ };
+ }
+
/**
* Convert an {@link Expression} tree into a {@link ExpressionEvaluator.Factory}
* for {@link ExpressionEvaluator}s in the same way as our planner.
*/
public static ExpressionEvaluator.Factory evaluator(Expression e) {
- e = new FoldNull().rule(e);
+ e = new FoldNull().rule(e, unboundLogicalOptimizerContext());
if (e.foldable()) {
- e = new Literal(e.source(), e.fold(), e.dataType());
+ e = new Literal(e.source(), e.fold(FoldContext.small()), e.dataType());
}
Layout.Builder builder = new Layout.Builder();
buildLayout(builder, e);
@@ -545,7 +562,7 @@ public static ExpressionEvaluator.Factory evaluator(Expression e) {
if (resolution.unresolved()) {
throw new AssertionError("expected resolved " + resolution.message());
}
- return EvalMapper.toEvaluator(e, builder.build());
+ return EvalMapper.toEvaluator(FoldContext.small(), e, builder.build());
}
protected final Page row(List values) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
index 65b9c447170f4..64086334b7251 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
@@ -20,6 +20,7 @@
import org.elasticsearch.core.Releasables;
import org.elasticsearch.indices.CrankyCircuitBreakerService;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.NumericUtils;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.FoldNull;
@@ -38,6 +39,7 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
@@ -132,7 +134,7 @@ public final void testEvaluate() {
if (resolution.unresolved()) {
throw new AssertionError("expected resolved " + resolution.message());
}
- expression = new FoldNull().rule(expression);
+ expression = new FoldNull().rule(expression, unboundLogicalOptimizerContext());
assertThat(expression.dataType(), equalTo(testCase.expectedType()));
logger.info("Result type: " + expression.dataType());
@@ -363,11 +365,11 @@ public void testFold() {
return;
}
assertFalse("expected resolved", expression.typeResolved().unresolved());
- Expression nullOptimized = new FoldNull().rule(expression);
+ Expression nullOptimized = new FoldNull().rule(expression, unboundLogicalOptimizerContext());
assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
assertTrue(nullOptimized.foldable());
if (testCase.foldingExceptionClass() == null) {
- Object result = nullOptimized.fold();
+ Object result = nullOptimized.fold(FoldContext.small());
// Decode unsigned longs into BigIntegers
if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) {
result = NumericUtils.unsignedLongAsBigInteger((Long) result);
@@ -380,7 +382,7 @@ public void testFold() {
assertWarnings(testCase.getExpectedWarnings());
}
} else {
- Throwable t = expectThrows(testCase.foldingExceptionClass(), nullOptimized::fold);
+ Throwable t = expectThrows(testCase.foldingExceptionClass(), () -> nullOptimized.fold(FoldContext.small()));
assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage()));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java
index 19af9892015b2..e507640c7b23c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java
@@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
import org.elasticsearch.xpack.esql.analysis.Verifier;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.function.Function;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -80,7 +81,9 @@ public EsqlFunctionRegistry snapshotRegistry() {
var plan = parser.createStatement(esql);
plan = plan.transformDown(
Limit.class,
- l -> Objects.equals(l.limit().fold(), 10) ? new LicensedLimit(l.source(), l.limit(), l.child(), functionLicenseFeature) : l
+ l -> Objects.equals(l.limit().fold(FoldContext.small()), 10)
+ ? new LicensedLimit(l.source(), l.limit(), l.child(), functionLicenseFeature)
+ : l
);
return analyzer(registry, operationMode).analyze(plan);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java
index f2bae0c5a4979..9bf063518d4ba 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java
@@ -1412,7 +1412,7 @@ public static final class TestCase {
/**
* Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator}
- * or {@link Expression#fold()} on the expression built by this.
+ * or {@link Expression#fold} on the expression built by this.
*/
private final String[] expectedBuildEvaluatorWarnings;
@@ -1542,7 +1542,7 @@ public String[] getExpectedWarnings() {
/**
* Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator}
- * or {@link Expression#fold()} on the expression built by this.
+ * or {@link Expression#fold} on the expression built by this.
*/
public String[] getExpectedBuildEvaluatorWarnings() {
return expectedBuildEvaluatorWarnings;
@@ -1624,7 +1624,7 @@ public TestCase withWarning(String warning) {
/**
* Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator}
- * or {@link Expression#fold()} on the expression built by this.
+ * or {@link Expression#fold} on the expression built by this.
*/
public TestCase withBuildEvaluatorWarning(String warning) {
return new TestCase(
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java
index de84086e3cb4e..911878a645b42 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java
@@ -19,9 +19,11 @@
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.junit.After;
@@ -67,7 +69,7 @@ public void testPartialFoldDropsFirstFalse() {
);
assertThat(c.foldable(), equalTo(false));
assertThat(
- c.partiallyFold(),
+ c.partiallyFold(FoldContext.small()),
equalTo(new Case(Source.synthetic("case"), field("last_cond", DataType.BOOLEAN), List.of(field("last", DataType.LONG))))
);
}
@@ -80,7 +82,7 @@ public void testPartialFoldMv() {
);
assertThat(c.foldable(), equalTo(false));
assertThat(
- c.partiallyFold(),
+ c.partiallyFold(FoldContext.small()),
equalTo(new Case(Source.synthetic("case"), field("last_cond", DataType.BOOLEAN), List.of(field("last", DataType.LONG))))
);
}
@@ -92,7 +94,7 @@ public void testPartialFoldNoop() {
List.of(field("first", DataType.LONG), field("last", DataType.LONG))
);
assertThat(c.foldable(), equalTo(false));
- assertThat(c.partiallyFold(), sameInstance(c));
+ assertThat(c.partiallyFold(FoldContext.small()), sameInstance(c));
}
public void testPartialFoldFirst() {
@@ -102,7 +104,7 @@ public void testPartialFoldFirst() {
List.of(field("first", DataType.LONG), field("last", DataType.LONG))
);
assertThat(c.foldable(), equalTo(false));
- assertThat(c.partiallyFold(), equalTo(field("first", DataType.LONG)));
+ assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("first", DataType.LONG)));
}
public void testPartialFoldFirstAfterKeepingUnknown() {
@@ -118,7 +120,7 @@ public void testPartialFoldFirstAfterKeepingUnknown() {
);
assertThat(c.foldable(), equalTo(false));
assertThat(
- c.partiallyFold(),
+ c.partiallyFold(FoldContext.small()),
equalTo(
new Case(
Source.synthetic("case"),
@@ -141,7 +143,7 @@ public void testPartialFoldSecond() {
)
);
assertThat(c.foldable(), equalTo(false));
- assertThat(c.partiallyFold(), equalTo(field("second", DataType.LONG)));
+ assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("second", DataType.LONG)));
}
public void testPartialFoldSecondAfterDroppingFalse() {
@@ -156,7 +158,7 @@ public void testPartialFoldSecondAfterDroppingFalse() {
)
);
assertThat(c.foldable(), equalTo(false));
- assertThat(c.partiallyFold(), equalTo(field("second", DataType.LONG)));
+ assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("second", DataType.LONG)));
}
public void testPartialFoldLast() {
@@ -171,7 +173,7 @@ public void testPartialFoldLast() {
)
);
assertThat(c.foldable(), equalTo(false));
- assertThat(c.partiallyFold(), equalTo(field("last", DataType.LONG)));
+ assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("last", DataType.LONG)));
}
public void testPartialFoldLastAfterKeepingUnknown() {
@@ -187,7 +189,7 @@ public void testPartialFoldLastAfterKeepingUnknown() {
);
assertThat(c.foldable(), equalTo(false));
assertThat(
- c.partiallyFold(),
+ c.partiallyFold(FoldContext.small()),
equalTo(
new Case(
Source.synthetic("case"),
@@ -203,7 +205,7 @@ public void testEvalCase() {
DriverContext driverContext = driverContext();
Page page = new Page(driverContext.blockFactory().newConstantIntBlockWith(0, 1));
try (
- EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(AbstractFunctionTestCase::evaluator).get(driverContext);
+ EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(AbstractFunctionTestCase.toEvaluator()).get(driverContext);
Block block = eval.eval(page)
) {
return toJavaObject(block, 0);
@@ -216,7 +218,7 @@ public void testEvalCase() {
public void testFoldCase() {
testCase(caseExpr -> {
assertTrue(caseExpr.foldable());
- return caseExpr.fold();
+ return caseExpr.fold(FoldContext.small());
});
}
@@ -265,22 +267,31 @@ public void testCaseWithIncompatibleTypes() {
public void testCaseIsLazy() {
Case caseExpr = caseExpr(true, 1, true, 2);
DriverContext driveContext = driverContext();
- EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(child -> {
- Object value = child.fold();
- if (value != null && value.equals(2)) {
- return dvrCtx -> new EvalOperator.ExpressionEvaluator() {
- @Override
- public Block eval(Page page) {
- fail("Unexpected evaluation of 4th argument");
- return null;
- }
+ EvaluatorMapper.ToEvaluator toEvaluator = new EvaluatorMapper.ToEvaluator() {
+ @Override
+ public EvalOperator.ExpressionEvaluator.Factory apply(Expression expression) {
+ Object value = expression.fold(FoldContext.small());
+ if (value != null && value.equals(2)) {
+ return dvrCtx -> new EvalOperator.ExpressionEvaluator() {
+ @Override
+ public Block eval(Page page) {
+ fail("Unexpected evaluation of 4th argument");
+ return null;
+ }
- @Override
- public void close() {}
- };
+ @Override
+ public void close() {}
+ };
+ }
+ return AbstractFunctionTestCase.evaluator(expression);
}
- return AbstractFunctionTestCase.evaluator(child);
- }).get(driveContext);
+
+ @Override
+ public FoldContext foldCtx() {
+ return FoldContext.small();
+ }
+ };
+ EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(toEvaluator).get(driveContext);
Page page = new Page(driveContext.blockFactory().newConstantIntBlockWith(0, 1));
try (Block block = evaluator.eval(page)) {
assertEquals(1, toJavaObject(block, 0));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java
index 05923246520fc..23a0f2307171c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java
@@ -14,6 +14,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.NumericUtils;
@@ -779,7 +780,7 @@ public void testFancyFolding() {
return;
}
assertThat(e.foldable(), equalTo(true));
- Object result = e.fold();
+ Object result = e.fold(FoldContext.small());
if (testCase.getExpectedBuildEvaluatorWarnings() != null) {
assertWarnings(testCase.getExpectedBuildEvaluatorWarnings());
}
@@ -799,18 +800,18 @@ public void testPartialFold() {
}
Case c = (Case) buildFieldExpression(testCase);
if (extra().expectedPartialFold == null) {
- assertThat(c.partiallyFold(), sameInstance(c));
+ assertThat(c.partiallyFold(FoldContext.small()), sameInstance(c));
return;
}
if (extra().expectedPartialFold.size() == 1) {
- assertThat(c.partiallyFold(), equalTo(extra().expectedPartialFold.get(0).asField()));
+ assertThat(c.partiallyFold(FoldContext.small()), equalTo(extra().expectedPartialFold.get(0).asField()));
return;
}
Case expected = build(
Source.synthetic("expected"),
extra().expectedPartialFold.stream().map(TestCaseSupplier.TypedData::asField).toList()
);
- assertThat(c.partiallyFold(), equalTo(expected));
+ assertThat(c.partiallyFold(FoldContext.small()), equalTo(expected));
}
private static Function addWarnings(List warnings) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java
index be978eda06758..cd27ce511b317 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java
@@ -15,6 +15,7 @@
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -99,7 +100,7 @@ public void testAllChronoFields() {
EsqlTestUtils.TEST_CFG
);
- assertThat(instance.fold(), is(date.getLong(value)));
+ assertThat(instance.fold(FoldContext.small()), is(date.getLong(value)));
assertThat(
DateExtract.process(epochMilli, new BytesRef(value.name()), EsqlTestUtils.TEST_CFG.zoneId()),
is(date.getLong(value))
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java
index 797c99992815e..688341ebaa2b7 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java
@@ -16,6 +16,7 @@
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -187,19 +188,27 @@ public void testCoalesceIsLazy() {
Layout.Builder builder = new Layout.Builder();
buildLayout(builder, exp);
Layout layout = builder.build();
- EvaluatorMapper.ToEvaluator toEvaluator = child -> {
- if (child == evil) {
- return dvrCtx -> new EvalOperator.ExpressionEvaluator() {
- @Override
- public Block eval(Page page) {
- throw new AssertionError("shouldn't be called");
- }
-
- @Override
- public void close() {}
- };
+ EvaluatorMapper.ToEvaluator toEvaluator = new EvaluatorMapper.ToEvaluator() {
+ @Override
+ public EvalOperator.ExpressionEvaluator.Factory apply(Expression expression) {
+ if (expression == evil) {
+ return dvrCtx -> new EvalOperator.ExpressionEvaluator() {
+ @Override
+ public Block eval(Page page) {
+ throw new AssertionError("shouldn't be called");
+ }
+
+ @Override
+ public void close() {}
+ };
+ }
+ return EvalMapper.toEvaluator(FoldContext.small(), expression, layout);
+ }
+
+ @Override
+ public FoldContext foldCtx() {
+ return FoldContext.small();
}
- return EvalMapper.toEvaluator(child, layout);
};
try (
EvalOperator.ExpressionEvaluator eval = exp.toEvaluator(toEvaluator).get(driverContext());
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java
index 4f8adf3abaae6..6c41552a9fc52 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java
@@ -12,6 +12,7 @@
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -159,8 +160,8 @@ protected Expression build(Source source, List args) {
Expression expression = args.get(0);
Literal pattern = (Literal) args.get(1);
Literal caseInsensitive = args.size() > 2 ? (Literal) args.get(2) : null;
- String patternString = ((BytesRef) pattern.fold()).utf8ToString();
- boolean caseInsensitiveBool = caseInsensitive != null ? (boolean) caseInsensitive.fold() : false;
+ String patternString = ((BytesRef) pattern.fold(FoldContext.small())).utf8ToString();
+ boolean caseInsensitiveBool = caseInsensitive != null ? (boolean) caseInsensitive.fold(FoldContext.small()) : false;
logger.info("pattern={} caseInsensitive={}", patternString, caseInsensitiveBool);
return caseInsensitiveBool
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java
index b355feb6130a3..f779dd038454d 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java
@@ -15,6 +15,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -54,7 +55,7 @@ public void testRandomLocale() {
String testString = randomAlphaOfLength(10);
Configuration cfg = randomLocaleConfig();
ToLower func = new ToLower(Source.EMPTY, new Literal(Source.EMPTY, testString, DataType.KEYWORD), cfg);
- assertThat(BytesRefs.toBytesRef(testString.toLowerCase(cfg.locale())), equalTo(func.fold()));
+ assertThat(BytesRefs.toBytesRef(testString.toLowerCase(cfg.locale())), equalTo(func.fold(FoldContext.small())));
}
private Configuration randomLocaleConfig() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java
index fdae4f953a0fa..3957c2e1fb2c0 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java
@@ -15,6 +15,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -54,7 +55,7 @@ public void testRandomLocale() {
String testString = randomAlphaOfLength(10);
Configuration cfg = randomLocaleConfig();
ToUpper func = new ToUpper(Source.EMPTY, new Literal(Source.EMPTY, testString, DataType.KEYWORD), cfg);
- assertThat(BytesRefs.toBytesRef(testString.toUpperCase(cfg.locale())), equalTo(func.fold()));
+ assertThat(BytesRefs.toBytesRef(testString.toUpperCase(cfg.locale())), equalTo(func.fold(FoldContext.small())));
}
private Configuration randomLocaleConfig() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java
index eed2c7379e9e1..6626ac50d60b5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java
@@ -12,6 +12,7 @@
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern;
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -78,8 +79,8 @@ protected Expression build(Source source, List args) {
Literal pattern = (Literal) args.get(1);
if (args.size() > 2) {
Literal caseInsensitive = (Literal) args.get(2);
- assertThat(caseInsensitive.fold(), equalTo(false));
+ assertThat(caseInsensitive.fold(FoldContext.small()), equalTo(false));
}
- return new WildcardLike(source, expression, new WildcardPattern(((BytesRef) pattern.fold()).utf8ToString()));
+ return new WildcardLike(source, expression, new WildcardPattern(((BytesRef) pattern.fold(FoldContext.small())).utf8ToString()));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java
index a8c7b5b5a83fd..15860d35539e0 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -154,7 +155,7 @@ public void testEdgeCases() {
private Object foldTemporalAmount(Object val) {
Neg neg = new Neg(Source.EMPTY, new Literal(Source.EMPTY, val, typeOf(val)));
- return neg.fold();
+ return neg.fold(FoldContext.small());
}
private static DataType typeOf(Object val) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java
index b004adca351ab..80f67ec8e5e3a 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java
@@ -13,6 +13,7 @@
import org.elasticsearch.geo.GeometryTestUtils;
import org.elasticsearch.geo.ShapeTestUtils;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -48,27 +49,27 @@ public InTests(@Name("TestCase") Supplier testCaseSup
public void testInWithContainedValue() {
In in = new In(EMPTY, TWO, Arrays.asList(ONE, TWO, THREE));
- assertTrue((Boolean) in.fold());
+ assertTrue((Boolean) in.fold(FoldContext.small()));
}
public void testInWithNotContainedValue() {
In in = new In(EMPTY, THREE, Arrays.asList(ONE, TWO));
- assertFalse((Boolean) in.fold());
+ assertFalse((Boolean) in.fold(FoldContext.small()));
}
public void testHandleNullOnLeftValue() {
In in = new In(EMPTY, NULL, Arrays.asList(ONE, TWO, THREE));
- assertNull(in.fold());
+ assertNull(in.fold(FoldContext.small()));
in = new In(EMPTY, NULL, Arrays.asList(ONE, NULL, THREE));
- assertNull(in.fold());
+ assertNull(in.fold(FoldContext.small()));
}
public void testHandleNullsOnRightValue() {
In in = new In(EMPTY, THREE, Arrays.asList(ONE, NULL, THREE));
- assertTrue((Boolean) in.fold());
+ assertTrue((Boolean) in.fold(FoldContext.small()));
in = new In(EMPTY, ONE, Arrays.asList(TWO, NULL, THREE));
- assertNull(in.fold());
+ assertNull(in.fold(FoldContext.small()));
}
private static Literal L(Object value) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java
index faf0a0d8f418c..6fa1112f23f45 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java
@@ -10,6 +10,7 @@
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.of;
@@ -18,37 +19,37 @@
public class InsensitiveEqualsTests extends ESTestCase {
public void testFold() {
- assertTrue(insensitiveEquals(l("foo"), l("foo")).fold());
- assertTrue(insensitiveEquals(l("Foo"), l("foo")).fold());
- assertTrue(insensitiveEquals(l("Foo"), l("fOO")).fold());
- assertTrue(insensitiveEquals(l("foo*"), l("foo*")).fold());
- assertTrue(insensitiveEquals(l("foo*"), l("FOO*")).fold());
- assertTrue(insensitiveEquals(l("foo?bar"), l("foo?bar")).fold());
- assertTrue(insensitiveEquals(l("foo?bar"), l("FOO?BAR")).fold());
- assertFalse(insensitiveEquals(l("Foo"), l("fo*")).fold());
- assertFalse(insensitiveEquals(l("Fox"), l("fo?")).fold());
- assertFalse(insensitiveEquals(l("Foo"), l("*OO")).fold());
- assertFalse(insensitiveEquals(l("BarFooBaz"), l("*O*")).fold());
- assertFalse(insensitiveEquals(l("BarFooBaz"), l("bar*baz")).fold());
- assertFalse(insensitiveEquals(l("foo"), l("*")).fold());
+ assertTrue(insensitiveEquals(l("foo"), l("foo")).fold(FoldContext.small()));
+ assertTrue(insensitiveEquals(l("Foo"), l("foo")).fold(FoldContext.small()));
+ assertTrue(insensitiveEquals(l("Foo"), l("fOO")).fold(FoldContext.small()));
+ assertTrue(insensitiveEquals(l("foo*"), l("foo*")).fold(FoldContext.small()));
+ assertTrue(insensitiveEquals(l("foo*"), l("FOO*")).fold(FoldContext.small()));
+ assertTrue(insensitiveEquals(l("foo?bar"), l("foo?bar")).fold(FoldContext.small()));
+ assertTrue(insensitiveEquals(l("foo?bar"), l("FOO?BAR")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("Foo"), l("fo*")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("Fox"), l("fo?")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("Foo"), l("*OO")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("BarFooBaz"), l("*O*")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("BarFooBaz"), l("bar*baz")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("foo"), l("*")).fold(FoldContext.small()));
- assertFalse(insensitiveEquals(l("foo*bar"), l("foo\\*bar")).fold());
- assertFalse(insensitiveEquals(l("foo?"), l("foo\\?")).fold());
- assertFalse(insensitiveEquals(l("foo?bar"), l("foo\\?bar")).fold());
- assertFalse(insensitiveEquals(l(randomAlphaOfLength(10)), l("*")).fold());
- assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("???")).fold());
+ assertFalse(insensitiveEquals(l("foo*bar"), l("foo\\*bar")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("foo?"), l("foo\\?")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("foo?bar"), l("foo\\?bar")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l(randomAlphaOfLength(10)), l("*")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("???")).fold(FoldContext.small()));
- assertFalse(insensitiveEquals(l("foo"), l("bar")).fold());
- assertFalse(insensitiveEquals(l("foo"), l("ba*")).fold());
- assertFalse(insensitiveEquals(l("foo"), l("*a*")).fold());
- assertFalse(insensitiveEquals(l(""), l("bar")).fold());
- assertFalse(insensitiveEquals(l("foo"), l("")).fold());
- assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("??")).fold());
- assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("????")).fold());
+ assertFalse(insensitiveEquals(l("foo"), l("bar")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("foo"), l("ba*")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("foo"), l("*a*")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l(""), l("bar")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l("foo"), l("")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("??")).fold(FoldContext.small()));
+ assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("????")).fold(FoldContext.small()));
- assertNull(insensitiveEquals(l("foo"), Literal.NULL).fold());
- assertNull(insensitiveEquals(Literal.NULL, l("foo")).fold());
- assertNull(insensitiveEquals(Literal.NULL, Literal.NULL).fold());
+ assertNull(insensitiveEquals(l("foo"), Literal.NULL).fold(FoldContext.small()));
+ assertNull(insensitiveEquals(Literal.NULL, l("foo")).fold(FoldContext.small()));
+ assertNull(insensitiveEquals(Literal.NULL, Literal.NULL).fold(FoldContext.small()));
}
public void testProcess() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java
index 0c03556241d28..11cd123c731e8 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java
@@ -20,6 +20,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
@@ -70,6 +71,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForExistingField;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.hamcrest.Matchers.contains;
@@ -93,7 +95,7 @@ public static void init() {
mapping = loadMapping("mapping-basic.json");
EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(test);
- logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+ logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
@@ -161,7 +163,7 @@ public void testMissingFieldInProject() {
assertThat(Expressions.names(eval.fields()), contains("last_name"));
var alias = as(eval.fields().get(0), Alias.class);
var literal = as(alias.child(), Literal.class);
- assertThat(literal.fold(), is(nullValue()));
+ assertThat(literal.value(), is(nullValue()));
assertThat(literal.dataType(), is(DataType.KEYWORD));
var limit = as(eval.child(), Limit.class);
@@ -304,7 +306,7 @@ public void testMissingFieldInEval() {
var alias = as(eval.fields().get(0), Alias.class);
var literal = as(alias.child(), Literal.class);
- assertThat(literal.fold(), is(nullValue()));
+ assertThat(literal.value(), is(nullValue()));
assertThat(literal.dataType(), is(DataType.INTEGER));
var limit = as(eval.child(), Limit.class);
@@ -402,7 +404,7 @@ public void testSparseDocument() throws Exception {
EsIndex index = new EsIndex("large", large, Map.of("large", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(index);
- var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+ var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()),
@@ -411,7 +413,7 @@ public void testSparseDocument() throws Exception {
var analyzed = analyzer.analyze(parser.createStatement(query));
var optimized = logicalOptimizer.optimize(analyzed);
- var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, searchStats);
+ var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small(), searchStats);
var plan = new LocalLogicalPlanOptimizer(localContext).localOptimize(optimized);
var project = as(plan, Project.class);
@@ -423,7 +425,7 @@ public void testSparseDocument() throws Exception {
var eval = as(project.child(), Eval.class);
var field = eval.fields().get(0);
assertThat(Expressions.name(field), is("field005"));
- assertThat(Alias.unwrap(field).fold(), Matchers.nullValue());
+ assertThat(Alias.unwrap(field).fold(FoldContext.small()), Matchers.nullValue());
}
// InferIsNotNull
@@ -561,7 +563,7 @@ private LogicalPlan plan(String query) {
}
private LogicalPlan localPlan(LogicalPlan plan, SearchStats searchStats) {
- var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, searchStats);
+ var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small(), searchStats);
// System.out.println(plan);
var localPlan = new LocalLogicalPlanOptimizer(localContext).localOptimize(plan);
// System.out.println(localPlan);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
index 6dee34323443d..1536ed7f99fec 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
@@ -36,6 +36,7 @@
import org.elasticsearch.xpack.esql.analysis.Verifier;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -90,6 +91,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType;
import static org.hamcrest.Matchers.contains;
@@ -407,7 +409,7 @@ public void testMultiCountAllWithFilter() {
@SuppressWarnings("unchecked")
public void testSingleCountWithStatsFilter() {
// an optimizer that filters out the ExtractAggregateCommonFilter rule
- var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(config)) {
+ var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()) {
@Override
protected List> batches() {
var oldBatches = super.batches();
@@ -486,7 +488,7 @@ public void testQueryStringFunction() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var expected = QueryBuilders.queryStringQuery("last_name: Smith");
assertThat(query.query().toString(), is(expected.toString()));
}
@@ -515,7 +517,7 @@ public void testQueryStringFunctionConjunctionWhereOperands() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(2, 37, "emp_no > 10000");
var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource);
@@ -550,7 +552,7 @@ public void testQueryStringFunctionWithFunctionsPushedToLucene() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(2, 37, "cidr_match(ip, \"127.0.0.1/32\")");
var terms = wrapWithSingleQuery(queryText, QueryBuilders.termsQuery("ip", "127.0.0.1/32"), "ip", filterSource);
@@ -585,7 +587,7 @@ public void testQueryStringFunctionMultipleWhereClauses() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(3, 8, "emp_no > 10000");
var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource);
@@ -618,7 +620,7 @@ public void testQueryStringFunctionMultipleQstrClauses() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var queryStringLeft = QueryBuilders.queryStringQuery("last_name: Smith");
var queryStringRight = QueryBuilders.queryStringQuery("emp_no: [10010 TO *]");
@@ -647,7 +649,7 @@ public void testMatchFunction() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var expected = QueryBuilders.matchQuery("last_name", "Smith").lenient(true);
assertThat(query.query().toString(), is(expected.toString()));
}
@@ -676,7 +678,7 @@ public void testMatchFunctionConjunctionWhereOperands() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(2, 38, "emp_no > 10000");
var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource);
@@ -711,7 +713,7 @@ public void testMatchFunctionWithFunctionsPushedToLucene() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(2, 32, "cidr_match(ip, \"127.0.0.1/32\")");
var terms = wrapWithSingleQuery(queryText, QueryBuilders.termsQuery("ip", "127.0.0.1/32"), "ip", filterSource);
@@ -745,7 +747,7 @@ public void testMatchFunctionMultipleWhereClauses() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(3, 8, "emp_no > 10000");
var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource);
@@ -777,7 +779,7 @@ public void testMatchFunctionMultipleMatchClauses() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var queryStringLeft = QueryBuilders.matchQuery("last_name", "Smith").lenient(true);
var queryStringRight = QueryBuilders.matchQuery("first_name", "John").lenient(true);
@@ -806,7 +808,7 @@ public void testKqlFunction() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var expected = kqlQueryBuilder("last_name: Smith");
assertThat(query.query().toString(), is(expected.toString()));
}
@@ -835,7 +837,7 @@ public void testKqlFunctionConjunctionWhereOperands() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(2, 36, "emp_no > 10000");
var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource);
@@ -870,7 +872,7 @@ public void testKqlFunctionWithFunctionsPushedToLucene() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(2, 36, "cidr_match(ip, \"127.0.0.1/32\")");
var terms = wrapWithSingleQuery(queryText, QueryBuilders.termsQuery("ip", "127.0.0.1/32"), "ip", filterSource);
@@ -905,7 +907,7 @@ public void testKqlFunctionMultipleWhereClauses() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
Source filterSource = new Source(3, 8, "emp_no > 10000");
var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource);
@@ -938,7 +940,7 @@ public void testKqlFunctionMultipleKqlClauses() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var kqlQueryLeft = kqlQueryBuilder("last_name: Smith");
var kqlQueryRight = kqlQueryBuilder("emp_no > 10010");
@@ -1004,7 +1006,7 @@ public void testIsNotNullPushdownFilter() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var expected = QueryBuilders.existsQuery("emp_no");
assertThat(query.query().toString(), is(expected.toString()));
}
@@ -1028,7 +1030,7 @@ public void testIsNullPushdownFilter() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var expected = QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery("emp_no"));
assertThat(query.query().toString(), is(expected.toString()));
}
@@ -1599,7 +1601,7 @@ public void testTermFunction() {
var project = as(exchange.child(), ProjectExec.class);
var field = as(project.child(), FieldExtractExec.class);
var query = as(field.child(), EsQueryExec.class);
- assertThat(query.limit().fold(), is(1000));
+ assertThat(as(query.limit(), Literal.class).value(), is(1000));
var expected = QueryBuilders.termQuery("last_name", "Smith");
assertThat(query.query().toString(), is(expected.toString()));
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
index 4d175dea05071..2aed259e7ad0b 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
@@ -33,6 +33,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@@ -151,6 +152,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze;
@@ -188,6 +190,7 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
private static EsqlParser parser;
private static Analyzer analyzer;
+ private static LogicalOptimizerContext logicalOptimizerCtx;
private static LogicalPlanOptimizer logicalOptimizer;
private static Map mapping;
private static Map mappingAirports;
@@ -203,7 +206,7 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
private static Analyzer metricsAnalyzer;
private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer {
- static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+ static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(unboundLogicalOptimizerContext());
SubstitutionOnlyOptimizer(LogicalOptimizerContext optimizerContext) {
super(optimizerContext);
@@ -218,7 +221,8 @@ protected List> batches() {
@BeforeClass
public static void init() {
parser = new EsqlParser();
- logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+ logicalOptimizerCtx = unboundLogicalOptimizerContext();
+ logicalOptimizer = new LogicalPlanOptimizer(logicalOptimizerCtx);
enrichResolution = new EnrichResolution();
AnalyzerTestUtils.loadEnrichPolicyResolution(enrichResolution, "languages_idx", "id", "languages_idx", "mapping-languages.json");
@@ -325,7 +329,7 @@ public void testEmptyProjectInStatWithEval() {
assertThat(exprs.size(), equalTo(1));
var alias = as(exprs.get(0), Alias.class);
assertThat(alias.name(), equalTo("x"));
- assertThat(alias.child().fold(), equalTo(1));
+ assertThat(alias.child().fold(FoldContext.small()), equalTo(1));
}
/**
@@ -361,11 +365,11 @@ public void testEmptyProjectInStatWithGroupAndEval() {
assertThat(exprs.size(), equalTo(1));
var alias = as(exprs.get(0), Alias.class);
assertThat(alias.name(), equalTo("x"));
- assertThat(alias.child().fold(), equalTo(1));
+ assertThat(alias.child().fold(FoldContext.small()), equalTo(1));
var filterCondition = as(filter.condition(), GreaterThan.class);
assertThat(Expressions.name(filterCondition.left()), equalTo("languages"));
- assertThat(filterCondition.right().fold(), equalTo(1));
+ assertThat(filterCondition.right().fold(FoldContext.small()), equalTo(1));
}
public void testCombineProjections() {
@@ -625,7 +629,7 @@ public void testReplaceStatsFilteredAggWithEvalSingleAggWithExpression() {
assertThat(alias.name(), is("sum(salary) + 1 where false"));
var add = as(alias.child(), Add.class);
var literal = as(add.right(), Literal.class);
- assertThat(literal.fold(), is(1));
+ assertThat(literal.value(), is(1));
var limit = as(eval.child(), Limit.class);
var source = as(limit.child(), LocalRelation.class);
@@ -658,7 +662,7 @@ public void testReplaceStatsFilteredAggWithEvalMixedFilterAndNoFilter() {
var alias = as(eval.fields().getFirst(), Alias.class);
assertTrue(alias.child().foldable());
- assertThat(alias.child().fold(), nullValue());
+ assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(LONG));
alias = as(eval.fields().getLast(), Alias.class);
@@ -695,7 +699,7 @@ public void testReplaceStatsFilteredAggWithEvalFilterFalseAndNull() {
var alias = as(eval.fields().getFirst(), Alias.class);
assertTrue(alias.child().foldable());
- assertThat(alias.child().fold(), nullValue());
+ assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(LONG));
alias = as(eval.fields().get(1), Alias.class);
@@ -703,7 +707,7 @@ public void testReplaceStatsFilteredAggWithEvalFilterFalseAndNull() {
alias = as(eval.fields().getLast(), Alias.class);
assertTrue(alias.child().foldable());
- assertThat(alias.child().fold(), nullValue());
+ assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(LONG));
var limit = as(eval.child(), Limit.class);
@@ -752,7 +756,7 @@ public void testReplaceStatsFilteredAggWithEvalCountDistinctInExpression() {
assertThat(alias.name(), is("count_distinct(salary + 2) + 3 where false"));
var add = as(alias.child(), Add.class);
var literal = as(add.right(), Literal.class);
- assertThat(literal.fold(), is(3));
+ assertThat(literal.value(), is(3));
var limit = as(eval.child(), Limit.class);
var source = as(limit.child(), LocalRelation.class);
@@ -788,13 +792,13 @@ public void testReplaceStatsFilteredAggWithEvalSameAggWithAndWithoutFilter() {
var alias = as(eval.fields().getFirst(), Alias.class);
assertThat(Expressions.name(alias), containsString("max_a"));
assertTrue(alias.child().foldable());
- assertThat(alias.child().fold(), nullValue());
+ assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(INTEGER));
alias = as(eval.fields().getLast(), Alias.class);
assertThat(Expressions.name(alias), containsString("min_a"));
assertTrue(alias.child().foldable());
- assertThat(alias.child().fold(), nullValue());
+ assertThat(alias.child().fold(FoldContext.small()), nullValue());
assertThat(alias.child().dataType(), is(INTEGER));
var limit = as(eval.child(), Limit.class);
@@ -933,7 +937,7 @@ public void testExtractStatsCommonFilterUsingJustOneAlias() {
var gt = as(filter.condition(), GreaterThan.class);
assertThat(Expressions.name(gt.left()), is("emp_no"));
assertTrue(gt.right().foldable());
- assertThat(gt.right().fold(), is(1));
+ assertThat(gt.right().fold(FoldContext.small()), is(1));
var source = as(filter.child(), EsRelation.class);
}
@@ -1053,7 +1057,7 @@ public void testExtractStatsCommonFilterInConjunction() {
var gt = as(filter.condition(), GreaterThan.class); // name is "emp_no > 1 + 1"
assertThat(Expressions.name(gt.left()), is("emp_no"));
assertTrue(gt.right().foldable());
- assertThat(gt.right().fold(), is(2));
+ assertThat(gt.right().fold(FoldContext.small()), is(2));
var source = as(filter.child(), EsRelation.class);
}
@@ -1083,12 +1087,12 @@ public void testExtractStatsCommonFilterInConjunctionWithMultipleCommonConjuncti
var lt = as(and.left(), LessThan.class);
assertThat(Expressions.name(lt.left()), is("emp_no"));
assertTrue(lt.right().foldable());
- assertThat(lt.right().fold(), is(10));
+ assertThat(lt.right().fold(FoldContext.small()), is(10));
var equals = as(and.right(), Equals.class);
assertThat(Expressions.name(equals.left()), is("last_name"));
assertTrue(equals.right().foldable());
- assertThat(equals.right().fold(), is(BytesRefs.toBytesRef("Doe")));
+ assertThat(equals.right().fold(FoldContext.small()), is(BytesRefs.toBytesRef("Doe")));
var source = as(filter.child(), EsRelation.class);
}
@@ -1303,7 +1307,7 @@ public void testCombineLimits() {
var anotherLimit = new Limit(EMPTY, L(limitValues[secondLimit]), oneLimit);
assertEquals(
new Limit(EMPTY, L(Math.min(limitValues[0], limitValues[1])), emptySource()),
- new PushDownAndCombineLimits().rule(anotherLimit)
+ new PushDownAndCombineLimits().rule(anotherLimit, logicalOptimizerCtx)
);
}
@@ -1322,7 +1326,7 @@ public void testPushdownLimitsPastLeftJoin() {
var limit = new Limit(EMPTY, L(10), join);
- var optimizedPlan = new PushDownAndCombineLimits().rule(limit);
+ var optimizedPlan = new PushDownAndCombineLimits().rule(limit, logicalOptimizerCtx);
assertEquals(join.replaceChildren(limit.replaceChild(join.left()), join.right()), optimizedPlan);
}
@@ -1340,10 +1344,7 @@ public void testMultipleCombineLimits() {
var value = i == limitWithMinimum ? minimum : randomIntBetween(100, 1000);
plan = new Limit(EMPTY, L(value), plan);
}
- assertEquals(
- new Limit(EMPTY, L(minimum), relation),
- new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)).optimize(plan)
- );
+ assertEquals(new Limit(EMPTY, L(minimum), relation), logicalOptimizer.optimize(plan));
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/115311")
@@ -1864,7 +1865,7 @@ public void testCopyDefaultLimitPastMvExpand() {
assertThat(mvExpand.limit(), equalTo(1000));
var keep = as(mvExpand.child(), EsqlProject.class);
var limitPastMvExpand = as(keep.child(), Limit.class);
- assertThat(limitPastMvExpand.limit().fold(), equalTo(1000));
+ assertThat(limitPastMvExpand.limit().fold(FoldContext.small()), equalTo(1000));
as(limitPastMvExpand.child(), EsRelation.class);
}
@@ -1887,7 +1888,7 @@ public void testDontPushDownLimitPastMvExpand() {
assertThat(mvExpand.limit(), equalTo(10));
var project = as(mvExpand.child(), EsqlProject.class);
var limit = as(project.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(1));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1));
as(limit.child(), EsRelation.class);
}
@@ -1921,7 +1922,7 @@ public void testMultipleMvExpandWithSortAndLimit() {
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(5));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(5));
assertThat(orderNames(topN), contains("salary"));
var mvExp = as(topN.child(), MvExpand.class);
assertThat(mvExp.limit(), equalTo(5));
@@ -1931,7 +1932,7 @@ public void testMultipleMvExpandWithSortAndLimit() {
mvExp = as(filter.child(), MvExpand.class);
assertThat(mvExp.limit(), equalTo(10));
topN = as(mvExp.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(10));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10));
filter = as(topN.child(), Filter.class);
as(filter.child(), EsRelation.class);
}
@@ -1955,11 +1956,11 @@ public void testPushDownLimitThroughMultipleSort_AfterMvExpand() {
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(5));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(5));
assertThat(orderNames(topN), contains("salary", "first_name"));
var mvExp = as(topN.child(), MvExpand.class);
topN = as(mvExp.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
assertThat(orderNames(topN), contains("emp_no"));
as(topN.child(), EsRelation.class);
}
@@ -1985,14 +1986,14 @@ public void testPushDownLimitThroughMultipleSort_AfterMvExpand2() {
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(5));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(5));
assertThat(orderNames(topN), contains("first_name"));
topN = as(topN.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(5));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(5));
assertThat(orderNames(topN), contains("salary"));
var mvExp = as(topN.child(), MvExpand.class);
topN = as(mvExp.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
assertThat(orderNames(topN), contains("emp_no"));
as(topN.child(), EsRelation.class);
}
@@ -2021,11 +2022,11 @@ public void testDontPushDownLimitPastAggregate_AndMvExpand() {
var limit = as(plan, Limit.class);
var filter = as(limit.child(), Filter.class);
- assertThat(limit.limit().fold(), equalTo(5));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(5));
var agg = as(filter.child(), Aggregate.class);
var mvExp = as(agg.child(), MvExpand.class);
var topN = as(mvExp.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(50));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(50));
assertThat(orderNames(topN), contains("emp_no"));
as(topN.child(), EsRelation.class);
}
@@ -2052,13 +2053,13 @@ public void testPushDown_TheRightLimit_PastMvExpand() {
| limit 5""");
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(5));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(5));
var filter = as(limit.child(), Filter.class);
var agg = as(filter.child(), Aggregate.class);
var mvExp = as(agg.child(), MvExpand.class);
assertThat(mvExp.limit(), equalTo(50));
limit = as(mvExp.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(50));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(50));
as(limit.child(), EsRelation.class);
}
@@ -2083,12 +2084,12 @@ public void testPushDownLimit_PastEvalAndMvExpand() {
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(5));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(5));
assertThat(orderNames(topN), contains("salary"));
var eval = as(topN.child(), Eval.class);
var mvExp = as(eval.child(), MvExpand.class);
topN = as(mvExp.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
assertThat(orderNames(topN), contains("first_name"));
as(topN.child(), EsRelation.class);
}
@@ -2114,7 +2115,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedField_ResultT
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(1000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(1000));
assertThat(orderNames(topN), contains("salary", "first_name"));
var filter = as(topN.child(), Filter.class);
assertThat(filter.condition(), instanceOf(And.class));
@@ -2143,7 +2144,7 @@ public void testFilterWithSortBeforeMvExpand() {
var mvExp = as(plan, MvExpand.class);
assertThat(mvExp.limit(), equalTo(10));
var topN = as(mvExp.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(10));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10));
assertThat(orderNames(topN), contains("emp_no"));
var filter = as(topN.child(), Filter.class);
as(filter.child(), EsRelation.class);
@@ -2168,7 +2169,7 @@ public void testMultiMvExpand_SortDownBelow() {
| sort first_name""");
var topN = as(plan, TopN.class);
- assertThat(topN.limit().fold(), equalTo(1000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(1000));
assertThat(orderNames(topN), contains("first_name"));
var mvExpand = as(topN.child(), MvExpand.class);
var filter = as(mvExpand.child(), Filter.class);
@@ -2200,11 +2201,11 @@ public void testLimitThenSortBeforeMvExpand() {
assertThat(mvExpand.limit(), equalTo(10000));
var project = as(mvExpand.child(), EsqlProject.class);
var topN = as(project.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(7300));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(7300));
assertThat(orderNames(topN), contains("a"));
mvExpand = as(topN.child(), MvExpand.class);
var limit = as(mvExpand.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(7300));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(7300));
as(limit.child(), LocalRelation.class);
}
@@ -2224,7 +2225,7 @@ public void testRemoveUnusedSortBeforeMvExpand_DefaultLimit10000() {
var topN = as(plan, TopN.class);
assertThat(orderNames(topN), contains("first_name"));
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
var mvExpand = as(topN.child(), MvExpand.class);
var topN2 = as(mvExpand.child(), TopN.class); // TODO is it correct? Double-check AddDefaultTopN rule
as(topN2.child(), EsRelation.class);
@@ -2252,7 +2253,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedField() {
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(15));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(15));
assertThat(orderNames(topN), contains("salary", "first_name"));
var filter = as(topN.child(), Filter.class);
assertThat(filter.condition(), instanceOf(And.class));
@@ -2260,7 +2261,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedField() {
topN = as(mvExp.child(), TopN.class);
// the filter acts on first_name (the one used in mv_expand), so the limit 15 is not pushed down past mv_expand
// instead the default limit is added
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
assertThat(orderNames(topN), contains("emp_no"));
as(topN.child(), EsRelation.class);
}
@@ -2287,7 +2288,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilter_NOT_OnExpandedField()
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(15));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(15));
assertThat(orderNames(topN), contains("salary", "first_name"));
var filter = as(topN.child(), Filter.class);
assertThat(filter.condition(), instanceOf(And.class));
@@ -2295,7 +2296,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilter_NOT_OnExpandedField()
topN = as(mvExp.child(), TopN.class);
// the filters after mv_expand do not act on the expanded field values, as such the limit 15 is the one being pushed down
// otherwise that limit wouldn't have pushed down and the default limit was instead being added by default before mv_expanded
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
assertThat(orderNames(topN), contains("emp_no"));
as(topN.child(), EsRelation.class);
}
@@ -2323,14 +2324,14 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedFieldAlias()
var keep = as(plan, EsqlProject.class);
var topN = as(keep.child(), TopN.class);
- assertThat(topN.limit().fold(), equalTo(15));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(15));
assertThat(orderNames(topN), contains("salary", "first_name"));
var filter = as(topN.child(), Filter.class);
assertThat(filter.condition(), instanceOf(And.class));
var mvExp = as(filter.child(), MvExpand.class);
topN = as(mvExp.child(), TopN.class);
// the filter uses an alias ("x") to the expanded field ("first_name"), so the default limit is used and not the one provided
- assertThat(topN.limit().fold(), equalTo(10000));
+ assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000));
assertThat(orderNames(topN), contains("gender"));
as(topN.child(), EsRelation.class);
}
@@ -2369,7 +2370,7 @@ public void testSortMvExpandLimit() {
var expand = as(plan, MvExpand.class);
assertThat(expand.limit(), equalTo(20));
var topN = as(expand.child(), TopN.class);
- assertThat(topN.limit().fold(), is(20));
+ assertThat(topN.limit().fold(FoldContext.small()), is(20));
var row = as(topN.child(), EsRelation.class);
}
@@ -2390,7 +2391,7 @@ public void testWhereMvExpand() {
var expand = as(plan, MvExpand.class);
assertThat(expand.limit(), equalTo(1000));
var limit2 = as(expand.child(), Limit.class);
- assertThat(limit2.limit().fold(), is(1000));
+ assertThat(limit2.limit().fold(FoldContext.small()), is(1000));
var row = as(limit2.child(), LocalRelation.class);
}
@@ -2583,7 +2584,7 @@ public void testSimplifyLikeNoWildcard() {
assertTrue(filter.condition() instanceof Equals);
Equals equals = as(filter.condition(), Equals.class);
- assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold());
+ assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold(FoldContext.small()));
assertTrue(filter.child() instanceof EsRelation);
}
@@ -2609,7 +2610,7 @@ public void testSimplifyRLikeNoWildcard() {
assertTrue(filter.condition() instanceof Equals);
Equals equals = as(filter.condition(), Equals.class);
- assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold());
+ assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold(FoldContext.small()));
assertTrue(filter.child() instanceof EsRelation);
}
@@ -2773,7 +2774,7 @@ public void testEnrich() {
""");
var enrich = as(plan, Enrich.class);
assertTrue(enrich.policyName().resolved());
- assertThat(enrich.policyName().fold(), is(BytesRefs.toBytesRef("languages_idx")));
+ assertThat(enrich.policyName().fold(FoldContext.small()), is(BytesRefs.toBytesRef("languages_idx")));
var eval = as(enrich.child(), Eval.class);
var limit = as(eval.child(), Limit.class);
as(limit.child(), EsRelation.class);
@@ -2819,7 +2820,7 @@ public void testEnrichNotNullFilter() {
var filter = as(limit.child(), Filter.class);
var enrich = as(filter.child(), Enrich.class);
assertTrue(enrich.policyName().resolved());
- assertThat(enrich.policyName().fold(), is(BytesRefs.toBytesRef("languages_idx")));
+ assertThat(enrich.policyName().fold(FoldContext.small()), is(BytesRefs.toBytesRef("languages_idx")));
var eval = as(enrich.child(), Eval.class);
as(eval.child(), EsRelation.class);
}
@@ -2940,7 +2941,7 @@ public void testMedianReplacement() {
var a = as(aggs.get(0), Alias.class);
var per = as(a.child(), Percentile.class);
var literal = as(per.percentile(), Literal.class);
- assertThat((int) QuantileStates.MEDIAN, is(literal.fold()));
+ assertThat((int) QuantileStates.MEDIAN, is(literal.value()));
assertThat(Expressions.names(agg.groupings()), contains("last_name"));
}
@@ -2949,7 +2950,7 @@ public void testSplittingInWithFoldableValue() {
FieldAttribute fa = getFieldAttribute("foo");
In in = new In(EMPTY, ONE, List.of(TWO, THREE, fa, L(null)));
Or expected = new Or(EMPTY, new In(EMPTY, ONE, List.of(TWO, THREE)), new In(EMPTY, ONE, List.of(fa, L(null))));
- assertThat(new SplitInWithFoldableValue().rule(in), equalTo(expected));
+ assertThat(new SplitInWithFoldableValue().rule(in, logicalOptimizerCtx), equalTo(expected));
}
public void testReplaceFilterWithExact() {
@@ -3706,7 +3707,7 @@ private void aggFieldName(Expression exp, Class
var alias = as(exp, Alias.class);
var af = as(alias.child(), aggType);
var field = af.field();
- var name = field.foldable() ? BytesRefs.toString(field.fold()) : Expressions.name(field);
+ var name = field.foldable() ? BytesRefs.toString(field.fold(FoldContext.small())) : Expressions.name(field);
assertThat(name, is(fieldName));
}
@@ -4118,7 +4119,7 @@ public void testNestedExpressionsWithGroupingKeyInAggs() {
var value = Alias.unwrap(fields.get(0));
var math = as(value, Mod.class);
assertThat(Expressions.name(math.left()), is("emp_no"));
- assertThat(math.right().fold(), is(2));
+ assertThat(math.right().fold(FoldContext.small()), is(2));
// languages + emp_no % 2
var add = as(Alias.unwrap(fields.get(1).canonical()), Add.class);
if (add.left() instanceof Mod mod) {
@@ -4127,7 +4128,7 @@ public void testNestedExpressionsWithGroupingKeyInAggs() {
assertThat(Expressions.name(add.left()), is("languages"));
var mod = as(add.right().canonical(), Mod.class);
assertThat(Expressions.name(mod.left()), is("emp_no"));
- assertThat(mod.right().fold(), is(2));
+ assertThat(mod.right().fold(FoldContext.small()), is(2));
}
/**
@@ -4156,7 +4157,7 @@ public void testNestedExpressionsWithMultiGrouping() {
var value = Alias.unwrap(fields.get(0).canonical());
var math = as(value, Mod.class);
assertThat(Expressions.name(math.left()), is("emp_no"));
- assertThat(math.right().fold(), is(2));
+ assertThat(math.right().fold(FoldContext.small()), is(2));
// languages + salary
var add = as(Alias.unwrap(fields.get(1).canonical()), Add.class);
assertThat(Expressions.name(add.left()), anyOf(is("languages"), is("salary")));
@@ -4173,7 +4174,7 @@ public void testNestedExpressionsWithMultiGrouping() {
assertThat(Expressions.name(add3.right()), anyOf(is("salary"), is("languages")));
// emp_no % 2
assertThat(Expressions.name(mod.left()), is("emp_no"));
- assertThat(mod.right().fold(), is(2));
+ assertThat(mod.right().fold(FoldContext.small()), is(2));
}
/**
@@ -4611,8 +4612,8 @@ public void testCountOfLiteral() {
var mvCoalesce = as(mul.left(), Coalesce.class);
assertThat(mvCoalesce.children().size(), equalTo(2));
var mvCount = as(mvCoalesce.children().get(0), MvCount.class);
- assertThat(mvCount.fold(), equalTo(2));
- assertThat(mvCoalesce.children().get(1).fold(), equalTo(0));
+ assertThat(mvCount.fold(FoldContext.small()), equalTo(2));
+ assertThat(mvCoalesce.children().get(1).fold(FoldContext.small()), equalTo(0));
var count = as(mul.right(), ReferenceAttribute.class);
assertThat(count.name(), equalTo("$$COUNT$s$0"));
@@ -4623,8 +4624,8 @@ public void testCountOfLiteral() {
var mvCoalesce_expr = as(mul_expr.left(), Coalesce.class);
assertThat(mvCoalesce_expr.children().size(), equalTo(2));
var mvCount_expr = as(mvCoalesce_expr.children().get(0), MvCount.class);
- assertThat(mvCount_expr.fold(), equalTo(1));
- assertThat(mvCoalesce_expr.children().get(1).fold(), equalTo(0));
+ assertThat(mvCount_expr.fold(FoldContext.small()), equalTo(1));
+ assertThat(mvCoalesce_expr.children().get(1).fold(FoldContext.small()), equalTo(0));
var count_expr = as(mul_expr.right(), ReferenceAttribute.class);
assertThat(count_expr.name(), equalTo("$$COUNT$s$0"));
@@ -4636,7 +4637,7 @@ public void testCountOfLiteral() {
assertThat(mvCoalesce_null.children().size(), equalTo(2));
var mvCount_null = as(mvCoalesce_null.children().get(0), MvCount.class);
assertThat(mvCount_null.field(), equalTo(NULL));
- assertThat(mvCoalesce_null.children().get(1).fold(), equalTo(0));
+ assertThat(mvCoalesce_null.children().get(1).fold(FoldContext.small()), equalTo(0));
var count_null = as(mul_null.right(), ReferenceAttribute.class);
assertThat(count_null.name(), equalTo("$$COUNT$s$0"));
}
@@ -4675,7 +4676,7 @@ public void testSumOfLiteral() {
assertThat(s.name(), equalTo("s"));
var mul = as(s.child(), Mul.class);
var mvSum = as(mul.left(), MvSum.class);
- assertThat(mvSum.fold(), equalTo(3));
+ assertThat(mvSum.fold(FoldContext.small()), equalTo(3));
var count = as(mul.right(), ReferenceAttribute.class);
assertThat(count.name(), equalTo("$$COUNT$s$0"));
@@ -4684,7 +4685,7 @@ public void testSumOfLiteral() {
assertThat(s_expr.name(), equalTo("s_expr"));
var mul_expr = as(s_expr.child(), Mul.class);
var mvSum_expr = as(mul_expr.left(), MvSum.class);
- assertThat(mvSum_expr.fold(), equalTo(3.14));
+ assertThat(mvSum_expr.fold(FoldContext.small()), equalTo(3.14));
var count_expr = as(mul_expr.right(), ReferenceAttribute.class);
assertThat(count_expr.name(), equalTo("$$COUNT$s$0"));
@@ -4833,7 +4834,7 @@ private static void assertAggOfConstExprs(AggOfLiteralTestCase testCase, List x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x);
+ be -> LITERALS_ON_THE_RIGHT.rule(be, logicalOptimizerCtx)
+ ).transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(FoldContext.small()), x.dataType()) : x);
List resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute);
for (Expression field : resolvedFields) {
@@ -5679,8 +5680,7 @@ public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() {
}
private void assertNullLiteral(Expression expression) {
- assertEquals(Literal.class, expression.getClass());
- assertNull(expression.fold());
+ assertNull(as(expression, Literal.class).value());
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519")
@@ -5776,7 +5776,7 @@ public void testReplaceStringCasingWithInsensitiveEqualsEquals() {
var filter = as(limit.child(), Filter.class);
var insensitive = as(filter.condition(), InsensitiveEquals.class);
as(insensitive.left(), FieldAttribute.class);
- var bRef = as(insensitive.right().fold(), BytesRef.class);
+ var bRef = as(insensitive.right().fold(FoldContext.small()), BytesRef.class);
assertThat(bRef.utf8ToString(), is(value));
as(filter.child(), EsRelation.class);
}
@@ -5792,7 +5792,7 @@ public void testReplaceStringCasingWithInsensitiveEqualsNotEquals() {
var not = as(filter.condition(), Not.class);
var insensitive = as(not.field(), InsensitiveEquals.class);
as(insensitive.left(), FieldAttribute.class);
- var bRef = as(insensitive.right().fold(), BytesRef.class);
+ var bRef = as(insensitive.right().fold(FoldContext.small()), BytesRef.class);
assertThat(bRef.utf8ToString(), is(value));
as(filter.child(), EsRelation.class);
}
@@ -5805,7 +5805,7 @@ public void testReplaceStringCasingWithInsensitiveEqualsUnwrap() {
var insensitive = as(filter.condition(), InsensitiveEquals.class);
var field = as(insensitive.left(), FieldAttribute.class);
assertThat(field.fieldName(), is("first_name"));
- var bRef = as(insensitive.right().fold(), BytesRef.class);
+ var bRef = as(insensitive.right().fold(FoldContext.small()), BytesRef.class);
assertThat(bRef.utf8ToString(), is("VALÜ"));
as(filter.child(), EsRelation.class);
}
@@ -5856,7 +5856,7 @@ public void testLookupSimple() {
var left = as(join.left(), EsqlProject.class);
assertThat(left.output().toString(), containsString("int{r}"));
var limit = as(left.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
assertThat(join.config().type(), equalTo(JoinTypes.LEFT));
assertThat(join.config().matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("int{r}")));
@@ -5925,7 +5925,7 @@ public void testLookupStats() {
}
var plan = optimizedPlan(query);
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
var agg = as(limit.child(), Aggregate.class);
assertMap(
@@ -6017,7 +6017,7 @@ public void testLookupJoinPushDownFilterOnJoinKeyWithRename() {
assertThat(join.config().type(), equalTo(JoinTypes.LEFT));
var project = as(join.left(), Project.class);
var limit = as(project.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
var filter = as(limit.child(), Filter.class);
// assert that the rename has been undone
var op = as(filter.condition(), GreaterThan.class);
@@ -6061,7 +6061,7 @@ public void testLookupJoinPushDownFilterOnLeftSideField() {
var project = as(join.left(), Project.class);
var limit = as(project.child(), Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
var filter = as(limit.child(), Filter.class);
var op = as(filter.condition(), GreaterThan.class);
var field = as(op.left(), FieldAttribute.class);
@@ -6100,7 +6100,7 @@ public void testLookupJoinPushDownDisabledForLookupField() {
var plan = optimizedPlan(query);
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
var filter = as(limit.child(), Filter.class);
var op = as(filter.condition(), Equals.class);
@@ -6144,7 +6144,7 @@ public void testLookupJoinPushDownSeparatedForConjunctionBetweenLeftAndRightFiel
var plan = optimizedPlan(query);
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
// filter kept in place, working on the right side
var filter = as(limit.child(), Filter.class);
EsqlBinaryComparison op = as(filter.condition(), Equals.class);
@@ -6195,7 +6195,7 @@ public void testLookupJoinPushDownDisabledForDisjunctionBetweenLeftAndRightField
var plan = optimizedPlan(query);
var limit = as(plan, Limit.class);
- assertThat(limit.limit().fold(), equalTo(1000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000));
var filter = as(limit.child(), Filter.class);
var or = as(filter.condition(), Or.class);
@@ -6289,7 +6289,7 @@ public void testTranslateMixedAggsWithMathWithoutGrouping() {
as(addEval.child(), EsRelation.class);
assertThat(Expressions.attribute(mul.left()).id(), equalTo(finalAggs.aggregates().get(1).id()));
- assertThat(mul.right().fold(), equalTo(1.1));
+ assertThat(mul.right().fold(FoldContext.small()), equalTo(1.1));
assertThat(finalAggs.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD));
Max maxRate = as(Alias.unwrap(finalAggs.aggregates().get(0)), Max.class);
@@ -6304,7 +6304,7 @@ public void testTranslateMixedAggsWithMathWithoutGrouping() {
ToPartial toPartialMaxCost = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), ToPartial.class);
assertThat(Expressions.attribute(toPartialMaxCost.field()).id(), equalTo(addEval.fields().get(0).id()));
assertThat(Expressions.attribute(add.left()).name(), equalTo("network.cost"));
- assertThat(add.right().fold(), equalTo(0.2));
+ assertThat(add.right().fold(FoldContext.small()), equalTo(0.2));
}
public void testTranslateMetricsGroupedByOneDimension() {
@@ -6533,7 +6533,7 @@ METRICS k8s avg(round(1.05 * rate(network.total_bytes_in))) BY bucket(@timestamp
assertThat(Expressions.attribute(finalAgg.groupings().get(1)).id(), equalTo(aggsByTsid.aggregates().get(1).id()));
assertThat(Expressions.attribute(mul.left()).id(), equalTo(aggsByTsid.aggregates().get(0).id()));
- assertThat(mul.right().fold(), equalTo(1.05));
+ assertThat(mul.right().fold(FoldContext.small()), equalTo(1.05));
assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS));
Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class);
assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in"));
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
index ff710a90e8154..66891210a1e47 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
@@ -51,6 +51,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
@@ -164,6 +165,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization;
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze;
@@ -245,7 +247,7 @@ public PhysicalPlanOptimizerTests(String name, Configuration config) {
@Before
public void init() {
parser = new EsqlParser();
- logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+ logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(config));
EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry();
mapper = new Mapper();
@@ -1117,7 +1119,7 @@ public void testLimit() {
var fieldExtract = as(project.child(), FieldExtractExec.class);
var source = source(fieldExtract.child());
assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES));
- assertThat(source.limit().fold(), is(10));
+ assertThat(source.limit().fold(FoldContext.small()), is(10));
}
/**
@@ -1199,7 +1201,7 @@ public void testPushLimitToSource() {
var leaves = extract.collectLeaves();
assertEquals(1, leaves.size());
var source = as(leaves.get(0), EsQueryExec.class);
- assertThat(source.limit().fold(), is(10));
+ assertThat(source.limit().fold(FoldContext.small()), is(10));
// extra ints for doc id and emp_no_10
assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES * 2));
}
@@ -1246,7 +1248,7 @@ public void testPushLimitAndFilterToSource() {
var source = source(extract.child());
assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES * 2));
- assertThat(source.limit().fold(), is(10));
+ assertThat(source.limit().fold(FoldContext.small()), is(10));
var rq = as(sv(source.query(), "emp_no"), RangeQueryBuilder.class);
assertThat(rq.fieldName(), equalTo("emp_no"));
assertThat(rq.from(), equalTo(0));
@@ -2902,14 +2904,14 @@ public void testAvgSurrogateFunctionAfterRenameAndLimit() {
var eval = as(project.child(), EvalExec.class);
var limit = as(eval.child(), LimitExec.class);
assertThat(limit.limit(), instanceOf(Literal.class));
- assertThat(limit.limit().fold(), equalTo(10000));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(10000));
var aggFinal = as(limit.child(), AggregateExec.class);
assertThat(aggFinal.getMode(), equalTo(FINAL));
var aggPartial = as(aggFinal.child(), AggregateExec.class);
assertThat(aggPartial.getMode(), equalTo(INITIAL));
limit = as(aggPartial.child(), LimitExec.class);
assertThat(limit.limit(), instanceOf(Literal.class));
- assertThat(limit.limit().fold(), equalTo(10));
+ assertThat(limit.limit().fold(FoldContext.small()), equalTo(10));
var exchange = as(limit.child(), ExchangeExec.class);
project = as(exchange.child(), ProjectExec.class);
@@ -2918,7 +2920,7 @@ public void testAvgSurrogateFunctionAfterRenameAndLimit() {
var fieldExtract = as(project.child(), FieldExtractExec.class);
assertThat(Expressions.names(fieldExtract.attributesToExtract()), is(expectedFields));
var source = source(fieldExtract.child());
- assertThat(source.limit().fold(), equalTo(10));
+ assertThat(source.limit().fold(FoldContext.small()), equalTo(10));
}
/**
@@ -4849,15 +4851,15 @@ public void testPushSpatialDistanceMultiEvalToSource() {
var alias1 = as(eval.fields().get(0), Alias.class);
assertThat(alias1.name(), is("poi"));
var poi = as(alias1.child(), Literal.class);
- assertThat(poi.fold(), instanceOf(BytesRef.class));
+ assertThat(poi.value(), instanceOf(BytesRef.class));
var alias2 = as(eval.fields().get(1), Alias.class);
assertThat(alias2.name(), is("distance"));
var stDistance = as(alias2.child(), StDistance.class);
var location = as(stDistance.left(), FieldAttribute.class);
assertThat(location.fieldName(), is("location"));
var poiRef = as(stDistance.right(), Literal.class);
- assertThat(poiRef.fold(), instanceOf(BytesRef.class));
- assertThat(poiRef.fold().toString(), is(poi.fold().toString()));
+ assertThat(poiRef.value(), instanceOf(BytesRef.class));
+ assertThat(poiRef.value().toString(), is(poi.value().toString()));
// Validate the filter condition
var and = as(filter.condition(), And.class);
@@ -6205,15 +6207,15 @@ public void testPushCompoundTopNDistanceWithCompoundFilterAndCompoundEvalToSourc
var alias1 = as(evalExec.fields().get(0), Alias.class);
assertThat(alias1.name(), is("poi"));
var poi = as(alias1.child(), Literal.class);
- assertThat(poi.fold(), instanceOf(BytesRef.class));
+ assertThat(poi.value(), instanceOf(BytesRef.class));
var alias2 = as(evalExec.fields().get(1), Alias.class);
assertThat(alias2.name(), is("distance"));
var stDistance = as(alias2.child(), StDistance.class);
var location = as(stDistance.left(), FieldAttribute.class);
assertThat(location.fieldName(), is("location"));
var poiRef = as(stDistance.right(), Literal.class);
- assertThat(poiRef.fold(), instanceOf(BytesRef.class));
- assertThat(poiRef.fold().toString(), is(poi.fold().toString()));
+ assertThat(poiRef.value(), instanceOf(BytesRef.class));
+ assertThat(poiRef.value().toString(), is(poi.value().toString()));
extract = as(evalExec.child(), FieldExtractExec.class);
assertThat(names(extract.attributesToExtract()), contains("location"));
var source = source(extract.child());
@@ -6294,7 +6296,7 @@ public void testPushCompoundTopNDistanceWithDeeplyNestedCompoundEvalToSource() {
var alias1 = as(evalExec.fields().get(0), Alias.class);
assertThat(alias1.name(), is("poi"));
var poi = as(alias1.child(), Literal.class);
- assertThat(poi.fold(), instanceOf(BytesRef.class));
+ assertThat(poi.value(), instanceOf(BytesRef.class));
var alias4 = as(evalExec.fields().get(3), Alias.class);
assertThat(alias4.name(), is("loc2"));
as(alias4.child(), FieldAttribute.class);
@@ -6307,8 +6309,8 @@ public void testPushCompoundTopNDistanceWithDeeplyNestedCompoundEvalToSource() {
var refLocation = as(stDistance.left(), ReferenceAttribute.class);
assertThat(refLocation.name(), is("loc3"));
var poiRef = as(stDistance.right(), Literal.class);
- assertThat(poiRef.fold(), instanceOf(BytesRef.class));
- assertThat(poiRef.fold().toString(), is(poi.fold().toString()));
+ assertThat(poiRef.value(), instanceOf(BytesRef.class));
+ assertThat(poiRef.value().toString(), is(poi.value().toString()));
var alias7 = as(evalExec.fields().get(6), Alias.class);
assertThat(alias7.name(), is("distance"));
as(alias7.child(), ReferenceAttribute.class);
@@ -6391,15 +6393,15 @@ public void testPushCompoundTopNDistanceWithCompoundFilterAndNestedCompoundEvalT
var alias1 = as(evalExec.fields().get(0), Alias.class);
assertThat(alias1.name(), is("poi"));
var poi = as(alias1.child(), Literal.class);
- assertThat(poi.fold(), instanceOf(BytesRef.class));
+ assertThat(poi.value(), instanceOf(BytesRef.class));
var alias2 = as(evalExec.fields().get(1), Alias.class);
assertThat(alias2.name(), is("distance"));
var stDistance = as(alias2.child(), StDistance.class);
var location = as(stDistance.left(), FieldAttribute.class);
assertThat(location.fieldName(), is("location"));
var poiRef = as(stDistance.right(), Literal.class);
- assertThat(poiRef.fold(), instanceOf(BytesRef.class));
- assertThat(poiRef.fold().toString(), is(poi.fold().toString()));
+ assertThat(poiRef.value(), instanceOf(BytesRef.class));
+ assertThat(poiRef.value().toString(), is(poi.value().toString()));
extract = as(evalExec.child(), FieldExtractExec.class);
assertThat(names(extract.attributesToExtract()), contains("location"));
var source = source(extract.child());
@@ -6931,7 +6933,7 @@ public void testManyEnrich() {
var fragment = as(exchange.child(), FragmentExec.class);
var partialTopN = as(fragment.fragment(), TopN.class);
var enrich2 = as(partialTopN.child(), Enrich.class);
- assertThat(BytesRefs.toString(enrich2.policyName().fold()), equalTo("departments"));
+ assertThat(BytesRefs.toString(enrich2.policyName().fold(FoldContext.small())), equalTo("departments"));
assertThat(enrich2.mode(), equalTo(Enrich.Mode.ANY));
var eval = as(enrich2.child(), Eval.class);
as(eval.child(), EsRelation.class);
@@ -6957,7 +6959,7 @@ public void testManyEnrich() {
var fragment = as(exchange.child(), FragmentExec.class);
var partialTopN = as(fragment.fragment(), TopN.class);
var enrich2 = as(partialTopN.child(), Enrich.class);
- assertThat(BytesRefs.toString(enrich2.policyName().fold()), equalTo("departments"));
+ assertThat(BytesRefs.toString(enrich2.policyName().fold(FoldContext.small())), equalTo("departments"));
assertThat(enrich2.mode(), equalTo(Enrich.Mode.ANY));
var eval = as(enrich2.child(), Eval.class);
as(eval.child(), EsRelation.class);
@@ -7549,7 +7551,7 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP
// The TopN needs an estimated row size for the planner to work
var plans = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode(EstimatesRowSize.estimateRowSize(0, plan), config);
plan = useDataNodePlan ? plans.v2() : plans.v1();
- plan = PlannerUtils.localPlan(List.of(), config, plan);
+ plan = PlannerUtils.localPlan(List.of(), config, FoldContext.small(), plan);
LocalExecutionPlanner planner = new LocalExecutionPlanner(
"test",
"",
@@ -7562,10 +7564,10 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP
new ExchangeSinkHandler(null, 10, () -> 10),
null,
null,
- new EsPhysicalOperationProviders(List.of(), null)
+ new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null)
);
- return planner.plan(plan);
+ return planner.plan(FoldContext.small(), plan);
}
private List> findFieldNamesInLookupJoinDescription(LocalExecutionPlanner.LocalExecutionPlan physicalOperations) {
@@ -7662,7 +7664,7 @@ public void testReductionPlanForTopN() {
PhysicalPlan reduction = PlannerUtils.reductionPlan(plans.v2());
TopNExec reductionTopN = as(reduction, TopNExec.class);
assertThat(reductionTopN.estimatedRowSize(), equalTo(allFieldRowSize));
- assertThat(reductionTopN.limit().fold(), equalTo(limit));
+ assertThat(reductionTopN.limit().fold(FoldContext.small()), equalTo(limit));
}
public void testReductionPlanForAggs() {
@@ -7818,7 +7820,7 @@ private PhysicalPlan optimizedPlan(PhysicalPlan plan, SearchStats searchStats) {
// individually hence why here the plan is kept as is
var l = p.transformUp(FragmentExec.class, fragment -> {
- var localPlan = PlannerUtils.localPlan(config, fragment, searchStats);
+ var localPlan = PlannerUtils.localPlan(config, FoldContext.small(), fragment, searchStats);
return EstimatesRowSize.estimateRowSize(fragment.estimatedRowSize(), localPlan);
});
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java
index 9fe479dbb8625..e6a7d110f8c09 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java
@@ -9,6 +9,7 @@
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.analysis.Analyzer;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
@@ -26,7 +27,7 @@ public class TestPlannerOptimizer {
private final Configuration config;
public TestPlannerOptimizer(Configuration config, Analyzer analyzer) {
- this(config, analyzer, new LogicalPlanOptimizer(new LogicalOptimizerContext(config)));
+ this(config, analyzer, new LogicalPlanOptimizer(new LogicalOptimizerContext(config, FoldContext.small())));
}
public TestPlannerOptimizer(Configuration config, Analyzer analyzer, LogicalPlanOptimizer logicalOptimizer) {
@@ -61,8 +62,13 @@ private PhysicalPlan optimizedPlan(PhysicalPlan plan, SearchStats searchStats) {
// this is of no use in the unit tests, which checks the plan as a whole instead of each
// individually hence why here the plan is kept as is
- var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(config, searchStats));
- var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(config, searchStats), true);
+ var logicalTestOptimizer = new LocalLogicalPlanOptimizer(
+ new LocalLogicalOptimizerContext(config, FoldContext.small(), searchStats)
+ );
+ var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(
+ new LocalPhysicalOptimizerContext(config, FoldContext.small(), searchStats),
+ true
+ );
var l = PlannerUtils.localPlan(physicalPlan, logicalTestOptimizer, physicalTestOptimizer);
// handle local reduction alignment
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java
new file mode 100644
index 0000000000000..5d2fec0fc8181
--- /dev/null
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.EqualsHashCodeTestUtils;
+import org.elasticsearch.xpack.esql.ConfigurationTestUtils;
+import org.elasticsearch.xpack.esql.EsqlTestUtils;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
+import org.elasticsearch.xpack.esql.session.Configuration;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class LogicalOptimizerContextTests extends ESTestCase {
+ public void testToString() {
+ // Random looking numbers for FoldContext are indeed random. Just so we have consistent numbers to assert on in toString.
+ LogicalOptimizerContext ctx = new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, new FoldContext(102));
+ ctx.foldCtx().trackAllocation(Source.EMPTY, 99);
+ assertThat(
+ ctx.toString(),
+ equalTo("LogicalOptimizerContext[configuration=" + EsqlTestUtils.TEST_CFG + ", foldCtx=FoldContext[3/102]]")
+ );
+ }
+
+ public void testEqualsAndHashCode() {
+ EqualsHashCodeTestUtils.checkEqualsAndHashCode(randomLogicalOptimizerContext(), this::copy, this::mutate);
+ }
+
+ private LogicalOptimizerContext randomLogicalOptimizerContext() {
+ return new LogicalOptimizerContext(ConfigurationTestUtils.randomConfiguration(), randomFoldContext());
+ }
+
+ private LogicalOptimizerContext copy(LogicalOptimizerContext c) {
+ return new LogicalOptimizerContext(c.configuration(), c.foldCtx());
+ }
+
+ private LogicalOptimizerContext mutate(LogicalOptimizerContext c) {
+ Configuration configuration = c.configuration();
+ FoldContext foldCtx = c.foldCtx();
+ if (randomBoolean()) {
+ configuration = randomValueOtherThan(configuration, ConfigurationTestUtils::randomConfiguration);
+ } else {
+ foldCtx = randomValueOtherThan(foldCtx, this::randomFoldContext);
+ }
+ return new LogicalOptimizerContext(configuration, foldCtx);
+ }
+
+ private FoldContext randomFoldContext() {
+ FoldContext ctx = new FoldContext(randomNonNegativeLong());
+ if (randomBoolean()) {
+ ctx.trackAllocation(Source.EMPTY, randomLongBetween(0, ctx.initialAllowedBytes()));
+ }
+ return ctx;
+ }
+}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java
index 08c8612d8097c..c0c145aee5382 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java
@@ -22,25 +22,26 @@
import static java.util.Arrays.asList;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
public class BooleanFunctionEqualsEliminationTests extends ESTestCase {
+ private Expression booleanFunctionEqualElimination(BinaryComparison e) {
+ return new BooleanFunctionEqualsElimination().rule(e, unboundLogicalOptimizerContext());
+ }
public void testBoolEqualsSimplificationOnExpressions() {
- BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination();
Expression exp = new GreaterThan(EMPTY, getFieldAttribute(), new Literal(EMPTY, 0, DataType.INTEGER), null);
- assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE)));
+ assertEquals(exp, booleanFunctionEqualElimination(new Equals(EMPTY, exp, TRUE)));
// TODO: Replace use of QL Not with ESQL Not
- assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE)));
+ assertEquals(new Not(EMPTY, exp), booleanFunctionEqualElimination(new Equals(EMPTY, exp, FALSE)));
}
public void testBoolEqualsSimplificationOnFields() {
- BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination();
-
FieldAttribute field = getFieldAttribute();
List extends BinaryComparison> comparisons = asList(
@@ -55,7 +56,7 @@ public void testBoolEqualsSimplificationOnFields() {
);
for (BinaryComparison comparison : comparisons) {
- assertEquals(comparison, s.rule(comparison));
+ assertEquals(comparison, booleanFunctionEqualElimination(comparison));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java
index 3b1f8cfc83af3..5b4bf806518de 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java
@@ -9,9 +9,11 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
@@ -20,33 +22,31 @@ public class BooleanSimplificationTests extends ESTestCase {
private static final Expression DUMMY_EXPRESSION =
new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0);
- public void testBoolSimplifyOr() {
- BooleanSimplification simplification = new BooleanSimplification();
+ private Expression booleanSimplification(ScalarFunction e) {
+ return new BooleanSimplification().rule(e, unboundLogicalOptimizerContext());
+ }
- assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, TRUE)));
- assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, DUMMY_EXPRESSION)));
- assertEquals(TRUE, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, TRUE)));
+ public void testBoolSimplifyOr() {
+ assertEquals(TRUE, booleanSimplification(new Or(EMPTY, TRUE, TRUE)));
+ assertEquals(TRUE, booleanSimplification(new Or(EMPTY, TRUE, DUMMY_EXPRESSION)));
+ assertEquals(TRUE, booleanSimplification(new Or(EMPTY, DUMMY_EXPRESSION, TRUE)));
- assertEquals(FALSE, simplification.rule(new Or(EMPTY, FALSE, FALSE)));
- assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, FALSE, DUMMY_EXPRESSION)));
- assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, FALSE)));
+ assertEquals(FALSE, booleanSimplification(new Or(EMPTY, FALSE, FALSE)));
+ assertEquals(DUMMY_EXPRESSION, booleanSimplification(new Or(EMPTY, FALSE, DUMMY_EXPRESSION)));
+ assertEquals(DUMMY_EXPRESSION, booleanSimplification(new Or(EMPTY, DUMMY_EXPRESSION, FALSE)));
}
public void testBoolSimplifyAnd() {
- BooleanSimplification simplification = new BooleanSimplification();
+ assertEquals(TRUE, booleanSimplification(new And(EMPTY, TRUE, TRUE)));
+ assertEquals(DUMMY_EXPRESSION, booleanSimplification(new And(EMPTY, TRUE, DUMMY_EXPRESSION)));
+ assertEquals(DUMMY_EXPRESSION, booleanSimplification(new And(EMPTY, DUMMY_EXPRESSION, TRUE)));
- assertEquals(TRUE, simplification.rule(new And(EMPTY, TRUE, TRUE)));
- assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, TRUE, DUMMY_EXPRESSION)));
- assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, TRUE)));
-
- assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, FALSE)));
- assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, DUMMY_EXPRESSION)));
- assertEquals(FALSE, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, FALSE)));
+ assertEquals(FALSE, booleanSimplification(new And(EMPTY, FALSE, FALSE)));
+ assertEquals(FALSE, booleanSimplification(new And(EMPTY, FALSE, DUMMY_EXPRESSION)));
+ assertEquals(FALSE, booleanSimplification(new And(EMPTY, DUMMY_EXPRESSION, FALSE)));
}
public void testBoolCommonFactorExtraction() {
- BooleanSimplification simplification = new BooleanSimplification();
-
Expression a1 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1);
Expression a2 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1);
Expression b = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 2);
@@ -55,7 +55,7 @@ public void testBoolCommonFactorExtraction() {
Or actual = new Or(EMPTY, new And(EMPTY, a1, b), new And(EMPTY, a2, c));
And expected = new And(EMPTY, a1, new Or(EMPTY, b, c));
- assertEquals(expected, simplification.rule(actual));
+ assertEquals(expected, booleanSimplification(actual));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java
index d388369e0b167..a0d23731ae82d 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java
@@ -12,6 +12,7 @@
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
+import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
@@ -37,6 +38,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
@@ -45,19 +47,17 @@
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
public class CombineBinaryComparisonsTests extends ESTestCase {
-
- private static final Expression DUMMY_EXPRESSION =
- new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0);
+ private Expression combine(BinaryLogic e) {
+ return new CombineBinaryComparisons().rule(e, unboundLogicalOptimizerContext());
+ }
public void testCombineBinaryComparisonsNotComparable() {
FieldAttribute fa = getFieldAttribute();
LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX);
LessThan lt = lessThanOf(fa, FALSE);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
-
And and = new And(EMPTY, lte, lt);
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(exp, and);
}
@@ -67,9 +67,7 @@ public void testCombineBinaryComparisonsUpper() {
LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX);
LessThan lt = lessThanOf(fa, FIVE);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
-
- Expression exp = rule.rule(new And(EMPTY, lte, lt));
+ Expression exp = combine(new And(EMPTY, lte, lt));
assertEquals(LessThan.class, exp.getClass());
LessThan r = (LessThan) exp;
assertEquals(FIVE, r.right());
@@ -81,9 +79,7 @@ public void testCombineBinaryComparisonsLower() {
GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, SIX);
GreaterThan gt = greaterThanOf(fa, FIVE);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
-
- Expression exp = rule.rule(new And(EMPTY, gte, gt));
+ Expression exp = combine(new And(EMPTY, gte, gt));
assertEquals(GreaterThanOrEqual.class, exp.getClass());
GreaterThanOrEqual r = (GreaterThanOrEqual) exp;
assertEquals(SIX, r.right());
@@ -95,9 +91,7 @@ public void testCombineBinaryComparisonsInclude() {
GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, FIVE);
GreaterThan gt = greaterThanOf(fa, FIVE);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
-
- Expression exp = rule.rule(new And(EMPTY, gte, gt));
+ Expression exp = combine(new And(EMPTY, gte, gt));
assertEquals(GreaterThan.class, exp.getClass());
GreaterThan r = (GreaterThan) exp;
assertEquals(FIVE, r.right());
@@ -111,9 +105,7 @@ public void testCombineMultipleBinaryComparisons() {
LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7));
LessThan lt = lessThanOf(fa, SIX);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
-
- Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, gt, new And(EMPTY, lt, lte))));
+ Expression exp = combine(new And(EMPTY, gte, new And(EMPTY, gt, new And(EMPTY, lt, lte))));
assertEquals(And.class, exp.getClass());
And and = (And) exp;
assertEquals(gt, and.left());
@@ -128,10 +120,8 @@ public void testCombineMixedMultipleBinaryComparisons() {
LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7));
Expression ne = notEqualsOf(fa, FIVE);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
-
// TRUE AND a != 5 AND 4 < a <= 7
- Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, TRUE, new And(EMPTY, gt, new And(EMPTY, ne, lte)))));
+ Expression exp = combine(new And(EMPTY, gte, new And(EMPTY, TRUE, new And(EMPTY, gt, new And(EMPTY, ne, lte)))));
assertEquals(And.class, exp.getClass());
And and = ((And) exp);
assertEquals(And.class, and.right().getClass());
@@ -150,8 +140,7 @@ public void testCombineComparisonsIntoRange() {
GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE);
LessThan lt = lessThanOf(fa, FIVE);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(new And(EMPTY, gte, lt));
+ Expression exp = combine(new And(EMPTY, gte, lt));
assertEquals(And.class, exp.getClass());
And and = (And) exp;
@@ -167,8 +156,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndGt3() {
GreaterThan gt = greaterThanOf(fa, THREE);
And and = new And(EMPTY, neq, gt);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(gt, exp);
}
@@ -180,8 +168,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndGte2() {
GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO);
And and = new And(EMPTY, neq, gte);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(GreaterThan.class, exp.getClass());
GreaterThan gt = (GreaterThan) exp;
assertEquals(TWO, gt.right());
@@ -195,8 +182,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndGte1() {
GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE);
And and = new And(EMPTY, neq, gte);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(And.class, exp.getClass()); // can't optimize
}
@@ -208,8 +194,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndLte3() {
LessThanOrEqual lte = lessThanOrEqualOf(fa, THREE);
And and = new And(EMPTY, neq, lte);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(and, exp); // can't optimize
}
@@ -221,8 +206,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndLte2() {
LessThanOrEqual lte = lessThanOrEqualOf(fa, TWO);
And and = new And(EMPTY, neq, lte);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(LessThan.class, exp.getClass());
LessThan lt = (LessThan) exp;
assertEquals(TWO, lt.right());
@@ -236,8 +220,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndLte1() {
LessThanOrEqual lte = lessThanOrEqualOf(fa, ONE);
And and = new And(EMPTY, neq, lte);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals(lte, exp);
}
@@ -251,8 +234,7 @@ public void testCombineBinaryComparisonsDisjunctionNotComparable() {
Or or = new Or(EMPTY, gt1, gt2);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(or);
+ Expression exp = combine(or);
assertEquals(exp, or);
}
@@ -266,8 +248,7 @@ public void testCombineBinaryComparisonsDisjunctionLowerBound() {
Or or = new Or(EMPTY, gt1, new Or(EMPTY, gt2, gt3));
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(or);
+ Expression exp = combine(or);
assertEquals(GreaterThan.class, exp.getClass());
GreaterThan gt = (GreaterThan) exp;
@@ -284,8 +265,7 @@ public void testCombineBinaryComparisonsDisjunctionIncludeLowerBounds() {
Or or = new Or(EMPTY, new Or(EMPTY, gt1, gt2), gte3);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(or);
+ Expression exp = combine(or);
assertEquals(GreaterThan.class, exp.getClass());
GreaterThan gt = (GreaterThan) exp;
@@ -302,8 +282,7 @@ public void testCombineBinaryComparisonsDisjunctionUpperBound() {
Or or = new Or(EMPTY, new Or(EMPTY, lt1, lt2), lt3);
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(or);
+ Expression exp = combine(or);
assertEquals(LessThan.class, exp.getClass());
LessThan lt = (LessThan) exp;
@@ -320,8 +299,7 @@ public void testCombineBinaryComparisonsDisjunctionIncludeUpperBounds() {
Or or = new Or(EMPTY, lt2, new Or(EMPTY, lte2, lt1));
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(or);
+ Expression exp = combine(or);
assertEquals(LessThanOrEqual.class, exp.getClass());
LessThanOrEqual lte = (LessThanOrEqual) exp;
@@ -340,8 +318,7 @@ public void testCombineBinaryComparisonsDisjunctionOfLowerAndUpperBounds() {
Or or = new Or(EMPTY, new Or(EMPTY, lt2, gt3), new Or(EMPTY, lt1, gt4));
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(or);
+ Expression exp = combine(or);
assertEquals(Or.class, exp.getClass());
Or ro = (Or) exp;
@@ -367,7 +344,7 @@ public void testBooleanSimplificationCommonExpressionSubstraction() {
And right = new And(EMPTY, a2, common);
Or or = new Or(EMPTY, left, right);
- Expression exp = new BooleanSimplification().rule(or);
+ Expression exp = new BooleanSimplification().rule(or, unboundLogicalOptimizerContext());
assertEquals(new And(EMPTY, common, new Or(EMPTY, a1, a2)), exp);
}
@@ -391,8 +368,7 @@ public void testBinaryComparisonAndOutOfRangeNotEqualsDifferentFields() {
);
for (And and : testCases) {
- CombineBinaryComparisons rule = new CombineBinaryComparisons();
- Expression exp = rule.rule(and);
+ Expression exp = combine(and);
assertEquals("Rule should not have transformed [" + and.nodeString() + "]", and, exp);
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java
index 043d18dac9fd4..bb5f2fd3505e9 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java
@@ -38,16 +38,24 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.hamcrest.Matchers.contains;
public class CombineDisjunctionsTests extends ESTestCase {
+ private Expression combineDisjunctions(Or e) {
+ return new CombineDisjunctions().rule(e, unboundLogicalOptimizerContext());
+ }
+
+ private LogicalPlan combineDisjunctions(LogicalPlan l) {
+ return new CombineDisjunctions().apply(l, unboundLogicalOptimizerContext());
+ }
public void testTwoEqualsWithOr() {
FieldAttribute fa = getFieldAttribute();
Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO));
- Expression e = new CombineDisjunctions().rule(or);
+ Expression e = combineDisjunctions(or);
assertEquals(In.class, e.getClass());
In in = (In) e;
assertEquals(fa, in.value());
@@ -58,7 +66,7 @@ public void testTwoEqualsWithSameValue() {
FieldAttribute fa = getFieldAttribute();
Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE));
- Expression e = new CombineDisjunctions().rule(or);
+ Expression e = combineDisjunctions(or);
assertEquals(Equals.class, e.getClass());
Equals eq = (Equals) e;
assertEquals(fa, eq.left());
@@ -69,7 +77,7 @@ public void testOneEqualsOneIn() {
FieldAttribute fa = getFieldAttribute();
Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, List.of(TWO)));
- Expression e = new CombineDisjunctions().rule(or);
+ Expression e = combineDisjunctions(or);
assertEquals(In.class, e.getClass());
In in = (In) e;
assertEquals(fa, in.value());
@@ -80,7 +88,7 @@ public void testOneEqualsOneInWithSameValue() {
FieldAttribute fa = getFieldAttribute();
Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO)));
- Expression e = new CombineDisjunctions().rule(or);
+ Expression e = combineDisjunctions(or);
assertEquals(In.class, e.getClass());
In in = (In) e;
assertEquals(fa, in.value());
@@ -92,7 +100,7 @@ public void testSingleValueInToEquals() {
Equals equals = equalsOf(fa, ONE);
Or or = new Or(EMPTY, equals, new In(EMPTY, fa, List.of(ONE)));
- Expression e = new CombineDisjunctions().rule(or);
+ Expression e = combineDisjunctions(or);
assertEquals(equals, e);
}
@@ -101,7 +109,7 @@ public void testEqualsBehindAnd() {
And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO));
Filter dummy = new Filter(EMPTY, relation(), and);
- LogicalPlan transformed = new CombineDisjunctions().apply(dummy);
+ LogicalPlan transformed = combineDisjunctions(dummy);
assertSame(dummy, transformed);
assertEquals(and, ((Filter) transformed).condition());
}
@@ -111,7 +119,7 @@ public void testTwoEqualsDifferentFields() {
FieldAttribute fieldTwo = getFieldAttribute("TWO");
Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO));
- Expression e = new CombineDisjunctions().rule(or);
+ Expression e = combineDisjunctions(or);
assertEquals(or, e);
}
@@ -120,7 +128,7 @@ public void testMultipleIn() {
Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TWO)));
Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE)));
- Expression e = new CombineDisjunctions().rule(secondOr);
+ Expression e = combineDisjunctions(secondOr);
assertEquals(In.class, e.getClass());
In in = (In) e;
assertEquals(fa, in.value());
@@ -132,7 +140,7 @@ public void testOrWithNonCombinableExpressions() {
Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), lessThanOf(fa, TWO));
Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE)));
- Expression e = new CombineDisjunctions().rule(secondOr);
+ Expression e = combineDisjunctions(secondOr);
assertEquals(Or.class, e.getClass());
Or or = (Or) e;
assertEquals(or.left(), firstOr.right());
@@ -160,7 +168,7 @@ public void testCombineCIDRMatch() {
cidrs.add(new CIDRMatch(EMPTY, faa, ipa2));
cidrs.add(new CIDRMatch(EMPTY, fab, ipb2));
Or oldOr = (Or) Predicates.combineOr(cidrs);
- Expression e = new CombineDisjunctions().rule(oldOr);
+ Expression e = combineDisjunctions(oldOr);
assertEquals(Or.class, e.getClass());
Or newOr = (Or) e;
assertEquals(CIDRMatch.class, newOr.left().getClass());
@@ -211,7 +219,7 @@ public void testCombineCIDRMatchEqualsIns() {
Or oldOr = (Or) Predicates.combineOr(all);
- Expression e = new CombineDisjunctions().rule(oldOr);
+ Expression e = combineDisjunctions(oldOr);
assertEquals(Or.class, e.getClass());
Or newOr = (Or) e;
assertEquals(Or.class, newOr.left().getClass());
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java
index 01af91271e1ba..8a8585b8d0ab5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java
@@ -11,6 +11,7 @@
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator;
@@ -45,68 +46,77 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.of;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
public class ConstantFoldingTests extends ESTestCase {
+ private Expression constantFolding(Expression e) {
+ return new ConstantFolding().rule(e, unboundLogicalOptimizerContext());
+ }
public void testConstantFolding() {
Expression exp = new Add(EMPTY, TWO, THREE);
assertTrue(exp.foldable());
- Expression result = new ConstantFolding().rule(exp);
- assertTrue(result instanceof Literal);
- assertEquals(5, ((Literal) result).value());
+ Expression result = constantFolding(exp);
+ assertEquals(5, as(result, Literal.class).value());
// check now with an alias
- result = new ConstantFolding().rule(new Alias(EMPTY, "a", exp));
+ result = constantFolding(new Alias(EMPTY, "a", exp));
assertEquals("a", Expressions.name(result));
assertEquals(Alias.class, result.getClass());
}
public void testConstantFoldingBinaryComparison() {
- assertEquals(FALSE, new ConstantFolding().rule(greaterThanOf(TWO, THREE)).canonical());
- assertEquals(FALSE, new ConstantFolding().rule(greaterThanOrEqualOf(TWO, THREE)).canonical());
- assertEquals(FALSE, new ConstantFolding().rule(equalsOf(TWO, THREE)).canonical());
- assertEquals(TRUE, new ConstantFolding().rule(notEqualsOf(TWO, THREE)).canonical());
- assertEquals(TRUE, new ConstantFolding().rule(lessThanOrEqualOf(TWO, THREE)).canonical());
- assertEquals(TRUE, new ConstantFolding().rule(lessThanOf(TWO, THREE)).canonical());
+ assertEquals(FALSE, constantFolding(greaterThanOf(TWO, THREE)).canonical());
+ assertEquals(FALSE, constantFolding(greaterThanOrEqualOf(TWO, THREE)).canonical());
+ assertEquals(FALSE, constantFolding(equalsOf(TWO, THREE)).canonical());
+ assertEquals(TRUE, constantFolding(notEqualsOf(TWO, THREE)).canonical());
+ assertEquals(TRUE, constantFolding(lessThanOrEqualOf(TWO, THREE)).canonical());
+ assertEquals(TRUE, constantFolding(lessThanOf(TWO, THREE)).canonical());
}
public void testConstantFoldingBinaryLogic() {
- assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical());
- assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical());
+ assertEquals(FALSE, constantFolding(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical());
+ assertEquals(TRUE, constantFolding(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical());
}
public void testConstantFoldingBinaryLogic_WithNullHandling() {
- assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable());
- assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable());
- assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical());
- assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical());
- assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable());
-
- assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical());
- assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical());
- assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable());
- assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable());
- assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable());
+ assertEquals(Nullability.TRUE, constantFolding(new And(EMPTY, NULL, TRUE)).canonical().nullable());
+ assertEquals(Nullability.TRUE, constantFolding(new And(EMPTY, TRUE, NULL)).canonical().nullable());
+ assertEquals(FALSE, constantFolding(new And(EMPTY, NULL, FALSE)).canonical());
+ assertEquals(FALSE, constantFolding(new And(EMPTY, FALSE, NULL)).canonical());
+ assertEquals(Nullability.TRUE, constantFolding(new And(EMPTY, NULL, NULL)).canonical().nullable());
+
+ assertEquals(TRUE, constantFolding(new Or(EMPTY, NULL, TRUE)).canonical());
+ assertEquals(TRUE, constantFolding(new Or(EMPTY, TRUE, NULL)).canonical());
+ assertEquals(Nullability.TRUE, constantFolding(new Or(EMPTY, NULL, FALSE)).canonical().nullable());
+ assertEquals(Nullability.TRUE, constantFolding(new Or(EMPTY, FALSE, NULL)).canonical().nullable());
+ assertEquals(Nullability.TRUE, constantFolding(new Or(EMPTY, NULL, NULL)).canonical().nullable());
}
public void testConstantFoldingRange() {
- assertEquals(true, new ConstantFolding().rule(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold());
- assertEquals(false, new ConstantFolding().rule(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold());
+ assertEquals(
+ true,
+ constantFolding(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold(FoldContext.small())
+ );
+ assertEquals(
+ false,
+ constantFolding(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold(FoldContext.small())
+ );
}
public void testConstantNot() {
- assertEquals(FALSE, new ConstantFolding().rule(new Not(EMPTY, TRUE)));
- assertEquals(TRUE, new ConstantFolding().rule(new Not(EMPTY, FALSE)));
+ assertEquals(FALSE, constantFolding(new Not(EMPTY, TRUE)));
+ assertEquals(TRUE, constantFolding(new Not(EMPTY, FALSE)));
}
public void testConstantFoldingLikes() {
- assertEquals(TRUE, new ConstantFolding().rule(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical());
- assertEquals(TRUE, new ConstantFolding().rule(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical());
+ assertEquals(TRUE, constantFolding(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical());
+ assertEquals(TRUE, constantFolding(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical());
}
public void testArithmeticFolding() {
@@ -125,7 +135,7 @@ public void testFoldRange() {
Expression value = new Literal(EMPTY, 12, DataType.INTEGER);
Range range = new Range(EMPTY, value, lowerBound, randomBoolean(), upperBound, randomBoolean(), randomZone());
- Expression folded = new ConstantFolding().rule(range);
+ Expression folded = constantFolding(range);
assertTrue((Boolean) as(folded, Literal.class).value());
}
@@ -156,16 +166,15 @@ public void testFoldRangeWithInvalidBoundaries() {
// Just applying this to the range directly won't perform a transformDown.
LogicalPlan filter = new Filter(EMPTY, emptySource(), range);
- Filter foldedOnce = as(new ConstantFolding().apply(filter), Filter.class);
+ Filter foldedOnce = as(new ConstantFolding().apply(filter, unboundLogicalOptimizerContext()), Filter.class);
// We need to run the rule twice, because during the first run only the boundaries can be folded - the range doesn't know it's
// foldable, yet.
- Filter foldedTwice = as(new ConstantFolding().apply(foldedOnce), Filter.class);
+ Filter foldedTwice = as(new ConstantFolding().apply(foldedOnce, unboundLogicalOptimizerContext()), Filter.class);
assertFalse((Boolean) as(foldedTwice.condition(), Literal.class).value());
}
- private static Object foldOperator(BinaryOperator, ?, ?, ?> b) {
- return ((Literal) new ConstantFolding().rule(b)).value();
+ private Object foldOperator(BinaryOperator, ?, ?, ?> b) {
+ return ((Literal) constantFolding(b)).value();
}
-
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java
index ae31576184938..252b25a214bb8 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java
@@ -9,6 +9,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
@@ -67,9 +68,11 @@
import java.util.List;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.L;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
@@ -85,28 +88,27 @@
import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION;
public class FoldNullTests extends ESTestCase {
+ private Expression foldNull(Expression e) {
+ return new FoldNull().rule(e, unboundLogicalOptimizerContext());
+ }
public void testBasicNullFolding() {
- FoldNull rule = new FoldNull();
- assertNullLiteral(rule.rule(new Add(EMPTY, L(randomInt()), Literal.NULL)));
- assertNullLiteral(rule.rule(new Round(EMPTY, Literal.NULL, null)));
- assertNullLiteral(rule.rule(new Pow(EMPTY, Literal.NULL, Literal.NULL)));
- assertNullLiteral(rule.rule(new DateFormat(EMPTY, Literal.NULL, Literal.NULL, null)));
- assertNullLiteral(rule.rule(new DateParse(EMPTY, Literal.NULL, Literal.NULL)));
- assertNullLiteral(rule.rule(new DateTrunc(EMPTY, Literal.NULL, Literal.NULL)));
- assertNullLiteral(rule.rule(new Substring(EMPTY, Literal.NULL, Literal.NULL, Literal.NULL)));
+ assertNullLiteral(foldNull(new Add(EMPTY, L(randomInt()), Literal.NULL)));
+ assertNullLiteral(foldNull(new Round(EMPTY, Literal.NULL, null)));
+ assertNullLiteral(foldNull(new Pow(EMPTY, Literal.NULL, Literal.NULL)));
+ assertNullLiteral(foldNull(new DateFormat(EMPTY, Literal.NULL, Literal.NULL, null)));
+ assertNullLiteral(foldNull(new DateParse(EMPTY, Literal.NULL, Literal.NULL)));
+ assertNullLiteral(foldNull(new DateTrunc(EMPTY, Literal.NULL, Literal.NULL)));
+ assertNullLiteral(foldNull(new Substring(EMPTY, Literal.NULL, Literal.NULL, Literal.NULL)));
}
public void testNullFoldingIsNotNull() {
- FoldNull foldNull = new FoldNull();
- assertEquals(true, foldNull.rule(new IsNotNull(EMPTY, TRUE)).fold());
- assertEquals(false, foldNull.rule(new IsNotNull(EMPTY, NULL)).fold());
+ assertEquals(true, foldNull(new IsNotNull(EMPTY, TRUE)).fold(FoldContext.small()));
+ assertEquals(false, foldNull(new IsNotNull(EMPTY, NULL)).fold(FoldContext.small()));
}
@SuppressWarnings("unchecked")
public void testNullFoldingDoesNotApplyOnAbstractMultivalueFunction() throws Exception {
- FoldNull rule = new FoldNull();
-
List> items = List.of(
MvDedupe.class,
MvFirst.class,
@@ -119,119 +121,112 @@ public void testNullFoldingDoesNotApplyOnAbstractMultivalueFunction() throws Exc
for (Class extends AbstractMultivalueFunction> clazz : items) {
Constructor extends AbstractMultivalueFunction> ctor = clazz.getConstructor(Source.class, Expression.class);
AbstractMultivalueFunction conditionalFunction = ctor.newInstance(EMPTY, getFieldAttribute("a"));
- assertEquals(conditionalFunction, rule.rule(conditionalFunction));
+ assertEquals(conditionalFunction, foldNull(conditionalFunction));
conditionalFunction = ctor.newInstance(EMPTY, NULL);
- assertEquals(NULL, rule.rule(conditionalFunction));
+ assertEquals(NULL, foldNull(conditionalFunction));
}
// avg and count ar different just because they know the return type in advance (all the others infer the type from the input)
MvAvg avg = new MvAvg(EMPTY, getFieldAttribute("a"));
- assertEquals(avg, rule.rule(avg));
+ assertEquals(avg, foldNull(avg));
avg = new MvAvg(EMPTY, NULL);
- assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(avg));
+ assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(avg));
MvCount count = new MvCount(EMPTY, getFieldAttribute("a"));
- assertEquals(count, rule.rule(count));
+ assertEquals(count, foldNull(count));
count = new MvCount(EMPTY, NULL);
- assertEquals(new Literal(EMPTY, null, INTEGER), rule.rule(count));
+ assertEquals(new Literal(EMPTY, null, INTEGER), foldNull(count));
}
public void testNullFoldingIsNull() {
- FoldNull foldNull = new FoldNull();
- assertEquals(true, foldNull.rule(new IsNull(EMPTY, NULL)).fold());
- assertEquals(false, foldNull.rule(new IsNull(EMPTY, TRUE)).fold());
+ assertEquals(true, foldNull(new IsNull(EMPTY, NULL)).fold(FoldContext.small()));
+ assertEquals(false, foldNull(new IsNull(EMPTY, TRUE)).fold(FoldContext.small()));
}
public void testGenericNullableExpression() {
FoldNull rule = new FoldNull();
// arithmetic
- assertNullLiteral(rule.rule(new Add(EMPTY, getFieldAttribute("a"), NULL)));
+ assertNullLiteral(foldNull(new Add(EMPTY, getFieldAttribute("a"), NULL)));
// comparison
- assertNullLiteral(rule.rule(greaterThanOf(getFieldAttribute("a"), NULL)));
+ assertNullLiteral(foldNull(greaterThanOf(getFieldAttribute("a"), NULL)));
// regex
- assertNullLiteral(rule.rule(new RLike(EMPTY, NULL, new RLikePattern("123"))));
+ assertNullLiteral(foldNull(new RLike(EMPTY, NULL, new RLikePattern("123"))));
// date functions
- assertNullLiteral(rule.rule(new DateExtract(EMPTY, NULL, NULL, configuration(""))));
+ assertNullLiteral(foldNull(new DateExtract(EMPTY, NULL, NULL, configuration(""))));
// math functions
- assertNullLiteral(rule.rule(new Cos(EMPTY, NULL)));
+ assertNullLiteral(foldNull(new Cos(EMPTY, NULL)));
// string functions
- assertNullLiteral(rule.rule(new LTrim(EMPTY, NULL)));
+ assertNullLiteral(foldNull(new LTrim(EMPTY, NULL)));
// spatial
- assertNullLiteral(rule.rule(new SpatialCentroid(EMPTY, NULL)));
+ assertNullLiteral(foldNull(new SpatialCentroid(EMPTY, NULL)));
// ip
- assertNullLiteral(rule.rule(new CIDRMatch(EMPTY, NULL, List.of(NULL))));
+ assertNullLiteral(foldNull(new CIDRMatch(EMPTY, NULL, List.of(NULL))));
// conversion
- assertNullLiteral(rule.rule(new ToString(EMPTY, NULL)));
+ assertNullLiteral(foldNull(new ToString(EMPTY, NULL)));
}
public void testNullFoldingDoesNotApplyOnLogicalExpressions() {
- FoldNull rule = new FoldNull();
-
Or or = new Or(EMPTY, NULL, TRUE);
- assertEquals(or, rule.rule(or));
+ assertEquals(or, foldNull(or));
or = new Or(EMPTY, NULL, NULL);
- assertEquals(or, rule.rule(or));
+ assertEquals(or, foldNull(or));
And and = new And(EMPTY, NULL, TRUE);
- assertEquals(and, rule.rule(and));
+ assertEquals(and, foldNull(and));
and = new And(EMPTY, NULL, NULL);
- assertEquals(and, rule.rule(and));
+ assertEquals(and, foldNull(and));
}
@SuppressWarnings("unchecked")
public void testNullFoldingDoesNotApplyOnAggregate() throws Exception {
- FoldNull rule = new FoldNull();
-
List> items = List.of(Max.class, Min.class);
for (Class extends AggregateFunction> clazz : items) {
Constructor extends AggregateFunction> ctor = clazz.getConstructor(Source.class, Expression.class);
AggregateFunction conditionalFunction = ctor.newInstance(EMPTY, getFieldAttribute("a"));
- assertEquals(conditionalFunction, rule.rule(conditionalFunction));
+ assertEquals(conditionalFunction, foldNull(conditionalFunction));
conditionalFunction = ctor.newInstance(EMPTY, NULL);
- assertEquals(NULL, rule.rule(conditionalFunction));
+ assertEquals(NULL, foldNull(conditionalFunction));
}
Avg avg = new Avg(EMPTY, getFieldAttribute("a"));
- assertEquals(avg, rule.rule(avg));
+ assertEquals(avg, foldNull(avg));
avg = new Avg(EMPTY, NULL);
- assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(avg));
+ assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(avg));
Count count = new Count(EMPTY, getFieldAttribute("a"));
- assertEquals(count, rule.rule(count));
+ assertEquals(count, foldNull(count));
count = new Count(EMPTY, NULL);
- assertEquals(count, rule.rule(count));
+ assertEquals(count, foldNull(count));
CountDistinct countd = new CountDistinct(EMPTY, getFieldAttribute("a"), getFieldAttribute("a"));
- assertEquals(countd, rule.rule(countd));
+ assertEquals(countd, foldNull(countd));
countd = new CountDistinct(EMPTY, NULL, NULL);
- assertEquals(new Literal(EMPTY, null, LONG), rule.rule(countd));
+ assertEquals(new Literal(EMPTY, null, LONG), foldNull(countd));
Median median = new Median(EMPTY, getFieldAttribute("a"));
- assertEquals(median, rule.rule(median));
+ assertEquals(median, foldNull(median));
median = new Median(EMPTY, NULL);
- assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(median));
+ assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(median));
MedianAbsoluteDeviation medianad = new MedianAbsoluteDeviation(EMPTY, getFieldAttribute("a"));
- assertEquals(medianad, rule.rule(medianad));
+ assertEquals(medianad, foldNull(medianad));
medianad = new MedianAbsoluteDeviation(EMPTY, NULL);
- assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(medianad));
+ assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(medianad));
Percentile percentile = new Percentile(EMPTY, getFieldAttribute("a"), getFieldAttribute("a"));
- assertEquals(percentile, rule.rule(percentile));
+ assertEquals(percentile, foldNull(percentile));
percentile = new Percentile(EMPTY, NULL, NULL);
- assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(percentile));
+ assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(percentile));
Sum sum = new Sum(EMPTY, getFieldAttribute("a"));
- assertEquals(sum, rule.rule(sum));
+ assertEquals(sum, foldNull(sum));
sum = new Sum(EMPTY, NULL);
- assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(sum));
+ assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(sum));
}
public void testNullFoldableDoesNotApplyToIsNullAndNotNull() {
- FoldNull rule = new FoldNull();
-
DataType numericType = randomFrom(INTEGER, LONG, DOUBLE);
DataType genericType = randomFrom(INTEGER, LONG, DOUBLE, UNSIGNED_LONG, KEYWORD, TEXT, GEO_POINT, GEO_SHAPE, VERSION, IP);
List items = List.of(
@@ -260,29 +255,26 @@ public void testNullFoldableDoesNotApplyToIsNullAndNotNull() {
);
for (Expression item : items) {
Expression isNull = new IsNull(EMPTY, item);
- Expression transformed = rule.rule(isNull);
+ Expression transformed = foldNull(isNull);
assertEquals(isNull, transformed);
IsNotNull isNotNull = new IsNotNull(EMPTY, item);
- transformed = rule.rule(isNotNull);
+ transformed = foldNull(isNotNull);
assertEquals(isNotNull, transformed);
}
}
public void testNullBucketGetsFolded() {
- FoldNull foldNull = new FoldNull();
- assertEquals(NULL, foldNull.rule(new Bucket(EMPTY, NULL, NULL, NULL, NULL)));
+ assertEquals(NULL, foldNull(new Bucket(EMPTY, NULL, NULL, NULL, NULL)));
}
public void testNullCategorizeGroupingNotFolded() {
- FoldNull foldNull = new FoldNull();
Categorize categorize = new Categorize(EMPTY, NULL);
- assertEquals(categorize, foldNull.rule(categorize));
+ assertEquals(categorize, foldNull(categorize));
}
private void assertNullLiteral(Expression expression) {
- assertEquals(Literal.class, expression.getClass());
- assertNull(expression.fold());
+ assertNull(as(expression, Literal.class).value());
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java
index 17e69e81444c5..1664e9f4653bb 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java
@@ -15,6 +15,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
@@ -22,7 +23,7 @@ public class LiteralsOnTheRightTests extends ESTestCase {
public void testLiteralsOnTheRight() {
Alias a = new Alias(EMPTY, "a", new Literal(EMPTY, 10, INTEGER));
- Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a));
+ Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a), unboundLogicalOptimizerContext());
assertTrue(result instanceof Equals);
Equals eq = (Equals) result;
assertEquals(a, eq.left());
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java
index 55091653e75d4..a6c0d838b2c21 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java
@@ -14,6 +14,7 @@
import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates;
import org.elasticsearch.xpack.esql.core.expression.predicate.Range;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
+import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
@@ -37,11 +38,15 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
public class PropagateEqualsTests extends ESTestCase {
+ private Expression propagateEquals(BinaryLogic e) {
+ return new PropagateEquals().rule(e, unboundLogicalOptimizerContext());
+ }
// a == 1 AND a == 2 -> FALSE
public void testDualEqualsConjunction() {
@@ -49,8 +54,7 @@ public void testDualEqualsConjunction() {
Equals eq1 = equalsOf(fa, ONE);
Equals eq2 = equalsOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq1, eq2));
+ Expression exp = propagateEquals(new And(EMPTY, eq1, eq2));
assertEquals(FALSE, exp);
}
@@ -60,8 +64,7 @@ public void testEliminateRangeByEqualsOutsideInterval() {
Equals eq1 = equalsOf(fa, new Literal(EMPTY, 10, DataType.INTEGER));
Range r = rangeOf(fa, ONE, false, new Literal(EMPTY, 10, DataType.INTEGER), false);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq1, r));
+ Expression exp = propagateEquals(new And(EMPTY, eq1, r));
assertEquals(FALSE, exp);
}
@@ -71,8 +74,7 @@ public void testPropagateEquals_VarNeq3AndVarEq3() {
NotEquals neq = notEqualsOf(fa, THREE);
Equals eq = equalsOf(fa, THREE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, neq, eq));
+ Expression exp = propagateEquals(new And(EMPTY, neq, eq));
assertEquals(FALSE, exp);
}
@@ -82,8 +84,7 @@ public void testPropagateEquals_VarNeq4AndVarEq3() {
NotEquals neq = notEqualsOf(fa, FOUR);
Equals eq = equalsOf(fa, THREE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, neq, eq));
+ Expression exp = propagateEquals(new And(EMPTY, neq, eq));
assertEquals(Equals.class, exp.getClass());
assertEquals(eq, exp);
}
@@ -94,8 +95,7 @@ public void testPropagateEquals_VarEq2AndVarLt2() {
Equals eq = equalsOf(fa, TWO);
LessThan lt = lessThanOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq, lt));
+ Expression exp = propagateEquals(new And(EMPTY, eq, lt));
assertEquals(FALSE, exp);
}
@@ -105,8 +105,7 @@ public void testPropagateEquals_VarEq2AndVarLte2() {
Equals eq = equalsOf(fa, TWO);
LessThanOrEqual lt = lessThanOrEqualOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq, lt));
+ Expression exp = propagateEquals(new And(EMPTY, eq, lt));
assertEquals(eq, exp);
}
@@ -116,8 +115,7 @@ public void testPropagateEquals_VarEq2AndVarLte1() {
Equals eq = equalsOf(fa, TWO);
LessThanOrEqual lt = lessThanOrEqualOf(fa, ONE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq, lt));
+ Expression exp = propagateEquals(new And(EMPTY, eq, lt));
assertEquals(FALSE, exp);
}
@@ -127,8 +125,7 @@ public void testPropagateEquals_VarEq2AndVarGt2() {
Equals eq = equalsOf(fa, TWO);
GreaterThan gt = greaterThanOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq, gt));
+ Expression exp = propagateEquals(new And(EMPTY, eq, gt));
assertEquals(FALSE, exp);
}
@@ -138,8 +135,7 @@ public void testPropagateEquals_VarEq2AndVarGte2() {
Equals eq = equalsOf(fa, TWO);
GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq, gte));
+ Expression exp = propagateEquals(new And(EMPTY, eq, gte));
assertEquals(eq, exp);
}
@@ -149,8 +145,7 @@ public void testPropagateEquals_VarEq2AndVarLt3() {
Equals eq = equalsOf(fa, TWO);
GreaterThan gt = greaterThanOf(fa, THREE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq, gt));
+ Expression exp = propagateEquals(new And(EMPTY, eq, gt));
assertEquals(FALSE, exp);
}
@@ -162,9 +157,8 @@ public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() {
GreaterThan gt = greaterThanOf(fa, ONE);
NotEquals neq = notEqualsOf(fa, FOUR);
- PropagateEquals rule = new PropagateEquals();
Expression and = Predicates.combineAnd(asList(eq, lt, gt, neq));
- Expression exp = rule.rule((And) and);
+ Expression exp = propagateEquals((And) and);
assertEquals(eq, exp);
}
@@ -176,9 +170,8 @@ public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() {
GreaterThan gt = greaterThanOf(fa, new Literal(EMPTY, 0, DataType.INTEGER));
NotEquals neq = notEqualsOf(fa, FOUR);
- PropagateEquals rule = new PropagateEquals();
Expression and = Predicates.combineAnd(asList(eq, range, gt, neq));
- Expression exp = rule.rule((And) and);
+ Expression exp = propagateEquals((And) and);
assertEquals(eq, exp);
}
@@ -188,8 +181,7 @@ public void testPropagateEquals_VarEq2OrVarGt1() {
Equals eq = equalsOf(fa, TWO);
GreaterThan gt = greaterThanOf(fa, ONE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, gt));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, gt));
assertEquals(gt, exp);
}
@@ -199,8 +191,7 @@ public void testPropagateEquals_VarEq2OrVarGte2() {
Equals eq = equalsOf(fa, TWO);
GreaterThan gt = greaterThanOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, gt));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, gt));
assertEquals(GreaterThanOrEqual.class, exp.getClass());
GreaterThanOrEqual gte = (GreaterThanOrEqual) exp;
assertEquals(TWO, gte.right());
@@ -212,8 +203,7 @@ public void testPropagateEquals_VarEq2OrVarLt3() {
Equals eq = equalsOf(fa, TWO);
LessThan lt = lessThanOf(fa, THREE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, lt));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, lt));
assertEquals(lt, exp);
}
@@ -223,8 +213,7 @@ public void testPropagateEquals_VarEq3OrVarLt3() {
Equals eq = equalsOf(fa, THREE);
LessThan lt = lessThanOf(fa, THREE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, lt));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, lt));
assertEquals(LessThanOrEqual.class, exp.getClass());
LessThanOrEqual lte = (LessThanOrEqual) exp;
assertEquals(THREE, lte.right());
@@ -236,8 +225,7 @@ public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() {
Equals eq = equalsOf(fa, TWO);
Range range = rangeOf(fa, ONE, false, THREE, false);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, range));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, range));
assertEquals(range, exp);
}
@@ -247,8 +235,7 @@ public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() {
Equals eq = equalsOf(fa, TWO);
Range range = rangeOf(fa, TWO, false, THREE, false);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, range));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, range));
assertEquals(Range.class, exp.getClass());
Range r = (Range) exp;
assertEquals(TWO, r.lower());
@@ -263,8 +250,7 @@ public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() {
Equals eq = equalsOf(fa, THREE);
Range range = rangeOf(fa, TWO, false, THREE, false);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, range));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, range));
assertEquals(Range.class, exp.getClass());
Range r = (Range) exp;
assertEquals(TWO, r.lower());
@@ -279,8 +265,7 @@ public void testPropagateEquals_VarEq2OrVarNeq2() {
Equals eq = equalsOf(fa, TWO);
NotEquals neq = notEqualsOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, neq));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, neq));
assertEquals(TRUE, exp);
}
@@ -290,8 +275,7 @@ public void testPropagateEquals_VarEq2OrVarNeq5() {
Equals eq = equalsOf(fa, TWO);
NotEquals neq = notEqualsOf(fa, FIVE);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new Or(EMPTY, eq, neq));
+ Expression exp = propagateEquals(new Or(EMPTY, eq, neq));
assertEquals(NotEquals.class, exp.getClass());
NotEquals ne = (NotEquals) exp;
assertEquals(FIVE, ne.right());
@@ -305,8 +289,7 @@ public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() {
GreaterThan gt = greaterThanOf(fa, TWO);
NotEquals neq = notEqualsOf(fa, TWO);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule((Or) Predicates.combineOr(asList(eq, range, neq, gt)));
+ Expression exp = propagateEquals((Or) Predicates.combineOr(asList(eq, range, neq, gt)));
assertEquals(TRUE, exp);
}
@@ -317,8 +300,7 @@ public void testPropagateEquals_ignoreDateTimeFields() {
Equals eq2 = equalsOf(fa, TWO);
And and = new And(EMPTY, eq1, eq2);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(and);
+ Expression exp = propagateEquals(and);
assertEquals(and, exp);
}
@@ -328,8 +310,7 @@ public void testEliminateRangeByEqualsInInterval() {
Equals eq1 = equalsOf(fa, ONE);
Range r = rangeOf(fa, ONE, true, new Literal(EMPTY, 10, DataType.INTEGER), false);
- PropagateEquals rule = new PropagateEquals();
- Expression exp = rule.rule(new And(EMPTY, eq1, r));
+ Expression exp = propagateEquals(new And(EMPTY, eq1, r));
assertEquals(eq1, exp);
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java
index d1d6a7fbaa208..d35890e5b56bb 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java
@@ -31,11 +31,20 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
public class PropagateNullableTests extends ESTestCase {
+ private Expression propagateNullable(And e) {
+ return new PropagateNullable().rule(e, unboundLogicalOptimizerContext());
+ }
+
+ private LogicalPlan propagateNullable(LogicalPlan p) {
+ return new PropagateNullable().apply(p, unboundLogicalOptimizerContext());
+ }
+
private Literal nullOf(DataType dataType) {
return new Literal(Source.EMPTY, null, dataType);
}
@@ -45,7 +54,7 @@ public void testIsNullAndNotNull() {
FieldAttribute fa = getFieldAttribute();
And and = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa));
- assertEquals(FALSE, new PropagateNullable().rule(and));
+ assertEquals(FALSE, propagateNullable(and));
}
// a IS NULL AND b IS NOT NULL AND c IS NULL AND d IS NOT NULL AND e IS NULL AND a IS NOT NULL => false
@@ -58,7 +67,7 @@ public void testIsNullAndNotNullMultiField() {
And and = new And(EMPTY, andOne, new And(EMPTY, andThree, andTwo));
- assertEquals(FALSE, new PropagateNullable().rule(and));
+ assertEquals(FALSE, propagateNullable(and));
}
// a IS NULL AND a > 1 => a IS NULL AND false
@@ -67,7 +76,7 @@ public void testIsNullAndComparison() {
IsNull isNull = new IsNull(EMPTY, fa);
And and = new And(EMPTY, isNull, greaterThanOf(fa, ONE));
- assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), new PropagateNullable().rule(and));
+ assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), propagateNullable(and));
}
// a IS NULL AND b < 1 AND c < 1 AND a < 1 => a IS NULL AND b < 1 AND c < 1 => a IS NULL AND b < 1 AND c < 1
@@ -79,7 +88,7 @@ public void testIsNullAndMultipleComparison() {
And and = new And(EMPTY, isNull, nestedAnd);
And top = new And(EMPTY, and, lessThanOf(fa, ONE));
- Expression optimized = new PropagateNullable().rule(top);
+ Expression optimized = propagateNullable(top);
Expression expected = new And(EMPTY, and, nullOf(BOOLEAN));
assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized));
}
@@ -97,7 +106,7 @@ public void testIsNullAndDeeplyNestedExpression() {
Expression kept = new And(EMPTY, isNull, lessThanOf(getFieldAttribute("b"), THREE));
And and = new And(EMPTY, nullified, kept);
- Expression optimized = new PropagateNullable().rule(and);
+ Expression optimized = propagateNullable(and);
Expression expected = new And(EMPTY, new And(EMPTY, nullOf(BOOLEAN), nullOf(BOOLEAN)), kept);
assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized));
@@ -110,13 +119,13 @@ public void testIsNullInDisjunction() {
Or or = new Or(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa));
Filter dummy = new Filter(EMPTY, relation(), or);
- LogicalPlan transformed = new PropagateNullable().apply(dummy);
+ LogicalPlan transformed = propagateNullable(dummy);
assertSame(dummy, transformed);
assertEquals(or, ((Filter) transformed).condition());
or = new Or(EMPTY, new IsNull(EMPTY, fa), greaterThanOf(fa, ONE));
dummy = new Filter(EMPTY, relation(), or);
- transformed = new PropagateNullable().apply(dummy);
+ transformed = propagateNullable(dummy);
assertSame(dummy, transformed);
assertEquals(or, ((Filter) transformed).condition());
}
@@ -129,7 +138,7 @@ public void testIsNullDisjunction() {
Or or = new Or(EMPTY, isNull, greaterThanOf(fa, THREE));
And and = new And(EMPTY, new Add(EMPTY, fa, ONE), or);
- assertEquals(and, new PropagateNullable().rule(and));
+ assertEquals(and, propagateNullable(and));
}
public void testDoNotOptimizeIsNullAndMultipleComparisonWithConstants() {
@@ -141,7 +150,7 @@ public void testDoNotOptimizeIsNullAndMultipleComparisonWithConstants() {
And aIsNull_AND_bLT1_AND_cLT1 = new And(EMPTY, aIsNull, bLT1_AND_cLT1);
And aIsNull_AND_bLT1_AND_cLT1_AND_aLT1 = new And(EMPTY, aIsNull_AND_bLT1_AND_cLT1, lessThanOf(a, ONE));
- Expression optimized = new PropagateNullable().rule(aIsNull_AND_bLT1_AND_cLT1_AND_aLT1);
+ Expression optimized = propagateNullable(aIsNull_AND_bLT1_AND_cLT1_AND_aLT1);
Literal nullLiteral = new Literal(EMPTY, null, BOOLEAN);
assertEquals(asList(aIsNull, nullLiteral, nullLiteral, nullLiteral), Predicates.splitAnd(optimized));
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java
index c7206c6971bde..b9ffc39e5e130 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java
@@ -10,8 +10,10 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern;
+import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexMatch;
import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern;
import org.elasticsearch.xpack.esql.core.util.StringUtils;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike;
@@ -20,16 +22,20 @@
import static java.util.Arrays.asList;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
public class ReplaceRegexMatchTests extends ESTestCase {
+ private Expression replaceRegexMatch(RegexMatch> e) {
+ return new ReplaceRegexMatch().rule(e, unboundLogicalOptimizerContext());
+ }
public void testMatchAllWildcardLikeToExist() {
for (String s : asList("*", "**", "***")) {
WildcardPattern pattern = new WildcardPattern(s);
FieldAttribute fa = getFieldAttribute();
WildcardLike l = new WildcardLike(EMPTY, fa, pattern);
- Expression e = new ReplaceRegexMatch().rule(l);
+ Expression e = replaceRegexMatch(l);
assertEquals(IsNotNull.class, e.getClass());
IsNotNull inn = (IsNotNull) e;
assertEquals(fa, inn.field());
@@ -40,7 +46,7 @@ public void testMatchAllRLikeToExist() {
RLikePattern pattern = new RLikePattern(".*");
FieldAttribute fa = getFieldAttribute();
RLike l = new RLike(EMPTY, fa, pattern);
- Expression e = new ReplaceRegexMatch().rule(l);
+ Expression e = replaceRegexMatch(l);
assertEquals(IsNotNull.class, e.getClass());
IsNotNull inn = (IsNotNull) e;
assertEquals(fa, inn.field());
@@ -51,11 +57,11 @@ public void testExactMatchWildcardLike() {
WildcardPattern pattern = new WildcardPattern(s);
FieldAttribute fa = getFieldAttribute();
WildcardLike l = new WildcardLike(EMPTY, fa, pattern);
- Expression e = new ReplaceRegexMatch().rule(l);
+ Expression e = replaceRegexMatch(l);
assertEquals(Equals.class, e.getClass());
Equals eq = (Equals) e;
assertEquals(fa, eq.left());
- assertEquals(s.replace("\\", StringUtils.EMPTY), eq.right().fold());
+ assertEquals(s.replace("\\", StringUtils.EMPTY), eq.right().fold(FoldContext.small()));
}
}
@@ -63,11 +69,11 @@ public void testExactMatchRLike() {
RLikePattern pattern = new RLikePattern("abc");
FieldAttribute fa = getFieldAttribute();
RLike l = new RLike(EMPTY, fa, pattern);
- Expression e = new ReplaceRegexMatch().rule(l);
+ Expression e = replaceRegexMatch(l);
assertEquals(Equals.class, e.getClass());
Equals eq = (Equals) e;
assertEquals(fa, eq.left());
- assertEquals("abc", eq.right().fold());
+ assertEquals("abc", eq.right().fold(FoldContext.small()));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java
index 2429bcb1a1b04..90c8ae1032325 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java
@@ -19,6 +19,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.Nullability;
@@ -417,7 +418,7 @@ private static void assertNoPushdownSort(TestPhysicalPlanBuilder builder, String
private static PhysicalPlan pushTopNToSource(TopNExec topNExec) {
var configuration = EsqlTestUtils.configuration("from test");
- var ctx = new LocalPhysicalOptimizerContext(configuration, SearchStats.EMPTY);
+ var ctx = new LocalPhysicalOptimizerContext(configuration, FoldContext.small(), SearchStats.EMPTY);
var pushTopNToSource = new PushTopNToSource();
return pushTopNToSource.rule(topNExec, ctx);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
index 710637c05a900..85d4017b166fa 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java
@@ -10,6 +10,7 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar;
@@ -617,16 +618,14 @@ public void testSimplifyInWithSingleElementList() {
Equals eq = (Equals) e;
assertThat(eq.left(), instanceOf(UnresolvedAttribute.class));
assertThat(((UnresolvedAttribute) eq.left()).name(), equalTo("a"));
- assertThat(eq.right(), instanceOf(Literal.class));
- assertThat(eq.right().fold(), equalTo(1));
+ assertThat(as(eq.right(), Literal.class).value(), equalTo(1));
e = whereExpression("1 IN (a)");
assertThat(e, instanceOf(Equals.class));
eq = (Equals) e;
assertThat(eq.right(), instanceOf(UnresolvedAttribute.class));
assertThat(((UnresolvedAttribute) eq.right()).name(), equalTo("a"));
- assertThat(eq.left(), instanceOf(Literal.class));
- assertThat(eq.left().fold(), equalTo(1));
+ assertThat(eq.left().fold(FoldContext.small()), equalTo(1));
e = whereExpression("1 NOT IN (a)");
assertThat(e, instanceOf(Not.class));
@@ -635,9 +634,7 @@ public void testSimplifyInWithSingleElementList() {
eq = (Equals) e;
assertThat(eq.right(), instanceOf(UnresolvedAttribute.class));
assertThat(((UnresolvedAttribute) eq.right()).name(), equalTo("a"));
- assertThat(eq.left(), instanceOf(Literal.class));
- assertThat(eq.left().fold(), equalTo(1));
-
+ assertThat(eq.left().fold(FoldContext.small()), equalTo(1));
}
private Expression whereExpression(String e) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
index b83892ea47049..49f03e9b8bc2f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
@@ -16,6 +16,7 @@
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
@@ -1129,51 +1130,51 @@ public void testInputParams() {
assertThat(field.name(), is("x"));
assertThat(field, instanceOf(Alias.class));
Alias alias = (Alias) field;
- assertThat(alias.child().fold(), is(1));
+ assertThat(alias.child().fold(FoldContext.small()), is(1));
field = row.fields().get(1);
assertThat(field.name(), is("y"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is("2"));
+ assertThat(alias.child().fold(FoldContext.small()), is("2"));
field = row.fields().get(2);
assertThat(field.name(), is("a"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is("2 days"));
+ assertThat(alias.child().fold(FoldContext.small()), is("2 days"));
field = row.fields().get(3);
assertThat(field.name(), is("b"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is("4 hours"));
+ assertThat(alias.child().fold(FoldContext.small()), is("4 hours"));
field = row.fields().get(4);
assertThat(field.name(), is("c"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold().getClass(), is(String.class));
- assertThat(alias.child().fold().toString(), is("1.2.3"));
+ assertThat(alias.child().fold(FoldContext.small()).getClass(), is(String.class));
+ assertThat(alias.child().fold(FoldContext.small()).toString(), is("1.2.3"));
field = row.fields().get(5);
assertThat(field.name(), is("d"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold().getClass(), is(String.class));
- assertThat(alias.child().fold().toString(), is("127.0.0.1"));
+ assertThat(alias.child().fold(FoldContext.small()).getClass(), is(String.class));
+ assertThat(alias.child().fold(FoldContext.small()).toString(), is("127.0.0.1"));
field = row.fields().get(6);
assertThat(field.name(), is("e"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is(9));
+ assertThat(alias.child().fold(FoldContext.small()), is(9));
field = row.fields().get(7);
assertThat(field.name(), is("f"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is(11));
+ assertThat(alias.child().fold(FoldContext.small()), is(11));
}
public void testMissingInputParams() {
@@ -1190,13 +1191,13 @@ public void testNamedParams() {
assertThat(field.name(), is("x"));
assertThat(field, instanceOf(Alias.class));
Alias alias = (Alias) field;
- assertThat(alias.child().fold(), is(1));
+ assertThat(alias.child().fold(FoldContext.small()), is(1));
field = row.fields().get(1);
assertThat(field.name(), is("y"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is(1));
+ assertThat(alias.child().fold(FoldContext.small()), is(1));
}
public void testInvalidNamedParams() {
@@ -1237,13 +1238,13 @@ public void testPositionalParams() {
assertThat(field.name(), is("x"));
assertThat(field, instanceOf(Alias.class));
Alias alias = (Alias) field;
- assertThat(alias.child().fold(), is(1));
+ assertThat(alias.child().fold(FoldContext.small()), is(1));
field = row.fields().get(1);
assertThat(field.name(), is("y"));
assertThat(field, instanceOf(Alias.class));
alias = (Alias) field;
- assertThat(alias.child().fold(), is(1));
+ assertThat(alias.child().fold(FoldContext.small()), is(1));
}
public void testInvalidPositionalParams() {
@@ -2057,7 +2058,7 @@ private void assertStringAsLookupIndexPattern(String string, String statement) {
var plan = statement(statement);
var lookup = as(plan, Lookup.class);
var tableName = as(lookup.tableName(), Literal.class);
- assertThat(tableName.fold(), equalTo(string));
+ assertThat(tableName.fold(FoldContext.small()), equalTo(string));
}
public void testIdPatternUnquoted() throws Exception {
@@ -2125,7 +2126,7 @@ public void testLookup() {
var plan = statement(query);
var lookup = as(plan, Lookup.class);
var tableName = as(lookup.tableName(), Literal.class);
- assertThat(tableName.fold(), equalTo("t"));
+ assertThat(tableName.fold(FoldContext.small()), equalTo("t"));
assertThat(lookup.matchFields(), hasSize(1));
var matchField = as(lookup.matchFields().get(0), UnresolvedAttribute.class);
assertThat(matchField.name(), equalTo("j"));
@@ -2306,7 +2307,7 @@ public void testMatchOperatorConstantQueryString() {
var match = (Match) filter.condition();
var matchField = (UnresolvedAttribute) match.field();
assertThat(matchField.name(), equalTo("field"));
- assertThat(match.query().fold(), equalTo("value"));
+ assertThat(match.query().fold(FoldContext.small()), equalTo("value"));
}
public void testInvalidMatchOperator() {
@@ -2341,7 +2342,7 @@ public void testMatchFunctionFieldCasting() {
var toInteger = (ToInteger) function.children().get(0);
var matchField = (UnresolvedAttribute) toInteger.field();
assertThat(matchField.name(), equalTo("field"));
- assertThat(function.children().get(1).fold(), equalTo("value"));
+ assertThat(function.children().get(1).fold(FoldContext.small()), equalTo("value"));
}
public void testMatchOperatorFieldCasting() {
@@ -2351,6 +2352,6 @@ public void testMatchOperatorFieldCasting() {
var toInteger = (ToInteger) match.field();
var matchField = (UnresolvedAttribute) toInteger.field();
assertThat(matchField.name(), equalTo("field"));
- assertThat(match.query().fold(), equalTo("value"));
+ assertThat(match.query().fold(FoldContext.small()), equalTo("value"));
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java
index a254207865ad5..2f47a672a68d0 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java
@@ -11,6 +11,7 @@
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
@@ -42,7 +43,7 @@ public void testTransformWithExpressionTopLevel() throws Exception {
assertEquals(Limit.class, transformed.getClass());
Limit l = (Limit) transformed;
- assertEquals(24, l.limit().fold());
+ assertEquals(24, l.limit().fold(FoldContext.small()));
}
public void testTransformWithExpressionTree() throws Exception {
@@ -53,7 +54,7 @@ public void testTransformWithExpressionTree() throws Exception {
assertEquals(OrderBy.class, transformed.getClass());
OrderBy order = (OrderBy) transformed;
assertEquals(Limit.class, order.child().getClass());
- assertEquals(24, ((Limit) order.child()).limit().fold());
+ assertEquals(24, ((Limit) order.child()).limit().fold(FoldContext.small()));
}
public void testTransformWithExpressionTopLevelInCollection() throws Exception {
@@ -83,12 +84,12 @@ public void testForEachWithExpressionTopLevel() throws Exception {
List list = new ArrayList<>();
project.forEachExpression(Literal.class, l -> {
- if (l.fold().equals(42)) {
- list.add(l.fold());
+ if (l.value().equals(42)) {
+ list.add(l.value());
}
});
- assertEquals(singletonList(one.child().fold()), list);
+ assertEquals(singletonList(one.child().fold(FoldContext.small())), list);
}
public void testForEachWithExpressionTree() throws Exception {
@@ -97,12 +98,12 @@ public void testForEachWithExpressionTree() throws Exception {
List list = new ArrayList<>();
o.forEachExpressionDown(Literal.class, l -> {
- if (l.fold().equals(42)) {
- list.add(l.fold());
+ if (l.value().equals(42)) {
+ list.add(l.value());
}
});
- assertEquals(singletonList(limit.limit().fold()), list);
+ assertEquals(singletonList(limit.limit().fold(FoldContext.small())), list);
}
public void testForEachWithExpressionTopLevelInCollection() throws Exception {
@@ -129,12 +130,12 @@ public void testForEachWithExpressionTreeInCollection() throws Exception {
List list = new ArrayList<>();
project.forEachExpression(Literal.class, l -> {
- if (l.fold().equals(42)) {
- list.add(l.fold());
+ if (l.value().equals(42)) {
+ list.add(l.value());
}
});
- assertEquals(singletonList(one.child().fold()), list);
+ assertEquals(singletonList(one.child().fold(FoldContext.small())), list);
}
public void testPlanExpressions() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java
index 5a7547d011c0f..e2eb05b0c14d3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java
@@ -20,6 +20,7 @@
import org.elasticsearch.xpack.esql.TestBlockFactory;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And;
import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not;
@@ -145,7 +146,7 @@ public void testEvaluatorSuppliers() {
lb.append(LONG);
Layout layout = lb.build();
- var supplier = EvalMapper.toEvaluator(expression, layout);
+ var supplier = EvalMapper.toEvaluator(FoldContext.small(), expression, layout);
EvalOperator.ExpressionEvaluator evaluator1 = supplier.get(driverContext());
EvalOperator.ExpressionEvaluator evaluator2 = supplier.get(driverContext());
assertNotNull(evaluator1);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java
index 55f32d07fc2cb..4191f42f08237 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java
@@ -30,7 +30,6 @@
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
-import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer;
@@ -52,6 +51,7 @@
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization;
import static org.elasticsearch.xpack.esql.core.util.Queries.Clause.FILTER;
@@ -78,7 +78,7 @@ public static void init() {
Map mapping = loadMapping("mapping-basic.json");
EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(test);
- logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG));
+ logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(EsqlTestUtils.TEST_CFG));
mapper = new Mapper();
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java
index 5d8da21c6faad..a1648c67d9bd4 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java
@@ -32,6 +32,7 @@
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.xpack.esql.TestBlockFactory;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
@@ -83,6 +84,7 @@ public void closeIndex() throws IOException {
public void testLuceneSourceOperatorHugeRowSize() throws IOException {
int estimatedRowSize = randomEstimatedRowSize(estimatedRowSizeIsHuge);
LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan(
+ FoldContext.small(),
new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, null, null, estimatedRowSize)
);
assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency()));
@@ -98,6 +100,7 @@ public void testLuceneTopNSourceOperator() throws IOException {
EsQueryExec.FieldSort sort = new EsQueryExec.FieldSort(sortField, Order.OrderDirection.ASC, Order.NullsPosition.LAST);
Literal limit = new Literal(Source.EMPTY, 10, DataType.INTEGER);
LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan(
+ FoldContext.small(),
new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, limit, List.of(sort), estimatedRowSize)
);
assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency()));
@@ -113,6 +116,7 @@ public void testLuceneTopNSourceOperatorDistanceSort() throws IOException {
EsQueryExec.GeoDistanceSort sort = new EsQueryExec.GeoDistanceSort(sortField, Order.OrderDirection.ASC, 1, -1);
Literal limit = new Literal(Source.EMPTY, 10, DataType.INTEGER);
LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan(
+ FoldContext.small(),
new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, limit, List.of(sort), estimatedRowSize)
);
assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency()));
@@ -187,7 +191,7 @@ private EsPhysicalOperationProviders esPhysicalOperationProviders() throws IOExc
);
}
releasables.add(searcher);
- return new EsPhysicalOperationProviders(shardContexts, null);
+ return new EsPhysicalOperationProviders(FoldContext.small(), shardContexts, null);
}
private IndexReader reader() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java
index 01dd4db123ee2..628737aa36c6c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java
@@ -41,6 +41,7 @@
import org.elasticsearch.xpack.esql.TestBlockFactory;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField;
import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes;
@@ -71,13 +72,13 @@
public class TestPhysicalOperationProviders extends AbstractPhysicalOperationProviders {
private final List indexPages;
- private TestPhysicalOperationProviders(List indexPages, AnalysisRegistry analysisRegistry) {
- super(analysisRegistry);
+ private TestPhysicalOperationProviders(FoldContext foldContext, List indexPages, AnalysisRegistry analysisRegistry) {
+ super(foldContext, analysisRegistry);
this.indexPages = indexPages;
}
- public static TestPhysicalOperationProviders create(List indexPages) throws IOException {
- return new TestPhysicalOperationProviders(indexPages, createAnalysisRegistry());
+ public static TestPhysicalOperationProviders create(FoldContext foldContext, List indexPages) throws IOException {
+ return new TestPhysicalOperationProviders(foldContext, indexPages, createAnalysisRegistry());
}
public record IndexPage(String index, Page page, List columnNames) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java
index f2a619f0dbd89..f3b1d84e507a5 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java
@@ -26,7 +26,6 @@
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
-import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -39,10 +38,10 @@
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration;
import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables;
-import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
+import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
import static org.hamcrest.Matchers.equalTo;
@@ -187,7 +186,7 @@ static LogicalPlan parse(String query) {
Map mapping = loadMapping("mapping-basic.json");
EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(test);
- var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG));
+ var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext());
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
TEST_VERIFIER
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java
index 2cc733c2ea2e3..fac3495697da8 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java
@@ -21,6 +21,7 @@
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.analysis.Analyzer;
import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
import org.elasticsearch.xpack.esql.index.EsIndex;
@@ -289,7 +290,7 @@ static LogicalPlan parse(String query) {
Map mapping = loadMapping("mapping-basic.json");
EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD));
IndexResolution getIndexResult = IndexResolution.valid(test);
- var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG));
+ var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG, FoldContext.small()));
var analyzer = new Analyzer(
new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()),
TEST_VERIFIER
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java
index 539cd0314a4d1..a3c5cd9168b4f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java
@@ -28,6 +28,7 @@
import org.elasticsearch.xpack.esql.action.EsqlQueryRequest;
import org.elasticsearch.xpack.esql.action.EsqlResolveFieldsAction;
import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver;
import org.elasticsearch.xpack.esql.execution.PlanExecutor;
import org.elasticsearch.xpack.esql.session.EsqlSession;
@@ -119,6 +120,7 @@ public void testFailedMetric() {
request,
randomAlphaOfLength(10),
EsqlTestUtils.TEST_CFG,
+ FoldContext.small(),
enrichResolver,
new EsqlExecutionInfo(randomBoolean()),
groupIndicesByCluster,
@@ -149,6 +151,7 @@ public void onFailure(Exception e) {
request,
randomAlphaOfLength(10),
EsqlTestUtils.TEST_CFG,
+ FoldContext.small(),
enrichResolver,
new EsqlExecutionInfo(randomBoolean()),
groupIndicesByCluster,
From d6133235a339eb99a146a6ff3ad6958fe0d63cf6 Mon Sep 17 00:00:00 2001
From: Craig Taverner
Date: Mon, 13 Jan 2025 16:12:34 +0100
Subject: [PATCH 12/44] Add index name validation rule for empty index names
(#119960)
* Add index name validation rule for empty index names
* Created negative tests for ES|QL index names with only exclusion character
* Added test for `*-`
---
.../cluster/metadata/MetadataCreateIndexService.java | 3 +++
.../cluster/metadata/MetadataCreateIndexServiceTests.java | 3 +++
.../xpack/esql/parser/AbstractStatementParserTests.java | 8 ++++++--
.../xpack/esql/parser/StatementParserTests.java | 3 +++
4 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java
index 7fc8a8693dcdc..fb19a3f04fcec 100644
--- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java
+++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexService.java
@@ -237,6 +237,9 @@ public SystemIndices getSystemIndices() {
* Validate the name for an index or alias against some static rules.
*/
public static void validateIndexOrAliasName(String index, BiFunction exceptionCtor) {
+ if (index == null || index.isEmpty()) {
+ throw exceptionCtor.apply(index, "must not be empty");
+ }
if (Strings.validFileName(index) == false) {
throw exceptionCtor.apply(index, "must not contain the following characters " + Strings.INVALID_FILENAME_CHARS);
}
diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
index 3623683532c59..3ed74392f746e 100644
--- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
@@ -561,6 +561,9 @@ public void testValidateIndexName() throws Exception {
validateIndexName(checkerService, "..", "must not be '.' or '..'");
validateIndexName(checkerService, "foo:bar", "must not contain ':'");
+
+ validateIndexName(checkerService, "", "must not be empty");
+ validateIndexName(checkerService, null, "must not be empty");
}));
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java
index e6fef186721a0..99a04b6ed8f10 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java
@@ -151,8 +151,12 @@ void expectInvalidIndexNameErrorWithLineNumber(String query, String indexString,
expectInvalidIndexNameErrorWithLineNumber(query, "\"" + indexString + "\"", lineNumber, indexString);
}
- void expectInvalidIndexNameErrorWithLineNumber(String query, String indexString, String lineNumber, String error) {
- expectError(LoggerMessageFormat.format(null, query, indexString), lineNumber + "Invalid index name [" + error);
+ void expectInvalidIndexNameErrorWithLineNumber(String query, String indexString, String lineNumber, String name) {
+ expectError(LoggerMessageFormat.format(null, query, indexString), lineNumber + "Invalid index name [" + name);
+ }
+
+ void expectInvalidIndexNameErrorWithLineNumber(String query, String indexString, String lineNumber, String name, String error) {
+ expectError(LoggerMessageFormat.format(null, query, indexString), lineNumber + "Invalid index name [" + name + "], " + error);
}
void expectDateMathErrorWithLineNumber(String query, String arg, String lineNumber, String error) {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
index 49f03e9b8bc2f..a4712ae77b5d8 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java
@@ -578,6 +578,9 @@ public void testInvalidCharacterInIndexPattern() {
expectInvalidIndexNameErrorWithLineNumber(command, "indexpattern, --indexpattern", lineNumber, "-indexpattern");
expectInvalidIndexNameErrorWithLineNumber(command, "indexpattern, \"--indexpattern\"", lineNumber, "-indexpattern");
expectInvalidIndexNameErrorWithLineNumber(command, "\"indexpattern, --indexpattern\"", commands.get(command), "-indexpattern");
+ expectInvalidIndexNameErrorWithLineNumber(command, "\"- , -\"", commands.get(command), "", "must not be empty");
+ expectInvalidIndexNameErrorWithLineNumber(command, "\"indexpattern,-\"", commands.get(command), "", "must not be empty");
+ clustersAndIndices(command, "indexpattern", "*-");
clustersAndIndices(command, "indexpattern", "-indexpattern");
}
From a8d2680cc203b79a583b18ec9df4c49b7c219dd2 Mon Sep 17 00:00:00 2001
From: Ioana Tagirta
Date: Mon, 13 Jan 2025 16:13:21 +0100
Subject: [PATCH 13/44] Unmute LocalPhysicalPlanOptimizerTests (#119884)
* Unmute tests
* Mute tests
---
muted-tests.yml | 9 ---------
1 file changed, 9 deletions(-)
diff --git a/muted-tests.yml b/muted-tests.yml
index d9fcebb7f2916..cecff79b002a2 100644
--- a/muted-tests.yml
+++ b/muted-tests.yml
@@ -227,15 +227,6 @@ tests:
- class: org.elasticsearch.search.profile.dfs.DfsProfilerIT
method: testProfileDfs
issue: https://github.com/elastic/elasticsearch/issues/119711
-- class: org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizerTests
- method: testSingleMatchFunctionFilterPushdownWithStringValues {default}
- issue: https://github.com/elastic/elasticsearch/issues/119720
-- class: org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizerTests
- method: testSingleMatchFunctionPushdownWithCasting {default}
- issue: https://github.com/elastic/elasticsearch/issues/119722
-- class: org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizerTests
- method: testSingleMatchOperatorFilterPushdownWithStringValues {default}
- issue: https://github.com/elastic/elasticsearch/issues/119721
- class: org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilterIT
method: testBulkOperations {p0=false}
issue: https://github.com/elastic/elasticsearch/issues/119901
From 53773cbf2404e8b426e0c78f062b3463e604f642 Mon Sep 17 00:00:00 2001
From: Svilen Mihaylov
Date: Mon, 13 Jan 2025 10:34:29 -0500
Subject: [PATCH 14/44] Fix for issue 119723 (#120001)
Adjust random parameter generator to exclude newly valid value (-1)
Closes #119723
---
.../search/fetch/subphase/highlight/HighlightBuilderTests.java | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilderTests.java b/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilderTests.java
index 0f73c367ff2ef..1bcc89b68c141 100644
--- a/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilderTests.java
+++ b/server/src/test/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilderTests.java
@@ -574,11 +574,10 @@ public void testPreTagsWithoutPostTags() throws IOException {
assertEquals("pre_tags are set but post_tags are not set", e.getCause().getCause().getMessage());
}
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/119723")
public void testInvalidMaxAnalyzedOffset() throws IOException {
XContentParseException e = expectParseThrows(
XContentParseException.class,
- "{ \"max_analyzed_offset\" : " + randomIntBetween(-100, -1) + "}"
+ "{ \"max_analyzed_offset\" : " + randomIntBetween(-100, -2) + "}"
);
assertThat(e.getMessage(), containsString("[highlight] failed to parse field [" + MAX_ANALYZED_OFFSET_FIELD.toString() + "]"));
assertThat(e.getCause().getMessage(), containsString("[max_analyzed_offset] must be a positive integer, or -1"));
From de0e7c9d80488e491b1d32927d87aa4de9111de5 Mon Sep 17 00:00:00 2001
From: Luke Whiting
Date: Mon, 13 Jan 2025 16:01:50 +0000
Subject: [PATCH 15/44] Match dot prefix of migrated DS backing index with the
source index (#120042)
* Match dot prefix of migrated index with the source index
* Update docs/changelog/120042.yaml
---
docs/changelog/120042.yaml | 5 +++
...indexDatastreamIndexTransportActionIT.java | 44 ++++++++++++++++++-
...ReindexDataStreamIndexTransportAction.java | 8 +++-
3 files changed, 54 insertions(+), 3 deletions(-)
create mode 100644 docs/changelog/120042.yaml
diff --git a/docs/changelog/120042.yaml b/docs/changelog/120042.yaml
new file mode 100644
index 0000000000000..0093068ae9894
--- /dev/null
+++ b/docs/changelog/120042.yaml
@@ -0,0 +1,5 @@
+pr: 120042
+summary: Match dot prefix of migrated DS backing index with the source index
+area: Data streams
+type: bug
+issues: []
diff --git a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDatastreamIndexTransportActionIT.java b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDatastreamIndexTransportActionIT.java
index 0902f6ce6468a..46af8ab2fb4c2 100644
--- a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDatastreamIndexTransportActionIT.java
+++ b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDatastreamIndexTransportActionIT.java
@@ -96,7 +96,7 @@ public void testDestIndexDeletedIfExists() throws Exception {
assertHitCount(prepareSearch(destIndex).setSize(0), 0);
}
- public void testDestIndexNameSet() throws Exception {
+ public void testDestIndexNameSet_noDotPrefix() throws Exception {
assumeTrue("requires the migration reindex feature flag", REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled());
var sourceIndex = randomAlphaOfLength(20).toLowerCase(Locale.ROOT);
@@ -110,6 +110,20 @@ public void testDestIndexNameSet() throws Exception {
assertEquals(expectedDestIndexName, response.getDestIndex());
}
+ public void testDestIndexNameSet_withDotPrefix() throws Exception {
+ assumeTrue("requires the migration reindex feature flag", REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled());
+
+ var sourceIndex = "." + randomAlphaOfLength(20).toLowerCase(Locale.ROOT);
+ indicesAdmin().create(new CreateIndexRequest(sourceIndex)).get();
+
+ // call reindex
+ var response = client().execute(ReindexDataStreamIndexAction.INSTANCE, new ReindexDataStreamIndexAction.Request(sourceIndex))
+ .actionGet();
+
+ var expectedDestIndexName = ReindexDataStreamIndexTransportAction.generateDestIndexName(sourceIndex);
+ assertEquals(expectedDestIndexName, response.getDestIndex());
+ }
+
public void testDestIndexContainsDocs() throws Exception {
assumeTrue("requires the migration reindex feature flag", REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled());
@@ -446,4 +460,32 @@ private static String getIndexUUID(String index) {
.get(index)
.get(IndexMetadata.SETTING_INDEX_UUID);
}
+
+ public void testGenerateDestIndexName_noDotPrefix() {
+ String sourceIndex = "sourceindex";
+ String expectedDestIndex = "migrated-sourceindex";
+ String actualDestIndex = ReindexDataStreamIndexTransportAction.generateDestIndexName(sourceIndex);
+ assertEquals(expectedDestIndex, actualDestIndex);
+ }
+
+ public void testGenerateDestIndexName_withDotPrefix() {
+ String sourceIndex = ".sourceindex";
+ String expectedDestIndex = ".migrated-sourceindex";
+ String actualDestIndex = ReindexDataStreamIndexTransportAction.generateDestIndexName(sourceIndex);
+ assertEquals(expectedDestIndex, actualDestIndex);
+ }
+
+ public void testGenerateDestIndexName_withHyphen() {
+ String sourceIndex = "source-index";
+ String expectedDestIndex = "migrated-source-index";
+ String actualDestIndex = ReindexDataStreamIndexTransportAction.generateDestIndexName(sourceIndex);
+ assertEquals(expectedDestIndex, actualDestIndex);
+ }
+
+ public void testGenerateDestIndexName_withUnderscore() {
+ String sourceIndex = "source_index";
+ String expectedDestIndex = "migrated-source_index";
+ String actualDestIndex = ReindexDataStreamIndexTransportAction.generateDestIndexName(sourceIndex);
+ assertEquals(expectedDestIndex, actualDestIndex);
+ }
}
diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamIndexTransportAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamIndexTransportAction.java
index d86885ce0fbe4..ff350429dae01 100644
--- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamIndexTransportAction.java
+++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamIndexTransportAction.java
@@ -221,8 +221,12 @@ private static void copySettingOrUnset(Settings settingsBefore, Settings.Builder
}
}
- public static String generateDestIndexName(String sourceIndex) {
- return "migrated-" + sourceIndex;
+ static String generateDestIndexName(String sourceIndex) {
+ String prefix = "migrated-";
+ if (sourceIndex.startsWith(".")) {
+ return "." + prefix + sourceIndex.substring(1);
+ }
+ return prefix + sourceIndex;
}
private static ActionListener failIfNotAcknowledged(
From d79edcbec8058b957c851541fa74e8bb3e3fa46a Mon Sep 17 00:00:00 2001
From: Dan Rubinstein
Date: Mon, 13 Jan 2025 11:38:26 -0500
Subject: [PATCH 16/44] Add enterprise license check for Inference API actions
(#119893)
* Add enterprise license check for Inference API actions
* Update docs/changelog/119893.yaml
* Adding missing plugin to ModelRegistryIT and removing license check from get inference services API
* Fix tests
* Fix basic license test
* Removing unused feature flag from InferenceUpgradeTestCase
---------
Co-authored-by: Elastic Machine
---
docs/changelog/119893.yaml | 5 ++
.../inference/InferenceBasicLicenseIT.java | 84 +++++++++++++++++++
.../InferenceLicenseBaseRestTest.java | 43 ++++++++++
.../inference/InferenceTrialLicenseIT.java | 84 +++++++++++++++++++
.../application/InferenceUpgradeTestCase.java | 21 ++++-
.../ShardBulkInferenceActionFilterIT.java | 3 +-
.../integration/ModelRegistryIT.java | 3 +-
.../xpack/inference/InferencePlugin.java | 8 ++
.../action/BaseTransportInferenceAction.java | 12 +++
.../action/TransportInferenceAction.java | 3 +
.../TransportPutInferenceModelAction.java | 12 +++
...sportUnifiedCompletionInferenceAction.java | 3 +
.../TransportUpdateInferenceModelAction.java | 12 +++
.../SemanticTextClusterMetadataTests.java | 3 +-
.../BaseTransportInferenceActionTestCase.java | 21 ++++-
.../action/TransportInferenceActionTests.java | 3 +
...TransportUnifiedCompletionActionTests.java | 3 +
...emanticTextNonDynamicFieldMapperTests.java | 3 +-
.../TextSimilarityRankMultiNodeTests.java | 3 +-
.../TextSimilarityRankTests.java | 3 +-
20 files changed, 323 insertions(+), 9 deletions(-)
create mode 100644 docs/changelog/119893.yaml
create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java
create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceLicenseBaseRestTest.java
create mode 100644 x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceTrialLicenseIT.java
diff --git a/docs/changelog/119893.yaml b/docs/changelog/119893.yaml
new file mode 100644
index 0000000000000..35a46ce0940d3
--- /dev/null
+++ b/docs/changelog/119893.yaml
@@ -0,0 +1,5 @@
+pr: 119893
+summary: Add enterprise license check for Inference API actions
+area: Machine Learning
+type: enhancement
+issues: []
diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java
new file mode 100644
index 0000000000000..4400ad8bbb538
--- /dev/null
+++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.cluster.local.distribution.DistributionType;
+import org.junit.ClassRule;
+
+import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.mockSparseServiceModelConfig;
+
+public class InferenceBasicLicenseIT extends InferenceLicenseBaseRestTest {
+ @ClassRule
+ public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
+ .distribution(DistributionType.DEFAULT)
+ .setting("xpack.license.self_generated.type", "basic")
+ .setting("xpack.security.enabled", "true")
+ .user("x_pack_rest_user", "x-pack-test-password")
+ .plugin("inference-service-test")
+ .build();
+
+ @Override
+ protected String getTestRestCluster() {
+ return cluster.getHttpAddresses();
+ }
+
+ @Override
+ protected Settings restClientSettings() {
+ String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
+ return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
+ }
+
+ public void testPutModel_RestrictedWithBasicLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ var modelConfig = mockSparseServiceModelConfig(null, true);
+ sendRestrictedRequest("PUT", endpoint, modelConfig);
+ }
+
+ public void testUpdateModel_RestrictedWithBasicLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ var requestBody = """
+ {
+ "task_settings": {
+ "num_threads": 2
+ }
+ }
+ """;
+ sendRestrictedRequest("PUT", endpoint, requestBody);
+ }
+
+ public void testPerformInference_RestrictedWithBasicLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ var requestBody = """
+ {
+ "input": ["washing", "machine"]
+ }
+ """;
+ sendRestrictedRequest("POST", endpoint, requestBody);
+ }
+
+ public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
+ var endpoint = "_inference/_services";
+ sendNonRestrictedRequest("GET", endpoint, null, 200, false);
+ }
+
+ public void testGetModels_NonRestrictedWithBasicLicense() throws Exception {
+ var endpoint = "_inference/_all";
+ sendNonRestrictedRequest("GET", endpoint, null, 200, false);
+ }
+
+ public void testDeleteModel_NonRestrictedWithBasicLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ sendNonRestrictedRequest("DELETE", endpoint, null, 404, true);
+ }
+}
diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceLicenseBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceLicenseBaseRestTest.java
new file mode 100644
index 0000000000000..43183bae73252
--- /dev/null
+++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceLicenseBaseRestTest.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.test.rest.ESRestTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class InferenceLicenseBaseRestTest extends ESRestTestCase {
+ protected void sendRestrictedRequest(String method, String endpoint, String body) throws IOException {
+ var request = new Request(method, endpoint);
+ request.setJsonEntity(body);
+
+ var exception = assertThrows(ResponseException.class, () -> client().performRequest(request));
+ assertEquals(403, exception.getResponse().getStatusLine().getStatusCode());
+ assertThat(exception.getMessage(), containsString("current license is non-compliant for [inference]"));
+ }
+
+ protected void sendNonRestrictedRequest(String method, String endpoint, String body, int expectedStatusCode, boolean exceptionExpected)
+ throws IOException {
+ var request = new Request(method, endpoint);
+ request.setJsonEntity(body);
+
+ int actualStatusCode;
+ if (exceptionExpected) {
+ var exception = assertThrows(ResponseException.class, () -> client().performRequest(request));
+ actualStatusCode = exception.getResponse().getStatusLine().getStatusCode();
+ } else {
+ var response = client().performRequest(request);
+ actualStatusCode = response.getStatusLine().getStatusCode();
+ }
+ assertEquals(expectedStatusCode, actualStatusCode);
+ }
+}
diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceTrialLicenseIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceTrialLicenseIT.java
new file mode 100644
index 0000000000000..c7066a827fa7e
--- /dev/null
+++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceTrialLicenseIT.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.cluster.local.distribution.DistributionType;
+import org.junit.ClassRule;
+
+import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.mockSparseServiceModelConfig;
+
+public class InferenceTrialLicenseIT extends InferenceLicenseBaseRestTest {
+ @ClassRule
+ public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
+ .distribution(DistributionType.DEFAULT)
+ .setting("xpack.license.self_generated.type", "trial")
+ .setting("xpack.security.enabled", "true")
+ .user("x_pack_rest_user", "x-pack-test-password")
+ .plugin("inference-service-test")
+ .build();
+
+ @Override
+ protected String getTestRestCluster() {
+ return cluster.getHttpAddresses();
+ }
+
+ @Override
+ protected Settings restClientSettings() {
+ String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
+ return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
+ }
+
+ public void testPutModel_NonRestrictedWithTrialLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ var modelConfig = mockSparseServiceModelConfig(null, true);
+ sendNonRestrictedRequest("PUT", endpoint, modelConfig, 200, false);
+ }
+
+ public void testUpdateModel_NonRestrictedWithTrialLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ var requestBody = """
+ {
+ "task_settings": {
+ "num_threads": 2
+ }
+ }
+ """;
+ sendNonRestrictedRequest("PUT", endpoint, requestBody, 404, true);
+ }
+
+ public void testPerformInference_NonRestrictedWithTrialLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ var requestBody = """
+ {
+ "input": ["washing", "machine"]
+ }
+ """;
+ sendNonRestrictedRequest("POST", endpoint, requestBody, 404, true);
+ }
+
+ public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
+ var endpoint = "_inference/_services";
+ sendNonRestrictedRequest("GET", endpoint, null, 200, false);
+ }
+
+ public void testGetModels_NonRestrictedWithBasicLicense() throws Exception {
+ var endpoint = "_inference/_all";
+ sendNonRestrictedRequest("GET", endpoint, null, 200, false);
+ }
+
+ public void testDeleteModel_NonRestrictedWithBasicLicense() throws Exception {
+ var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
+ sendNonRestrictedRequest("DELETE", endpoint, null, 404, true);
+ }
+}
diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java
index d38503a884092..880557c59f11c 100644
--- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java
+++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java
@@ -12,8 +12,11 @@
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.http.MockWebServer;
-import org.elasticsearch.upgrades.AbstractRollingUpgradeTestCase;
+import org.elasticsearch.upgrades.ParameterizedRollingUpgradeTestCase;
+import org.junit.ClassRule;
import java.io.IOException;
import java.util.LinkedList;
@@ -22,7 +25,7 @@
import static org.elasticsearch.core.Strings.format;
-public class InferenceUpgradeTestCase extends AbstractRollingUpgradeTestCase {
+public class InferenceUpgradeTestCase extends ParameterizedRollingUpgradeTestCase {
static final String MODELS_RENAMED_TO_ENDPOINTS = "8.15.0";
@@ -30,6 +33,20 @@ public InferenceUpgradeTestCase(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}
+ @ClassRule
+ public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
+ .distribution(DistributionType.DEFAULT)
+ .version(getOldClusterTestVersion())
+ .nodes(NODE_NUM)
+ .setting("xpack.security.enabled", "false")
+ .setting("xpack.license.self_generated.type", "trial")
+ .build();
+
+ @Override
+ protected ElasticsearchCluster getUpgradeCluster() {
+ return cluster;
+ }
+
protected static String getUrl(MockWebServer webServer) {
return format("http://%s:%s", webServer.getHostName(), webServer.getPort());
}
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java
index 90a9dd3355b3c..3f7a51cc4e9f9 100644
--- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java
@@ -27,6 +27,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
@@ -73,7 +74,7 @@ public void setup() throws Exception {
@Override
protected Collection> nodePlugins() {
- return Arrays.asList(Utils.TestInferencePlugin.class);
+ return Arrays.asList(Utils.TestInferencePlugin.class, LocalStateCompositeXPackPlugin.class);
}
@Override
diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java
index be6b3725b0f35..24585318b15b3 100644
--- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java
+++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java
@@ -31,6 +31,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -76,7 +77,7 @@ public void createComponents() {
@Override
protected Collection> getPlugins() {
- return pluginList(ReindexPlugin.class, InferencePlugin.class);
+ return pluginList(ReindexPlugin.class, InferencePlugin.class, LocalStateCompositeXPackPlugin.class);
}
public void testStoreModel() throws Exception {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
index b16c53a428d73..d303ead4d9188 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
@@ -29,6 +29,8 @@
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.license.License;
+import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
@@ -150,6 +152,12 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
Setting.Property.Dynamic
);
+ public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary(
+ "inference",
+ "api",
+ License.OperationMode.ENTERPRISE
+ );
+
public static final String NAME = "inference";
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
index 2a0e8e1775279..b6c7d26b36f9a 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
@@ -23,9 +23,12 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
@@ -38,6 +41,7 @@
import java.util.stream.Collectors;
import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
@@ -48,6 +52,7 @@ public abstract class BaseTransportInferenceAction requestReader
) {
super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
+ this.licenseState = licenseState;
this.modelRegistry = modelRegistry;
this.serviceRegistry = serviceRegistry;
this.inferenceStats = inferenceStats;
@@ -72,6 +79,11 @@ public BaseTransportInferenceAction(
@Override
protected void doExecute(Task task, Request request, ActionListener listener) {
+ if (INFERENCE_API_FEATURE.check(licenseState) == false) {
+ listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
+ return;
+ }
+
var timer = InferenceTimer.start();
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
index 08e6d869a553d..24ef0d7d610d0 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
@@ -16,6 +16,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
+import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
@@ -28,6 +29,7 @@ public class TransportInferenceAction extends BaseTransportInferenceAction listener
) throws Exception {
+ if (INFERENCE_API_FEATURE.check(licenseState) == false) {
+ listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
+ return;
+ }
+
var requestAsMap = requestToMap(request);
var resolvedTaskType = ServiceUtils.resolveTaskType(request.getTaskType(), (String) requestAsMap.remove(TaskType.NAME));
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
index f0906231d8f42..1478130f6a6c8 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
@@ -17,6 +17,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
+import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
@@ -30,6 +31,7 @@ public class TransportUnifiedCompletionInferenceAction extends BaseTransportInfe
public TransportUnifiedCompletionInferenceAction(
TransportService transportService,
ActionFilters actionFilters,
+ XPackLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
@@ -39,6 +41,7 @@ public TransportUnifiedCompletionInferenceAction(
UnifiedCompletionAction.NAME,
transportService,
actionFilters,
+ licenseState,
modelRegistry,
serviceRegistry,
inferenceStats,
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java
index 3c47de1ad64d1..4a90f73722438 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java
@@ -34,12 +34,15 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
@@ -57,6 +60,7 @@
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
+import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS;
@@ -66,6 +70,7 @@ public class TransportUpdateInferenceModelAction extends TransportMasterNodeActi
private static final Logger logger = LogManager.getLogger(TransportUpdateInferenceModelAction.class);
+ private final XPackLicenseState licenseState;
private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry;
private final Client client;
@@ -77,6 +82,7 @@ public TransportUpdateInferenceModelAction(
ThreadPool threadPool,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
+ XPackLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
Client client
@@ -92,6 +98,7 @@ public TransportUpdateInferenceModelAction(
UpdateInferenceModelAction.Response::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
+ this.licenseState = licenseState;
this.modelRegistry = modelRegistry;
this.serviceRegistry = serviceRegistry;
this.client = client;
@@ -104,6 +111,11 @@ protected void masterOperation(
ClusterState state,
ActionListener masterListener
) {
+ if (INFERENCE_API_FEATURE.check(licenseState) == false) {
+ masterListener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
+ return;
+ }
+
var bodyTaskType = request.getContentAsSettings().taskType();
var resolvedTaskType = resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java
index bfec2d5ac3484..a595134ecd548 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/cluster/metadata/SemanticTextClusterMetadataTests.java
@@ -14,6 +14,7 @@
import org.elasticsearch.index.IndexService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.hamcrest.Matchers;
@@ -28,7 +29,7 @@ public class SemanticTextClusterMetadataTests extends ESSingleNodeTestCase {
@Override
protected Collection> getPlugins() {
- return List.of(InferencePlugin.class);
+ return List.of(InferencePlugin.class, LocalStateCompositeXPackPlugin.class);
}
public void testCreateIndexWithSemanticTextField() {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
index 47f3a0e0b57aa..a723e5a9dffdf 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
@@ -18,11 +18,13 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
+import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
@@ -49,6 +51,7 @@
import static org.mockito.Mockito.when;
public abstract class BaseTransportInferenceActionTestCase extends ESTestCase {
+ private MockLicenseState licenseState;
private ModelRegistry modelRegistry;
private StreamingTaskManager streamingTaskManager;
private BaseTransportInferenceAction action;
@@ -64,16 +67,28 @@ public void setUp() throws Exception {
super.setUp();
TransportService transportService = mock();
ActionFilters actionFilters = mock();
+ licenseState = mock();
modelRegistry = mock();
serviceRegistry = mock();
inferenceStats = new InferenceStats(mock(), mock());
streamingTaskManager = mock();
- action = createAction(transportService, actionFilters, modelRegistry, serviceRegistry, inferenceStats, streamingTaskManager);
+ action = createAction(
+ transportService,
+ actionFilters,
+ licenseState,
+ modelRegistry,
+ serviceRegistry,
+ inferenceStats,
+ streamingTaskManager
+ );
+
+ mockValidLicenseState();
}
protected abstract BaseTransportInferenceAction createAction(
TransportService transportService,
ActionFilters actionFilters,
+ MockLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
@@ -361,4 +376,8 @@ protected void mockModelAndServiceRegistry(InferenceService service) {
when(serviceRegistry.getService(any())).thenReturn(Optional.of(service));
}
+
+ protected void mockValidLicenseState(){
+ when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true);
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java
index e54175cb27009..a5efe04c22c04 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java
@@ -9,6 +9,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
@@ -23,6 +24,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
protected BaseTransportInferenceAction createAction(
TransportService transportService,
ActionFilters actionFilters,
+ MockLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
@@ -31,6 +33,7 @@ protected BaseTransportInferenceAction createAction(
return new TransportInferenceAction(
transportService,
actionFilters,
+ licenseState,
modelRegistry,
serviceRegistry,
inferenceStats,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java
index 4c943599ce523..3856a3d111b6e 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java
@@ -11,6 +11,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
@@ -36,6 +37,7 @@ public class TransportUnifiedCompletionActionTests extends BaseTransportInferenc
protected BaseTransportInferenceAction createAction(
TransportService transportService,
ActionFilters actionFilters,
+ MockLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
@@ -44,6 +46,7 @@ protected BaseTransportInferenceAction createAc
return new TransportUnifiedCompletionInferenceAction(
transportService,
actionFilters,
+ licenseState,
modelRegistry,
serviceRegistry,
inferenceStats,
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java
index 1f58c4165056d..0025b3a53a69f 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java
@@ -9,6 +9,7 @@
import org.elasticsearch.index.mapper.NonDynamicFieldMapperTests;
import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.junit.Before;
@@ -26,7 +27,7 @@ public void setup() throws Exception {
@Override
protected Collection> getPlugins() {
- return List.of(Utils.TestInferencePlugin.class);
+ return List.of(Utils.TestInferencePlugin.class, LocalStateCompositeXPackPlugin.class);
}
@Override
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java
index 6d6403b69ea11..69b1e19fa91de 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java
@@ -10,6 +10,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rank.rerank.AbstractRerankerIT;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.InferencePlugin;
import java.util.Collection;
@@ -40,7 +41,7 @@ protected RankBuilder getThrowingRankBuilder(int rankWindowSize, String rankFeat
@Override
protected Collection> pluginsNeeded() {
- return List.of(InferencePlugin.class, TextSimilarityTestPlugin.class);
+ return List.of(InferencePlugin.class, TextSimilarityTestPlugin.class, LocalStateCompositeXPackPlugin.class);
}
public void testQueryPhaseShardThrowingAllShardsFail() throws Exception {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java
index a042fca44fdb5..a6a4ce2b2ffdf 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java
@@ -19,6 +19,7 @@
import org.elasticsearch.search.rank.rerank.AbstractRerankerIT;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.junit.Before;
@@ -108,7 +109,7 @@ protected InferenceAction.Request generateRequest(List docFeatures) {
@Override
protected Collection> getPlugins() {
- return List.of(InferencePlugin.class, TextSimilarityTestPlugin.class);
+ return List.of(InferencePlugin.class, TextSimilarityTestPlugin.class, LocalStateCompositeXPackPlugin.class);
}
@Before
From 3e12016a9795f1307e848d4d295bd64725d35376 Mon Sep 17 00:00:00 2001
From: Ignacio Vera
Date: Mon, 13 Jan 2025 17:46:08 +0100
Subject: [PATCH 17/44] Improve error handling when parsing retrievers
(#120047)
PARSER#declareNamedObject does not handle empty objects or objects with extra stuff.
Therefore we do it manually.
---
.../random/RandomRankRetrieverBuilder.java | 24 ++++++-
.../RandomRankRetrieverBuilderTests.java | 63 +++++++++++++++++++
2 files changed, 86 insertions(+), 1 deletion(-)
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java
index 7236b0141a86d..503000c31f7e7 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java
@@ -14,6 +14,7 @@
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
@@ -48,7 +49,12 @@ public class RandomRankRetrieverBuilder extends RetrieverBuilder {
});
static {
- PARSER.declareNamedObject(constructorArg(), (p, c, n) -> p.namedObject(RetrieverBuilder.class, n, c), RETRIEVER_FIELD);
+ PARSER.declareField(
+ constructorArg(),
+ RandomRankRetrieverBuilder::parseRetrieverBuilder,
+ RETRIEVER_FIELD,
+ ObjectParser.ValueType.OBJECT
+ );
PARSER.declareString(optionalConstructorArg(), FIELD_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(optionalConstructorArg(), SEED_FIELD);
@@ -63,6 +69,22 @@ public static RandomRankRetrieverBuilder fromXContent(XContentParser parser, Ret
return PARSER.apply(parser, context);
}
+ private static RetrieverBuilder parseRetrieverBuilder(XContentParser parser, RetrieverParserContext context) throws IOException {
+ assert parser.currentToken() == XContentParser.Token.START_OBJECT;
+ parser.nextToken();
+ if (parser.currentToken() == XContentParser.Token.END_OBJECT) {
+ throw new ParsingException(parser.getTokenLocation(), "empty [" + RETRIEVER_FIELD + "] object");
+ }
+ assert parser.currentToken() == XContentParser.Token.FIELD_NAME;
+ final RetrieverBuilder builder = parser.namedObject(RetrieverBuilder.class, parser.currentName(), context);
+ parser.nextToken();
+ if (parser.currentToken() == XContentParser.Token.FIELD_NAME) {
+ throw new ParsingException(parser.getTokenLocation(), "unexpected field [" + parser.currentName() + "]");
+ }
+ assert parser.currentToken() == XContentParser.Token.END_OBJECT;
+ return builder;
+ }
+
private final RetrieverBuilder retrieverBuilder;
private final String field;
private final int rankWindowSize;
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java
index c0ef4e45f101f..deb5fb47ab939 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java
@@ -15,6 +15,7 @@
import org.elasticsearch.usage.SearchUsage;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.json.JsonXContent;
@@ -23,6 +24,7 @@
import java.util.List;
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
+import static org.hamcrest.Matchers.containsString;
public class RandomRankRetrieverBuilderTests extends AbstractXContentTestCase {
@@ -99,4 +101,65 @@ public void testParserDefaults() throws IOException {
}
}
+ public void testParserEmptyRetriever() throws IOException {
+ String json = """
+ {
+ "retriever": {
+ },
+ "field": "my-field"
+ }""";
+
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
+ XContentParseException ex = expectThrows(
+ XContentParseException.class,
+ () -> RandomRankRetrieverBuilder.PARSER.parse(parser, null)
+ );
+ assertThat(ex.getMessage(), containsString("[random_reranker] failed to parse field [retriever]"));
+ assertThat(ex.getCause().getMessage(), containsString("empty [retriever] object"));
+ }
+ }
+
+ public void testParserWrongRetrieverName() throws IOException {
+ String json = """
+ {
+ "retriever": {
+ "test2": {
+ "value": "my-test-retriever"
+ }
+ },
+ "field": "my-field"
+ }""";
+
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
+ XContentParseException ex = expectThrows(
+ XContentParseException.class,
+ () -> RandomRankRetrieverBuilder.PARSER.parse(parser, null)
+ );
+ assertThat(ex.getMessage(), containsString("[random_reranker] failed to parse field [retriever]"));
+ assertThat(ex.getCause().getMessage(), containsString("unknown field [test2]"));
+ }
+ }
+
+ public void testExtraContent() throws IOException {
+ String json = """
+ {
+ "retriever": {
+ "test": {
+ "value": "my-test-retriever"
+ },
+ "field2": "my-field"
+ },
+ "field": "my-field"
+ }""";
+
+ try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
+ XContentParseException ex = expectThrows(
+ XContentParseException.class,
+ () -> RandomRankRetrieverBuilder.PARSER.parse(parser, null)
+ );
+ assertThat(ex.getMessage(), containsString("[random_reranker] failed to parse field [retriever]"));
+ assertThat(ex.getCause().getMessage(), containsString("unexpected field [field2]"));
+ }
+ }
+
}
From 7a267ee9e2bd2554246c7efe64ec84010f9ad55b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Iv=C3=A1n=20Cea=20Fontenla?=
Date: Mon, 13 Jan 2025 17:55:39 +0100
Subject: [PATCH 18/44] ESQL: Function tests standardization (#119506)
Made a quick review of the ESQL function tests, using `parameterSuppliersFromTypedDataWithDefaultChecks` when possible, and documenting a bit if not.
Some tests didn't have the error checks for types/nulls, so it added them.
Some functions couldn't be updated, and found some bugs and 500s in the process, which will be fixed in separate PR(s).
---
.../AbstractScalarFunctionTestCase.java | 2 +-
.../scalar/convert/ToDatePeriodTests.java | 2 +-
.../scalar/convert/ToTimeDurationTests.java | 2 +-
.../function/scalar/date/DateDiffTests.java | 2 +-
.../function/scalar/date/NowTests.java | 6 +-
.../function/scalar/math/AcosTests.java | 2 +-
.../function/scalar/math/AsinTests.java | 2 +-
.../function/scalar/math/CbrtTests.java | 2 +-
.../function/scalar/math/ETests.java | 24 +++--
.../function/scalar/math/ExpTests.java | 2 +-
.../function/scalar/math/Log10Tests.java | 2 +-
.../function/scalar/math/LogTests.java | 6 +-
.../function/scalar/math/PiTests.java | 24 +++--
.../function/scalar/math/PowTests.java | 2 +-
.../function/scalar/math/SignumTests.java | 2 +-
.../function/scalar/math/SqrtTests.java | 2 +-
.../function/scalar/math/TauTests.java | 24 +++--
.../multivalue/MvPSeriesWeightedSumTests.java | 12 +--
.../scalar/multivalue/MvSortTests.java | 93 +++++--------------
.../scalar/multivalue/MvSumTests.java | 2 +-
.../function/scalar/string/LocateTests.java | 4 +-
.../function/scalar/string/RepeatTests.java | 4 +-
.../function/scalar/string/ReverseTests.java | 2 +-
.../operator/arithmetic/AddTests.java | 39 +++-----
.../operator/arithmetic/DivTests.java | 5 +-
.../operator/arithmetic/ModTests.java | 5 +-
.../operator/arithmetic/MulTests.java | 16 +++-
.../operator/arithmetic/SubTests.java | 44 ++++-----
.../operator/comparison/EqualsTests.java | 2 +-
.../comparison/GreaterThanOrEqualTests.java | 2 +-
.../operator/comparison/GreaterThanTests.java | 2 +-
.../comparison/LessThanOrEqualTests.java | 2 +-
.../operator/comparison/LessThanTests.java | 2 +-
.../operator/comparison/NotEqualsTests.java | 3 +-
34 files changed, 160 insertions(+), 187 deletions(-)
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
index 64086334b7251..944515e54af75 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
@@ -421,7 +421,7 @@ protected static TestCaseSupplier arithmeticExceptionOverflowCase(
String typeNameOverflow = dataType.typeName().toLowerCase(Locale.ROOT) + " overflow";
return new TestCaseSupplier(
"<" + typeNameOverflow + ">",
- List.of(dataType),
+ List.of(dataType, dataType),
() -> new TestCaseSupplier.TestCase(
List.of(
new TestCaseSupplier.TypedData(lhsSupplier.get(), dataType, "lhs"),
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDatePeriodTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDatePeriodTests.java
index 9abbfbd61c1a8..8060326365d5c 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDatePeriodTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDatePeriodTests.java
@@ -71,7 +71,7 @@ public static Iterable parameters() {
}));
}
}
- return parameterSuppliersFromTypedData(anyNullIsNull(true, suppliers));
+ return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToTimeDurationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToTimeDurationTests.java
index 6486f6efcdd3c..980eed1af9109 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToTimeDurationTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToTimeDurationTests.java
@@ -70,7 +70,7 @@ public static Iterable parameters() {
}));
}
}
- return parameterSuppliersFromTypedData(anyNullIsNull(true, suppliers));
+ return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffTests.java
index da069e3c37cc4..b4a37b0297571 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiffTests.java
@@ -99,7 +99,7 @@ public static Iterable parameters() {
equalTo(0)
);
}));
- return parameterSuppliersFromTypedData(anyNullIsNull(false, suppliers));
+ return parameterSuppliersFromTypedDataWithDefaultChecksNoErrors(true, suppliers);
}
public void testDateDiffFunction() {
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/NowTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/NowTests.java
index ce6ee1702ee66..c667747a8ba75 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/NowTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/NowTests.java
@@ -32,7 +32,8 @@ public NowTests(@Name("TestCase") Supplier testCaseSu
@ParametersFactory
public static Iterable parameters() {
- return parameterSuppliersFromTypedData(
+ return parameterSuppliersFromTypedDataWithDefaultChecks(
+ true,
List.of(
new TestCaseSupplier(
"Now Test",
@@ -44,7 +45,8 @@ public static Iterable parameters() {
equalTo(EsqlTestUtils.TEST_CFG.now().toInstant().toEpochMilli())
)
)
- )
+ ),
+ (valid, position) -> ""
);
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AcosTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AcosTests.java
index 278c9123e30b1..6531e7bee90ab 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AcosTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AcosTests.java
@@ -56,7 +56,7 @@ public static Iterable parameters() {
)
)
);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AsinTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AsinTests.java
index 04fec5a20b438..410dc61ec5fa6 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AsinTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/AsinTests.java
@@ -56,7 +56,7 @@ public static Iterable parameters() {
)
)
);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CbrtTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CbrtTests.java
index bfe35a08b8ba1..d702e28baf9d8 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CbrtTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/CbrtTests.java
@@ -72,7 +72,7 @@ public static Iterable parameters() {
);
suppliers = anyNullIsNull(true, suppliers);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ETests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ETests.java
index 50ed71262e5df..f3922a355180d 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ETests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ETests.java
@@ -29,14 +29,22 @@ public ETests(@Name("TestCase") Supplier testCaseSupp
@ParametersFactory
public static Iterable parameters() {
- return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("E Test", List.of(DataType.INTEGER), () -> {
- return new TestCaseSupplier.TestCase(
- List.of(new TestCaseSupplier.TypedData(1, DataType.INTEGER, "foo")),
- "LiteralsEvaluator[lit=2.718281828459045]",
- DataType.DOUBLE,
- equalTo(Math.E)
- );
- })));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(
+ true,
+ List.of(
+ new TestCaseSupplier(
+ "E Test",
+ List.of(),
+ () -> new TestCaseSupplier.TestCase(
+ List.of(),
+ "LiteralsEvaluator[lit=2.718281828459045]",
+ DataType.DOUBLE,
+ equalTo(Math.E)
+ )
+ )
+ ),
+ (v, p) -> ""
+ );
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ExpTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ExpTests.java
index d42f4ffde0609..bc5faf1b2560d 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ExpTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/ExpTests.java
@@ -75,7 +75,7 @@ public static Iterable parameters() {
suppliers = anyNullIsNull(true, suppliers);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log10Tests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log10Tests.java
index 44ad4547481d6..7942320656f3f 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log10Tests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Log10Tests.java
@@ -124,7 +124,7 @@ public static Iterable parameters() {
)
);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java
index 671cffe9e7f9e..0ee277dbcadb2 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/LogTests.java
@@ -187,11 +187,7 @@ public static Iterable parameters() {
)
);
- // Add null cases before the rest of the error cases, so messages are correct.
- suppliers = anyNullIsNull(true, suppliers);
-
- // Negative cases
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiTests.java
index e93bc7b43d98e..79742952dbf59 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PiTests.java
@@ -29,14 +29,22 @@ public PiTests(@Name("TestCase") Supplier testCaseSup
@ParametersFactory
public static Iterable parameters() {
- return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("Pi Test", List.of(DataType.INTEGER), () -> {
- return new TestCaseSupplier.TestCase(
- List.of(new TestCaseSupplier.TypedData(1, DataType.INTEGER, "foo")),
- "LiteralsEvaluator[lit=3.141592653589793]",
- DataType.DOUBLE,
- equalTo(Math.PI)
- );
- })));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(
+ true,
+ List.of(
+ new TestCaseSupplier(
+ "Pi Test",
+ List.of(),
+ () -> new TestCaseSupplier.TestCase(
+ List.of(),
+ "LiteralsEvaluator[lit=3.141592653589793]",
+ DataType.DOUBLE,
+ equalTo(Math.PI)
+ )
+ )
+ ),
+ (v, p) -> "numeric"
+ );
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java
index 9d8b87bab8878..2fc139a5458c3 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/PowTests.java
@@ -77,7 +77,7 @@ public static Iterable parameters() {
)
)
);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumTests.java
index 8c612e5e664e0..4bf1351969d79 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SignumTests.java
@@ -72,7 +72,7 @@ public static Iterable parameters() {
suppliers = anyNullIsNull(true, suppliers);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SqrtTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SqrtTests.java
index 23f2adc6c02e0..7cba5d6d57d45 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SqrtTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/SqrtTests.java
@@ -109,7 +109,7 @@ public static Iterable parameters() {
"Line -1:-1: java.lang.ArithmeticException: Square root of negative"
)
);
- return parameterSuppliersFromTypedData(errorsForCasesWithoutExamples(suppliers, (v, p) -> "numeric"));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(true, suppliers, (v, p) -> "numeric");
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauTests.java
index 1a622b25b3353..40e66333f953e 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/TauTests.java
@@ -29,14 +29,22 @@ public TauTests(@Name("TestCase") Supplier testCaseSu
@ParametersFactory
public static Iterable parameters() {
- return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("Tau Test", List.of(DataType.INTEGER), () -> {
- return new TestCaseSupplier.TestCase(
- List.of(new TestCaseSupplier.TypedData(1, DataType.INTEGER, "foo")),
- "LiteralsEvaluator[lit=6.283185307179586]",
- DataType.DOUBLE,
- equalTo(Tau.TAU)
- );
- })));
+ return parameterSuppliersFromTypedDataWithDefaultChecks(
+ true,
+ List.of(
+ new TestCaseSupplier(
+ "Tau Test",
+ List.of(),
+ () -> new TestCaseSupplier.TestCase(
+ List.of(),
+ "LiteralsEvaluator[lit=6.283185307179586]",
+ DataType.DOUBLE,
+ equalTo(Tau.TAU)
+ )
+ )
+ ),
+ (v, p) -> "numeric"
+ );
}
@Override
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java
index 156fc4bfe7c36..0c905b28ac931 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java
@@ -32,22 +32,20 @@ public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier parameters() {
List