diff --git a/core/build.gradle b/core/build.gradle index a338b8f368..624c10fd6b 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -57,6 +57,7 @@ dependencies { testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.12.4' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: '3.12.4' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' } diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 29c0e4050a..aef7de69a8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -49,6 +49,7 @@ import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -87,6 +88,7 @@ import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRareTopN; @@ -563,6 +565,12 @@ public LogicalPlan visitML(ML node, AnalysisContext context) { return new LogicalML(child, node.getArguments()); } + @Override + public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) { + LogicalPlan child = paginate.getChild().get(0).accept(this, context); + return new LogicalPaginate(paginate.getPageSize(), List.of(child)); + } + /** * The first argument is always "asc", others are optional. * Given nullFirst argument, use its value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index d2ebb9eb99..9c283d95f6 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -48,6 +48,7 @@ import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; @@ -294,4 +295,8 @@ public T visitQuery(Query node, C context) { public T visitExplain(Explain node, C context) { return visitStatement(node, context); } + + public T visitPaginate(Paginate paginate, C context) { + return visitChildren(paginate, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/statement/Query.java b/core/src/main/java/org/opensearch/sql/ast/statement/Query.java index 17682cd47b..82efdde4dd 100644 --- a/core/src/main/java/org/opensearch/sql/ast/statement/Query.java +++ b/core/src/main/java/org/opensearch/sql/ast/statement/Query.java @@ -27,6 +27,7 @@ public class Query extends Statement { protected final UnresolvedPlan plan; + protected final int fetchSize; @Override public R accept(AbstractNodeVisitor visitor, C context) { diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Paginate.java b/core/src/main/java/org/opensearch/sql/ast/tree/Paginate.java new file mode 100644 index 0000000000..55e0e8c7a6 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Paginate.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; + +/** + * AST node to represent pagination operation. + * Actually a wrapper to the AST. + */ +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +@ToString +public class Paginate extends UnresolvedPlan { + @Getter + private final int pageSize; + private UnresolvedPlan child; + + public Paginate(int pageSize, UnresolvedPlan child) { + this.pageSize = pageSize; + this.child = child; + } + + @Override + public List getChild() { + return List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitPaginate(this, context); + } + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } +} diff --git a/core/src/main/java/org/opensearch/sql/exception/NoCursorException.java b/core/src/main/java/org/opensearch/sql/exception/NoCursorException.java new file mode 100644 index 0000000000..9383bece57 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/exception/NoCursorException.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.exception; + +/** + * This should be thrown on serialization of a PhysicalPlan tree if paging is finished. + * Processing of such exception should outcome of responding no cursor to the user. + */ +public class NoCursorException extends RuntimeException { +} diff --git a/core/src/main/java/org/opensearch/sql/exception/UnsupportedCursorRequestException.java b/core/src/main/java/org/opensearch/sql/exception/UnsupportedCursorRequestException.java new file mode 100644 index 0000000000..6ed8e02e5f --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/exception/UnsupportedCursorRequestException.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.exception; + +/** + * This should be thrown by V2 engine to support fallback scenario. + */ +public class UnsupportedCursorRequestException extends RuntimeException { +} diff --git a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java index 1936a0f517..8d87bd9b14 100644 --- a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java +++ b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java @@ -14,6 +14,7 @@ import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.physical.PhysicalPlan; /** @@ -53,6 +54,8 @@ void execute(PhysicalPlan plan, ExecutionContext context, class QueryResponse { private final Schema schema; private final List results; + private final long total; + private final Cursor cursor; } @Data diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index 94e7081920..a4cd1982cd 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -46,6 +46,14 @@ public void execute(UnresolvedPlan plan, } } + /** + * Execute a physical plan without analyzing or planning anything. + */ + public void executePlan(PhysicalPlan plan, + ResponseListener listener) { + executionEngine.execute(plan, ExecutionContext.emptyExecutionContext(), listener); + } + /** * Execute the {@link UnresolvedPlan}, with {@link PlanContext} and using {@link ResponseListener} * to get response. diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlan.java new file mode 100644 index 0000000000..eda65aba2d --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlan.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.execution; + +import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.QueryId; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +/** + * ContinuePaginatedPlan represents cursor a request. + * It returns subsequent pages to the user (2nd page and all next). + */ +public class ContinuePaginatedPlan extends AbstractPlan { + + private final String cursor; + private final QueryService queryService; + private final PlanSerializer planSerializer; + + private final ResponseListener queryResponseListener; + + + /** + * Create an abstract plan that can continue paginating a given cursor. + */ + public ContinuePaginatedPlan(QueryId queryId, String cursor, QueryService queryService, + PlanSerializer planCache, + ResponseListener + queryResponseListener) { + super(queryId); + this.cursor = cursor; + this.planSerializer = planCache; + this.queryService = queryService; + this.queryResponseListener = queryResponseListener; + } + + @Override + public void execute() { + try { + PhysicalPlan plan = planSerializer.convertToPlan(cursor); + queryService.executePlan(plan, queryResponseListener); + } catch (Exception e) { + queryResponseListener.onFailure(e); + } + } + + @Override + public void explain(ResponseListener listener) { + listener.onFailure(new UnsupportedOperationException( + "Explain of a paged query continuation is not supported. " + + "Use `explain` for the initial query request.")); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java index af5c032d49..df9bc0c734 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlan.java @@ -8,6 +8,9 @@ package org.opensearch.sql.executor.execution; +import java.util.Optional; +import org.apache.commons.lang3.NotImplementedException; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.executor.ExecutionEngine; @@ -33,25 +36,51 @@ public class QueryPlan extends AbstractPlan { protected final ResponseListener listener; - /** constructor. */ + protected final Optional pageSize; + + /** Constructor. */ + public QueryPlan( + QueryId queryId, + UnresolvedPlan plan, + QueryService queryService, + ResponseListener listener) { + super(queryId); + this.plan = plan; + this.queryService = queryService; + this.listener = listener; + this.pageSize = Optional.empty(); + } + + /** Constructor with page size. */ public QueryPlan( QueryId queryId, UnresolvedPlan plan, + int pageSize, QueryService queryService, ResponseListener listener) { super(queryId); this.plan = plan; this.queryService = queryService; this.listener = listener; + this.pageSize = Optional.of(pageSize); } @Override public void execute() { - queryService.execute(plan, listener); + if (pageSize.isPresent()) { + queryService.execute(new Paginate(pageSize.get(), plan), listener); + } else { + queryService.execute(plan, listener); + } } @Override public void explain(ResponseListener listener) { - queryService.explain(plan, listener); + if (pageSize.isPresent()) { + listener.onFailure(new NotImplementedException( + "`explain` feature for paginated requests is not implemented yet.")); + } else { + queryService.explain(plan, listener); + } } } diff --git a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java index 851381cc7a..18455c2a02 100644 --- a/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java +++ b/core/src/main/java/org/opensearch/sql/executor/execution/QueryPlanFactory.java @@ -18,9 +18,11 @@ import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.PlanSerializer; /** * QueryExecution Factory. @@ -37,9 +39,10 @@ public class QueryPlanFactory * Query Service. */ private final QueryService queryService; + private final PlanSerializer planSerializer; /** - * NO_CONSUMER_RESPONSE_LISTENER should never been called. It is only used as constructor + * NO_CONSUMER_RESPONSE_LISTENER should never be called. It is only used as constructor * parameter of {@link QueryPlan}. */ @VisibleForTesting @@ -62,39 +65,62 @@ public void onFailure(Exception e) { /** * Create QueryExecution from Statement. */ - public AbstractPlan create( + public AbstractPlan createContinuePaginatedPlan( Statement statement, Optional> queryListener, Optional> explainListener) { return statement.accept(this, Pair.of(queryListener, explainListener)); } + /** + * Creates a ContinuePaginatedPlan from a cursor. + */ + public AbstractPlan createContinuePaginatedPlan(String cursor, boolean isExplain, + ResponseListener queryResponseListener, + ResponseListener explainListener) { + QueryId queryId = QueryId.queryId(); + var plan = new ContinuePaginatedPlan(queryId, cursor, queryService, + planSerializer, queryResponseListener); + return isExplain ? new ExplainPlan(queryId, plan, explainListener) : plan; + } + @Override public AbstractPlan visitQuery( Query node, - Pair< - Optional>, - Optional>> + Pair>, + Optional>> context) { Preconditions.checkArgument( context.getLeft().isPresent(), "[BUG] query listener must be not null"); - return new QueryPlan(QueryId.queryId(), node.getPlan(), queryService, context.getLeft().get()); + if (node.getFetchSize() > 0) { + if (planSerializer.canConvertToCursor(node.getPlan())) { + return new QueryPlan(QueryId.queryId(), node.getPlan(), node.getFetchSize(), + queryService, + context.getLeft().get()); + } else { + // This should be picked up by the legacy engine. + throw new UnsupportedCursorRequestException(); + } + } else { + return new QueryPlan(QueryId.queryId(), node.getPlan(), queryService, + context.getLeft().get()); + } } @Override public AbstractPlan visitExplain( Explain node, - Pair< - Optional>, - Optional>> + Pair>, + Optional>> context) { Preconditions.checkArgument( context.getRight().isPresent(), "[BUG] explain listener must be not null"); return new ExplainPlan( QueryId.queryId(), - create(node.getStatement(), Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), + createContinuePaginatedPlan(node.getStatement(), + Optional.of(NO_CONSUMER_RESPONSE_LISTENER), Optional.empty()), context.getRight().get()); } } diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java new file mode 100644 index 0000000000..3164794abb --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/CanPaginateVisitor.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.Relation; + +/** + * Use this unresolved plan visitor to check if a plan can be serialized by PaginatedPlanCache. + * If plan.accept(new CanPaginateVisitor(...)) returns true, + * then PaginatedPlanCache.convertToCursor will succeed. Otherwise, it will fail. + * The purpose of this visitor is to activate legacy engine fallback mechanism. + * Currently, the conditions are: + * - only projection of a relation is supported. + * - projection only has * (a.k.a. allFields). + * - Relation only scans one table + * - The table is an open search index. + * So it accepts only queries like `select * from $index` + * See PaginatedPlanCache.canConvertToCursor for usage. + */ +public class CanPaginateVisitor extends AbstractNodeVisitor { + + @Override + public Boolean visitRelation(Relation node, Object context) { + if (!node.getChild().isEmpty()) { + // Relation instance should never have a child, but check just in case. + return Boolean.FALSE; + } + + return Boolean.TRUE; + } + + @Override + public Boolean visitChildren(Node node, Object context) { + return Boolean.FALSE; + } + + @Override + public Boolean visitProject(Project node, Object context) { + // Allow queries with 'SELECT *' only. Those restriction could be removed, but consider + // in-memory aggregation performed by window function (see WindowOperator). + // SELECT max(age) OVER (PARTITION BY city) ... + var projections = node.getProjectList(); + if (projections.size() != 1) { + return Boolean.FALSE; + } + + if (!(projections.get(0) instanceof AllFields)) { + return Boolean.FALSE; + } + + var children = node.getChild(); + if (children.size() != 1) { + return Boolean.FALSE; + } + + return children.get(0).accept(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java b/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java new file mode 100644 index 0000000000..bb320f5c67 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/Cursor.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +@EqualsAndHashCode +@RequiredArgsConstructor +public class Cursor { + public static final Cursor None = new Cursor(null); + + @Getter + private final String data; + + public String toString() { + return data; + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java new file mode 100644 index 0000000000..d6d10ee89c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/executor/pagination/PlanSerializer.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import com.google.common.hash.HashCode; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.NotSerializableException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.zip.Deflater; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; + +/** + * This class is entry point to paged requests. It is responsible to cursor serialization + * and deserialization. + */ +@RequiredArgsConstructor +public class PlanSerializer { + public static final String CURSOR_PREFIX = "n:"; + + private final StorageEngine engine; + + public boolean canConvertToCursor(UnresolvedPlan plan) { + return plan.accept(new CanPaginateVisitor(), null); + } + + /** + * Converts a physical plan tree to a cursor. + */ + public Cursor convertToCursor(PhysicalPlan plan) { + try { + return new Cursor(CURSOR_PREFIX + + serialize(((SerializablePlan) plan).getPlanForSerialization())); + // ClassCastException thrown when a plan in the tree doesn't implement SerializablePlan + } catch (NotSerializableException | ClassCastException | NoCursorException e) { + return Cursor.None; + } + } + + /** + * Serializes and compresses the object. + * @param object The object. + * @return Encoded binary data. + */ + protected String serialize(Serializable object) throws NotSerializableException { + try { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(object); + objectOutput.flush(); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + // GZIP provides 35-45%, lzma from apache commons-compress has few % better compression + GZIPOutputStream gzip = new GZIPOutputStream(out) { { + this.def.setLevel(Deflater.BEST_COMPRESSION); + } }; + gzip.write(output.toByteArray()); + gzip.close(); + + return HashCode.fromBytes(out.toByteArray()).toString(); + } catch (NotSerializableException e) { + throw e; + } catch (IOException e) { + throw new IllegalStateException("Failed to serialize: " + object, e); + } + } + + /** + * Decompresses and deserializes the binary data. + * @param code Encoded binary data. + * @return An object. + */ + protected Serializable deserialize(String code) { + try { + GZIPInputStream gzip = new GZIPInputStream( + new ByteArrayInputStream(HashCode.fromString(code).asBytes())); + ObjectInputStream objectInput = new CursorDeserializationStream( + new ByteArrayInputStream(gzip.readAllBytes())); + return (Serializable) objectInput.readObject(); + } catch (Exception e) { + throw new IllegalStateException("Failed to deserialize object", e); + } + } + + /** + * Converts a cursor to a physical plan tree. + */ + public PhysicalPlan convertToPlan(String cursor) { + if (!cursor.startsWith(CURSOR_PREFIX)) { + throw new UnsupportedOperationException("Unsupported cursor"); + } + try { + return (PhysicalPlan) deserialize(cursor.substring(CURSOR_PREFIX.length())); + } catch (Exception e) { + throw new UnsupportedOperationException("Unsupported cursor", e); + } + } + + /** + * This function is used in testing only, to get access to {@link CursorDeserializationStream}. + */ + public CursorDeserializationStream getCursorDeserializationStream(InputStream in) + throws IOException { + return new CursorDeserializationStream(in); + } + + public class CursorDeserializationStream extends ObjectInputStream { + public CursorDeserializationStream(InputStream in) throws IOException { + super(in); + } + + @Override + public Object resolveObject(Object obj) throws IOException { + return obj.equals("engine") ? engine : obj; + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index d4cdb528fa..9bde4ab647 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -152,5 +152,4 @@ protected PhysicalPlan visitChild(LogicalPlan node, C context) { // Logical operators visited here must have a single child return node.getChild().get(0).accept(this, context); } - } diff --git a/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java new file mode 100644 index 0000000000..487b1da6bd --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/SerializablePlan.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner; + +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import org.opensearch.sql.executor.pagination.PlanSerializer; + +/** + * All subtypes of PhysicalPlan which needs to be serialized (in cursor, for pagination feature) + * should follow one of the following options. + *
    + *
  • Both: + *
      + *
    • Override both methods from {@link Externalizable}.
    • + *
    • Define a public no-arg constructor.
    • + *
    + *
  • + *
  • + * Overwrite {@link #getPlanForSerialization} to return + * another instance of {@link SerializablePlan}. + *
  • + *
+ */ +public interface SerializablePlan extends Externalizable { + + /** + * Argument is an instance of {@link PlanSerializer.CursorDeserializationStream}. + */ + @Override + void readExternal(ObjectInput in) throws IOException, ClassNotFoundException; + + /** + * Each plan which has as a child plan should do. + *
{@code
+   * out.writeObject(input.getPlanForSerialization());
+   * }
+ */ + @Override + void writeExternal(ObjectOutput out) throws IOException; + + /** + * Override to return child or delegated plan, so parent plan should skip this one + * for serialization, but it should try to serialize grandchild plan. + * Imagine plan structure like this + *
+   *    A         -> this
+   *    `- B      -> child
+   *      `- C    -> this
+   * 
+ * In that case only plans A and C should be attempted to serialize. + * It is needed to skip a `ResourceMonitorPlan` instance only, actually. + * @return Next plan for serialization. + */ + default SerializablePlan getPlanForSerialization() { + return this; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPaginate.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPaginate.java new file mode 100644 index 0000000000..372f9dcf0b --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPaginate.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; + +/** + * LogicalPaginate represents pagination operation for underlying plan. + */ +@ToString +@EqualsAndHashCode(callSuper = false) +public class LogicalPaginate extends LogicalPlan { + @Getter + private final int pageSize; + + public LogicalPaginate(int pageSize, List childPlans) { + super(childPlans); + this.pageSize = pageSize; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitPaginate(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 411d9a51be..e95e47a013 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -54,6 +54,10 @@ public static LogicalPlan rename( return new LogicalRename(input, renameMap); } + public static LogicalPlan paginate(LogicalPlan input, int fetchSize) { + return new LogicalPaginate(fetchSize, List.of(input)); + } + public static LogicalPlan project(LogicalPlan input, NamedExpression... fields) { return new LogicalProject(input, Arrays.asList(fields), ImmutableList.of()); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index d7ab75f869..b3d63e843f 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -104,4 +104,8 @@ public R visitML(LogicalML plan, C context) { public R visitAD(LogicalAD plan, C context) { return visitNode(plan, context); } + + public R visitPaginate(LogicalPaginate plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index 097c5ff8ce..afe86d0cb1 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -13,6 +13,7 @@ import java.util.List; import java.util.stream.Collectors; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.rule.CreatePagingTableScanBuilder; import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter; import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; @@ -51,6 +52,7 @@ public static LogicalPlanOptimizer create() { * Phase 2: Transformations that rely on data source push down capability */ new CreateTableScanBuilder(), + new CreatePagingTableScanBuilder(), TableScanPushDown.PUSH_DOWN_FILTER, TableScanPushDown.PUSH_DOWN_AGGREGATION, TableScanPushDown.PUSH_DOWN_SORT, diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilder.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilder.java new file mode 100644 index 0000000000..c635400c33 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilder.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; +import lombok.Getter; +import lombok.experimental.Accessors; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.optimizer.Rule; + +/** + * Rule to create a paged TableScanBuilder in pagination request. + */ +public class CreatePagingTableScanBuilder implements Rule { + /** Capture the table inside matched logical paginate operator. */ + private LogicalPlan relationParent = null; + /** Pattern that matches logical relation operator. */ + @Accessors(fluent = true) + @Getter + private final Pattern pattern; + + /** + * Constructor. + */ + public CreatePagingTableScanBuilder() { + this.pattern = Pattern.typeOf(LogicalPaginate.class).matching(this::findLogicalRelation); + } + + /** + * Finds an instance of LogicalRelation and saves a reference in relationParent variable. + * @param logicalPaginate An instance of LogicalPaginate + * @return true if {@link LogicalRelation} node was found among the descendents of + * {@link this.logicalPaginate}, false otherwise. + */ + private boolean findLogicalRelation(LogicalPaginate logicalPaginate) { + Deque plans = new ArrayDeque<>(); + plans.add(logicalPaginate); + do { + final var plan = plans.removeFirst(); + final var children = plan.getChild(); + if (children.stream().anyMatch(LogicalRelation.class::isInstance)) { + if (children.size() > 1) { + throw new UnsupportedOperationException( + "Unsupported plan: relation operator cannot have siblings"); + } + relationParent = plan; + return true; + } + plans.addAll(children); + } while (!plans.isEmpty()); + return false; + } + + + @Override + public LogicalPlan apply(LogicalPaginate plan, Captures captures) { + var logicalRelation = (LogicalRelation) relationParent.getChild().get(0); + var scan = logicalRelation.getTable().createPagedScanBuilder(plan.getPageSize()); + relationParent.replaceChildPlans(List.of(scan)); + + return plan.getChild().get(0); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java index 86cd411a2d..a9c7597c3e 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java @@ -17,8 +17,9 @@ import org.opensearch.sql.storage.bindingtuple.BindingTuple; /** - * The Filter operator use the conditions to evaluate the input {@link BindingTuple}. - * The Filter operator only return the results that evaluated to true. + * The Filter operator represents WHERE clause and + * uses the conditions to evaluate the input {@link BindingTuple}. + * The Filter operator only returns the results that evaluated to true. * The NULL and MISSING are handled by the logic defined in {@link BinaryPredicateOperator}. */ @EqualsAndHashCode(callSuper = false) @@ -29,7 +30,9 @@ public class FilterOperator extends PhysicalPlan { private final PhysicalPlan input; @Getter private final Expression conditions; - @ToString.Exclude private ExprValue next = null; + @ToString.Exclude + private ExprValue next = null; + private long totalHits = 0; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -48,6 +51,7 @@ public boolean hasNext() { ExprValue exprValue = conditions.valueOf(inputValue.bindingTuples()); if (!(exprValue.isNull() || exprValue.isMissing()) && (exprValue.booleanValue())) { next = inputValue; + totalHits++; return true; } } @@ -58,4 +62,10 @@ public boolean hasNext() { public ExprValue next() { return next; } + + @Override + public long getTotalHits() { + // ignore `input.getTotalHits()`, because it returns wrong (unfiltered) value + return totalHits; + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java index 049e9fd16e..cea8ae6c14 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/NestedOperator.java @@ -47,6 +47,8 @@ public class NestedOperator extends PhysicalPlan { @EqualsAndHashCode.Exclude private ListIterator> flattenedResult = result.listIterator(); + private long totalHits = 0; + /** * Constructor for NestedOperator with list of map as arg. * @param input : PhysicalPlan input. @@ -99,7 +101,6 @@ public boolean hasNext() { return input.hasNext() || flattenedResult.hasNext(); } - @Override public ExprValue next() { if (!flattenedResult.hasNext()) { @@ -120,11 +121,13 @@ public ExprValue next() { if (result.isEmpty()) { flattenedResult = result.listIterator(); + totalHits++; return new ExprTupleValue(new LinkedHashMap<>()); } flattenedResult = result.listIterator(); } + totalHits++; return new ExprTupleValue(new LinkedHashMap<>(flattenedResult.next())); } @@ -233,7 +236,6 @@ boolean containSamePath(Map newMap) { return false; } - /** * Retrieve nested field(s) in row. * @@ -281,4 +283,9 @@ private void getNested( row, ret, currentObj); } } + + @Override + public long getTotalHits() { + return totalHits; + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java index b476b01557..b4547a63b0 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java @@ -15,9 +15,8 @@ /** * Physical plan. */ -public abstract class PhysicalPlan implements PlanNode, - Iterator, - AutoCloseable { +public abstract class PhysicalPlan + implements PlanNode, Iterator, AutoCloseable { /** * Accept the {@link PhysicalPlanNodeVisitor}. * @@ -43,6 +42,17 @@ public void add(Split split) { public ExecutionEngine.Schema schema() { throw new IllegalStateException(String.format("[BUG] schema can been only applied to " - + "ProjectOperator, instead of %s", toString())); + + "ProjectOperator, instead of %s", this.getClass().getSimpleName())); + } + + /** + * Returns Total hits matched the search criteria. Note: query may return less if limited. + * {@see Settings#QUERY_SIZE_LIMIT}. + * Any plan which adds/removes rows to the response should overwrite it to provide valid values. + * + * @return Total hits matched the search criteria. + */ + public long getTotalHits() { + return getChild().stream().mapToLong(PhysicalPlan::getTotalHits).max().orElse(0); } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java index 496e4e6ddb..1699c97c15 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java @@ -8,13 +8,16 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap.Builder; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -22,20 +25,21 @@ import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.parse.ParseExpression; +import org.opensearch.sql.planner.SerializablePlan; /** * Project the fields specified in {@link ProjectOperator#projectList} from input. */ @ToString @EqualsAndHashCode(callSuper = false) -@RequiredArgsConstructor -public class ProjectOperator extends PhysicalPlan { +@AllArgsConstructor +public class ProjectOperator extends PhysicalPlan implements SerializablePlan { @Getter - private final PhysicalPlan input; + private PhysicalPlan input; @Getter - private final List projectList; + private List projectList; @Getter - private final List namedParseExpressions; + private List namedParseExpressions; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -94,4 +98,24 @@ public ExecutionEngine.Schema schema() { .map(expr -> new ExecutionEngine.Schema.Column(expr.getName(), expr.getAlias(), expr.type())).collect(Collectors.toList())); } + + /** Don't use, it is for deserialization needs only. */ + @Deprecated + public ProjectOperator() { + } + + @SuppressWarnings("unchecked") + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + projectList = (List) in.readObject(); + // note: namedParseExpressions aren't serialized and deserialized + namedParseExpressions = List.of(); + input = (PhysicalPlan) in.readObject(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(projectList); + out.writeObject(((SerializablePlan) input).getPlanForSerialization()); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java index 51d2850df7..45884830e1 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/ValuesOperator.java @@ -15,6 +15,7 @@ import lombok.ToString; import org.opensearch.sql.data.model.ExprCollectionValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; /** @@ -55,10 +56,17 @@ public boolean hasNext() { return valuesIterator.hasNext(); } + @Override + public long getTotalHits() { + // ValuesOperator used for queries without `FROM` clause, e.g. `select 1`. + // Such query always returns 1 row. + return 1; + } + @Override public ExprValue next() { List values = valuesIterator.next().stream() - .map(expr -> expr.valueOf()) + .map(Expression::valueOf) .collect(Collectors.toList()); return new ExprCollectionValue(values); } diff --git a/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java b/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java index 246a50ea09..ffcc0911de 100644 --- a/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java +++ b/core/src/main/java/org/opensearch/sql/storage/StorageEngine.java @@ -29,5 +29,4 @@ public interface StorageEngine { default Collection getFunctions() { return Collections.emptyList(); } - } diff --git a/core/src/main/java/org/opensearch/sql/storage/Table.java b/core/src/main/java/org/opensearch/sql/storage/Table.java index e2586ed22c..0194f1d03e 100644 --- a/core/src/main/java/org/opensearch/sql/storage/Table.java +++ b/core/src/main/java/org/opensearch/sql/storage/Table.java @@ -99,4 +99,9 @@ default TableWriteBuilder createWriteBuilder(LogicalWrite plan) { default StreamingSource asStreamingSource() { throw new UnsupportedOperationException(); } + + default TableScanBuilder createPagedScanBuilder(int pageSize) { + var error = String.format("'%s' does not support pagination", getClass().toString()); + throw new UnsupportedOperationException(error); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 5a2b37c017..20927f262c 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -77,6 +77,7 @@ import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; +import org.opensearch.sql.ast.tree.Paginate; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -91,6 +92,7 @@ import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalMLCommons; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalProject; @@ -1632,4 +1634,11 @@ public void ml_relation_predict_rcf_without_time_field() { assertTrue(((LogicalProject) actual).getProjectList() .contains(DSL.named(RCF_ANOMALOUS, DSL.ref(RCF_ANOMALOUS, BOOLEAN)))); } + + @Test + public void visit_paginate() { + LogicalPlan actual = analyze(new Paginate(10, AstDSL.relation("dummy"))); + assertTrue(actual instanceof LogicalPaginate); + assertEquals(10, ((LogicalPaginate) actual).getPageSize()); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java b/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java index 4df38027f4..525de79afc 100644 --- a/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java @@ -15,11 +15,9 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.lenient; -import static org.mockito.Mockito.when; import java.util.Collections; import java.util.Optional; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -27,6 +25,7 @@ import org.opensearch.sql.analysis.Analyzer; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.PlanContext; import org.opensearch.sql.planner.Planner; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -134,7 +133,8 @@ Helper executeSuccess(Split split) { invocation -> { ResponseListener listener = invocation.getArgument(2); listener.onResponse( - new ExecutionEngine.QueryResponse(schema, Collections.emptyList())); + new ExecutionEngine.QueryResponse(schema, Collections.emptyList(), 0, + Cursor.None)); return null; }) .when(executionEngine) diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlanTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlanTest.java new file mode 100644 index 0000000000..3e08280acb --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/execution/ContinuePaginatedPlanTest.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.execution; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.DefaultExecutionEngine; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.QueryId; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.StorageEngine; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class ContinuePaginatedPlanTest { + + private static PlanSerializer planSerializer; + + private static QueryService queryService; + + /** + * Initialize the mocks. + */ + @BeforeAll + public static void setUp() { + var storageEngine = mock(StorageEngine.class); + planSerializer = new PlanSerializer(storageEngine); + queryService = new QueryService(null, new DefaultExecutionEngine(), null); + } + + @Test + public void can_execute_plan() { + var planSerializer = mock(PlanSerializer.class); + when(planSerializer.convertToPlan(anyString())).thenReturn(mock(PhysicalPlan.class)); + var listener = new ResponseListener() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + assertNotNull(response); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }; + var plan = new ContinuePaginatedPlan(QueryId.queryId(), "", + queryService, planSerializer, listener); + plan.execute(); + } + + @Test + public void can_handle_error_while_executing_plan() { + var listener = new ResponseListener() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + fail(); + } + + @Override + public void onFailure(Exception e) { + assertNotNull(e); + } + }; + var plan = new ContinuePaginatedPlan(QueryId.queryId(), "", queryService, + planSerializer, listener); + plan.execute(); + } + + @Test + public void explain_is_not_supported() { + var listener = mock(ResponseListener.class); + mock(ContinuePaginatedPlan.class, withSettings().defaultAnswer(CALLS_REAL_METHODS)) + .explain(listener); + verify(listener).onFailure(any(UnsupportedOperationException.class)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java index cc4bf070fb..6bdbf1c4c9 100644 --- a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanFactoryTest.java @@ -8,9 +8,11 @@ package org.opensearch.sql.executor.execution; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; import static org.opensearch.sql.executor.execution.QueryPlanFactory.NO_CONSUMER_RESPONSE_LISTENER; import java.util.Optional; @@ -24,8 +26,10 @@ import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.PlanSerializer; @ExtendWith(MockitoExtension.class) class QueryPlanFactoryTest { @@ -45,46 +49,60 @@ class QueryPlanFactoryTest { @Mock private ExecutionEngine.QueryResponse queryResponse; + @Mock + private PlanSerializer planSerializer; private QueryPlanFactory factory; @BeforeEach void init() { - factory = new QueryPlanFactory(queryService); + factory = new QueryPlanFactory(queryService, planSerializer); } @Test public void createFromQueryShouldSuccess() { - Statement query = new Query(plan); + Statement query = new Query(plan, 0); AbstractPlan queryExecution = - factory.create(query, Optional.of(queryListener), Optional.empty()); + factory.createContinuePaginatedPlan(query, Optional.of(queryListener), Optional.empty()); assertTrue(queryExecution instanceof QueryPlan); } @Test public void createFromExplainShouldSuccess() { - Statement query = new Explain(new Query(plan)); + Statement query = new Explain(new Query(plan, 0)); AbstractPlan queryExecution = - factory.create(query, Optional.empty(), Optional.of(explainListener)); + factory.createContinuePaginatedPlan(query, Optional.empty(), Optional.of(explainListener)); assertTrue(queryExecution instanceof ExplainPlan); } + @Test + public void createFromCursorShouldSuccess() { + AbstractPlan queryExecution = factory.createContinuePaginatedPlan("", false, + queryListener, explainListener); + AbstractPlan explainExecution = factory.createContinuePaginatedPlan("", true, + queryListener, explainListener); + assertAll( + () -> assertTrue(queryExecution instanceof ContinuePaginatedPlan), + () -> assertTrue(explainExecution instanceof ExplainPlan) + ); + } + @Test public void createFromQueryWithoutQueryListenerShouldThrowException() { - Statement query = new Query(plan); + Statement query = new Query(plan, 0); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> factory.create(query, - Optional.empty(), Optional.empty())); + assertThrows(IllegalArgumentException.class, () -> factory.createContinuePaginatedPlan( + query, Optional.empty(), Optional.empty())); assertEquals("[BUG] query listener must be not null", exception.getMessage()); } @Test public void createFromExplainWithoutExplainListenerShouldThrowException() { - Statement query = new Explain(new Query(plan)); + Statement query = new Explain(new Query(plan, 0)); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> factory.create(query, - Optional.empty(), Optional.empty())); + assertThrows(IllegalArgumentException.class, () -> factory.createContinuePaginatedPlan( + query, Optional.empty(), Optional.empty())); assertEquals("[BUG] explain listener must be not null", exception.getMessage()); } @@ -104,4 +122,24 @@ public void noConsumerResponseChannel() { assertEquals( "[BUG] exception response should not sent to unexpected channel", exception.getMessage()); } + + @Test + public void createQueryWithFetchSizeWhichCanBePaged() { + when(planSerializer.canConvertToCursor(plan)).thenReturn(true); + factory = new QueryPlanFactory(queryService, planSerializer); + Statement query = new Query(plan, 10); + AbstractPlan queryExecution = + factory.createContinuePaginatedPlan(query, Optional.of(queryListener), Optional.empty()); + assertTrue(queryExecution instanceof QueryPlan); + } + + @Test + public void createQueryWithFetchSizeWhichCannotBePaged() { + when(planSerializer.canConvertToCursor(plan)).thenReturn(false); + factory = new QueryPlanFactory(queryService, planSerializer); + Statement query = new Query(plan, 10); + assertThrows(UnsupportedCursorRequestException.class, + () -> factory.createContinuePaginatedPlan(query, + Optional.of(queryListener), Optional.empty())); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java index 834db76996..a0a98e2be7 100644 --- a/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/execution/QueryPlanTest.java @@ -8,21 +8,30 @@ package org.opensearch.sql.executor.execution; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import org.apache.commons.lang3.NotImplementedException; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.executor.DefaultExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryId; import org.opensearch.sql.executor.QueryService; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class QueryPlanTest { @Mock @@ -41,7 +50,7 @@ class QueryPlanTest { private ResponseListener queryListener; @Test - public void execute() { + public void execute_no_page_size() { QueryPlan query = new QueryPlan(queryId, plan, queryService, queryListener); query.execute(); @@ -49,10 +58,62 @@ public void execute() { } @Test - public void explain() { + public void explain_no_page_size() { QueryPlan query = new QueryPlan(queryId, plan, queryService, queryListener); query.explain(explainListener); verify(queryService, times(1)).explain(plan, explainListener); } + + @Test + public void can_execute_paginated_plan() { + var listener = new ResponseListener() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + assertNotNull(response); + } + + @Override + public void onFailure(Exception e) { + fail(); + } + }; + var plan = new QueryPlan(QueryId.queryId(), mock(UnresolvedPlan.class), 10, + queryService, listener); + plan.execute(); + } + + @Test + // Same as previous test, but with incomplete QueryService + public void can_handle_error_while_executing_plan() { + var listener = new ResponseListener() { + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + fail(); + } + + @Override + public void onFailure(Exception e) { + assertNotNull(e); + } + }; + var plan = new QueryPlan(QueryId.queryId(), mock(UnresolvedPlan.class), 10, + new QueryService(null, new DefaultExecutionEngine(), null), listener); + plan.execute(); + } + + @Test + public void explain_is_not_supported_for_pagination() { + new QueryPlan(null, null, 0, null, null).explain(new ResponseListener<>() { + @Override + public void onResponse(ExecutionEngine.ExplainResponse response) { + fail(); + } + + @Override + public void onFailure(Exception e) { + assertTrue(e instanceof NotImplementedException); + } + }); + } } diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java new file mode 100644 index 0000000000..02a0dbc05e --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/CanPaginateVisitorTest.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import java.util.List; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.executor.pagination.CanPaginateVisitor; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class CanPaginateVisitorTest { + + static final CanPaginateVisitor visitor = new CanPaginateVisitor(); + + @Test + // select * from y + public void accept_query_with_select_star_and_from() { + var plan = AstDSL.project(AstDSL.relation("dummy"), AstDSL.allFields()); + assertTrue(plan.accept(visitor, null)); + } + + @Test + // select x from y + public void reject_query_with_select_field_and_from() { + var plan = AstDSL.project(AstDSL.relation("dummy"), AstDSL.field("pewpew")); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select x,z from y + public void reject_query_with_select_fields_and_from() { + var plan = AstDSL.project(AstDSL.relation("dummy"), + AstDSL.field("pewpew"), AstDSL.field("pewpew")); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select x + public void reject_query_without_from() { + var plan = AstDSL.project(AstDSL.values(List.of(AstDSL.intLiteral(1))), + AstDSL.alias("1",AstDSL.intLiteral(1))); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y limit z + public void reject_query_with_limit() { + var plan = AstDSL.project(AstDSL.limit(AstDSL.relation("dummy"), 1, 2), AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y where z + public void reject_query_with_where() { + var plan = AstDSL.project(AstDSL.filter(AstDSL.relation("dummy"), + AstDSL.booleanLiteral(true)), AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y order by z + public void reject_query_with_order_by() { + var plan = AstDSL.project(AstDSL.sort(AstDSL.relation("dummy"), AstDSL.field("1")), + AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y group by z + public void reject_query_with_group_by() { + var plan = AstDSL.project(AstDSL.agg( + AstDSL.relation("dummy"), List.of(), List.of(), List.of(AstDSL.field("1")), List.of()), + AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select agg(x) from y + public void reject_query_with_aggregation_function() { + var plan = AstDSL.project(AstDSL.agg( + AstDSL.relation("dummy"), + List.of(AstDSL.alias("agg", AstDSL.aggregate("func", AstDSL.field("pewpew")))), + List.of(), List.of(), List.of()), + AstDSL.allFields()); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select window(x) from y + public void reject_query_with_window_function() { + var plan = AstDSL.project(AstDSL.relation("dummy"), + AstDSL.alias("pewpew", + AstDSL.window( + AstDSL.aggregate("func", AstDSL.field("pewpew")), + List.of(AstDSL.qualifiedName("1")), List.of()))); + assertFalse(plan.accept(visitor, null)); + } + + @Test + // select * from y, z + public void reject_query_with_select_from_multiple_indices() { + var plan = mock(Project.class); + when(plan.getChild()).thenReturn(List.of(AstDSL.relation("dummy"), AstDSL.relation("pummy"))); + when(plan.getProjectList()).thenReturn(List.of(AstDSL.allFields())); + assertFalse(visitor.visitProject(plan, null)); + } + + @Test + // unreal case, added for coverage only + public void reject_project_when_relation_has_child() { + var relation = mock(Relation.class, withSettings().useConstructor(AstDSL.qualifiedName("42"))); + when(relation.getChild()).thenReturn(List.of(AstDSL.relation("pewpew"))); + when(relation.accept(visitor, null)).thenCallRealMethod(); + var plan = mock(Project.class); + when(plan.getChild()).thenReturn(List.of(relation)); + when(plan.getProjectList()).thenReturn(List.of(AstDSL.allFields())); + assertFalse(visitor.visitProject((Project) plan, null)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/CursorTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/CursorTest.java new file mode 100644 index 0000000000..e3e2c8cf33 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/CursorTest.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.executor.pagination.Cursor; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class CursorTest { + + @Test + void empty_array_is_none() { + Assertions.assertEquals(Cursor.None, new Cursor(null)); + } + + @Test + void toString_is_array_value() { + String cursorTxt = "This is a test"; + Assertions.assertEquals(cursorTxt, new Cursor(cursorTxt).toString()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java b/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java new file mode 100644 index 0000000000..b1e97920c8 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/executor/pagination/PlanSerializerTest.java @@ -0,0 +1,256 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.executor.pagination; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.List; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; +import org.opensearch.sql.storage.StorageEngine; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class PlanSerializerTest { + + StorageEngine storageEngine; + + PlanSerializer planCache; + + @BeforeEach + void setUp() { + storageEngine = mock(StorageEngine.class); + planCache = new PlanSerializer(storageEngine); + } + + @Test + void canConvertToCursor_relation() { + assertTrue(planCache.canConvertToCursor(AstDSL.relation("Table"))); + } + + @Test + void canConvertToCursor_project_allFields_relation() { + var unresolvedPlan = AstDSL.project(AstDSL.relation("table"), AstDSL.allFields()); + assertTrue(planCache.canConvertToCursor(unresolvedPlan)); + } + + @Test + void canConvertToCursor_project_some_fields_relation() { + var unresolvedPlan = AstDSL.project(AstDSL.relation("table"), AstDSL.field("rando")); + Assertions.assertFalse(planCache.canConvertToCursor(unresolvedPlan)); + } + + @ParameterizedTest + @ValueSource(strings = {"pewpew", "asdkfhashdfjkgakgfwuigfaijkb", "ajdhfgajklghadfjkhgjkadhgad" + + "kadfhgadhjgfjklahdgqheygvskjfbvgsdklgfuirehiluANUIfgauighbahfuasdlhfnhaughsdlfhaughaggf" + + "and_some_other_funny_stuff_which_could_be_generated_while_sleeping_on_the_keyboard"}) + void serialize_deserialize_str(String input) { + var compressed = serialize(input); + assertEquals(input, deserialize(compressed)); + if (input.length() > 200) { + // Compression of short strings isn't profitable, because encoding into string and gzip + // headers add more bytes than input string has. + assertTrue(compressed.length() < input.length()); + } + } + + public static class SerializableTestClass implements Serializable { + public int field; + + @Override + public boolean equals(Object obj) { + return field == ((SerializableTestClass) obj).field; + } + } + + // Can't serialize private classes because they are not accessible + private class NotSerializableTestClass implements Serializable { + public int field; + + @Override + public boolean equals(Object obj) { + return field == ((SerializableTestClass) obj).field; + } + } + + @Test + void serialize_deserialize_obj() { + var obj = new SerializableTestClass(); + obj.field = 42; + assertEquals(obj, deserialize(serialize(obj))); + assertNotSame(obj, deserialize(serialize(obj))); + } + + @Test + void serialize_throws() { + assertThrows(Throwable.class, () -> serialize(new NotSerializableTestClass())); + var testObj = new TestOperator(); + testObj.throwIoOnWrite = true; + assertThrows(Throwable.class, () -> serialize(testObj)); + } + + @Test + void deserialize_throws() { + assertAll( + // from gzip - damaged header + () -> assertThrows(Throwable.class, () -> deserialize("00")), + // from HashCode::fromString + () -> assertThrows(Throwable.class, () -> deserialize("000")) + ); + } + + @Test + @SneakyThrows + void convertToCursor_returns_no_cursor_if_cant_serialize() { + var plan = new TestOperator(42); + plan.throwNoCursorOnWrite = true; + assertAll( + () -> assertThrows(NoCursorException.class, () -> serialize(plan)), + () -> assertEquals(Cursor.None, planCache.convertToCursor(plan)) + ); + } + + @Test + @SneakyThrows + void convertToCursor_returns_no_cursor_if_plan_is_not_paginate() { + var plan = mock(PhysicalPlan.class); + assertEquals(Cursor.None, planCache.convertToCursor(plan)); + } + + @Test + void convertToPlan_throws_cursor_has_no_prefix() { + assertThrows(UnsupportedOperationException.class, () -> + planCache.convertToPlan("abc")); + } + + @Test + void convertToPlan_throws_if_failed_to_deserialize() { + assertThrows(UnsupportedOperationException.class, () -> + planCache.convertToPlan("n:" + serialize(mock(Serializable.class)))); + } + + @Test + @SneakyThrows + void serialize_and_deserialize() { + var plan = new TestOperator(42); + var roundTripPlan = planCache.deserialize(planCache.serialize(plan)); + assertEquals(roundTripPlan, plan); + assertNotSame(roundTripPlan, plan); + } + + @Test + void convertToCursor_and_convertToPlan() { + var plan = new TestOperator(100500); + var roundTripPlan = (SerializablePlan) + planCache.convertToPlan(planCache.convertToCursor(plan).toString()); + assertEquals(plan, roundTripPlan); + assertNotSame(plan, roundTripPlan); + } + + @Test + @SneakyThrows + void resolveObject() { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject("Hello, world!"); + objectOutput.flush(); + + var cds = planCache.getCursorDeserializationStream( + new ByteArrayInputStream(output.toByteArray())); + assertEquals(storageEngine, cds.resolveObject("engine")); + var object = new Object(); + assertSame(object, cds.resolveObject(object)); + } + + // Helpers and auxiliary classes section below + + public static class TestOperator extends PhysicalPlan implements SerializablePlan { + private int field; + private boolean throwNoCursorOnWrite = false; + private boolean throwIoOnWrite = false; + + public TestOperator() { + } + + public TestOperator(int value) { + field = value; + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + field = in.readInt(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + if (throwNoCursorOnWrite) { + throw new NoCursorException(); + } + if (throwIoOnWrite) { + throw new IOException(); + } + out.writeInt(field); + } + + @Override + public boolean equals(Object o) { + return field == ((TestOperator) o).field; + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return null; + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public List getChild() { + return null; + } + } + + @SneakyThrows + private String serialize(Serializable input) { + return new PlanSerializer(null).serialize(input); + } + + private Serializable deserialize(String input) { + return new PlanSerializer(null).deserialize(input); + } +} diff --git a/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java b/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java index 1a2b6e3f2a..ceb53b756a 100644 --- a/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.PlanContext; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.storage.split.Split; @@ -169,7 +170,8 @@ Helper executeSuccess(Long... offsets) { ResponseListener listener = invocation.getArgument(2); listener.onResponse( - new ExecutionEngine.QueryResponse(null, Collections.emptyList())); + new ExecutionEngine.QueryResponse(null, Collections.emptyList(), 0, + Cursor.None)); PlanContext planContext = invocation.getArgument(1); assertTrue(planContext.getSplit().isPresent()); diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index a717c4ed8f..bf1464f5f6 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -35,6 +35,8 @@ import java.util.Set; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -52,6 +54,7 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.expression.window.ranking.RowNumberFunction; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalRelation; @@ -64,24 +67,16 @@ import org.opensearch.sql.storage.write.TableWriteOperator; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class DefaultImplementorTest { - @Mock - private Expression filter; - - @Mock - private NamedAggregator aggregator; - - @Mock - private NamedExpression groupBy; - @Mock private Table table; private final DefaultImplementor implementor = new DefaultImplementor<>(); @Test - public void visitShouldReturnDefaultPhysicalOperator() { + public void visit_should_return_default_physical_operator() { String indexName = "test"; NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); @@ -181,14 +176,14 @@ public void visitShouldReturnDefaultPhysicalOperator() { } @Test - public void visitRelationShouldThrowException() { + public void visitRelation_should_throw_an_exception() { assertThrows(UnsupportedOperationException.class, () -> new LogicalRelation("test", table).accept(implementor, null)); } @SuppressWarnings({"rawtypes", "unchecked"}) @Test - public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { + public void visitWindowOperator_should_return_PhysicalWindowOperator() { NamedExpression windowFunction = named(new RowNumberFunction()); WindowDefinition windowDefinition = new WindowDefinition( Collections.singletonList(ref("state", STRING)), @@ -228,7 +223,7 @@ public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { } @Test - public void visitTableScanBuilderShouldBuildTableScanOperator() { + public void visitTableScanBuilder_should_build_TableScanOperator() { TableScanOperator tableScanOperator = Mockito.mock(TableScanOperator.class); TableScanBuilder tableScanBuilder = new TableScanBuilder() { @Override @@ -240,7 +235,7 @@ public TableScanOperator build() { } @Test - public void visitTableWriteBuilderShouldBuildTableWriteOperator() { + public void visitTableWriteBuilder_should_build_TableWriteOperator() { LogicalPlan child = values(); TableWriteOperator tableWriteOperator = Mockito.mock(TableWriteOperator.class); TableWriteBuilder logicalPlan = new TableWriteBuilder(child) { diff --git a/core/src/test/java/org/opensearch/sql/planner/SerializablePlanTest.java b/core/src/test/java/org/opensearch/sql/planner/SerializablePlanTest.java new file mode 100644 index 0000000000..8073445dc0 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/SerializablePlanTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Answers.CALLS_REAL_METHODS; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class SerializablePlanTest { + @Mock(answer = CALLS_REAL_METHODS) + SerializablePlan plan; + + @Test + void getPlanForSerialization_defaults_to_self() { + assertSame(plan, plan.getPlanForSerialization()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index fe76589066..34e0e39d87 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -8,23 +8,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.named; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.Literal; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValueUtils; @@ -45,20 +46,24 @@ /** * Todo. Temporary added for UT coverage, Will be removed. */ -@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class LogicalPlanNodeVisitorTest { - @Mock - Expression expression; - @Mock - ReferenceExpression ref; - @Mock - Aggregator aggregator; - @Mock - Table table; + static Expression expression; + static ReferenceExpression ref; + static Aggregator aggregator; + static Table table; + + @BeforeAll + private static void initMocks() { + expression = mock(Expression.class); + ref = mock(ReferenceExpression.class); + aggregator = mock(Aggregator.class); + table = mock(Table.class); + } @Test - public void logicalPlanShouldTraversable() { + public void logical_plan_should_be_traversable() { LogicalPlan logicalPlan = LogicalPlanDSL.rename( LogicalPlanDSL.aggregation( @@ -75,85 +80,42 @@ public void logicalPlanShouldTraversable() { assertEquals(5, result); } - @Test - public void testAbstractPlanNodeVisitorShouldReturnNull() { + @SuppressWarnings("unchecked") + private static Stream getLogicalPlansForVisitorTest() { LogicalPlan relation = LogicalPlanDSL.relation("schema", table); - assertNull(relation.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan tableScanBuilder = new TableScanBuilder() { @Override public TableScanOperator build() { return null; } }; - assertNull(tableScanBuilder.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan write = LogicalPlanDSL.write(null, table, Collections.emptyList()); - assertNull(write.accept(new LogicalPlanNodeVisitor() { - }, null)); - TableWriteBuilder tableWriteBuilder = new TableWriteBuilder(null) { @Override public TableWriteOperator build(PhysicalPlan child) { return null; } }; - assertNull(tableWriteBuilder.accept(new LogicalPlanNodeVisitor() { - }, null)); - + LogicalPlan write = LogicalPlanDSL.write(null, table, Collections.emptyList()); LogicalPlan filter = LogicalPlanDSL.filter(relation, expression); - assertNull(filter.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan aggregation = - LogicalPlanDSL.aggregation( - filter, ImmutableList.of(DSL.named("avg", aggregator)), ImmutableList.of(DSL.named( - "group", expression))); - assertNull(aggregation.accept(new LogicalPlanNodeVisitor() { - }, null)); - + LogicalPlan aggregation = LogicalPlanDSL.aggregation( + filter, ImmutableList.of(DSL.named("avg", aggregator)), ImmutableList.of(DSL.named( + "group", expression))); LogicalPlan rename = LogicalPlanDSL.rename(aggregation, ImmutableMap.of(ref, ref)); - assertNull(rename.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan project = LogicalPlanDSL.project(relation, named("ref", ref)); - assertNull(project.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan remove = LogicalPlanDSL.remove(relation, ref); - assertNull(remove.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan eval = LogicalPlanDSL.eval(relation, Pair.of(ref, expression)); - assertNull(eval.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan sort = LogicalPlanDSL.sort(relation, - Pair.of(SortOption.DEFAULT_ASC, expression)); - assertNull(sort.accept(new LogicalPlanNodeVisitor() { - }, null)); - + LogicalPlan sort = LogicalPlanDSL.sort(relation, Pair.of(SortOption.DEFAULT_ASC, expression)); LogicalPlan dedup = LogicalPlanDSL.dedupe(relation, 1, false, false, expression); - assertNull(dedup.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan window = LogicalPlanDSL.window(relation, named(expression), new WindowDefinition( ImmutableList.of(ref), ImmutableList.of(Pair.of(SortOption.DEFAULT_ASC, expression)))); - assertNull(window.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan rareTopN = LogicalPlanDSL.rareTopN( relation, CommandType.TOP, ImmutableList.of(expression), expression); - assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { - }, null)); - - Map args = new HashMap<>(); LogicalPlan highlight = new LogicalHighlight(filter, - new LiteralExpression(ExprValueUtils.stringValue("fieldA")), args); - assertNull(highlight.accept(new LogicalPlanNodeVisitor() { - }, null)); + new LiteralExpression(ExprValueUtils.stringValue("fieldA")), Map.of()); + LogicalPlan mlCommons = new LogicalMLCommons(relation, "kmeans", Map.of()); + LogicalPlan ad = new LogicalAD(relation, Map.of()); + LogicalPlan ml = new LogicalML(relation, Map.of()); + LogicalPlan paginate = new LogicalPaginate(42, List.of(relation)); List> nestedArgs = List.of( Map.of( @@ -167,42 +129,21 @@ public TableWriteOperator build(PhysicalPlan child) { ); LogicalNested nested = new LogicalNested(null, nestedArgs, projectList); - assertNull(nested.accept(new LogicalPlanNodeVisitor() { - }, null)); - LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema", table), - "kmeans", - ImmutableMap.builder() - .put("centroids", new Literal(3, DataType.INTEGER)) - .put("iterations", new Literal(3, DataType.DOUBLE)) - .put("distance_type", new Literal(null, DataType.STRING)) - .build()); - assertNull(mlCommons.accept(new LogicalPlanNodeVisitor() { - }, null)); - - LogicalPlan ad = new LogicalAD(LogicalPlanDSL.relation("schema", table), - new HashMap() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); - } - }); - assertNull(ad.accept(new LogicalPlanNodeVisitor() { - }, null)); + return Stream.of( + relation, tableScanBuilder, write, tableWriteBuilder, filter, aggregation, rename, project, + remove, eval, sort, dedup, window, rareTopN, highlight, mlCommons, ad, ml, paginate, nested + ).map(Arguments::of); + } - LogicalPlan ml = new LogicalML(LogicalPlanDSL.relation("schema", table), - new HashMap() {{ - put("action", new Literal("train", DataType.STRING)); - put("algorithm", new Literal("rcf", DataType.STRING)); - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal(null, DataType.STRING)); - } - }); - assertNull(ml.accept(new LogicalPlanNodeVisitor() { + @ParameterizedTest + @MethodSource("getLogicalPlansForVisitorTest") + public void abstract_plan_node_visitor_should_return_null(LogicalPlan plan) { + assertNull(plan.accept(new LogicalPlanNodeVisitor() { }, null)); } + private static class NodesCount extends LogicalPlanNodeVisitor { @Override public Integer visitRelation(LogicalRelation plan, Object context) { @@ -213,32 +154,28 @@ public Integer visitRelation(LogicalRelation plan, Object context) { public Integer visitFilter(LogicalFilter plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } @Override public Integer visitAggregation(LogicalAggregation plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } @Override public Integer visitRename(LogicalRename plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } @Override public Integer visitRareTopN(LogicalRareTopN plan, Object context) { return 1 + plan.getChild().stream() - .map(child -> child.accept(this, context)) - .collect(Collectors.summingInt(Integer::intValue)); + .map(child -> child.accept(this, context)).mapToInt(Integer::intValue).sum(); } } } diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index d220f599f8..543b261d9e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -9,6 +9,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.longValue; @@ -20,6 +24,7 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; @@ -32,6 +37,8 @@ import java.util.Map; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -43,13 +50,18 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.optimizer.rule.CreatePagingTableScanBuilder; +import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; import org.opensearch.sql.storage.read.TableScanBuilder; import org.opensearch.sql.storage.write.TableWriteBuilder; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class LogicalPlanOptimizerTest { @Mock @@ -58,9 +70,13 @@ class LogicalPlanOptimizerTest { @Spy private TableScanBuilder tableScanBuilder; + @Spy + private TableScanBuilder pagedTableScanBuilder; + @BeforeEach void setUp() { - when(table.createScanBuilder()).thenReturn(tableScanBuilder); + lenient().when(table.createScanBuilder()).thenReturn(tableScanBuilder); + lenient().when(table.createPagedScanBuilder(anyInt())).thenReturn(pagedTableScanBuilder); } /** @@ -279,7 +295,6 @@ void table_scan_builder_support_nested_push_down_can_apply_its_rule() { @Test void table_not_support_scan_builder_should_not_be_impact() { - Mockito.reset(table, tableScanBuilder); Table table = new Table() { @Override public Map getFieldTypes() { @@ -300,7 +315,6 @@ public PhysicalPlan implement(LogicalPlan plan) { @Test void table_support_write_builder_should_be_replaced() { - Mockito.reset(table, tableScanBuilder); TableWriteBuilder writeBuilder = Mockito.mock(TableWriteBuilder.class); when(table.createWriteBuilder(any())).thenReturn(writeBuilder); @@ -312,7 +326,6 @@ void table_support_write_builder_should_be_replaced() { @Test void table_not_support_write_builder_should_report_error() { - Mockito.reset(table, tableScanBuilder); Table table = new Table() { @Override public Map getFieldTypes() { @@ -329,6 +342,68 @@ public PhysicalPlan implement(LogicalPlan plan) { () -> table.createWriteBuilder(null)); } + @Test + void paged_table_scan_builder_support_project_push_down_can_apply_its_rule() { + + var relation = relation("schema", table); + + assertEquals( + project(pagedTableScanBuilder), + LogicalPlanOptimizer.create().optimize(paginate(project(relation), 4))); + } + + + @Test + void push_page_size_noop_if_no_relation() { + var paginate = new LogicalPaginate(42, List.of(project(values()))); + assertEquals(paginate, LogicalPlanOptimizer.create().optimize(paginate)); + } + + @Test + void pagination_optimizer_simple_query() { + var projectPlan = project(relation("schema", table), DSL.named(DSL.ref("intV", INTEGER))); + + var optimizer = new LogicalPlanOptimizer( + List.of(new CreateTableScanBuilder(), new CreatePagingTableScanBuilder())); + + { + optimizer.optimize(projectPlan); + verify(table).createScanBuilder(); + verify(table, never()).createPagedScanBuilder(anyInt()); + } + } + + @Test + void pagination_optimizer_paged_query() { + var relation = new LogicalRelation("schema", table); + var projectPlan = project(relation, DSL.named(DSL.ref("intV", INTEGER))); + var pagedPlan = new LogicalPaginate(10, List.of(projectPlan)); + + var optimizer = new LogicalPlanOptimizer( + List.of(new CreateTableScanBuilder(), new CreatePagingTableScanBuilder())); + var optimized = optimizer.optimize(pagedPlan); + verify(table).createPagedScanBuilder(anyInt()); + } + + @Test + void push_page_size_noop_if_no_sub_plans() { + var paginate = new LogicalPaginate(42, List.of()); + assertEquals(paginate, + LogicalPlanOptimizer.create().optimize(paginate)); + } + + @Test + void table_scan_builder_support_offset_push_down_can_apply_its_rule() { + when(table.createPagedScanBuilder(anyInt())).thenReturn(pagedTableScanBuilder); + + var relation = new LogicalRelation("schema", table); + var optimized = LogicalPlanOptimizer.create() + .optimize(new LogicalPaginate(42, List.of(project(relation)))); + // `optimized` structure: LogicalPaginate -> LogicalProject -> TableScanBuilder + // LogicalRelation replaced by a TableScanBuilder instance + assertEquals(project(pagedTableScanBuilder), optimized); + } + private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); final LogicalPlan optimize = optimizer.optimize(plan); diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java index 9f90fd8d05..ef310e3b0e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java @@ -6,35 +6,39 @@ package org.opensearch.sql.planner.optimizer.pattern; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.util.Collections; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; -@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class PatternsTest { - @Mock - LogicalPlan plan; - @Test void source_is_empty() { + var plan = mock(LogicalPlan.class); when(plan.getChild()).thenReturn(Collections.emptyList()); - assertFalse(Patterns.source().getFunction().apply(plan).isPresent()); - assertFalse(Patterns.source(null).getProperty().getFunction().apply(plan).isPresent()); + assertAll( + () -> assertFalse(Patterns.source().getFunction().apply(plan).isPresent()), + () -> assertFalse(Patterns.source(null).getProperty().getFunction().apply(plan).isPresent()) + ); } @Test void table_is_empty() { - plan = Mockito.mock(LogicalFilter.class); - assertFalse(Patterns.table().getFunction().apply(plan).isPresent()); - assertFalse(Patterns.writeTable().getFunction().apply(plan).isPresent()); + var plan = mock(LogicalFilter.class); + assertAll( + () -> assertFalse(Patterns.table().getFunction().apply(plan).isPresent()), + () -> assertFalse(Patterns.writeTable().getFunction().apply(plan).isPresent()) + ); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilderTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilderTest.java new file mode 100644 index 0000000000..79c7b55c60 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/rule/CreatePagingTableScanBuilderTest.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule; + +import static com.facebook.presto.matching.DefaultMatcher.DEFAULT_MATCHER; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.paginate; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; + +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.storage.Table; + +@ExtendWith(MockitoExtension.class) +class CreatePagingTableScanBuilderTest { + + @Mock + LogicalPlan multiRelationPaginate; + + @Mock + Table table; + + @BeforeEach + public void setUp() { + when(multiRelationPaginate.getChild()) + .thenReturn( + List.of(relation("t1", table), relation("t2", table))); + } + + @Test + void throws_when_mutliple_children() { + final var pattern = new CreatePagingTableScanBuilder().pattern(); + final var plan = paginate(multiRelationPaginate, 42); + assertThrows(UnsupportedOperationException.class, + () -> DEFAULT_MATCHER.match(pattern, plan)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java index be8080ad3c..247cfe6a1d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java @@ -17,22 +17,30 @@ import com.google.common.collect.ImmutableMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.AdditionalAnswers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class FilterOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; @Test - public void filterTest() { + public void filter_test() { FilterOperator plan = new FilterOperator(new TestScan(), DSL.and(DSL.notequal(DSL.ref("response", INTEGER), DSL.literal(200)), DSL.notequal(DSL.ref("response", INTEGER), DSL.literal(500)))); @@ -42,10 +50,11 @@ public void filterTest() { .tupleValue(ImmutableMap .of("ip", "209.160.24.63", "action", "GET", "response", 404, "referer", "www.amazon.com")))); + assertEquals(1, plan.getTotalHits()); } @Test - public void nullValueShouldBeenIgnored() { + public void null_value_should_been_ignored() { LinkedHashMap value = new LinkedHashMap<>(); value.put("response", LITERAL_NULL); when(inputPlan.hasNext()).thenReturn(true, false); @@ -55,10 +64,11 @@ public void nullValueShouldBeenIgnored() { DSL.equal(DSL.ref("response", INTEGER), DSL.literal(404))); List result = execute(plan); assertEquals(0, result.size()); + assertEquals(0, plan.getTotalHits()); } @Test - public void missingValueShouldBeenIgnored() { + public void missing_value_should_been_ignored() { LinkedHashMap value = new LinkedHashMap<>(); value.put("response", LITERAL_MISSING); when(inputPlan.hasNext()).thenReturn(true, false); @@ -68,5 +78,21 @@ public void missingValueShouldBeenIgnored() { DSL.equal(DSL.ref("response", INTEGER), DSL.literal(404))); List result = execute(plan); assertEquals(0, result.size()); + assertEquals(0, plan.getTotalHits()); + } + + @Test + public void totalHits() { + when(inputPlan.hasNext()).thenReturn(true, true, true, true, true, false); + var answers = Stream.of(200, 240, 300, 403, 404).map(c -> + new ExprTupleValue(new LinkedHashMap<>(Map.of("response", new ExprIntegerValue(c))))) + .collect(Collectors.toList()); + when(inputPlan.next()).thenAnswer(AdditionalAnswers.returnsElementsOf(answers)); + + FilterOperator plan = new FilterOperator(inputPlan, + DSL.less(DSL.ref("response", INTEGER), DSL.literal(400))); + List result = execute(plan); + assertEquals(3, result.size()); + assertEquals(3, plan.getTotalHits()); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java index 5d8b893869..9024ae50c9 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/NestedOperatorTest.java @@ -7,6 +7,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.collectionValue; @@ -91,8 +92,10 @@ public void nested_one_nested_field() { Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); + assertThat( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)), + execute(nested), contains( tupleValue( new LinkedHashMap<>() {{ @@ -159,6 +162,7 @@ public void nested_one_nested_field() { ) ) ); + assertEquals(3, nested.getTotalHits()); } @Test @@ -176,8 +180,10 @@ public void nested_two_nested_field() { "field", new ReferenceExpression("comment.data", STRING), "path", new ReferenceExpression("comment", STRING)) ); + var nested = new NestedOperator(inputPlan, fields); + assertThat( - execute(new NestedOperator(inputPlan, fields)), + execute(nested), contains( tupleValue( new LinkedHashMap<>() {{ @@ -235,6 +241,7 @@ public void nested_two_nested_field() { ) ) ); + assertEquals(9, nested.getTotalHits()); } @Test @@ -252,8 +259,10 @@ public void nested_two_nested_fields_with_same_path() { "field", new ReferenceExpression("message.id", STRING), "path", new ReferenceExpression("message", STRING)) ); + var nested = new NestedOperator(inputPlan, fields); + assertThat( - execute(new NestedOperator(inputPlan, fields)), + execute(nested), contains( tupleValue( new LinkedHashMap<>() {{ @@ -275,6 +284,7 @@ public void nested_two_nested_fields_with_same_path() { ) ) ); + assertEquals(3, nested.getTotalHits()); } @Test @@ -286,12 +296,15 @@ public void non_nested_field_tests() { Set fields = Set.of("message"); Map> groupedFieldsByPath = Map.of("message", List.of("message.info")); + + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); assertThat( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)), + execute(nested), contains( tupleValue(new LinkedHashMap<>(Map.of("message", "val"))) ) ); + assertEquals(1, nested.getTotalHits()); } @Test @@ -302,12 +315,15 @@ public void nested_missing_tuple_field() { Set fields = Set.of("message.val"); Map> groupedFieldsByPath = Map.of("message", List.of("message.val")); + + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); assertThat( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)), + execute(nested), contains( tupleValue(new LinkedHashMap<>(Map.of("message.val", ExprNullValue.of()))) ) ); + assertEquals(1, nested.getTotalHits()); } @Test @@ -318,11 +334,12 @@ public void nested_missing_array_field() { Set fields = Set.of("missing.data"); Map> groupedFieldsByPath = Map.of("message", List.of("message.data")); - assertTrue( - execute(new NestedOperator(inputPlan, fields, groupedFieldsByPath)) - .get(0) - .tupleValue() - .size() == 0 - ); + + var nested = new NestedOperator(inputPlan, fields, groupedFieldsByPath); + assertEquals(0, execute(nested) + .get(0) + .tupleValue() + .size()); + assertEquals(1, nested.getTotalHits()); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java index 0a93c96bbb..2c67994d2e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java @@ -5,9 +5,19 @@ package org.opensearch.sql.planner.physical; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.List; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -16,6 +26,7 @@ import org.opensearch.sql.storage.split.Split; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class PhysicalPlanTest { @Mock Split split; @@ -46,8 +57,25 @@ public List getChild() { }; @Test - void addSplitToChildByDefault() { + void add_split_to_child_by_default() { testPlan.add(split); verify(child).add(split); } + + @Test + void get_total_hits_from_child() { + var plan = mock(PhysicalPlan.class); + when(child.getTotalHits()).thenReturn(42L); + when(plan.getChild()).thenReturn(List.of(child)); + when(plan.getTotalHits()).then(CALLS_REAL_METHODS); + assertEquals(42, plan.getTotalHits()); + verify(child).getTotalHits(); + } + + @Test + void get_total_hits_uses_default_value() { + var plan = mock(PhysicalPlan.class); + when(plan.getTotalHits()).then(CALLS_REAL_METHODS); + assertEquals(0, plan.getTotalHits()); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java index 24be5eb2b8..77fcb7a505 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/ProjectOperatorTest.java @@ -11,6 +11,7 @@ import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.iterableWithSize; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_MISSING; import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; @@ -20,7 +21,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectInputStream; +import java.io.ObjectOutput; +import java.io.ObjectOutputStream; import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.SneakyThrows; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -30,11 +40,12 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.planner.SerializablePlan; @ExtendWith(MockitoExtension.class) class ProjectOperatorTest extends PhysicalPlanTestBase { - @Mock + @Mock(serializable = true) private PhysicalPlan inputPlan; @Test @@ -206,4 +217,53 @@ public void project_parse_missing_will_fallback() { ExprValueUtils.tupleValue(ImmutableMap.of("action", "GET", "response", "200")), ExprValueUtils.tupleValue(ImmutableMap.of("action", "POST"))))); } + + @Test + @SneakyThrows + public void serializable() { + var projects = List.of(DSL.named("action", DSL.ref("action", STRING))); + var project = new ProjectOperator(new TestOperator(), projects, List.of()); + + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(project); + objectOutput.flush(); + + ObjectInputStream objectInput = new ObjectInputStream( + new ByteArrayInputStream(output.toByteArray())); + var roundTripPlan = (ProjectOperator) objectInput.readObject(); + assertEquals(project, roundTripPlan); + } + + @EqualsAndHashCode(callSuper = false) + public static class TestOperator extends PhysicalPlan implements SerializablePlan { + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return null; + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public ExprValue next() { + return null; + } + + @Override + public List getChild() { + return null; + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + } + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java index bf046bf0a6..ec950e6016 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/RemoveOperatorTest.java @@ -113,12 +113,11 @@ public void remove_nothing_with_none_tuple_value() { @Test public void invalid_to_retrieve_schema_from_remove() { - PhysicalPlan plan = remove(inputPlan, DSL.ref("response", STRING), DSL.ref("referer", STRING)); + PhysicalPlan plan = remove(inputPlan); IllegalStateException exception = assertThrows(IllegalStateException.class, () -> plan.schema()); assertEquals( - "[BUG] schema can been only applied to ProjectOperator, " - + "instead of RemoveOperator(input=inputPlan, removeList=[response, referer])", + "[BUG] schema can been only applied to ProjectOperator, instead of RemoveOperator", exception.getMessage()); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/ValuesOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/ValuesOperatorTest.java index 9acab03d2b..bf6d28a23c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/ValuesOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/ValuesOperatorTest.java @@ -9,6 +9,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.opensearch.sql.data.model.ExprValueUtils.collectionValue; import static org.opensearch.sql.expression.DSL.literal; @@ -44,6 +45,7 @@ public void iterateSingleRow() { results, contains(collectionValue(Arrays.asList(1, "abc"))) ); + assertThat(values.getTotalHits(), equalTo(1L)); } } diff --git a/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java b/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java index 0e969c6dac..67014b76bd 100644 --- a/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java +++ b/core/src/test/java/org/opensearch/sql/storage/StorageEngineTest.java @@ -13,11 +13,9 @@ public class StorageEngineTest { - @Test void testFunctionsMethod() { StorageEngine k = (dataSourceSchemaName, tableName) -> null; Assertions.assertEquals(Collections.emptyList(), k.getFunctions()); } - } diff --git a/core/src/test/java/org/opensearch/sql/storage/TableTest.java b/core/src/test/java/org/opensearch/sql/storage/TableTest.java new file mode 100644 index 0000000000..a96ee71af0 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/storage/TableTest.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.storage; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.withSettings; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.mockito.invocation.InvocationOnMock; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class TableTest { + + @Test + public void createPagedScanBuilder_throws() { + var table = mock(Table.class, withSettings().defaultAnswer(InvocationOnMock::callRealMethod)); + assertThrows(Throwable.class, () -> table.createPagedScanBuilder(4)); + } +} diff --git a/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java b/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java index e4f9a185a3..3849d686a6 100644 --- a/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java +++ b/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java @@ -9,6 +9,7 @@ import java.util.List; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.planner.physical.PhysicalPlan; /** @@ -32,7 +33,8 @@ public void execute( while (plan.hasNext()) { result.add(plan.next()); } - QueryResponse response = new QueryResponse(new Schema(new ArrayList<>()), new ArrayList<>()); + QueryResponse response = new QueryResponse(new Schema(new ArrayList<>()), new ArrayList<>(), + 0, Cursor.None); listener.onResponse(response); } catch (Exception e) { listener.onFailure(e); diff --git a/docs/dev/Pagination-v2.md b/docs/dev/Pagination-v2.md new file mode 100644 index 0000000000..6e2f3f36d8 --- /dev/null +++ b/docs/dev/Pagination-v2.md @@ -0,0 +1,287 @@ +# Pagination in v2 Engine + +Pagination allows a SQL plugin client to retrieve arbitrarily large results sets one subset at a time. + +A cursor is a SQL abstraction for pagination. A client can open a cursor, retrieve a subset of data given a cursor and close a cursor. + +Currently, SQL plugin does not provide SQL cursor syntax. However, the SQL REST endpoint can return result a page at a time. This feature is used by JDBC and ODBC drivers. + + +# Scope +Currenty, V2 engine supports pagination only for simple `SELECT * FROM ` queries without any other clauses like `WHERE` or `ORDER BY`. + +# Demo +https://user-images.githubusercontent.com/88679692/224208630-8d38d833-abf8-4035-8d15-d5fb4382deca.mp4 + +# REST API +## Initial Query Request +```json +POST /_plugins/_sql +{ + "query" : "...", + "fetch_size": N +} +``` + +Response: +```json +{ + "cursor": /* cursor_id */, + "datarows": [ + // ... + ], + "schema" : [ + // ... + ] +} +``` +`query` is a DQL statement. `fetch_size` is a positive integer, indicating number of rows to return in each page. + +If `query` is a DML statement then pagination does not apply, the `fetch_size` parameter is ignored and a cursor is not created. This is existing behaviour in v1 engine. + +The client receives an (error response](#error-response) if: +- `fetch_size` is not a positive integer, or +- evaluating `query` results in a server-side error. + +## Next Page Request +```json +POST /_plugins/_sql +{ + "cursor": "" +} +``` +Similarly to v1 engine, the response object is the same as initial response if this is not the last page. + +`cursor_id` will be different with each request. + +If this is the last page, the `cursor` property is ommitted. The cursor is closed automatically. + +The client will receive an [error response](#error-response) if executing this request results in an OpenSearch or SQL plug-in error. + +## Cursor Keep Alive Timeout +Each cursor has a keep alive timer associated with it. When the timer runs out, the cursor is closed by OpenSearch. + +This timer is reset every time a page is retrieved. + +The client will receive an [error response](#error-response) if it sends a cursor request for an expired cursor. + +## Error Response +The client will receive an error response if any of the above REST calls result in an server-side error. + +The response object has the following format: +```json +{ + "error": { + "details": , + "reason": , + "type": + }, + "status": +} +``` + +`details`, `reason`, and `type` properties are string values. The exact values will depend on the error state encountered. +`status` is an HTTP status code + +## OpenSearch Data Retrieval Strategy + +OpenSearch provides several data retrival APIs that are optimized for different use cases. + +At this time, SQL plugin uses simple search API and scroll API. + +Simple retrieval API returns at most `max_result_window` number of documents. `max_result_window` is an index setting. + +Scroll API requests returns all documents but can incur high memory costs on OpenSearch coordination node. + +Efficient implementation of pagination needs to be aware of retrival API used. Each retrieval strategy will be considered separately. + +The discussion below uses *under max_result_window* to refer to scenarios that can be implemented with simple retrieval API and *over max_result_window* for scenarios that require scroll API to implement. + +## SQL Node Load Balancing +V2 SQL engine supports *sql node load balancing* -- a cursor request can be routed to any SQL node in a cluster. This is achieved by encoding all data necessary to retrieve the next page in the `cursor_id`. + +## Design Diagrams +New code workflows are highlighted. + +### First page +```mermaid +sequenceDiagram + participant SQLService + participant QueryPlanFactory + participant CanPaginateVisitor + participant QueryService + participant Planner + participant CreatePagingTableScanBuilder + participant OpenSearchExecutionEngine + participant PlanSerializer + participant Physical Plan Tree + +SQLService->>+QueryPlanFactory: execute + critical + QueryPlanFactory->>+CanPaginateVisitor: canConvertToCursor + CanPaginateVisitor-->>-QueryPlanFactory: true + end + QueryPlanFactory->>+QueryService: execute + QueryService->>+Planner: optimize + critical + Planner->>+CreatePagingTableScanBuilder: apply + CreatePagingTableScanBuilder-->>-Planner: paged index scan + end + Planner-->>-QueryService: Logical Plan Tree + QueryService->>+OpenSearchExecutionEngine: execute + Note over OpenSearchExecutionEngine: iterate result set + critical Serialization + OpenSearchExecutionEngine->>+PlanSerializer: convertToCursor + PlanSerializer-->>-OpenSearchExecutionEngine: cursor + end + critical + OpenSearchExecutionEngine->>+Physical Plan Tree: getTotalHits + Physical Plan Tree-->>-OpenSearchExecutionEngine: total hits + end + OpenSearchExecutionEngine-->>-QueryService: execution completed + QueryService-->>-QueryPlanFactory: execution completed + QueryPlanFactory-->>-SQLService: execution completed +``` + +### Second page +```mermaid +sequenceDiagram + participant SQLService + participant QueryPlanFactory + participant QueryService + participant OpenSearchExecutionEngine + participant PlanSerializer + participant Physical Plan Tree + +SQLService->>+QueryPlanFactory: execute + QueryPlanFactory->>+QueryService: execute + critical Deserialization + QueryService->>+PlanSerializer: convertToPlan + PlanSerializer-->>-QueryService: Physical plan tree + end + Note over QueryService: Planner, Optimizer and Implementor
are skipped + QueryService->>+OpenSearchExecutionEngine: execute + Note over OpenSearchExecutionEngine: iterate result set + critical Serialization + OpenSearchExecutionEngine->>+PlanSerializer: convertToCursor + PlanSerializer-->>-OpenSearchExecutionEngine: cursor + end + critical + OpenSearchExecutionEngine->>+Physical Plan Tree: getTotalHits + Physical Plan Tree-->>-OpenSearchExecutionEngine: total hits + end + OpenSearchExecutionEngine-->>-QueryService: execution completed + QueryService-->>-QueryPlanFactory: execution completed + QueryPlanFactory-->>-SQLService: execution completed +``` +### Legacy Engine Fallback +```mermaid +sequenceDiagram + participant RestSQLQueryAction + participant Legacy Engine + participant SQLService + participant QueryPlanFactory + participant CanPaginateVisitor + +RestSQLQueryAction->>+SQLService: prepareRequest + SQLService->>+QueryPlanFactory: execute + critical V2 support check + QueryPlanFactory->>+CanPaginateVisitor: canConvertToCursor + CanPaginateVisitor-->>-QueryPlanFactory: false + QueryPlanFactory-->>-RestSQLQueryAction: UnsupportedCursorRequestException + deactivate SQLService + end + RestSQLQueryAction->>Legacy Engine: accept + Note over Legacy Engine: Processing in Legacy engine + Legacy Engine-->>RestSQLQueryAction:complete +``` + +### Serialization +```mermaid +sequenceDiagram + participant PlanSerializer + participant ProjectOperator + participant ResourceMonitorPlan + participant OpenSearchPagedIndexScan + participant OpenSearchScrollRequest + participant ContinuePageRequest + +PlanSerializer->>+ProjectOperator: getPlanForSerialization + ProjectOperator-->>-PlanSerializer: this +PlanSerializer->>+ProjectOperator: serialize + Note over ProjectOperator: dump private fields + ProjectOperator->>+ResourceMonitorPlan: getPlanForSerialization + ResourceMonitorPlan-->>-ProjectOperator: delegate + Note over ResourceMonitorPlan: ResourceMonitorPlan
is not serialized + ProjectOperator->>+OpenSearchPagedIndexScan: serialize + alt First page + OpenSearchPagedIndexScan->>+OpenSearchScrollRequest: toCursor + OpenSearchScrollRequest-->>-OpenSearchPagedIndexScan: scroll ID + else Subsequent page + OpenSearchPagedIndexScan->>+ContinuePageRequest: toCursor + ContinuePageRequest-->>-OpenSearchPagedIndexScan: scroll ID + end + Note over OpenSearchPagedIndexScan: dump private fields + OpenSearchPagedIndexScan-->>-ProjectOperator: serialized + ProjectOperator-->>-PlanSerializer: serialized +Note over PlanSerializer: Zip to reduce size +``` + +### Deserialization +```mermaid +sequenceDiagram + participant PlanSerializer + participant Deserialization Stream + participant ProjectOperator + participant OpenSearchPagedIndexScan + participant ContinuePageRequest + +Note over PlanSerializer: Unzip +PlanSerializer->>+Deserialization Stream: deserialize + Deserialization Stream->>+ProjectOperator: create new + Note over ProjectOperator: load private fields + ProjectOperator-->>Deserialization Stream: deserialize input + activate Deserialization Stream + Deserialization Stream->>+OpenSearchPagedIndexScan: create new + deactivate Deserialization Stream + OpenSearchPagedIndexScan-->>+Deserialization Stream: resolve engine + Deserialization Stream->>-OpenSearchPagedIndexScan: OpenSearchStorageEngine + Note over OpenSearchPagedIndexScan: load private fields + OpenSearchPagedIndexScan->>+ContinuePageRequest: create new + ContinuePageRequest-->>-OpenSearchPagedIndexScan: created + OpenSearchPagedIndexScan-->>-ProjectOperator: deserialized + ProjectOperator-->>-PlanSerializer: deserialized + deactivate Deserialization Stream +``` + +### Total Hits + +Total Hits is the number of rows matching the search criteria; with `select *` queries it is equal to row (doc) number in the table (index). +Example: +Paging thru `SELECT * FROM calcs` (17 rows) with `fetch_size = 5` returns: + +* Page 1: total hits = 17, result size = 5, cursor +* Page 2: total hits = 17, result size = 5, cursor +* Page 3: total hits = 17, result size = 5, cursor +* Page 4: total hits = 17, result size = 2, cursor +* Page 5: total hits = 0, result size = 0 + +Default implementation of `getTotalHits` in a Physical Plan iterate child plans down the tree and gets the maximum value or 0. + +```mermaid +sequenceDiagram + participant OpenSearchExecutionEngine + participant ProjectOperator + participant ResourceMonitorPlan + participant OpenSearchPagedIndexScan + +OpenSearchExecutionEngine->>+ProjectOperator: getTotalHits + Note over ProjectOperator: default implementation + ProjectOperator->>+ResourceMonitorPlan: getTotalHits + Note over ResourceMonitorPlan: call to delegate + ResourceMonitorPlan->>+OpenSearchPagedIndexScan: getTotalHits + Note over OpenSearchPagedIndexScan: use stored value from the search response + OpenSearchPagedIndexScan-->>-ResourceMonitorPlan: value + ResourceMonitorPlan-->>-ProjectOperator: value + ProjectOperator-->>-OpenSearchExecutionEngine: value +``` diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 0a30e057ad..6e13d02782 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -125,6 +125,11 @@ compileTestJava { testClusters.all { testDistribution = 'archive' + + // debug with command, ./gradlew opensearch-sql:run -DdebugJVM. --debug-jvm does not work with keystore. + if (System.getProperty("debugJVM") != null) { + jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:5005' + } } testClusters.integTest { @@ -229,10 +234,16 @@ integTest { // Tell the test JVM if the cluster JVM is running under a debugger so that tests can use longer timeouts for // requests. The 'doFirst' delays reading the debug setting on the cluster till execution time. - doFirst { systemProperty 'cluster.debug', getDebug() } + doFirst { + if (System.getProperty("debug-jvm") != null) { + setDebug(true); + } + systemProperty 'cluster.debug', getDebug() + } + if (System.getProperty("test.debug") != null) { - jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' + jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5006' } if (System.getProperty("tests.rest.bwcsuite") == null) { diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java index 113a19885a..5b9a583d04 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/CursorIT.java @@ -123,11 +123,16 @@ public void validNumberOfPages() throws IOException { String selectQuery = StringUtils.format("SELECT firstname, state FROM %s", TEST_INDEX_ACCOUNT); JSONObject response = new JSONObject(executeFetchQuery(selectQuery, 50, JDBC)); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); + int pageCount = 1; while (!cursor.isEmpty()) { //this condition also checks that there is no cursor on last page response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + if (!cursor.isEmpty()) { + verifyIsV1Cursor(cursor); + } pageCount++; } @@ -136,12 +141,16 @@ public void validNumberOfPages() throws IOException { // using random value here, with fetch size of 28 we should get 36 pages (ceil of 1000/28) response = new JSONObject(executeFetchQuery(selectQuery, 28, JDBC)); cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); System.out.println(response); pageCount = 1; while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + if (!cursor.isEmpty()) { + verifyIsV1Cursor(cursor); + } pageCount++; } assertThat(pageCount, equalTo(36)); @@ -223,6 +232,7 @@ public void testCursorWithPreparedStatement() throws IOException { "}", TestsConstants.TEST_INDEX_ACCOUNT)); assertTrue(response.has(CURSOR)); + verifyIsV1Cursor(response.getString(CURSOR)); } @Test @@ -244,11 +254,13 @@ public void testRegressionOnDateFormatChange() throws IOException { StringUtils.format("SELECT login_time FROM %s LIMIT 500", TEST_INDEX_DATE_TIME); JSONObject response = new JSONObject(executeFetchQuery(selectQuery, 1, JDBC)); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); actualDateList.add(response.getJSONArray(DATAROWS).getJSONArray(0).getString(0)); while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + verifyIsV1Cursor(cursor); actualDateList.add(response.getJSONArray(DATAROWS).getJSONArray(0).getString(0)); } @@ -274,7 +286,6 @@ public void defaultBehaviorWhenCursorSettingIsDisabled() throws IOException { query = StringUtils.format("SELECT firstname, email, state FROM %s", TEST_INDEX_ACCOUNT); response = new JSONObject(executeFetchQuery(query, 100, JDBC)); assertTrue(response.has(CURSOR)); - wipeAllClusterSettings(); } @@ -305,12 +316,14 @@ public void testDefaultFetchSizeFromClusterSettings() throws IOException { JSONObject response = new JSONObject(executeFetchLessQuery(query, JDBC)); JSONArray datawRows = response.optJSONArray(DATAROWS); assertThat(datawRows.length(), equalTo(1000)); + verifyIsV1Cursor(response.getString(CURSOR)); updateClusterSettings(new ClusterSetting(TRANSIENT, "opensearch.sql.cursor.fetch_size", "786")); response = new JSONObject(executeFetchLessQuery(query, JDBC)); datawRows = response.optJSONArray(DATAROWS); assertThat(datawRows.length(), equalTo(786)); assertTrue(response.has(CURSOR)); + verifyIsV1Cursor(response.getString(CURSOR)); wipeAllClusterSettings(); } @@ -323,11 +336,12 @@ public void testCursorCloseAPI() throws IOException { "SELECT firstname, state FROM %s WHERE balance > 100 and age < 40", TEST_INDEX_ACCOUNT); JSONObject result = new JSONObject(executeFetchQuery(selectQuery, 50, JDBC)); String cursor = result.getString(CURSOR); - + verifyIsV1Cursor(cursor); // Retrieving next 10 pages out of remaining 19 pages for (int i = 0; i < 10; i++) { result = executeCursorQuery(cursor); cursor = result.optString(CURSOR); + verifyIsV1Cursor(cursor); } //Closing the cursor JSONObject closeResp = executeCursorCloseQuery(cursor); @@ -386,12 +400,14 @@ public void respectLimitPassedInSelectClause() throws IOException { StringUtils.format("SELECT age, balance FROM %s LIMIT %s", TEST_INDEX_ACCOUNT, limit); JSONObject response = new JSONObject(executeFetchQuery(selectQuery, 50, JDBC)); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); int actualDataRowCount = response.getJSONArray(DATAROWS).length(); int pageCount = 1; while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); cursor = response.optString(CURSOR); + verifyIsV1Cursor(cursor); actualDataRowCount += response.getJSONArray(DATAROWS).length(); pageCount++; } @@ -432,10 +448,12 @@ public void verifyWithAndWithoutPaginationResponse(String sqlQuery, String curso response.optJSONArray(DATAROWS).forEach(dataRows::put); String cursor = response.getString(CURSOR); + verifyIsV1Cursor(cursor); while (!cursor.isEmpty()) { response = executeCursorQuery(cursor); response.optJSONArray(DATAROWS).forEach(dataRows::put); cursor = response.optString(CURSOR); + verifyIsV1Cursor(cursor); } verifySchema(withoutCursorResponse.optJSONArray(SCHEMA), @@ -465,6 +483,13 @@ public String executeFetchAsStringQuery(String query, String fetchSize, String r return responseString; } + private void verifyIsV1Cursor(String cursor) { + if (cursor.isEmpty()) { + return; + } + assertTrue("The cursor '" + cursor + "' is not from v1 engine.", cursor.startsWith("d:")); + } + private String makeRequest(String query, String fetch_size) { return String.format("{" + " \"fetch_size\": \"%s\"," + diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index 35ae5d3675..7b4ec6e561 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -260,6 +260,17 @@ protected String executeFetchQuery(String query, int fetchSize, String requestTy return responseString; } + protected JSONObject executeQueryTemplate(String queryTemplate, String index, int fetchSize) + throws IOException { + var query = String.format(queryTemplate, index); + return new JSONObject(executeFetchQuery(query, fetchSize, "jdbc")); + } + + protected JSONObject executeQueryTemplate(String queryTemplate, String index) throws IOException { + var query = String.format(queryTemplate, index); + return executeQueryTemplate(queryTemplate, index, 4); + } + protected String executeFetchLessQuery(String query, String requestType) throws IOException { String endpoint = "/_plugins/_sql?format=" + requestType; diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index cca7833d66..595fd8acd5 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -41,28 +41,29 @@ import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.monitor.AlwaysHealthyMonitor; import org.opensearch.sql.monitor.ResourceMonitor; -import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.client.OpenSearchRestClient; import org.opensearch.sql.opensearch.executor.OpenSearchExecutionEngine; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.security.SecurityAccess; -import org.opensearch.sql.opensearch.storage.OpenSearchDataSourceFactory; import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; import org.opensearch.sql.planner.Planner; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; -import org.opensearch.sql.ppl.domain.PPLQueryRequest; -import org.opensearch.sql.protocol.response.QueryResult; -import org.opensearch.sql.protocol.response.format.SimpleJsonResponseFormatter; import org.opensearch.sql.sql.SQLService; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; -import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.util.ExecuteOnCallerThreadQueryManager; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchRestClient; +import org.opensearch.sql.opensearch.security.SecurityAccess; +import org.opensearch.sql.opensearch.storage.OpenSearchDataSourceFactory; +import org.opensearch.sql.ppl.domain.PPLQueryRequest; +import org.opensearch.sql.protocol.response.QueryResult; +import org.opensearch.sql.protocol.response.format.SimpleJsonResponseFormatter; +import org.opensearch.sql.storage.DataSourceFactory; /** * Run PPL with query engine outside OpenSearch cluster. This IT doesn't require our plugin @@ -71,13 +72,11 @@ */ public class StandaloneIT extends PPLIntegTestCase { - private RestHighLevelClient restClient; - private PPLService pplService; @Override public void init() { - restClient = new InternalRestHighLevelClient(client()); + RestHighLevelClient restClient = new InternalRestHighLevelClient(client()); OpenSearchClient client = new OpenSearchRestClient(restClient); DataSourceService dataSourceService = new DataSourceServiceImpl( new ImmutableSet.Builder() @@ -198,8 +197,9 @@ public StorageEngine storageEngine(OpenSearchClient client) { } @Provides - public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector) { - return new OpenSearchExecutionEngine(client, protector); + public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, + PlanSerializer planSerializer) { + return new OpenSearchExecutionEngine(client, protector, planSerializer); } @Provides @@ -229,17 +229,23 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl } @Provides - public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine) { + public PlanSerializer paginatedPlanCache(StorageEngine storageEngine) { + return new PlanSerializer(storageEngine); + } + + @Provides + public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine, + PlanSerializer planSerializer) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); - return new QueryPlanFactory(new QueryService(analyzer, executionEngine, planner)); + QueryService queryService = new QueryService(analyzer, executionEngine, planner); + return new QueryPlanFactory(queryService, planSerializer); } } - - private DataSourceMetadataStorage getDataSourceMetadataStorage() { + public static DataSourceMetadataStorage getDataSourceMetadataStorage() { return new DataSourceMetadataStorage() { @Override public List getDataSourceMetadata() { @@ -268,7 +274,7 @@ public void deleteDataSourceMetadata(String datasourceName) { }; } - private DataSourceUserAuthorizationHelper getDataSourceUserRoleHelper() { + public static DataSourceUserAuthorizationHelper getDataSourceUserRoleHelper() { return new DataSourceUserAuthorizationHelper() { @Override public void authorizeDataSource(DataSourceMetadata dataSourceMetadata) { @@ -276,5 +282,4 @@ public void authorizeDataSource(DataSourceMetadata dataSourceMetadata) { } }; } - } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java index 809e2dc7c5..0ab6d5c70f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java @@ -64,7 +64,7 @@ public void highlight_multiple_optional_arguments_test() { schema("highlight(Body, pre_tags='', " + "post_tags='')", null, "nested")); - assertEquals(1, response.getInt("total")); + assertEquals(1, response.getInt("size")); verifyDataRows(response, rows(new JSONArray(List.of("What are the differences between an IPA" + " and its variants?")), diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java new file mode 100644 index 0000000000..d8213b1fe4 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationBlackboxIT.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestUtils.getResponseBody; +import static org.opensearch.sql.legacy.TestUtils.isIndexExist; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ONLINE; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.SneakyThrows; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.opensearch.client.Request; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +// This class has only one test case, because it is parametrized and takes significant time +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class PaginationBlackboxIT extends SQLIntegTestCase { + + private final String index; + private final Integer pageSize; + + public PaginationBlackboxIT(@Name("index") String index, + @Name("pageSize") Integer pageSize) { + this.index = index; + this.pageSize = pageSize; + } + + @ParametersFactory(argumentFormatting = "index = %1$s, page_size = %2$d") + public static Iterable compareTwoDates() { + var indices = new PaginationBlackboxHelper().getIndices(); + var pageSizes = List.of(5, 10, 100, 1000); + var testData = new ArrayList(); + for (var index : indices) { + for (var pageSize : pageSizes) { + testData.add(new Object[] { index, pageSize }); + } + } + return testData; + } + + @Test + @SneakyThrows + public void test_pagination_blackbox() { + var response = executeJdbcRequest(String.format("select * from %s", index)); + var indexSize = response.getInt("total"); + var rows = response.getJSONArray("datarows"); + var schema = response.getJSONArray("schema"); + var testReportPrefix = String.format("index: %s, page size: %d || ", index, pageSize); + var rowsPaged = new JSONArray(); + var rowsReturned = 0; + response = new JSONObject(executeFetchQuery( + String.format("select * from %s", index), pageSize, "jdbc")); + var responseCounter = 1; + this.logger.info(testReportPrefix + "first response"); + while (response.has("cursor")) { + assertEquals(indexSize, response.getInt("total")); + assertTrue("Paged response schema doesn't match to non-paged", + schema.similar(response.getJSONArray("schema"))); + var cursor = response.getString("cursor"); + assertTrue(testReportPrefix + "Cursor returned from legacy engine", + cursor.startsWith("n:")); + rowsReturned += response.getInt("size"); + var datarows = response.getJSONArray("datarows"); + for (int i = 0; i < datarows.length(); i++) { + rowsPaged.put(datarows.get(i)); + } + response = executeCursorQuery(cursor); + this.logger.info(testReportPrefix + + String.format("subsequent response %d/%d", responseCounter++, (indexSize / pageSize) + 1)); + } + assertTrue("Paged response schema doesn't match to non-paged", + schema.similar(response.getJSONArray("schema"))); + assertEquals(0, response.getInt("total")); + + assertEquals(testReportPrefix + "Last page is not empty", + 0, response.getInt("size")); + assertEquals(testReportPrefix + "Last page is not empty", + 0, response.getJSONArray("datarows").length()); + assertEquals(testReportPrefix + "Paged responses return another row count that non-paged", + indexSize, rowsReturned); + assertTrue(testReportPrefix + "Paged accumulated result has other rows than non-paged", + rows.similar(rowsPaged)); + } + + // A dummy class created, because accessing to `client()` isn't available from a static context, + // but it is needed before an instance of `PaginationBlackboxIT` is created. + private static class PaginationBlackboxHelper extends SQLIntegTestCase { + + @SneakyThrows + private List getIndices() { + initClient(); + loadIndex(Index.ACCOUNT); + loadIndex(Index.BEER); + loadIndex(Index.BANK); + if (!isIndexExist(client(), "empty")) { + executeRequest(new Request("PUT", "/empty")); + } + return Arrays.stream(getResponseBody(client().performRequest(new Request("GET", "_cat/indices?h=i")), true).split("\n")) + // exclude this index, because it is too big and extends test time too long (almost 10k docs) + .map(String::trim).filter(i -> !i.equals(TEST_INDEX_ONLINE)).collect(Collectors.toList()); + } + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFallbackIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFallbackIT.java new file mode 100644 index 0000000000..33d9c5f6a8 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationFallbackIT.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ONLINE; +import static org.opensearch.sql.util.TestUtils.verifyIsV1Cursor; +import static org.opensearch.sql.util.TestUtils.verifyIsV2Cursor; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.util.TestUtils; + +public class PaginationFallbackIT extends SQLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.PHRASE); + loadIndex(Index.ONLINE); + } + + @Test + public void testWhereClause() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s WHERE 1 = 1", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectAll() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_ONLINE); + verifyIsV2Cursor(response); + } + + @Test + public void testSelectWithOpenSearchFuncInFilter() throws IOException { + var response = executeQueryTemplate( + "SELECT * FROM %s WHERE `11` = match_phrase('96')", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectWithHighlight() throws IOException { + var response = executeQueryTemplate( + "SELECT highlight(`11`) FROM %s WHERE match_query(`11`, '96')", TEST_INDEX_ONLINE); + // As of 2023-03-08, WHERE clause sends the query to legacy engine and legacy engine + // does not support highlight as an expression. + assertTrue(response.has("error")); + } + + @Test + public void testSelectWithFullTextSearch() throws IOException { + var response = executeQueryTemplate( + "SELECT * FROM %s WHERE match_phrase(`11`, '96')", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectFromIndexWildcard() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s*", TEST_INDEX); + verifyIsV2Cursor(response); + } + + @Test + public void testSelectFromDataSource() throws IOException { + var response = executeQueryTemplate("SELECT * FROM @opensearch.%s", + TEST_INDEX_ONLINE); + verifyIsV2Cursor(response); + } + + @Test + public void testSelectColumnReference() throws IOException { + var response = executeQueryTemplate("SELECT `107` from %s", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSubquery() throws IOException { + var response = executeQueryTemplate("SELECT `107` from (SELECT * FROM %s)", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testSelectExpression() throws IOException { + var response = executeQueryTemplate("SELECT 1 + 1 - `107` from %s", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testGroupBy() throws IOException { + // GROUP BY is not paged by either engine. + var response = executeQueryTemplate("SELECT * FROM %s GROUP BY `107`", + TEST_INDEX_ONLINE); + TestUtils.verifyNoCursor(response); + } + + @Test + public void testGroupByHaving() throws IOException { + // GROUP BY is not paged by either engine. + var response = executeQueryTemplate("SELECT * FROM %s GROUP BY `107` HAVING `107` > 400", + TEST_INDEX_ONLINE); + TestUtils.verifyNoCursor(response); + } + + @Test + public void testLimit() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s LIMIT 8", TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testLimitOffset() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s LIMIT 8 OFFSET 4", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + @Test + public void testOrderBy() throws IOException { + var response = executeQueryTemplate("SELECT * FROM %s ORDER By `107`", + TEST_INDEX_ONLINE); + verifyIsV1Cursor(response); + } + + +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationIT.java new file mode 100644 index 0000000000..a1d353cde8 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationIT.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_CALCS; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ONLINE; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.Ignore; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.util.TestUtils; + +public class PaginationIT extends SQLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.CALCS); + loadIndex(Index.ONLINE); + } + + @Test + public void testSmallDataSet() throws IOException { + var query = "SELECT * from " + TEST_INDEX_CALCS; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertTrue(response.has("cursor")); + assertEquals(4, response.getInt("size")); + TestUtils.verifyIsV2Cursor(response); + } + + @Test + public void testLargeDataSetV1() throws IOException { + var v1query = "SELECT * from " + TEST_INDEX_ONLINE + " WHERE 1 = 1"; + var v1response = new JSONObject(executeFetchQuery(v1query, 4, "jdbc")); + assertEquals(4, v1response.getInt("size")); + TestUtils.verifyIsV1Cursor(v1response); + } + + @Test + public void testLargeDataSetV2() throws IOException { + var query = "SELECT * from " + TEST_INDEX_ONLINE; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertEquals(4, response.getInt("size")); + TestUtils.verifyIsV2Cursor(response); + } + + @Ignore("Scroll may not expire after timeout") + // Scroll keep alive parameter guarantees that scroll context would be kept for that time, + // but doesn't define how fast it will be expired after time out. + // With KA = 1s scroll may be kept up to 30 sec or more. We can't test exact expiration. + // I disable the test to prevent it waiting for a minute and delay all CI. + public void testCursorTimeout() throws IOException, InterruptedException { + updateClusterSettings( + new ClusterSetting(PERSISTENT, Settings.Key.SQL_CURSOR_KEEP_ALIVE.getKeyValue(), "1s")); + + var query = "SELECT * from " + TEST_INDEX_CALCS; + var response = new JSONObject(executeFetchQuery(query, 4, "jdbc")); + assertTrue(response.has("cursor")); + var cursor = response.getString("cursor"); + Thread.sleep(2222L); // > 1s + + ResponseException exception = + expectThrows(ResponseException.class, () -> executeCursorQuery(cursor)); + response = new JSONObject(TestUtils.getResponseBody(exception.getResponse())); + assertEquals(response.getJSONObject("error").getString("reason"), + "Error occurred in OpenSearch engine: all shards failed"); + assertTrue(response.getJSONObject("error").getString("details") + .contains("SearchContextMissingException[No search context found for id")); + assertEquals(response.getJSONObject("error").getString("type"), + "SearchPhaseExecutionException"); + + wipeAllClusterSettings(); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java new file mode 100644 index 0000000000..724451ef65 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PaginationWindowIT.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_PHRASE; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.After; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +public class PaginationWindowIT extends SQLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.PHRASE); + } + + @After + void resetParams() throws IOException { + resetMaxResultWindow(TEST_INDEX_PHRASE); + resetQuerySizeLimit(); + } + + @Test + public void testFetchSizeLessThanMaxResultWindow() throws IOException { + setMaxResultWindow(TEST_INDEX_PHRASE, 6); + JSONObject response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 5); + + String cursor = ""; + int numRows = 0; + do { + // Process response + cursor = response.getString("cursor"); + numRows += response.getJSONArray("datarows").length(); + response = executeCursorQuery(cursor); + } while (response.has("cursor")); + + var countRows = executeJdbcRequest("SELECT COUNT(*) FROM " + TEST_INDEX_PHRASE) + .getJSONArray("datarows") + .getJSONArray(0) + .get(0); + assertEquals(countRows, numRows); + } + + @Test + public void testQuerySizeLimitDoesNotEffectTotalRowsReturned() throws IOException { + int querySizeLimit = 4; + setQuerySizeLimit(querySizeLimit); + JSONObject response = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 5); + assertTrue(response.getInt("size") > querySizeLimit); + + String cursor = ""; + int numRows = 0; + do { + // Process response + cursor = response.getString("cursor"); + numRows += response.getJSONArray("datarows").length(); + response = executeCursorQuery(cursor); + } while (response.has("cursor")); + + var countRows = executeJdbcRequest("SELECT COUNT(*) FROM " + TEST_INDEX_PHRASE) + .getJSONArray("datarows") + .getJSONArray(0) + .get(0); + assertEquals(countRows, numRows); + assertTrue(numRows > querySizeLimit); + } + + @Test + public void testQuerySizeLimitDoesNotEffectPageSize() throws IOException { + setQuerySizeLimit(3); + setMaxResultWindow(TEST_INDEX_PHRASE, 4); + var response + = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 4); + assertEquals(4, response.getInt("size")); + + var response2 + = executeQueryTemplate("SELECT * FROM %s", TEST_INDEX_PHRASE, 2); + assertEquals(2, response2.getInt("size")); + } + + @Test + public void testFetchSizeLargerThanResultWindowFails() throws IOException { + final int window = 2; + setMaxResultWindow(TEST_INDEX_PHRASE, 2); + assertThrows(ResponseException.class, + () -> executeQueryTemplate("SELECT * FROM %s", + TEST_INDEX_PHRASE, window + 1)); + resetMaxResultWindow(TEST_INDEX_PHRASE); + } + + +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java new file mode 100644 index 0000000000..0095bec7ca --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/StandalonePaginationIT.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata; +import static org.opensearch.sql.ppl.StandaloneIT.getDataSourceMetadataStorage; +import static org.opensearch.sql.ppl.StandaloneIT.getDataSourceUserRoleHelper; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import lombok.Getter; +import lombok.SneakyThrows; +import org.json.JSONObject; +import org.junit.Test; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.opensearch.client.Request; +import org.opensearch.client.ResponseException; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.common.response.ResponseListener; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchRestClient; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.opensearch.storage.OpenSearchDataSourceFactory; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.planner.PlanContext; +import org.opensearch.sql.planner.logical.LogicalPaginate; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.sql.util.InternalRestHighLevelClient; +import org.opensearch.sql.util.StandaloneModule; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class StandalonePaginationIT extends SQLIntegTestCase { + + private QueryService queryService; + + private PlanSerializer planSerializer; + + private OpenSearchClient client; + + @Override + @SneakyThrows + public void init() { + RestHighLevelClient restClient = new InternalRestHighLevelClient(client()); + client = new OpenSearchRestClient(restClient); + DataSourceService dataSourceService = new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new OpenSearchDataSourceFactory(client, defaultSettings())) + .build(), + getDataSourceMetadataStorage(), + getDataSourceUserRoleHelper() + ); + dataSourceService.createDataSource(defaultOpenSearchDataSourceMetadata()); + + ModulesBuilder modules = new ModulesBuilder(); + modules.add(new StandaloneModule(new InternalRestHighLevelClient(client()), defaultSettings(), dataSourceService)); + Injector injector = modules.createInjector(); + + queryService = injector.getInstance(QueryService.class); + planSerializer = injector.getInstance(PlanSerializer.class); + } + + @Test + public void test_pagination_whitebox() throws IOException { + class TestResponder + implements ResponseListener { + @Getter + Cursor cursor = Cursor.None; + @Override + public void onResponse(ExecutionEngine.QueryResponse response) { + cursor = response.getCursor(); + } + + @Override + public void onFailure(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + }; + + // arrange + { + Request request1 = new Request("PUT", "/test/_doc/1?refresh=true"); + request1.setJsonEntity("{\"name\": \"hello\", \"age\": 20}"); + client().performRequest(request1); + Request request2 = new Request("PUT", "/test/_doc/2?refresh=true"); + request2.setJsonEntity("{\"name\": \"world\", \"age\": 30}"); + client().performRequest(request2); + } + + // act 1, asserts in firstResponder + var t = new OpenSearchIndex(client, defaultSettings(), "test"); + LogicalPlan p = new LogicalPaginate(1, List.of( + new LogicalProject( + new LogicalRelation("test", t), List.of( + DSL.named("name", DSL.ref("name", ExprCoreType.STRING)), + DSL.named("age", DSL.ref("age", ExprCoreType.LONG))), + List.of() + ))); + var firstResponder = new TestResponder(); + queryService.executePlan(p, PlanContext.emptyPlanContext(), firstResponder); + + // act 2, asserts in secondResponder + + PhysicalPlan plan = planSerializer.convertToPlan(firstResponder.getCursor().toString()); + var secondResponder = new TestResponder(); + queryService.executePlan(plan, secondResponder); + + // act 3: confirm that there's no cursor. + } + + @Test + @SneakyThrows + public void test_explain_not_supported() { + var request = new Request("POST", "_plugins/_sql/_explain"); + // Request should be rejected before index names are resolved + request.setJsonEntity("{ \"query\": \"select * from something\", \"fetch_size\": 10 }"); + var exception = assertThrows(ResponseException.class, () -> client().performRequest(request)); + var response = new JSONObject(new String(exception.getResponse().getEntity().getContent().readAllBytes())); + assertEquals("`explain` feature for paginated requests is not implemented yet.", + response.getJSONObject("error").getString("details")); + + // Request should be rejected before cursor parsed + request.setJsonEntity("{ \"cursor\" : \"n:0000\" }"); + exception = assertThrows(ResponseException.class, () -> client().performRequest(request)); + response = new JSONObject(new String(exception.getResponse().getEntity().getContent().readAllBytes())); + assertEquals("Explain of a paged query continuation is not supported. Use `explain` for the initial query request.", + response.getJSONObject("error").getString("details")); + } + + private Settings defaultSettings() { + return new Settings() { + private final Map defaultSettings = new ImmutableMap.Builder() + .put(Key.QUERY_SIZE_LIMIT, 200) + .put(Key.SQL_CURSOR_KEEP_ALIVE, TimeValue.timeValueMinutes(1)) + .build(); + + @Override + public T getSettingValue(Key key) { + return (T) defaultSettings.get(key); + } + + @Override + public List getSettings() { + return (List) defaultSettings; + } + }; + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/InternalRestHighLevelClient.java b/integ-test/src/test/java/org/opensearch/sql/util/InternalRestHighLevelClient.java new file mode 100644 index 0000000000..57726089ae --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/util/InternalRestHighLevelClient.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.util; + +import java.util.Collections; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestHighLevelClient; + +/** + * Internal RestHighLevelClient only for testing purpose. + */ +public class InternalRestHighLevelClient extends RestHighLevelClient { + public InternalRestHighLevelClient(RestClient restClient) { + super(restClient, RestClient::close, Collections.emptyList()); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java new file mode 100644 index 0000000000..a86f251377 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/util/StandaloneModule.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.util; + +import lombok.RequiredArgsConstructor; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.inject.AbstractModule; +import org.opensearch.common.inject.Provides; +import org.opensearch.common.inject.Singleton; +import org.opensearch.sql.analysis.Analyzer; +import org.opensearch.sql.analysis.ExpressionAnalyzer; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.executor.QueryManager; +import org.opensearch.sql.executor.QueryService; +import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.monitor.AlwaysHealthyMonitor; +import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.client.OpenSearchRestClient; +import org.opensearch.sql.opensearch.executor.OpenSearchExecutionEngine; +import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; +import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; +import org.opensearch.sql.planner.Planner; +import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.ppl.PPLService; +import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; +import org.opensearch.sql.sql.SQLService; +import org.opensearch.sql.sql.antlr.SQLSyntaxParser; +import org.opensearch.sql.storage.StorageEngine; + +/** + * A utility class which registers SQL engine singletons as `OpenSearchPluginModule` does. + * It is needed to get access to those instances in test and validate their behavior. + */ +@RequiredArgsConstructor +public class StandaloneModule extends AbstractModule { + + private final RestHighLevelClient client; + + private final Settings settings; + + private final DataSourceService dataSourceService; + + private final BuiltinFunctionRepository functionRepository = + BuiltinFunctionRepository.getInstance(); + + @Override + protected void configure() { + } + + @Provides + public OpenSearchClient openSearchClient() { + return new OpenSearchRestClient(client); + } + + @Provides + public StorageEngine storageEngine(OpenSearchClient client) { + return new OpenSearchStorageEngine(client, settings); + } + + @Provides + public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, + PlanSerializer planSerializer) { + return new OpenSearchExecutionEngine(client, protector, planSerializer); + } + + @Provides + public ResourceMonitor resourceMonitor() { + return new AlwaysHealthyMonitor(); + } + + @Provides + public ExecutionProtector protector(ResourceMonitor resourceMonitor) { + return new OpenSearchExecutionProtector(resourceMonitor); + } + + @Provides + @Singleton + public QueryManager queryManager() { + return new ExecuteOnCallerThreadQueryManager(); + } + + @Provides + public PPLService pplService(QueryManager queryManager, QueryPlanFactory queryPlanFactory) { + return new PPLService(new PPLSyntaxParser(), queryManager, queryPlanFactory); + } + + @Provides + public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPlanFactory) { + return new SQLService(new SQLSyntaxParser(), queryManager, queryPlanFactory); + } + + @Provides + public PlanSerializer paginatedPlanCache(StorageEngine storageEngine) { + return new PlanSerializer(storageEngine); + } + + @Provides + public QueryPlanFactory queryPlanFactory(ExecutionEngine executionEngine, + PlanSerializer planSerializer, + QueryService qs) { + + return new QueryPlanFactory(qs, planSerializer); + } + + @Provides + public QueryService queryService(ExecutionEngine executionEngine) { + Analyzer analyzer = + new Analyzer( + new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); + Planner planner = new Planner(LogicalPlanOptimizer.create()); + return new QueryService(analyzer, executionEngine, planner); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java index bd75ead43b..69f1649190 100644 --- a/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java +++ b/integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java @@ -7,6 +7,8 @@ package org.opensearch.sql.util; import static com.google.common.base.Strings.isNullOrEmpty; +import static org.junit.Assert.assertTrue; +import static org.opensearch.sql.executor.pagination.PlanSerializer.CURSOR_PREFIX; import java.io.BufferedReader; import java.io.File; @@ -20,22 +22,21 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.stream.Collectors; import org.json.JSONObject; -import org.junit.Assert; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.client.Client; import org.opensearch.client.Request; -import org.opensearch.client.RequestOptions; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.rest.RestStatus; +import org.opensearch.sql.legacy.cursor.CursorType; public class TestUtils { @@ -839,4 +840,28 @@ public static List> getPermutations(final List items) { return result; } + + public static void verifyIsV1Cursor(JSONObject response) { + var legacyCursorPrefixes = Arrays.stream(CursorType.values()) + .map(c -> c.getId() + ":").collect(Collectors.toList()); + verifyCursor(response, legacyCursorPrefixes, "v1"); + } + + + public static void verifyIsV2Cursor(JSONObject response) { + verifyCursor(response, List.of(CURSOR_PREFIX), "v2"); + } + + private static void verifyCursor(JSONObject response, List validCursorPrefix, String engineName) { + assertTrue("'cursor' property does not exist", response.has("cursor")); + + var cursor = response.getString("cursor"); + assertTrue("'cursor' property is empty", !cursor.isEmpty()); + assertTrue("The cursor '" + cursor + "' is not from " + engineName + " engine.", + validCursorPrefix.stream().anyMatch(cursor::startsWith)); + } + + public static void verifyNoCursor(JSONObject response) { + assertTrue(!response.has("cursor")); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java index bc97f71b47..cbbc8c7b9c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java @@ -24,6 +24,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.utils.QueryContext; +import org.opensearch.sql.exception.UnsupportedCursorRequestException; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; @@ -119,14 +120,14 @@ private ResponseListener fallBackListener( return new ResponseListener() { @Override public void onResponse(T response) { - LOG.error("[{}] Request is handled by new SQL query engine", + LOG.info("[{}] Request is handled by new SQL query engine", QueryContext.getRequestId()); next.onResponse(response); } @Override public void onFailure(Exception e) { - if (e instanceof SyntaxCheckException) { + if (e instanceof SyntaxCheckException || e instanceof UnsupportedCursorRequestException) { fallBackHandler.accept(channel, e); } else { next.onFailure(e); @@ -172,7 +173,8 @@ private ResponseListener createQueryResponseListener( @Override public void onResponse(QueryResponse response) { sendResponse(channel, OK, - formatter.format(new QueryResult(response.getSchema(), response.getResults()))); + formatter.format(new QueryResult(response.getSchema(), response.getResults(), + response.getCursor(), response.getTotal()))); } @Override diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index 88ed42010b..e1c72f0f1e 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -42,6 +42,7 @@ import org.opensearch.sql.legacy.antlr.SqlAnalysisConfig; import org.opensearch.sql.legacy.antlr.SqlAnalysisException; import org.opensearch.sql.legacy.antlr.semantic.types.Type; +import org.opensearch.sql.legacy.cursor.CursorType; import org.opensearch.sql.legacy.domain.ColumnTypeProvider; import org.opensearch.sql.legacy.domain.QueryActionRequest; import org.opensearch.sql.legacy.esdomain.LocalClusterState; @@ -132,7 +133,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } final SqlRequest sqlRequest = SqlRequestFactory.getSqlRequest(request); - if (sqlRequest.cursor() != null) { + if (isLegacyCursor(sqlRequest)) { if (isExplainRequest(request)) { throw new IllegalArgumentException("Invalid request. Cannot explain cursor"); } else { @@ -148,14 +149,14 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli // Route request to new query engine if it's supported already SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(), - sqlRequest.getSql(), request.path(), request.params()); + sqlRequest.getSql(), request.path(), request.params(), sqlRequest.cursor()); return newSqlQueryHandler.prepareRequest(newSqlRequest, (restChannel, exception) -> { try{ if (newSqlRequest.isExplainRequest()) { LOG.info("Request is falling back to old SQL engine due to: " + exception.getMessage()); } - LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine", + LOG.info("[{}] Request {} is not supported and falling back to old SQL engine", QueryContext.getRequestId(), newSqlRequest); QueryAction queryAction = explainRequest(client, sqlRequest, format); executeSqlRequest(request, queryAction, client, restChannel); @@ -175,6 +176,17 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } + + /** + * @param sqlRequest client request + * @return true if this cursor was generated by the legacy engine, false otherwise. + */ + private static boolean isLegacyCursor(SqlRequest sqlRequest) { + String cursor = sqlRequest.cursor(); + return cursor != null + && CursorType.getById(cursor.substring(0, 1)) != CursorType.NULL; + } + @Override protected Set responseParams() { Set responseParams = new HashSet<>(super.responseParams()); diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java new file mode 100644 index 0000000000..a11f4c47d7 --- /dev/null +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionCursorFallbackTest.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.legacy.plugin; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.Strings; +import org.opensearch.common.inject.Injector; +import org.opensearch.common.inject.ModulesBuilder; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.executor.QueryManager; +import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.sql.SQLService; +import org.opensearch.sql.sql.antlr.SQLSyntaxParser; +import org.opensearch.sql.sql.domain.SQLQueryRequest; +import org.opensearch.threadpool.ThreadPool; + +/** + * A test suite that verifies fallback behaviour of cursor queries. + */ +@RunWith(MockitoJUnitRunner.class) +public class RestSQLQueryActionCursorFallbackTest extends BaseRestHandler { + + private NodeClient nodeClient; + + @Mock + private ThreadPool threadPool; + + @Mock + private QueryManager queryManager; + + @Mock + private QueryPlanFactory factory; + + @Mock + private RestChannel restChannel; + + private Injector injector; + + @Before + public void setup() { + nodeClient = new NodeClient(org.opensearch.common.settings.Settings.EMPTY, threadPool); + ModulesBuilder modules = new ModulesBuilder(); + modules.add(b -> { + b.bind(SQLService.class).toInstance(new SQLService(new SQLSyntaxParser(), queryManager, factory)); + }); + injector = modules.createInjector(); + Mockito.lenient().when(threadPool.getThreadContext()) + .thenReturn(new ThreadContext(org.opensearch.common.settings.Settings.EMPTY)); + } + + // Initial page request test cases + + @Test + public void no_fallback_with_column_reference() throws Exception { + String query = "SELECT name FROM test1"; + SQLQueryRequest request = createSqlQueryRequest(query, Optional.empty(), + Optional.of(5)); + + assertFalse(doesQueryFallback(request)); + } + + private static SQLQueryRequest createSqlQueryRequest(String query, Optional cursorId, + Optional fetchSize) throws IOException { + var builder = XContentFactory.jsonBuilder() + .startObject() + .field("query").value(query); + if (cursorId.isPresent()) { + builder.field("cursor").value(cursorId.get()); + } + + if (fetchSize.isPresent()) { + builder.field("fetch_size").value(fetchSize.get()); + } + builder.endObject(); + JSONObject jsonContent = new JSONObject(Strings.toString(builder)); + + return new SQLQueryRequest(jsonContent, query, QUERY_API_ENDPOINT, + Map.of("format", "jdbc"), cursorId.orElse("")); + } + + boolean doesQueryFallback(SQLQueryRequest request) throws Exception { + AtomicBoolean fallback = new AtomicBoolean(false); + RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); + queryAction.prepareRequest(request, (channel, exception) -> { + fallback.set(true); + }, (channel, exception) -> { + }).accept(restChannel); + return fallback.get(); + } + + @Override + public String getName() { + // do nothing, RestChannelConsumer is protected which required to extend BaseRestHandler + return null; + } + + @Override + protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) + { + // do nothing, RestChannelConsumer is protected which required to extend BaseRestHandler + return null; + } +} diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java index 1bc34edf50..be572f3dfb 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/plugin/RestSQLQueryActionTest.java @@ -74,7 +74,7 @@ public void handleQueryThatCanSupport() throws Exception { new JSONObject("{\"query\": \"SELECT -123\"}"), "SELECT -123", QUERY_API_ENDPOINT, - ""); + "jdbc"); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); queryAction.prepareRequest(request, (channel, exception) -> { @@ -90,7 +90,7 @@ public void handleExplainThatCanSupport() throws Exception { new JSONObject("{\"query\": \"SELECT -123\"}"), "SELECT -123", EXPLAIN_API_ENDPOINT, - ""); + "jdbc"); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); queryAction.prepareRequest(request, (channel, exception) -> { @@ -107,7 +107,7 @@ public void queryThatNotSupportIsHandledByFallbackHandler() throws Exception { "{\"query\": \"SELECT name FROM test1 JOIN test2 ON test1.name = test2.name\"}"), "SELECT name FROM test1 JOIN test2 ON test1.name = test2.name", QUERY_API_ENDPOINT, - ""); + "jdbc"); AtomicBoolean fallback = new AtomicBoolean(false); RestSQLQueryAction queryAction = new RestSQLQueryAction(injector); @@ -128,7 +128,7 @@ public void queryExecutionFailedIsHandledByExecutionErrorHandler() throws Except "{\"query\": \"SELECT -123\"}"), "SELECT -123", QUERY_API_ENDPOINT, - ""); + "jdbc"); doThrow(new IllegalStateException("execution exception")) .when(queryManager) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java index d6af4ca1e9..f9715ec1c3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClient.java @@ -42,7 +42,7 @@ public class OpenSearchNodeClient implements OpenSearchClient { private final NodeClient client; /** - * Constructor of ElasticsearchNodeClient. + * Constructor of OpenSearchNodeClient. */ public OpenSearchNodeClient(NodeClient client) { this.client = client; @@ -171,7 +171,14 @@ public Map meta() { @Override public void cleanup(OpenSearchRequest request) { - request.clean(scrollId -> client.prepareClearScroll().addScrollId(scrollId).get()); + request.clean(scrollId -> { + try { + client.prepareClearScroll().addScrollId(scrollId).get(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to clean up resources for search request " + request, e); + } + }); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java index d9f9dbbe5d..757ea99c1b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/client/OpenSearchRestClient.java @@ -184,7 +184,6 @@ public void cleanup(OpenSearchRequest request) { "Failed to clean up resources for search request " + request, e); } }); - } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index 9a136a3bec..bfc29b02d2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -15,6 +15,7 @@ import org.opensearch.sql.executor.ExecutionContext; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.Explain; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.executor.protector.ExecutionProtector; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -27,6 +28,7 @@ public class OpenSearchExecutionEngine implements ExecutionEngine { private final OpenSearchClient client; private final ExecutionProtector executionProtector; + private final PlanSerializer planSerializer; @Override public void execute(PhysicalPlan physicalPlan, ResponseListener listener) { @@ -49,7 +51,8 @@ public void execute(PhysicalPlan physicalPlan, ExecutionContext context, result.add(plan.next()); } - QueryResponse response = new QueryResponse(physicalPlan.schema(), result); + QueryResponse response = new QueryResponse(physicalPlan.schema(), result, + plan.getTotalHits(), planSerializer.convertToCursor(plan)); listener.onResponse(response); } catch (Exception e) { listener.onFailure(e); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java index 8fc7480dd1..0ec4d743b3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java @@ -6,12 +6,16 @@ package org.opensearch.sql.opensearch.executor.protector; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.List; import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.monitor.ResourceMonitor; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -21,7 +25,7 @@ @ToString @RequiredArgsConstructor @EqualsAndHashCode(callSuper = false) -public class ResourceMonitorPlan extends PhysicalPlan { +public class ResourceMonitorPlan extends PhysicalPlan implements SerializablePlan { /** * How many method calls to delegate's next() to perform resource check once. @@ -82,4 +86,28 @@ public ExprValue next() { } return delegate.next(); } + + @Override + public long getTotalHits() { + return delegate.getTotalHits(); + } + + @Override + public SerializablePlan getPlanForSerialization() { + return (SerializablePlan) delegate; + } + + /** + * Those two methods should never be called. They called if a plan upper in the tree missed to + * call {@link #getPlanForSerialization}. + */ + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + throw new UnsupportedOperationException(); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + throw new UnsupportedOperationException(); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequest.java new file mode 100644 index 0000000000..4789a50896 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; + +/** + * Scroll (cursor) request is used to page the search. This request is not configurable and has + * no search query. It just handles paging through responses to the initial request. + * It is used on second and next pagination (cursor) requests. + * First (initial) request is handled by {@link InitialPageRequestBuilder}. + */ +@EqualsAndHashCode +@RequiredArgsConstructor +public class ContinuePageRequest implements OpenSearchRequest { + private final String initialScrollId; + private final TimeValue scrollTimeout; + // ScrollId that OpenSearch returns after search. + private String responseScrollId; + + @EqualsAndHashCode.Exclude + @ToString.Exclude + @Getter + private final OpenSearchExprValueFactory exprValueFactory; + + @EqualsAndHashCode.Exclude + private boolean scrollFinished = false; + + @Override + public OpenSearchResponse search(Function searchAction, + Function scrollAction) { + SearchResponse openSearchResponse = scrollAction.apply(new SearchScrollRequest(initialScrollId) + .scroll(scrollTimeout)); + + // TODO if terminated_early - something went wrong, e.g. no scroll returned. + var response = new OpenSearchResponse(openSearchResponse, exprValueFactory, List.of()); + // on the last empty page, we should close the scroll + scrollFinished = response.isEmpty(); + responseScrollId = openSearchResponse.getScrollId(); + return response; + } + + @Override + public void clean(Consumer cleanAction) { + if (scrollFinished) { + cleanAction.accept(responseScrollId); + } + } + + @Override + public SearchSourceBuilder getSourceBuilder() { + throw new UnsupportedOperationException( + "SearchSourceBuilder is unavailable for ContinueScrollRequest"); + } + + @Override + public String toCursor() { + // on the last page, we shouldn't return the scroll to user, it is kept for closing (clean) + return scrollFinished ? null : responseScrollId; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilder.java new file mode 100644 index 0000000000..b1a6589aca --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilder.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.Getter; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; + +/** + * Builds a {@link ContinuePageRequest} to handle subsequent pagination/scroll/cursor requests. + * Initial search requests is handled by {@link InitialPageRequestBuilder}. + */ +public class ContinuePageRequestBuilder extends PagedRequestBuilder { + + @Getter + private final OpenSearchRequest.IndexName indexName; + @Getter + private final String scrollId; + private final TimeValue scrollTimeout; + private final OpenSearchExprValueFactory exprValueFactory; + + /** Constructor. */ + public ContinuePageRequestBuilder(OpenSearchRequest.IndexName indexName, + String scrollId, + Settings settings, + OpenSearchExprValueFactory exprValueFactory) { + this.indexName = indexName; + this.scrollId = scrollId; + this.scrollTimeout = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + this.exprValueFactory = exprValueFactory; + } + + @Override + public OpenSearchRequest build() { + return new ContinuePageRequest(scrollId, scrollTimeout, exprValueFactory); + } + + @Override + public void pushDownFilter(QueryBuilder query) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownAggregation(Pair, + OpenSearchAggregationResponseParser> aggregationBuilder) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownSort(List> sortBuilders) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownLimit(Integer limit, Integer offset) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownHighlight(String field, Map arguments) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownProjects(Set projects) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushTypeMapping(Map typeMapping) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownNested(List> nestedArgs) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } + + @Override + public void pushDownTrackedScore(boolean trackScores) { + throw new UnsupportedOperationException("Cursor requests don't support any push down"); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilder.java new file mode 100644 index 0000000000..25b7253eca --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilder.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.Getter; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; + +/** + * This builder assists creating the initial OpenSearch paging (scrolling) request. + * It is used only on the first page (pagination request). + * Subsequent requests (cursor requests) use {@link ContinuePageRequestBuilder}. + */ +public class InitialPageRequestBuilder extends PagedRequestBuilder { + + @Getter + private final OpenSearchRequest.IndexName indexName; + private final SearchSourceBuilder sourceBuilder; + private final OpenSearchExprValueFactory exprValueFactory; + private final TimeValue scrollTimeout; + + /** + * Constructor. + * @param indexName index being scanned + * @param pageSize page size + * @param exprValueFactory value factory + */ + // TODO accept indexName as string (same way as `OpenSearchRequestBuilder` does)? + public InitialPageRequestBuilder(OpenSearchRequest.IndexName indexName, + int pageSize, + Settings settings, + OpenSearchExprValueFactory exprValueFactory) { + this.indexName = indexName; + this.exprValueFactory = exprValueFactory; + this.scrollTimeout = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + this.sourceBuilder = new SearchSourceBuilder() + .from(0) + .size(pageSize) + .timeout(DEFAULT_QUERY_TIMEOUT); + } + + @Override + public OpenSearchScrollRequest build() { + return new OpenSearchScrollRequest(indexName, scrollTimeout, sourceBuilder, exprValueFactory); + } + + @Override + public void pushDownFilter(QueryBuilder query) { + throw new UnsupportedOperationException("Pagination does not support filter (WHERE clause)"); + } + + @Override + public void pushDownAggregation(Pair, + OpenSearchAggregationResponseParser> aggregationBuilder) { + throw new UnsupportedOperationException("Pagination does not support aggregations"); + } + + @Override + public void pushDownSort(List> sortBuilders) { + throw new UnsupportedOperationException("Pagination does not support sort (ORDER BY clause)"); + } + + @Override + public void pushDownLimit(Integer limit, Integer offset) { + throw new UnsupportedOperationException("Pagination does not support limit (LIMIT clause)"); + } + + @Override + public void pushDownHighlight(String field, Map arguments) { + throw new UnsupportedOperationException("Pagination does not support highlight function"); + } + + /** + * Push down project expression to OpenSearch. + */ + @Override + public void pushDownProjects(Set projects) { + sourceBuilder.fetchSource(projects.stream().map(ReferenceExpression::getAttr) + .distinct().toArray(String[]::new), new String[0]); + } + + @Override + public void pushTypeMapping(Map typeMapping) { + exprValueFactory.extendTypeMapping(typeMapping); + } + + @Override + public void pushDownNested(List> nestedArgs) { + throw new UnsupportedOperationException("Pagination does not support nested function"); + } + + @Override + public void pushDownTrackedScore(boolean trackScores) { + throw new UnsupportedOperationException("Pagination does not support score function"); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java index 3976f854fd..63aeed02f0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequest.java @@ -6,6 +6,8 @@ package org.opensearch.sql.opensearch.request; +import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; + import com.google.common.annotations.VisibleForTesting; import java.util.Arrays; import java.util.List; @@ -17,7 +19,6 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; -import org.opensearch.common.unit.TimeValue; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -35,11 +36,6 @@ @ToString public class OpenSearchQueryRequest implements OpenSearchRequest { - /** - * Default query timeout in minutes. - */ - public static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); - /** * {@link OpenSearchRequest.IndexName}. */ diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java index ce990780c1..c5b6d60af3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequest.java @@ -50,9 +50,13 @@ OpenSearchResponse search(Function searchAction, */ OpenSearchExprValueFactory getExprValueFactory(); + default String toCursor() { + return ""; + } + /** * OpenSearch Index Name. - * Indices are seperated by ",". + * Indices are separated by ",". */ @EqualsAndHashCode class IndexName { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 9f1b588af9..f8d62ad7ce 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -49,10 +49,10 @@ /** * OpenSearch search request builder. */ -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) @Getter @ToString -public class OpenSearchRequestBuilder { +public class OpenSearchRequestBuilder implements PushDownRequestBuilder { /** * Default query timeout in minutes. @@ -82,15 +82,21 @@ public class OpenSearchRequestBuilder { private final OpenSearchExprValueFactory exprValueFactory; /** - * Query size of the request. + * Query size of the request -- how many rows will be returned. */ - private Integer querySize; + private int querySize; + + /** + * Scroll context life time. + */ + private final TimeValue scrollTimeout; public OpenSearchRequestBuilder(String indexName, Integer maxResultWindow, Settings settings, OpenSearchExprValueFactory exprValueFactory) { - this(new OpenSearchRequest.IndexName(indexName), maxResultWindow, settings, exprValueFactory); + this(new OpenSearchRequest.IndexName(indexName), maxResultWindow, settings, + exprValueFactory); } /** @@ -102,13 +108,14 @@ public OpenSearchRequestBuilder(OpenSearchRequest.IndexName indexName, OpenSearchExprValueFactory exprValueFactory) { this.indexName = indexName; this.maxResultWindow = maxResultWindow; - this.sourceBuilder = new SearchSourceBuilder(); this.exprValueFactory = exprValueFactory; + this.scrollTimeout = settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); this.querySize = settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT); - sourceBuilder.from(0); - sourceBuilder.size(querySize); - sourceBuilder.timeout(DEFAULT_QUERY_TIMEOUT); - sourceBuilder.trackScores(false); + this.sourceBuilder = new SearchSourceBuilder() + .from(0) + .size(querySize) + .timeout(DEFAULT_QUERY_TIMEOUT) + .trackScores(false); } /** @@ -120,11 +127,12 @@ public OpenSearchRequest build() { Integer from = sourceBuilder.from(); Integer size = sourceBuilder.size(); - if (from + size <= maxResultWindow) { - return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory); - } else { + if (from + size > maxResultWindow) { sourceBuilder.size(maxResultWindow - from); - return new OpenSearchScrollRequest(indexName, sourceBuilder, exprValueFactory); + return new OpenSearchScrollRequest( + indexName, scrollTimeout, sourceBuilder, exprValueFactory); + } else { + return new OpenSearchQueryRequest(indexName, sourceBuilder, exprValueFactory); } } @@ -133,7 +141,8 @@ public OpenSearchRequest build() { * * @param query query request */ - public void pushDown(QueryBuilder query) { + @Override + public void pushDownFilter(QueryBuilder query) { QueryBuilder current = sourceBuilder.query(); if (current == null) { @@ -158,6 +167,7 @@ public void pushDown(QueryBuilder query) { * * @param aggregationBuilder pair of aggregation query and aggregation parser. */ + @Override public void pushDownAggregation( Pair, OpenSearchAggregationResponseParser> aggregationBuilder) { aggregationBuilder.getLeft().forEach(builder -> sourceBuilder.aggregation(builder)); @@ -170,6 +180,7 @@ public void pushDownAggregation( * * @param sortBuilders sortBuilders. */ + @Override public void pushDownSort(List> sortBuilders) { // TODO: Sort by _doc is added when filter push down. Remove both logic once doctest fixed. if (isSortByDocOnly()) { @@ -184,11 +195,13 @@ public void pushDownSort(List> sortBuilders) { /** * Push down size (limit) and from (offset) to DSL request. */ + @Override public void pushDownLimit(Integer limit, Integer offset) { querySize = limit; sourceBuilder.from(offset).size(limit); } + @Override public void pushDownTrackedScore(boolean trackScores) { sourceBuilder.trackScores(trackScores); } @@ -197,6 +210,7 @@ public void pushDownTrackedScore(boolean trackScores) { * Add highlight to DSL requests. * @param field name of the field to highlight */ + @Override public void pushDownHighlight(String field, Map arguments) { String unquotedField = StringUtils.unquoteText(field); if (sourceBuilder.highlighter() != null) { @@ -227,22 +241,20 @@ public void pushDownHighlight(String field, Map arguments) { } /** - * Push down project list to DSL requets. + * Push down project list to DSL requests. */ + @Override public void pushDownProjects(Set projects) { final Set projectsSet = projects.stream().map(ReferenceExpression::getAttr).collect(Collectors.toSet()); sourceBuilder.fetchSource(projectsSet.toArray(new String[0]), new String[0]); } + @Override public void pushTypeMapping(Map typeMapping) { exprValueFactory.extendTypeMapping(typeMapping); } - private boolean isBoolFilterQuery(QueryBuilder current) { - return (current instanceof BoolQueryBuilder); - } - private boolean isSortByDocOnly() { List> sorts = sourceBuilder.sorts(); if (sorts != null) { @@ -255,6 +267,7 @@ private boolean isSortByDocOnly() { * Push down nested to sourceBuilder. * @param nestedArgs : Nested arguments to push down. */ + @Override public void pushDownNested(List> nestedArgs) { initBoolQueryFilter(); groupFieldNamesByPath(nestedArgs).forEach( diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java index dacbecc7b9..77c6a781fe 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequest.java @@ -13,7 +13,6 @@ import java.util.function.Function; import lombok.EqualsAndHashCode; import lombok.Getter; -import lombok.RequiredArgsConstructor; import lombok.Setter; import lombok.ToString; import org.opensearch.action.search.SearchRequest; @@ -36,8 +35,8 @@ @ToString public class OpenSearchScrollRequest implements OpenSearchRequest { - /** Default scroll context timeout in minutes. */ - public static final TimeValue DEFAULT_SCROLL_TIMEOUT = TimeValue.timeValueMinutes(1L); + /** Scroll context timeout. */ + private final TimeValue scrollTimeout; /** * {@link OpenSearchRequest.IndexName}. @@ -54,27 +53,21 @@ public class OpenSearchScrollRequest implements OpenSearchRequest { * multi-thread so this state has to be maintained here. */ @Setter + @Getter private String scrollId; + private boolean needClean = false; + /** Search request source builder. */ private final SearchSourceBuilder sourceBuilder; - /** Constructor. */ - public OpenSearchScrollRequest(IndexName indexName, OpenSearchExprValueFactory exprValueFactory) { - this.indexName = indexName; - this.sourceBuilder = new SearchSourceBuilder(); - this.exprValueFactory = exprValueFactory; - } - - public OpenSearchScrollRequest(String indexName, OpenSearchExprValueFactory exprValueFactory) { - this(new IndexName(indexName), exprValueFactory); - } - /** Constructor. */ public OpenSearchScrollRequest(IndexName indexName, + TimeValue scrollTimeout, SearchSourceBuilder sourceBuilder, OpenSearchExprValueFactory exprValueFactory) { this.indexName = indexName; + this.scrollTimeout = scrollTimeout; this.sourceBuilder = sourceBuilder; this.exprValueFactory = exprValueFactory; } @@ -84,24 +77,30 @@ public OpenSearchScrollRequest(IndexName indexName, public OpenSearchResponse search(Function searchAction, Function scrollAction) { SearchResponse openSearchResponse; - if (isScrollStarted()) { + if (isScroll()) { openSearchResponse = scrollAction.apply(scrollRequest()); } else { openSearchResponse = searchAction.apply(searchRequest()); } - setScrollId(openSearchResponse.getScrollId()); FetchSourceContext fetchSource = this.sourceBuilder.fetchSource(); List includes = fetchSource != null && fetchSource.includes() != null ? Arrays.asList(this.sourceBuilder.fetchSource().includes()) : List.of(); - return new OpenSearchResponse(openSearchResponse, exprValueFactory, includes); + + var response = new OpenSearchResponse(openSearchResponse, exprValueFactory, includes); + if (!(needClean = response.isEmpty())) { + setScrollId(openSearchResponse.getScrollId()); + } + return response; } @Override public void clean(Consumer cleanAction) { try { - if (isScrollStarted()) { + // clean on the last page only, to prevent closing the scroll/cursor in the middle of paging. + if (needClean && isScroll()) { cleanAction.accept(getScrollId()); + setScrollId(null); } } finally { reset(); @@ -116,7 +115,7 @@ public void clean(Consumer cleanAction) { public SearchRequest searchRequest() { return new SearchRequest() .indices(indexName.getIndexNames()) - .scroll(DEFAULT_SCROLL_TIMEOUT) + .scroll(scrollTimeout) .source(sourceBuilder); } @@ -125,8 +124,8 @@ public SearchRequest searchRequest() { * * @return true if scroll started */ - public boolean isScrollStarted() { - return (scrollId != null); + public boolean isScroll() { + return scrollId != null; } /** @@ -136,7 +135,7 @@ public boolean isScrollStarted() { */ public SearchScrollRequest scrollRequest() { Objects.requireNonNull(scrollId, "Scroll id cannot be null"); - return new SearchScrollRequest().scroll(DEFAULT_SCROLL_TIMEOUT).scrollId(scrollId); + return new SearchScrollRequest().scroll(scrollTimeout).scrollId(scrollId); } /** @@ -146,4 +145,13 @@ public SearchScrollRequest scrollRequest() { public void reset() { scrollId = null; } + + /** + * Convert a scroll request to string that can be included in a cursor. + * @return a string representing the scroll request. + */ + @Override + public String toCursor() { + return scrollId; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PagedRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PagedRequestBuilder.java new file mode 100644 index 0000000000..69309bd7c9 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PagedRequestBuilder.java @@ -0,0 +1,12 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +public abstract class PagedRequestBuilder implements PushDownRequestBuilder { + public abstract OpenSearchRequest build(); + + public abstract OpenSearchRequest.IndexName getIndexName(); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PushDownRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PushDownRequestBuilder.java new file mode 100644 index 0000000000..59aa1949b6 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PushDownRequestBuilder.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.Getter; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; + +public interface PushDownRequestBuilder { + + default boolean isBoolFilterQuery(QueryBuilder current) { + return (current instanceof BoolQueryBuilder); + } + + void pushDownFilter(QueryBuilder query); + + void pushDownAggregation(Pair, + OpenSearchAggregationResponseParser> aggregationBuilder); + + void pushDownSort(List> sortBuilders); + + void pushDownLimit(Integer limit, Integer offset); + + void pushDownHighlight(String field, Map arguments); + + void pushDownProjects(Set projects); + + void pushTypeMapping(Map typeMapping); + + void pushDownNested(List> nestedArgs); + + void pushDownTrackedScore(boolean trackScores); +} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index 204a6bca22..af43be1a38 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -57,13 +57,13 @@ public class OpenSearchResponse implements Iterable { private final List includes; /** - * ElasticsearchExprValueFactory used to build ExprValue from search result. + * OpenSearchExprValueFactory used to build ExprValue from search result. */ @EqualsAndHashCode.Exclude private final OpenSearchExprValueFactory exprValueFactory; /** - * Constructor of ElasticsearchResponse. + * Constructor of OpenSearchResponse. */ public OpenSearchResponse(SearchResponse searchResponse, OpenSearchExprValueFactory exprValueFactory, @@ -75,7 +75,7 @@ public OpenSearchResponse(SearchResponse searchResponse, } /** - * Constructor of ElasticsearchResponse with SearchHits. + * Constructor of OpenSearchResponse with SearchHits. */ public OpenSearchResponse(SearchHits hits, OpenSearchExprValueFactory exprValueFactory, @@ -96,6 +96,10 @@ public boolean isEmpty() { return (hits.getHits() == null) || (hits.getHits().length == 0) && aggregations == null; } + public long getTotalHits() { + return hits.getTotalHits().value; + } + public boolean isAggregationResponse() { return aggregations != null; } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index ae5174d678..accd356041 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -99,8 +99,8 @@ public class OpenSearchSettings extends Settings { Setting.Property.Dynamic); /** - * Construct ElasticsearchSetting. - * The ElasticsearchSetting must be singleton. + * Construct OpenSearchSetting. + * The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index cf09b32de9..949b1e53ec 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -20,9 +20,14 @@ import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; +import org.opensearch.sql.opensearch.request.InitialPageRequestBuilder; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchPagedIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchPagedIndexScanBuilder; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; import org.opensearch.sql.planner.logical.LogicalML; @@ -179,6 +184,14 @@ public TableScanBuilder createScanBuilder() { return new OpenSearchIndexScanBuilder(indexScan); } + @Override + public TableScanBuilder createPagedScanBuilder(int pageSize) { + var requestBuilder = new InitialPageRequestBuilder(indexName, pageSize, settings, + new OpenSearchExprValueFactory(getFieldOpenSearchTypes())); + var indexScan = new OpenSearchPagedIndexScan(client, requestBuilder); + return new OpenSearchPagedIndexScanBuilder(indexScan); + } + @VisibleForTesting @RequiredArgsConstructor public static class OpenSearchDefaultImplementor diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java index 4a3393abc9..c915fa549b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex; +import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; @@ -21,8 +22,9 @@ public class OpenSearchStorageEngine implements StorageEngine { /** OpenSearch client connection. */ + @Getter private final OpenSearchClient client; - + @Getter private final Settings settings; @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java similarity index 93% rename from opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java rename to opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java index a26e64a809..2171fb564f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java @@ -4,7 +4,7 @@ */ -package org.opensearch.sql.opensearch.storage; +package org.opensearch.sql.opensearch.storage.scan; import java.util.Collections; import java.util.Iterator; @@ -104,6 +104,12 @@ public ExprValue next() { return iterator.next(); } + @Override + public long getTotalHits() { + // ignore response.getTotalHits(), because response returns entire index, regardless of LIMIT + return queryCount; + } + private void fetchNextBatch() { OpenSearchResponse response = client.search(request); if (!response.isEmpty()) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java index e52fc566cd..74be670dcc 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java @@ -16,7 +16,6 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; import org.opensearch.sql.planner.logical.LogicalAggregation; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 8e6c57d7d5..024331d267 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -8,7 +8,6 @@ import com.google.common.annotations.VisibleForTesting; import lombok.EqualsAndHashCode; import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalHighlight; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java index f20556ccc5..d9b4e6b4e0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -22,7 +22,6 @@ import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.function.OpenSearchFunctions; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; @@ -66,7 +65,7 @@ public boolean pushDownFilter(LogicalFilter filter) { new DefaultExpressionSerializer()); Expression queryCondition = filter.getCondition(); QueryBuilder query = queryBuilder.build(queryCondition); - indexScan.getRequestBuilder().pushDown(query); + indexScan.getRequestBuilder().pushDownFilter(query); indexScan.getRequestBuilder().pushDownTrackedScore( trackScoresFromOpenSearchFunction(queryCondition)); return true; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java new file mode 100644 index 0000000000..3667a3ffdf --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Collections; +import java.util.Iterator; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.apache.commons.lang3.NotImplementedException; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.ContinuePageRequestBuilder; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.PagedRequestBuilder; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; +import org.opensearch.sql.planner.SerializablePlan; +import org.opensearch.sql.storage.TableScanOperator; + +@EqualsAndHashCode(onlyExplicitlyIncluded = true, callSuper = false) +@ToString(onlyExplicitlyIncluded = true) +public class OpenSearchPagedIndexScan extends TableScanOperator implements SerializablePlan { + private OpenSearchClient client; + @Getter + private PagedRequestBuilder requestBuilder; + @EqualsAndHashCode.Include + @ToString.Include + private OpenSearchRequest request; + private Iterator iterator; + private long totalHits = 0; + + public OpenSearchPagedIndexScan(OpenSearchClient client, PagedRequestBuilder requestBuilder) { + this.client = client; + this.requestBuilder = requestBuilder; + } + + @Override + public String explain() { + throw new NotImplementedException("Implement OpenSearchPagedIndexScan.explain"); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public ExprValue next() { + return iterator.next(); + } + + @Override + public void open() { + super.open(); + request = requestBuilder.build(); + OpenSearchResponse response = client.search(request); + if (!response.isEmpty()) { + iterator = response.iterator(); + totalHits = response.getTotalHits(); + } else { + iterator = Collections.emptyIterator(); + } + } + + @Override + public void close() { + super.close(); + client.cleanup(request); + } + + @Override + public long getTotalHits() { + return totalHits; + } + + /** Don't use, it is for deserialization needs only. */ + @Deprecated + public OpenSearchPagedIndexScan() { + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + var engine = (OpenSearchStorageEngine) ((PlanSerializer.CursorDeserializationStream) in) + .resolveObject("engine"); + var indexName = (String) in.readUTF(); + var scrollId = (String) in.readUTF(); + client = engine.getClient(); + var index = new OpenSearchIndex(client, engine.getSettings(), indexName); + requestBuilder = new ContinuePageRequestBuilder( + new OpenSearchRequest.IndexName(indexName), + scrollId, engine.getSettings(), + new OpenSearchExprValueFactory(index.getFieldOpenSearchTypes())); + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + if (request.toCursor() == null || request.toCursor().isEmpty()) { + throw new NoCursorException(); + } + + out.writeUTF(requestBuilder.getIndexName().toString()); + out.writeUTF(request.toCursor()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanBuilder.java new file mode 100644 index 0000000000..779df4ebec --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanBuilder.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import lombok.EqualsAndHashCode; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Builder for a paged OpenSearch request. + * Override pushDown* methods from TableScanBuilder as more features + * support pagination. + */ +public class OpenSearchPagedIndexScanBuilder extends TableScanBuilder { + @EqualsAndHashCode.Include + OpenSearchPagedIndexScan indexScan; + + public OpenSearchPagedIndexScanBuilder(OpenSearchPagedIndexScan indexScan) { + this.indexScan = indexScan; + } + + @Override + public TableScanOperator build() { + return indexScan; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java index 1efa5b65d5..8b1cb08cfa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilder.java @@ -24,8 +24,6 @@ import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.data.type.ExprCoreType; -import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.NamedExpression; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java index eb4cb865e2..eba5eb126d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScan.java @@ -31,9 +31,13 @@ public class OpenSearchSystemIndexScan extends TableScanOperator { */ private Iterator iterator; + private long totalHits = 0; + @Override public void open() { - iterator = request.search().iterator(); + var response = request.search(); + totalHits = response.size(); + iterator = response.iterator(); } @Override @@ -46,6 +50,11 @@ public ExprValue next() { return iterator.next(); } + @Override + public long getTotalHits() { + return totalHits; + } + @Override public String explain() { return request.toString(); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java index 4ceb18f26e..6978155e87 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchNodeClientTest.java @@ -34,8 +34,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import lombok.SneakyThrows; +import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InOrder; @@ -57,6 +61,7 @@ import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.DeprecationHandler; @@ -65,6 +70,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -72,10 +78,12 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchNodeClientTest { private static final String TEST_MAPPING_FILE = "mappings/accounts.json"; @@ -107,7 +115,7 @@ void setUp() { } @Test - void isIndexExist() { + void is_index_exist() { when(nodeClient.admin().indices() .exists(any(IndicesExistsRequest.class)).actionGet()) .thenReturn(new IndicesExistsResponse(true)); @@ -116,7 +124,7 @@ void isIndexExist() { } @Test - void isIndexNotExist() { + void is_index_not_exist() { String indexName = "test"; when(nodeClient.admin().indices() .exists(any(IndicesExistsRequest.class)).actionGet()) @@ -126,14 +134,14 @@ void isIndexNotExist() { } @Test - void isIndexExistWithException() { + void is_index_exist_with_exception() { when(nodeClient.admin().indices().exists(any())).thenThrow(RuntimeException.class); assertThrows(IllegalStateException.class, () -> client.exists("test")); } @Test - void createIndex() { + void create_index() { String indexName = "test"; Map mappings = ImmutableMap.of( "properties", @@ -146,7 +154,7 @@ void createIndex() { } @Test - void createIndexWithException() { + void create_index_with_exception() { when(nodeClient.admin().indices().create(any())).thenThrow(RuntimeException.class); assertThrows(IllegalStateException.class, @@ -154,7 +162,7 @@ void createIndexWithException() { } @Test - void getIndexMappings() throws IOException { + void get_index_mappings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); String mappings = Resources.toString(url, Charsets.UTF_8); String indexName = "test"; @@ -225,7 +233,7 @@ void getIndexMappings() throws IOException { } @Test - void getIndexMappingsWithEmptyMapping() { + void get_index_mappings_with_empty_mapping() { String indexName = "test"; mockNodeClientIndicesMappings(indexName, ""); Map indexMappings = client.getIndexMappings(indexName); @@ -236,7 +244,7 @@ void getIndexMappingsWithEmptyMapping() { } @Test - void getIndexMappingsWithIOException() { + void get_index_mappings_with_IOException() { String indexName = "test"; when(nodeClient.admin().indices()).thenThrow(RuntimeException.class); @@ -244,7 +252,7 @@ void getIndexMappingsWithIOException() { } @Test - void getIndexMappingsWithNonExistIndex() { + void get_index_mappings_with_non_exist_index() { when(nodeClient.admin().indices() .prepareGetMappings(any()) .setLocal(anyBoolean()) @@ -255,7 +263,7 @@ void getIndexMappingsWithNonExistIndex() { } @Test - void getIndexMaxResultWindows() throws IOException { + void get_index_max_result_windows() throws IOException { URL url = Resources.getResource(TEST_MAPPING_SETTINGS_FILE); String indexMetadata = Resources.toString(url, Charsets.UTF_8); String indexName = "accounts"; @@ -269,7 +277,7 @@ void getIndexMaxResultWindows() throws IOException { } @Test - void getIndexMaxResultWindowsWithDefaultSettings() throws IOException { + void get_index_max_result_windows_with_default_settings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); String indexMetadata = Resources.toString(url, Charsets.UTF_8); String indexName = "accounts"; @@ -283,7 +291,7 @@ void getIndexMaxResultWindowsWithDefaultSettings() throws IOException { } @Test - void getIndexMaxResultWindowsWithIOException() { + void get_index_max_result_windows_with_IOException() { String indexName = "test"; when(nodeClient.admin().indices()).thenThrow(RuntimeException.class); @@ -292,7 +300,7 @@ void getIndexMaxResultWindowsWithIOException() { /** Jacoco enforce this constant lambda be tested. */ @Test - void testAllFieldsPredicate() { + void test_all_fields_predicate() { assertTrue(OpenSearchNodeClient.ALL_FIELDS.apply("any_index").test("any_field")); } @@ -315,11 +323,12 @@ void search() { // Mock second scroll request followed SearchResponse scrollResponse = mock(SearchResponse.class); when(nodeClient.searchScroll(any()).actionGet()).thenReturn(scrollResponse); - when(scrollResponse.getScrollId()).thenReturn("scroll456"); when(scrollResponse.getHits()).thenReturn(SearchHits.empty()); // Verify response for first scroll request - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); OpenSearchResponse response1 = client.search(request); assertFalse(response1.isEmpty()); @@ -329,6 +338,7 @@ void search() { assertFalse(hits.hasNext()); // Verify response for second scroll request + request.setScrollId("scroll123"); OpenSearchResponse response2 = client.search(request); assertTrue(response2.isEmpty()); } @@ -344,16 +354,21 @@ void schedule() { } @Test + @SneakyThrows void cleanup() { ClearScrollRequestBuilder requestBuilder = mock(ClearScrollRequestBuilder.class); when(nodeClient.prepareClearScroll()).thenReturn(requestBuilder); when(requestBuilder.addScrollId(any())).thenReturn(requestBuilder); when(requestBuilder.get()).thenReturn(null); - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); request.setScrollId("scroll123"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); client.cleanup(request); - assertFalse(request.isScrollStarted()); + assertFalse(request.isScroll()); InOrder inOrder = Mockito.inOrder(nodeClient, requestBuilder); inOrder.verify(nodeClient).prepareClearScroll(); @@ -362,14 +377,30 @@ void cleanup() { } @Test - void cleanupWithoutScrollId() { - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + void cleanup_without_scrollId() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); client.cleanup(request); verify(nodeClient, never()).prepareClearScroll(); } @Test - void getIndices() { + @SneakyThrows + void cleanup_rethrows_exception() { + when(nodeClient.prepareClearScroll()).thenThrow(new RuntimeException()); + + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); + request.setScrollId("scroll123"); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); + assertThrows(IllegalStateException.class, () -> client.cleanup(request)); + } + + @Test + void get_indices() { AliasMetadata aliasMetadata = mock(AliasMetadata.class); ImmutableOpenMap.Builder> builder = ImmutableOpenMap.builder(); builder.fPut("index",Arrays.asList(aliasMetadata)); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java index dd5bfd4e6f..ea463405b9 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/client/OpenSearchRestClientTest.java @@ -30,8 +30,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; +import lombok.SneakyThrows; +import org.apache.commons.lang3.reflect.FieldUtils; import org.apache.lucene.search.TotalHits; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -51,12 +55,14 @@ import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -64,10 +70,12 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchScrollRequest; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchRestClientTest { private static final String TEST_MAPPING_FILE = "mappings/accounts.json"; @@ -95,7 +103,7 @@ void setUp() { } @Test - void isIndexExist() throws IOException { + void is_index_exist() throws IOException { when(restClient.indices() .exists(any(), any())) // use any() because missing equals() in GetIndexRequest .thenReturn(true); @@ -104,7 +112,7 @@ void isIndexExist() throws IOException { } @Test - void isIndexNotExist() throws IOException { + void is_index_not_exist() throws IOException { when(restClient.indices() .exists(any(), any())) // use any() because missing equals() in GetIndexRequest .thenReturn(false); @@ -113,14 +121,14 @@ void isIndexNotExist() throws IOException { } @Test - void isIndexExistWithException() throws IOException { + void is_index_exist_with_exception() throws IOException { when(restClient.indices().exists(any(), any())).thenThrow(IOException.class); assertThrows(IllegalStateException.class, () -> client.exists("test")); } @Test - void createIndex() throws IOException { + void create_index() throws IOException { String indexName = "test"; Map mappings = ImmutableMap.of( "properties", @@ -133,7 +141,7 @@ void createIndex() throws IOException { } @Test - void createIndexWithIOException() throws IOException { + void create_index_with_IOException() throws IOException { when(restClient.indices().create(any(), any())).thenThrow(IOException.class); assertThrows(IllegalStateException.class, @@ -141,7 +149,7 @@ void createIndexWithIOException() throws IOException { } @Test - void getIndexMappings() throws IOException { + void get_index_mappings() throws IOException { URL url = Resources.getResource(TEST_MAPPING_FILE); String mappings = Resources.toString(url, Charsets.UTF_8); String indexName = "test"; @@ -216,14 +224,14 @@ void getIndexMappings() throws IOException { } @Test - void getIndexMappingsWithIOException() throws IOException { + void get_index_mappings_with_IOException() throws IOException { when(restClient.indices().getMapping(any(GetMappingsRequest.class), any())) .thenThrow(new IOException()); assertThrows(IllegalStateException.class, () -> client.getIndexMappings("test")); } @Test - void getIndexMaxResultWindowsSettings() throws IOException { + void get_index_max_result_windows_settings() throws IOException { String indexName = "test"; Integer maxResultWindow = 1000; @@ -247,7 +255,7 @@ void getIndexMaxResultWindowsSettings() throws IOException { } @Test - void getIndexMaxResultWindowsDefaultSettings() throws IOException { + void get_index_max_result_windows_default_settings() throws IOException { String indexName = "test"; Integer maxResultWindow = 10000; @@ -271,7 +279,7 @@ void getIndexMaxResultWindowsDefaultSettings() throws IOException { } @Test - void getIndexMaxResultWindowsWithIOException() throws IOException { + void get_index_max_result_windows_with_IOException() throws IOException { when(restClient.indices().getSettings(any(GetSettingsRequest.class), any())) .thenThrow(new IOException()); assertThrows(IllegalStateException.class, () -> client.getIndexMaxResultWindows("test")); @@ -296,11 +304,12 @@ void search() throws IOException { // Mock second scroll request followed SearchResponse scrollResponse = mock(SearchResponse.class); when(restClient.scroll(any(), any())).thenReturn(scrollResponse); - when(scrollResponse.getScrollId()).thenReturn("scroll456"); when(scrollResponse.getHits()).thenReturn(SearchHits.empty()); // Verify response for first scroll request - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); OpenSearchResponse response1 = client.search(request); assertFalse(response1.isEmpty()); @@ -310,20 +319,23 @@ void search() throws IOException { assertFalse(hits.hasNext()); // Verify response for second scroll request + request.setScrollId("scroll123"); OpenSearchResponse response2 = client.search(request); assertTrue(response2.isEmpty()); } @Test - void searchWithIOException() throws IOException { + void search_with_IOException() throws IOException { when(restClient.search(any(), any())).thenThrow(new IOException()); assertThrows( IllegalStateException.class, - () -> client.search(new OpenSearchScrollRequest("test", factory))); + () -> client.search(new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory))); } @Test - void scrollWithIOException() throws IOException { + void scroll_with_IOException() throws IOException { // Mock first scroll request SearchResponse searchResponse = mock(SearchResponse.class); when(restClient.search(any(), any())).thenReturn(searchResponse); @@ -339,7 +351,9 @@ void scrollWithIOException() throws IOException { when(restClient.scroll(any(), any())).thenThrow(new IOException()); // First request run successfully - OpenSearchScrollRequest scrollRequest = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest scrollRequest = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); client.search(scrollRequest); assertThrows( IllegalStateException.class, () -> client.search(scrollRequest)); @@ -356,32 +370,44 @@ void schedule() { } @Test - void cleanup() throws IOException { - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + @SneakyThrows + void cleanup() { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); request.setScrollId("scroll123"); client.cleanup(request); verify(restClient).clearScroll(any(), any()); - assertFalse(request.isScrollStarted()); + assertFalse(request.isScroll()); } @Test - void cleanupWithoutScrollId() throws IOException { - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + void cleanup_without_scrollId() throws IOException { + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); client.cleanup(request); verify(restClient, never()).clearScroll(any(), any()); } @Test - void cleanupWithIOException() throws IOException { + @SneakyThrows + void cleanup_with_IOException() { when(restClient.clearScroll(any(), any())).thenThrow(new IOException()); - OpenSearchScrollRequest request = new OpenSearchScrollRequest("test", factory); + OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "needClean", true, true); request.setScrollId("scroll123"); assertThrows(IllegalStateException.class, () -> client.cleanup(request)); } @Test - void getIndices() throws IOException { + void get_indices() throws IOException { when(restClient.indices().get(any(GetIndexRequest.class), any(RequestOptions.class))) .thenReturn(getIndexResponse); when(getIndexResponse.getIndices()).thenReturn(new String[] {"index"}); @@ -391,7 +417,7 @@ void getIndices() throws IOException { } @Test - void getIndicesWithIOException() throws IOException { + void get_indices_with_IOException() throws IOException { when(restClient.indices().get(any(GetIndexRequest.class), any(RequestOptions.class))) .thenThrow(new IOException()); assertThrows(IllegalStateException.class, () -> client.indices()); @@ -410,7 +436,7 @@ void meta() throws IOException { } @Test - void metaWithIOException() throws IOException { + void meta_with_IOException() throws IOException { when(restClient.cluster().getSettings(any(), any(RequestOptions.class))) .thenThrow(new IOException()); @@ -418,7 +444,7 @@ void metaWithIOException() throws IOException { } @Test - void mlWithException() { + void ml_with_exception() { assertThrows(UnsupportedOperationException.class, () -> client.getNodeClient()); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java index 4a0c6e24f1..c96782abea 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngineTest.java @@ -18,37 +18,46 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.sql.common.setting.Settings.Key.QUERY_SIZE_LIMIT; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import static org.opensearch.sql.executor.ExecutionEngine.QueryResponse; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import lombok.RequiredArgsConstructor; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.common.response.ResponseListener; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.executor.ExecutionContext; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.executor.protector.OpenSearchExecutionProtector; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.split.Split; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchExecutionEngineTest { @Mock private OpenSearchClient client; @@ -75,14 +84,15 @@ void setUp() { } @Test - void executeSuccessfully() { + void execute_successfully() { List expected = Arrays.asList( tupleValue(of("name", "John", "age", 20)), tupleValue(of("name", "Allen", "age", 30))); FakePhysicalPlan plan = new FakePhysicalPlan(expected.iterator()); when(protector.protect(plan)).thenReturn(plan); - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); List actual = new ArrayList<>(); executor.execute( plan, @@ -104,13 +114,43 @@ public void onFailure(Exception e) { } @Test - void executeWithFailure() { + void execute_with_cursor() { + List expected = + Arrays.asList( + tupleValue(of("name", "John", "age", 20)), tupleValue(of("name", "Allen", "age", 30))); + var plan = new FakePhysicalPlan(expected.iterator()); + when(protector.protect(plan)).thenReturn(plan); + + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); + List actual = new ArrayList<>(); + executor.execute( + plan, + new ResponseListener() { + @Override + public void onResponse(QueryResponse response) { + actual.addAll(response.getResults()); + assertTrue(response.getCursor().toString().startsWith("n:")); + } + + @Override + public void onFailure(Exception e) { + fail("Error occurred during execution", e); + } + }); + + assertEquals(expected, actual); + } + + @Test + void execute_with_failure() { PhysicalPlan plan = mock(PhysicalPlan.class); RuntimeException expected = new RuntimeException("Execution error"); when(plan.hasNext()).thenThrow(expected); when(protector.protect(plan)).thenReturn(plan); - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); AtomicReference actual = new AtomicReference<>(); executor.execute( plan, @@ -130,12 +170,16 @@ public void onFailure(Exception e) { } @Test - void explainSuccessfully() { - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + void explain_successfully() { + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); Settings settings = mock(Settings.class); when(settings.getSettingValue(QUERY_SIZE_LIMIT)).thenReturn(100); - PhysicalPlan plan = new OpenSearchIndexScan(mock(OpenSearchClient.class), - settings, "test", 10000, mock(OpenSearchExprValueFactory.class)); + when(settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + + PhysicalPlan plan = new OpenSearchIndexScan(mock(OpenSearchClient.class), settings, + "test", 10000, mock(OpenSearchExprValueFactory.class)); AtomicReference result = new AtomicReference<>(); executor.explain(plan, new ResponseListener() { @@ -154,8 +198,9 @@ public void onFailure(Exception e) { } @Test - void explainWithFailure() { - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + void explain_with_failure() { + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); PhysicalPlan plan = mock(PhysicalPlan.class); when(plan.accept(any(), any())).thenThrow(IllegalStateException.class); @@ -176,7 +221,7 @@ public void onFailure(Exception e) { } @Test - void callAddSplitAndOpenInOrder() { + void call_add_split_and_open_in_order() { List expected = Arrays.asList( tupleValue(of("name", "John", "age", 20)), tupleValue(of("name", "Allen", "age", 30))); @@ -184,7 +229,8 @@ void callAddSplitAndOpenInOrder() { when(protector.protect(plan)).thenReturn(plan); when(executionContext.getSplit()).thenReturn(Optional.of(split)); - OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector); + OpenSearchExecutionEngine executor = new OpenSearchExecutionEngine(client, protector, + new PlanSerializer(null)); List actual = new ArrayList<>(); executor.execute( plan, @@ -208,12 +254,20 @@ public void onFailure(Exception e) { } @RequiredArgsConstructor - private static class FakePhysicalPlan extends TableScanOperator { + private static class FakePhysicalPlan extends TableScanOperator implements SerializablePlan { private final Iterator it; private boolean hasOpen; private boolean hasClosed; private boolean hasSplit; + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + } + @Override public void open() { super.open(); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java index d4d987a7df..0b9f302ceb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java @@ -8,9 +8,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -19,6 +21,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.executor.protector.ResourceMonitorPlan; +import org.opensearch.sql.planner.SerializablePlan; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -107,4 +110,24 @@ void acceptSuccess() { monitorPlan.accept(visitor, context); verify(plan, times(1)).accept(visitor, context); } + + @Test + void getTotalHitsSuccess() { + monitorPlan.getTotalHits(); + verify(plan, times(1)).getTotalHits(); + } + + @Test + void getPlanForSerialization() { + plan = mock(PhysicalPlan.class, withSettings().extraInterfaces(SerializablePlan.class)); + monitorPlan = new ResourceMonitorPlan(plan, resourceMonitor); + assertEquals(plan, monitorPlan.getPlanForSerialization()); + } + + @Test + void notSerializable() { + // ResourceMonitorPlan shouldn't be serialized, attempt should throw an exception + assertThrows(UnsupportedOperationException.class, () -> monitorPlan.writeExternal(null)); + assertThrows(UnsupportedOperationException.class, () -> monitorPlan.readExternal(null)); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index f1fcaf677f..fe0077914e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -11,6 +11,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; +import static org.opensearch.sql.common.setting.Settings.Key.QUERY_SIZE_LIMIT; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; @@ -37,11 +39,11 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -59,7 +61,7 @@ import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.opensearch.setting.OpenSearchSettings; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.planner.physical.NestedOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; @@ -88,8 +90,9 @@ public void setup() { @Test public void testProtectIndexScan() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - + when(settings.getSettingValue(QUERY_SIZE_LIMIT)).thenReturn(200); + when(settings.getSettingValue(SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); String indexName = "test"; Integer maxResultWindow = 10000; NamedExpression include = named("age", ref("age", INTEGER)); @@ -124,9 +127,10 @@ public void testProtectIndexScan() { PhysicalPlanDSL.agg( filter( resourceMonitor( - new OpenSearchIndexScan( - client, settings, indexName, - maxResultWindow, exprValueFactory)), + new OpenSearchIndexScan(client, settings, + indexName, + maxResultWindow, + exprValueFactory)), filterExpr), aggregators, groupByExprs), @@ -152,9 +156,10 @@ public void testProtectIndexScan() { PhysicalPlanDSL.rename( PhysicalPlanDSL.agg( filter( - new OpenSearchIndexScan( - client, settings, indexName, - maxResultWindow, exprValueFactory), + new OpenSearchIndexScan(client, settings, + indexName, + maxResultWindow, + exprValueFactory), filterExpr), aggregators, groupByExprs), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilderTest.java new file mode 100644 index 0000000000..5cabe1930d --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestBuilderTest.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +public class ContinuePageRequestBuilderTest { + + @Mock + private OpenSearchExprValueFactory exprValueFactory; + + @Mock + private Settings settings; + + private final OpenSearchRequest.IndexName indexName = new OpenSearchRequest.IndexName("test"); + private final String scrollId = "scroll"; + + private ContinuePageRequestBuilder requestBuilder; + + @BeforeEach + void setup() { + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + requestBuilder = new ContinuePageRequestBuilder( + indexName, scrollId, settings, exprValueFactory); + } + + @Test + public void build() { + assertEquals( + new ContinuePageRequest(scrollId, TimeValue.timeValueMinutes(1), exprValueFactory), + requestBuilder.build() + ); + } + + @Test + public void getIndexName() { + assertEquals(indexName, requestBuilder.getIndexName()); + } + + @Test + public void pushDown_not_supported() { + assertAll( + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownFilter(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownAggregation(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownSort(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownLimit(1, 2)), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownHighlight("", Map.of())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownProjects(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushTypeMapping(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownNested(List.of())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownTrackedScore(true)) + ); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestTest.java new file mode 100644 index 0000000000..e991fc5787 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/ContinuePageRequestTest.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.function.Consumer; +import java.util.function.Function; +import lombok.SneakyThrows; +import org.apache.commons.lang3.reflect.FieldUtils; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class ContinuePageRequestTest { + + @Mock + private Function searchAction; + + @Mock + private Function scrollAction; + + @Mock + private Consumer cleanAction; + + @Mock + private SearchResponse searchResponse; + + @Mock + private SearchHits searchHits; + + @Mock + private SearchHit searchHit; + + @Mock + private OpenSearchExprValueFactory factory; + + private final String scroll = "scroll"; + private final String nextScroll = "nextScroll"; + + private final ContinuePageRequest request = new ContinuePageRequest( + scroll, TimeValue.timeValueMinutes(1), factory); + + @Test + public void search_with_non_empty_response() { + when(scrollAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(new SearchHit[] {searchHit}); + when(searchResponse.getScrollId()).thenReturn(nextScroll); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + assertAll( + () -> assertFalse(searchResponse.isEmpty()), + () -> assertEquals(nextScroll, request.toCursor()), + () -> verify(scrollAction, times(1)).apply(any()), + () -> verify(searchAction, never()).apply(any()) + ); + } + + @Test + // Empty response means scroll search is done and no cursor/scroll should be set + public void search_with_empty_response() { + when(scrollAction.apply(any())).thenReturn(searchResponse); + when(searchResponse.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(null); + lenient().when(searchResponse.getScrollId()).thenReturn(nextScroll); + + OpenSearchResponse searchResponse = request.search(searchAction, scrollAction); + assertAll( + () -> assertTrue(searchResponse.isEmpty()), + () -> assertNull(request.toCursor()), + () -> verify(scrollAction, times(1)).apply(any()), + () -> verify(searchAction, never()).apply(any()) + ); + } + + @Test + @SneakyThrows + public void clean() { + request.clean(cleanAction); + verify(cleanAction, never()).accept(any()); + // Enforce cleaning by setting a private field. + FieldUtils.writeField(request, "scrollFinished", true, true); + request.clean(cleanAction); + verify(cleanAction, times(1)).accept(any()); + } + + @Test + // Added for coverage only + public void getters() { + factory = mock(); + assertAll( + () -> assertThrows(Throwable.class, request::getSourceBuilder), + () -> assertSame(factory, new ContinuePageRequest("", null, factory).getExprValueFactory()) + ); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilderTest.java new file mode 100644 index 0000000000..ef850380d4 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/InitialPageRequestBuilderTest.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +public class InitialPageRequestBuilderTest { + + @Mock + private OpenSearchExprValueFactory exprValueFactory; + + @Mock + private Settings settings; + + private final int pageSize = 42; + + private final OpenSearchRequest.IndexName indexName = new OpenSearchRequest.IndexName("test"); + + private InitialPageRequestBuilder requestBuilder; + + @BeforeEach + void setup() { + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + requestBuilder = new InitialPageRequestBuilder( + indexName, pageSize, settings, exprValueFactory); + } + + @Test + public void build() { + assertEquals( + new OpenSearchScrollRequest(indexName, TimeValue.timeValueMinutes(1), + new SearchSourceBuilder() + .from(0) + .size(pageSize) + .timeout(DEFAULT_QUERY_TIMEOUT), + exprValueFactory), + requestBuilder.build() + ); + } + + @Test + public void pushDown_not_supported() { + assertAll( + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownFilter(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownAggregation(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownSort(mock())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownLimit(1, 2)), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownHighlight("", Map.of())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownNested(List.of())), + () -> assertThrows(UnsupportedOperationException.class, + () -> requestBuilder.pushDownTrackedScore(true)) + ); + } + + @Test + public void pushTypeMapping() { + Map typeMapping = Map.of("intA", OpenSearchDataType.of(INTEGER)); + requestBuilder.pushTypeMapping(typeMapping); + + verify(exprValueFactory).extendTypeMapping(typeMapping); + } + + @Test + public void pushDownProject() { + Set references = Set.of(DSL.ref("intA", INTEGER)); + requestBuilder.pushDownProjects(references); + + assertEquals( + new OpenSearchScrollRequest(indexName, TimeValue.timeValueMinutes(1), + new SearchSourceBuilder() + .from(0) + .size(pageSize) + .timeout(DEFAULT_QUERY_TIMEOUT) + .fetchSource(new String[]{"intA"}, new String[0]), + exprValueFactory), + requestBuilder.build() + ); + } + + @Test + public void getIndexName() { + assertEquals(indexName, requestBuilder.getIndexName()); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java index be83622578..adb2a16a84 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchQueryRequestTest.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder.DEFAULT_QUERY_TIMEOUT; import java.util.Iterator; import java.util.function.Consumer; @@ -146,7 +147,7 @@ void searchRequest() { new SearchRequest() .indices("test") .source(new SearchSourceBuilder() - .timeout(OpenSearchQueryRequest.DEFAULT_QUERY_TIMEOUT) + .timeout(DEFAULT_QUERY_TIMEOUT) .from(0) .size(200) .query(QueryBuilders.termQuery("name", "John"))), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 187f319d44..94433c29b9 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -23,6 +23,8 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.join.ScoreMode; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -52,6 +54,7 @@ import org.opensearch.sql.planner.logical.LogicalNested; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) public class OpenSearchRequestBuilderTest { private static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); @@ -70,13 +73,15 @@ public class OpenSearchRequestBuilderTest { @BeforeEach void setup() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); requestBuilder = new OpenSearchRequestBuilder( "test", MAX_RESULT_WINDOW, settings, exprValueFactory); } @Test - void buildQueryRequest() { + void build_query_request() { Integer limit = 200; Integer offset = 0; requestBuilder.pushDownLimit(limit, offset); @@ -95,14 +100,14 @@ void buildQueryRequest() { } @Test - void buildScrollRequestWithCorrectSize() { + void build_scroll_request_with_correct_size() { Integer limit = 800; Integer offset = 10; requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchScrollRequest( - new OpenSearchRequest.IndexName("test"), + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), new SearchSourceBuilder() .from(offset) .size(MAX_RESULT_WINDOW - offset) @@ -112,9 +117,9 @@ void buildScrollRequestWithCorrectSize() { } @Test - void testPushDownQuery() { + void test_push_down_query() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); - requestBuilder.pushDown(query); + requestBuilder.pushDownFilter(query); assertEquals( new SearchSourceBuilder() @@ -128,7 +133,7 @@ void testPushDownQuery() { } @Test - void testPushDownAggregation() { + void test_push_down_aggregation() { AggregationBuilder aggBuilder = AggregationBuilders.composite( "composite_buckets", Collections.singletonList(new TermsValuesSourceBuilder("longA"))); @@ -149,9 +154,9 @@ void testPushDownAggregation() { } @Test - void testPushDownQueryAndSort() { + void test_push_down_query_and_sort() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); - requestBuilder.pushDown(query); + requestBuilder.pushDownFilter(query); FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); requestBuilder.pushDownSort(List.of(sortBuilder)); @@ -167,7 +172,7 @@ void testPushDownQueryAndSort() { } @Test - void testPushDownSort() { + void test_push_down_sort() { FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); requestBuilder.pushDownSort(List.of(sortBuilder)); @@ -181,7 +186,7 @@ void testPushDownSort() { } @Test - void testPushDownNonFieldSort() { + void test_push_down_non_field_sort() { ScoreSortBuilder sortBuilder = SortBuilders.scoreSort(); requestBuilder.pushDownSort(List.of(sortBuilder)); @@ -195,7 +200,7 @@ void testPushDownNonFieldSort() { } @Test - void testPushDownMultipleSort() { + void test_push_down_multiple_sort() { requestBuilder.pushDownSort(List.of( SortBuilders.fieldSort("intA"), SortBuilders.fieldSort("intB"))); @@ -211,7 +216,7 @@ void testPushDownMultipleSort() { } @Test - void testPushDownProject() { + void test_push_down_project() { Set references = Set.of(DSL.ref("intA", INTEGER)); requestBuilder.pushDownProjects(references); @@ -225,7 +230,7 @@ void testPushDownProject() { } @Test - void testPushDownNested() { + void test_push_down_nested() { List> args = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), @@ -255,7 +260,7 @@ void testPushDownNested() { } @Test - void testPushDownMultipleNestedWithSamePath() { + void test_push_down_multiple_nested_with_same_path() { List> args = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), @@ -288,7 +293,7 @@ void testPushDownMultipleNestedWithSamePath() { } @Test - void testPushDownNestedWithFilter() { + void test_push_down_nested_with_filter() { List> args = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), @@ -325,7 +330,7 @@ void testPushDownNestedWithFilter() { } @Test - void testPushTypeMapping() { + void test_push_type_mapping() { Map typeMapping = Map.of("intA", OpenSearchDataType.of(INTEGER)); requestBuilder.pushTypeMapping(typeMapping); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestTest.java new file mode 100644 index 0000000000..d0a274ce2a --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestTest.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.request; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.withSettings; + +import org.junit.jupiter.api.Test; + +public class OpenSearchRequestTest { + + @Test + void toCursor() { + var request = mock(OpenSearchRequest.class, withSettings().defaultAnswer(CALLS_REAL_METHODS)); + assertEquals("", request.toCursor()); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java index b3c049ce03..461184e6d5 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchScrollRequestTest.java @@ -8,13 +8,20 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; +import org.apache.lucene.search.TotalHits; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -22,6 +29,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -31,6 +39,7 @@ import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchScrollRequestTest { @Mock @@ -56,8 +65,9 @@ class OpenSearchScrollRequestTest { @Mock private OpenSearchExprValueFactory factory; - private final OpenSearchScrollRequest request = - new OpenSearchScrollRequest("test", factory); + private final OpenSearchScrollRequest request = new OpenSearchScrollRequest( + new OpenSearchRequest.IndexName("test"), TimeValue.timeValueMinutes(1), + new SearchSourceBuilder(), factory); @Test void searchRequest() { @@ -66,17 +76,20 @@ void searchRequest() { assertEquals( new SearchRequest() .indices("test") - .scroll(OpenSearchScrollRequest.DEFAULT_SCROLL_TIMEOUT) + .scroll(TimeValue.timeValueMinutes(1)) .source(new SearchSourceBuilder().query(QueryBuilders.termQuery("name", "John"))), request.searchRequest()); } @Test void isScrollStarted() { - assertFalse(request.isScrollStarted()); + assertFalse(request.isScroll()); request.setScrollId("scroll123"); - assertTrue(request.isScrollStarted()); + assertTrue(request.isScroll()); + + request.reset(); + assertFalse(request.isScroll()); } @Test @@ -84,7 +97,7 @@ void scrollRequest() { request.setScrollId("scroll123"); assertEquals( new SearchScrollRequest() - .scroll(OpenSearchScrollRequest.DEFAULT_SCROLL_TIMEOUT) + .scroll(TimeValue.timeValueMinutes(1)) .scrollId("scroll123"), request.scrollRequest()); } @@ -93,6 +106,7 @@ void scrollRequest() { void search() { OpenSearchScrollRequest request = new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), sourceBuilder, factory ); @@ -113,6 +127,7 @@ void search() { void search_withoutContext() { OpenSearchScrollRequest request = new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), sourceBuilder, factory ); @@ -131,6 +146,7 @@ void search_withoutContext() { void search_withoutIncludes() { OpenSearchScrollRequest request = new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), + TimeValue.timeValueMinutes(1), sourceBuilder, factory ); @@ -145,4 +161,60 @@ void search_withoutIncludes() { verify(fetchSourceContext, times(1)).includes(); assertFalse(searchResponse.isEmpty()); } + + @Test + void toCursor() { + request.setScrollId("scroll123"); + assertEquals("scroll123", request.toCursor()); + + request.reset(); + assertNull(request.toCursor()); + } + + @Test + void clean_on_empty_response() { + // This could happen on sequential search calls + SearchResponse searchResponse = mock(); + when(searchResponse.getScrollId()).thenReturn("scroll1", "scroll2"); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[1], new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1F), + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals("scroll1", request.getScrollId()); + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals("scroll1", request.getScrollId()); + + AtomicBoolean cleanCalled = new AtomicBoolean(false); + request.clean((s) -> cleanCalled.set(true)); + + assertNull(request.getScrollId()); + assertTrue(cleanCalled.get()); + } + + @Test + void no_clean_on_non_empty_response() { + SearchResponse searchResponse = mock(); + when(searchResponse.getScrollId()).thenReturn("scroll"); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[1], new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1F)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertEquals("scroll", request.getScrollId()); + + request.clean((s) -> fail()); + assertNull(request.getScrollId()); + } + + @Test + void no_clean_if_no_scroll_in_response() { + SearchResponse searchResponse = mock(); + when(searchResponse.getHits()).thenReturn( + new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1F)); + + request.search((x) -> searchResponse, (x) -> searchResponse); + assertNull(request.getScrollId()); + + request.clean((s) -> fail()); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 65568cf5f1..8add6c8c85 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -80,20 +80,29 @@ void isEmpty() { new TotalHits(2L, TotalHits.Relation.EQUAL_TO), 1.0F)); - assertFalse(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); + var response = new OpenSearchResponse(searchResponse, factory, includes); + assertFalse(response.isEmpty()); + assertEquals(2L, response.getTotalHits()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(null); - assertTrue(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); + + response = new OpenSearchResponse(searchResponse, factory, includes); + assertTrue(response.isEmpty()); + assertEquals(0L, response.getTotalHits()); when(searchResponse.getHits()) .thenReturn(new SearchHits(null, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0)); - OpenSearchResponse response3 = new OpenSearchResponse(searchResponse, factory, includes); - assertTrue(response3.isEmpty()); + response = new OpenSearchResponse(searchResponse, factory, includes); + assertTrue(response.isEmpty()); + assertEquals(0L, response.getTotalHits()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(new Aggregations(emptyList())); - assertFalse(new OpenSearchResponse(searchResponse, factory, includes).isEmpty()); + + response = new OpenSearchResponse(searchResponse, factory, includes); + assertFalse(response.isEmpty()); + assertEquals(0L, response.getTotalHits()); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 3d856cb1e2..2ff1de862b 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; @@ -41,6 +42,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; @@ -56,6 +58,12 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.mapping.IndexMapping; +import org.opensearch.sql.opensearch.request.InitialPageRequestBuilder; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.request.PagedRequestBuilder; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchPagedIndexScan; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; @@ -201,30 +209,48 @@ void getReservedFieldTypes() { @Test void implementRelationOperatorOnly() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - assertEquals( - new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), - index.implement(plan)); + assertEquals(new OpenSearchIndexScan(client, settings, indexName, + maxResultWindow, exprValueFactory), index.implement(index.optimize(plan))); + } + + @Test + void implementPagedRelationOperatorOnly() { + when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); + + LogicalPlan plan = index.createPagedScanBuilder(42); + Integer maxResultWindow = index.getMaxResultWindow(); + PagedRequestBuilder builder = new InitialPageRequestBuilder( + new OpenSearchRequest.IndexName(indexName), + maxResultWindow, mock(), exprValueFactory); + assertEquals(new OpenSearchPagedIndexScan(client, builder), index.implement(plan)); } @Test void implementRelationOperatorWithOptimization() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); - assertEquals( - new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), - index.implement(index.optimize(plan))); + assertEquals(new OpenSearchIndexScan(client, settings, indexName, + maxResultWindow, exprValueFactory), index.implement(plan)); } @Test void implementOtherLogicalOperators() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); NamedExpression include = named("age", ref("age", INTEGER)); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index ab87f4531c..1089e7e252 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.storage; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; @@ -35,7 +36,10 @@ public void getTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); Table table = engine.getTable(new DataSourceSchemaName(DEFAULT_DATASOURCE_NAME, "default"), "test"); - assertNotNull(table); + assertAll( + () -> assertNotNull(table), + () -> assertTrue(table instanceof OpenSearchIndex) + ); } @Test @@ -43,7 +47,9 @@ public void getSystemTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); Table table = engine.getTable(new DataSourceSchemaName(DEFAULT_DATASOURCE_NAME, "default"), TABLE_INFO); - assertNotNull(table); - assertTrue(table instanceof OpenSearchSystemIndex); + assertAll( + () -> assertNotNull(table), + () -> assertTrue(table instanceof OpenSearchSystemIndex) + ); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java index fa98f5a3b9..bde940a939 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -78,7 +78,6 @@ import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; -import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalNested; @@ -702,7 +701,7 @@ private void assertEqualsAfterOptimization(LogicalPlan expected, LogicalPlan act } private Runnable withFilterPushedDown(QueryBuilder filteringCondition) { - return () -> verify(requestBuilder, times(1)).pushDown(filteringCondition); + return () -> verify(requestBuilder, times(1)).pushDownFilter(filteringCondition); } private Runnable withAggregationPushedDown( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java similarity index 61% rename from opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java rename to opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java index 8aec6a7d13..c788e78f1a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java @@ -4,12 +4,14 @@ */ -package org.opensearch.sql.opensearch.storage; +package org.opensearch.sql.opensearch.storage.scan; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -21,6 +23,8 @@ import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -28,6 +32,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.stubbing.Answer; import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; @@ -43,9 +48,11 @@ import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchIndexScanTest { @Mock @@ -61,122 +68,166 @@ class OpenSearchIndexScanTest { @BeforeEach void setup() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) + .thenReturn(TimeValue.timeValueMinutes(1)); } @Test - void queryEmptyResult() { - mockResponse(); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "test", 3, exprValueFactory)) { + void query_empty_result() { + mockResponse(client); + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "test", 3, exprValueFactory)) { indexScan.open(); - assertFalse(indexScan.hasNext()); + assertAll( + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(0, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void queryAllResultsWithQuery() { - mockResponse(new ExprValue[]{ + void query_all_results_with_query() { + mockResponse(client, new ExprValue[]{ employee(1, "John", "IT"), employee(2, "Smith", "HR"), employee(3, "Allen", "IT")}); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 10, exprValueFactory)) { + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "employees", 10, exprValueFactory)) { indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void queryAllResultsWithScroll() { - mockResponse( + void query_all_results_with_scroll() { + mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT")}); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 2, exprValueFactory)) { + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "employees", 10, exprValueFactory)) { indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void querySomeResultsWithQuery() { - mockResponse(new ExprValue[]{ + void query_some_results_with_query() { + mockResponse(client, new ExprValue[]{ employee(1, "John", "IT"), employee(2, "Smith", "HR"), employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 10, exprValueFactory)) { + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "employees", 10, exprValueFactory)) { indexScan.getRequestBuilder().pushDownLimit(3, 0); indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void querySomeResultsWithScroll() { - mockResponse( + void query_some_results_with_scroll() { + mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "employees", 2, exprValueFactory)) { + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "employees", 2, exprValueFactory)) { indexScan.getRequestBuilder().pushDownLimit(3, 0); indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void pushDownFilters() { + void query_results_limited_by_query_size() { + mockResponse(client, new ExprValue[]{ + employee(1, "John", "IT"), + employee(2, "Smith", "HR"), + employee(3, "Allen", "IT"), + employee(4, "Bob", "HR")}); + when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(2); + + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "employees", 10, exprValueFactory)) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(2, indexScan.getTotalHits()) + ); + } + verify(client).cleanup(any()); + } + + @Test + void push_down_filters() { assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) .shouldQuery(QueryBuilders.termQuery("name", "John")) @@ -194,7 +245,7 @@ void pushDownFilters() { } @Test - void pushDownHighlight() { + void push_down_highlight() { Map args = new HashMap<>(); assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) @@ -205,7 +256,7 @@ void pushDownHighlight() { } @Test - void pushDownHighlightWithArguments() { + void push_down_highlight_with_arguments() { Map args = new HashMap<>(); args.put("pre_tags", new Literal("", DataType.STRING)); args.put("post_tags", new Literal("", DataType.STRING)); @@ -220,13 +271,13 @@ void pushDownHighlightWithArguments() { } @Test - void pushDownHighlightWithRepeatingFields() { - mockResponse( + void push_down_highlight_with_repeating_fields() { + mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); - try (OpenSearchIndexScan indexScan = - new OpenSearchIndexScan(client, settings, "test", 2, exprValueFactory)) { + try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, + "test", 2, exprValueFactory)) { indexScan.getRequestBuilder().pushDownLimit(3, 0); indexScan.open(); Map args = new HashMap<>(); @@ -252,14 +303,15 @@ public PushDownAssertion(OpenSearchClient client, OpenSearchExprValueFactory valueFactory, Settings settings) { this.client = client; - this.indexScan = new OpenSearchIndexScan(client, settings, "test", 10000, valueFactory); + this.indexScan = new OpenSearchIndexScan(client, settings, + "test", 10000, valueFactory); this.response = mock(OpenSearchResponse.class); this.factory = valueFactory; when(response.isEmpty()).thenReturn(true); } PushDownAssertion pushDown(QueryBuilder query) { - indexScan.getRequestBuilder().pushDown(query); + indexScan.getRequestBuilder().pushDownFilter(query); return this; } @@ -290,7 +342,7 @@ PushDownAssertion shouldQuery(QueryBuilder expected) { } } - private void mockResponse(ExprValue[]... searchHitBatches) { + public static void mockResponse(OpenSearchClient client, ExprValue[]... searchHitBatches) { when(client.search(any())) .thenAnswer( new Answer() { @@ -304,6 +356,9 @@ public OpenSearchResponse answer(InvocationOnMock invocation) { when(response.isEmpty()).thenReturn(false); ExprValue[] searchHit = searchHitBatches[batchNum]; when(response.iterator()).thenReturn(Arrays.asList(searchHit).iterator()); + // used in OpenSearchPagedIndexScanTest + lenient().when(response.getTotalHits()) + .thenReturn((long) searchHitBatches[batchNum].length); } else { when(response.isEmpty()).thenReturn(true); } @@ -314,14 +369,14 @@ public OpenSearchResponse answer(InvocationOnMock invocation) { }); } - protected ExprValue employee(int docId, String name, String department) { + public static ExprValue employee(int docId, String name, String department) { SearchHit hit = new SearchHit(docId); hit.sourceRef( new BytesArray("{\"name\":\"" + name + "\",\"department\":\"" + department + "\"}")); return tupleValue(hit); } - private ExprValue tupleValue(SearchHit hit) { + private static ExprValue tupleValue(SearchHit hit) { return ExprValueUtils.tupleValue(hit.getSourceAsMap()); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java new file mode 100644 index 0000000000..cd94154012 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java @@ -0,0 +1,215 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.employee; +import static org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanTest.mockResponse; + +import com.google.common.collect.ImmutableMap; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.NoCursorException; +import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.ContinuePageRequestBuilder; +import org.opensearch.sql.opensearch.request.InitialPageRequestBuilder; +import org.opensearch.sql.opensearch.request.OpenSearchRequest; +import org.opensearch.sql.opensearch.request.PagedRequestBuilder; +import org.opensearch.sql.opensearch.response.OpenSearchResponse; +import org.opensearch.sql.opensearch.storage.OpenSearchStorageEngine; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class OpenSearchPagedIndexScanTest { + @Mock + private OpenSearchClient client; + + private final OpenSearchExprValueFactory exprValueFactory = new OpenSearchExprValueFactory( + ImmutableMap.of( + "name", OpenSearchDataType.of(STRING), + "department", OpenSearchDataType.of(STRING))); + + @Test + void query_empty_result() { + mockResponse(client); + InitialPageRequestBuilder builder = new InitialPageRequestBuilder( + new OpenSearchRequest.IndexName("test"), 3, mock(), exprValueFactory); + try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { + indexScan.open(); + assertFalse(indexScan.hasNext()); + } + verify(client).cleanup(any()); + } + + @Test + void query_all_results_initial_scroll_request() { + mockResponse(client, new ExprValue[]{ + employee(1, "John", "IT"), + employee(2, "Smith", "HR"), + employee(3, "Allen", "IT")}); + + PagedRequestBuilder builder = new InitialPageRequestBuilder( + new OpenSearchRequest.IndexName("test"), 3, mock(), exprValueFactory); + try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); + } + verify(client).cleanup(any()); + + builder = new ContinuePageRequestBuilder( + new OpenSearchRequest.IndexName("test"), "scroll", mock(), exprValueFactory); + try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { + indexScan.open(); + + assertFalse(indexScan.hasNext()); + } + verify(client, times(2)).cleanup(any()); + } + + @Test + void query_all_results_continuation_scroll_request() { + mockResponse(client, new ExprValue[]{ + employee(1, "John", "IT"), + employee(2, "Smith", "HR"), + employee(3, "Allen", "IT")}); + + ContinuePageRequestBuilder builder = new ContinuePageRequestBuilder( + new OpenSearchRequest.IndexName("test"), "scroll", mock(), exprValueFactory); + try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { + indexScan.open(); + + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), + + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), + + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); + } + verify(client).cleanup(any()); + + builder = new ContinuePageRequestBuilder( + new OpenSearchRequest.IndexName("test"), "scroll", mock(), exprValueFactory); + try (OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder)) { + indexScan.open(); + + assertFalse(indexScan.hasNext()); + } + verify(client, times(2)).cleanup(any()); + } + + @Test + void explain_not_implemented() { + assertThrows(Throwable.class, () -> mock(OpenSearchPagedIndexScan.class, + withSettings().defaultAnswer(CALLS_REAL_METHODS)).explain()); + } + + @Test + @SneakyThrows + void serialization() { + PagedRequestBuilder builder = mock(); + OpenSearchRequest request = mock(); + OpenSearchResponse response = mock(); + when(request.toCursor()).thenReturn("cu-cursor"); + when(builder.build()).thenReturn(request); + var indexName = new OpenSearchRequest.IndexName("index"); + when(builder.getIndexName()).thenReturn(indexName); + when(client.search(any())).thenReturn(response); + OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder); + indexScan.open(); + + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(indexScan); + objectOutput.flush(); + + when(client.getIndexMappings(any())).thenReturn(Map.of()); + OpenSearchStorageEngine engine = mock(); + when(engine.getClient()).thenReturn(client); + when(engine.getSettings()).thenReturn(mock()); + ObjectInputStream objectInput = new PlanSerializer(engine) + .getCursorDeserializationStream(new ByteArrayInputStream(output.toByteArray())); + var roundTripScan = (OpenSearchPagedIndexScan) objectInput.readObject(); + roundTripScan.open(); + + // indexScan's request could be a OpenSearchScrollRequest or a ContinuePageRequest, but + // roundTripScan's request is always a ContinuePageRequest + // Thus, we can't compare those scans + //assertEquals(indexScan, roundTripScan); + // But we can validate that index name and scroll was serialized-deserialized correctly + assertEquals(indexName, roundTripScan.getRequestBuilder().getIndexName()); + assertTrue(roundTripScan.getRequestBuilder() instanceof ContinuePageRequestBuilder); + assertEquals("cu-cursor", + ((ContinuePageRequestBuilder) roundTripScan.getRequestBuilder()).getScrollId()); + } + + @Test + @SneakyThrows + void dont_serialize_if_no_cursor() { + PagedRequestBuilder builder = mock(); + OpenSearchRequest request = mock(); + OpenSearchResponse response = mock(); + when(builder.build()).thenReturn(request); + when(client.search(any())).thenReturn(response); + OpenSearchPagedIndexScan indexScan = new OpenSearchPagedIndexScan(client, builder); + indexScan.open(); + + when(request.toCursor()).thenReturn(null, ""); + for (int i = 0; i < 2; i++) { + assertThrows(NoCursorException.class, () -> { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(indexScan); + objectOutput.flush(); + }); + } + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScanTest.java index 494f3ff2d0..c04ef25611 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/system/OpenSearchSystemIndexScanTest.java @@ -32,6 +32,7 @@ public void queryData() { systemIndexScan.open(); assertTrue(systemIndexScan.hasNext()); assertEquals(stringValue("text"), systemIndexScan.next()); + assertEquals(1, systemIndexScan.getTotalHits()); } @Test diff --git a/plugin/build.gradle b/plugin/build.gradle index e318103859..4a6c175d61 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -246,6 +246,7 @@ afterEvaluate { testClusters.integTest { plugin(project.tasks.bundlePlugin.archiveFile) + testDistribution = "ARCHIVE" // debug with command, ./gradlew opensearch-sql:run -DdebugJVM. --debug-jvm does not work with keystore. if (System.getProperty("debugJVM") != null) { diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java index 5ab4bbaecd..b80cb3faab 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/config/OpenSearchPluginModule.java @@ -18,6 +18,7 @@ import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.monitor.ResourceMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; @@ -58,8 +59,9 @@ public StorageEngine storageEngine(OpenSearchClient client, Settings settings) { } @Provides - public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector) { - return new OpenSearchExecutionEngine(client, protector); + public ExecutionEngine executionEngine(OpenSearchClient client, ExecutionProtector protector, + PlanSerializer planSerializer) { + return new OpenSearchExecutionEngine(client, protector, planSerializer); } @Provides @@ -72,6 +74,11 @@ public ExecutionProtector protector(ResourceMonitor resourceMonitor) { return new OpenSearchExecutionProtector(resourceMonitor); } + @Provides + public PlanSerializer paginatedPlanCache(StorageEngine storageEngine) { + return new PlanSerializer(storageEngine); + } + @Provides @Singleton public QueryManager queryManager(NodeClient nodeClient) { @@ -92,12 +99,15 @@ public SQLService sqlService(QueryManager queryManager, QueryPlanFactory queryPl * {@link QueryPlanFactory}. */ @Provides - public QueryPlanFactory queryPlanFactory( - DataSourceService dataSourceService, ExecutionEngine executionEngine) { + public QueryPlanFactory queryPlanFactory(DataSourceService dataSourceService, + ExecutionEngine executionEngine, + PlanSerializer planSerializer) { Analyzer analyzer = new Analyzer( new ExpressionAnalyzer(functionRepository), dataSourceService, functionRepository); Planner planner = new Planner(LogicalPlanOptimizer.create()); - return new QueryPlanFactory(new QueryService(analyzer, executionEngine, planner)); + QueryService queryService = new QueryService( + analyzer, executionEngine, planner); + return new QueryPlanFactory(queryService, planSerializer); } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index a5c094e956..acac65bd54 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -139,7 +139,8 @@ private ResponseListener createListener( @Override public void onResponse(ExecutionEngine.QueryResponse response) { String responseContent = - formatter.format(new QueryResult(response.getSchema(), response.getResults())); + formatter.format(new QueryResult(response.getSchema(), response.getResults(), + response.getCursor(), response.getTotal())); listener.onResponse(new TransportPPLQueryResponse(responseContent)); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java index e11edc1646..f91ac7222f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/PPLService.java @@ -90,6 +90,7 @@ private AbstractPlan plan( QueryContext.getRequestId(), anonymizer.anonymizeStatement(statement)); - return queryExecutionFactory.create(statement, queryListener, explainListener); + return queryExecutionFactory.createContinuePaginatedPlan( + statement, queryListener, explainListener); } } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index e4f40e9a11..3b7e5a78dd 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -33,7 +33,7 @@ public class AstStatementBuilder extends OpenSearchPPLParserBaseVisitor { ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + listener.onResponse(new QueryResponse(schema, Collections.emptyList(), 0, Cursor.None)); return null; }).when(queryService).execute(any(), any()); @@ -87,7 +93,7 @@ public void onFailure(Exception e) { public void testExecuteCsvFormatShouldPass() { doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + listener.onResponse(new QueryResponse(schema, Collections.emptyList(), 0, Cursor.None)); return null; }).when(queryService).execute(any(), any()); @@ -161,7 +167,7 @@ public void onFailure(Exception e) { public void testPrometheusQuery() { doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); + listener.onResponse(new QueryResponse(schema, Collections.emptyList(), 0, Cursor.None)); return null; }).when(queryService).execute(any(), any()); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java index 4760024692..de74e4932f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstStatementBuilderTest.java @@ -39,7 +39,8 @@ public void buildQueryStatement() { "search source=t a=1", new Query( project( - filter(relation("t"), compare("=", field("a"), intLiteral(1))), AllFields.of()))); + filter(relation("t"), compare("=", field("a"), + intLiteral(1))), AllFields.of()), 0)); } @Test @@ -50,7 +51,7 @@ public void buildExplainStatement() { new Query( project( filter(relation("t"), compare("=", field("a"), intLiteral(1))), - AllFields.of())))); + AllFields.of()), 0))); } private void assertEqual(String query, Statement expectedStatement) { diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java index 915a61f361..d06dba7719 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java @@ -16,6 +16,7 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.executor.ExecutionEngine.Schema.Column; +import org.opensearch.sql.executor.pagination.Cursor; /** * Query response that encapsulates query results and isolate {@link ExprValue} @@ -32,6 +33,16 @@ public class QueryResult implements Iterable { */ private final Collection exprValues; + @Getter + private final Cursor cursor; + + @Getter + private final long total; + + + public QueryResult(ExecutionEngine.Schema schema, Collection exprValues) { + this(schema, exprValues, Cursor.None, exprValues.size()); + } /** * size of results. diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java index 943287cb62..b9a2d2fcc6 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java @@ -15,6 +15,7 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.QueryEngineException; import org.opensearch.sql.executor.ExecutionEngine.Schema; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.opensearch.response.error.ErrorMessage; import org.opensearch.sql.opensearch.response.error.ErrorMessageFactory; import org.opensearch.sql.protocol.response.QueryResult; @@ -39,9 +40,12 @@ protected Object buildJsonObject(QueryResult response) { json.datarows(fetchDataRows(response)); // Populate other fields - json.total(response.size()) + json.total(response.getTotal()) .size(response.size()) .status(200); + if (!response.getCursor().equals(Cursor.None)) { + json.cursor(response.getCursor().toString()); + } return json.build(); } @@ -95,6 +99,8 @@ public static class JdbcResponse { private final long total; private final long size; private final int status; + + private final String cursor; } @RequiredArgsConstructor diff --git a/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java b/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java index 319965e2d0..470bb205a8 100644 --- a/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java +++ b/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java @@ -19,6 +19,7 @@ import java.util.Collections; import org.junit.jupiter.api.Test; import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.executor.pagination.Cursor; class QueryResultTest { @@ -35,7 +36,7 @@ void size() { tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Allen", "age", 30)), tupleValue(ImmutableMap.of("name", "Smith", "age", 40)) - )); + ), Cursor.None, 0); assertEquals(3, response.size()); } @@ -45,7 +46,7 @@ void columnNameTypes() { schema, Collections.singletonList( tupleValue(ImmutableMap.of("name", "John", "age", 20)) - )); + ), Cursor.None, 0); assertEquals( ImmutableMap.of("name", "string", "age", "integer"), @@ -59,7 +60,8 @@ void columnNameTypesWithAlias() { new ExecutionEngine.Schema.Column("name", "n", STRING))); QueryResult response = new QueryResult( schema, - Collections.singletonList(tupleValue(ImmutableMap.of("n", "John")))); + Collections.singletonList(tupleValue(ImmutableMap.of("n", "John"))), + Cursor.None, 0); assertEquals( ImmutableMap.of("n", "string"), @@ -71,7 +73,7 @@ void columnNameTypesWithAlias() { void columnNameTypesFromEmptyExprValues() { QueryResult response = new QueryResult( schema, - Collections.emptyList()); + Collections.emptyList(), Cursor.None, 0); assertEquals( ImmutableMap.of("name", "string", "age", "integer"), response.columnNameTypes() @@ -100,7 +102,7 @@ void iterate() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Allen", "age", 30)) - )); + ), Cursor.None, 0); int i = 0; for (Object[] objects : response) { diff --git a/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java b/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java index a6671c66f8..047e297c26 100644 --- a/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java +++ b/protocol/src/test/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatterTest.java @@ -31,6 +31,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.executor.pagination.Cursor; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.protocol.response.QueryResult; @@ -83,6 +84,37 @@ void format_response() { formatter.format(response)); } + @Test + void format_response_with_cursor() { + QueryResult response = new QueryResult( + new Schema(ImmutableList.of( + new Column("name", "name", STRING), + new Column("address", "address", OpenSearchTextType.of()), + new Column("age", "age", INTEGER))), + ImmutableList.of( + tupleValue(ImmutableMap.builder() + .put("name", "John") + .put("address", "Seattle") + .put("age", 20) + .build())), + new Cursor("test_cursor"), 42); + + assertJsonEquals( + "{" + + "\"schema\":[" + + "{\"name\":\"name\",\"alias\":\"name\",\"type\":\"keyword\"}," + + "{\"name\":\"address\",\"alias\":\"address\",\"type\":\"text\"}," + + "{\"name\":\"age\",\"alias\":\"age\",\"type\":\"integer\"}" + + "]," + + "\"datarows\":[" + + "[\"John\",\"Seattle\",20]]," + + "\"total\":42," + + "\"size\":1," + + "\"cursor\":\"test_cursor\"," + + "\"status\":200}", + formatter.format(response)); + } + @Test void format_response_with_missing_and_null_value() { QueryResult response = diff --git a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java index 082a3e9581..4ecf9e699b 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/SQLService.java +++ b/sql/src/main/java/org/opensearch/sql/sql/SQLService.java @@ -65,16 +65,24 @@ private AbstractPlan plan( SQLQueryRequest request, Optional> queryListener, Optional> explainListener) { - // 1.Parse query and convert parse tree (CST) to abstract syntax tree (AST) - ParseTree cst = parser.parse(request.getQuery()); - Statement statement = - cst.accept( - new AstStatementBuilder( - new AstBuilder(request.getQuery()), - AstStatementBuilder.StatementBuilderContext.builder() - .isExplain(request.isExplainRequest()) - .build())); + if (request.getCursor().isPresent()) { + // Handle v2 cursor here -- legacy cursor was handled earlier. + return queryExecutionFactory.createContinuePaginatedPlan(request.getCursor().get(), + request.isExplainRequest(), queryListener.orElse(null), explainListener.orElse(null)); + } else { + // 1.Parse query and convert parse tree (CST) to abstract syntax tree (AST) + ParseTree cst = parser.parse(request.getQuery()); + Statement statement = + cst.accept( + new AstStatementBuilder( + new AstBuilder(request.getQuery()), + AstStatementBuilder.StatementBuilderContext.builder() + .isExplain(request.isExplainRequest()) + .fetchSize(request.getFetchSize()) + .build())); - return queryExecutionFactory.create(statement, queryListener, explainListener); + return queryExecutionFactory.createContinuePaginatedPlan( + statement, queryListener, explainListener); + } } } diff --git a/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java b/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java index 508f80cee4..7545f4cc19 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java +++ b/sql/src/main/java/org/opensearch/sql/sql/domain/SQLQueryRequest.java @@ -6,13 +6,12 @@ package org.opensearch.sql.sql.domain; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableSet; import java.util.Collections; import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Stream; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -28,9 +27,9 @@ @EqualsAndHashCode @RequiredArgsConstructor public class SQLQueryRequest { - - private static final Set SUPPORTED_FIELDS = ImmutableSet.of( - "query", "fetch_size", "parameters"); + private static final String QUERY_FIELD_CURSOR = "cursor"; + private static final Set SUPPORTED_FIELDS = Set.of( + "query", "fetch_size", "parameters", QUERY_FIELD_CURSOR); private static final String QUERY_PARAMS_FORMAT = "format"; private static final String QUERY_PARAMS_SANITIZE = "sanitize"; @@ -64,36 +63,50 @@ public class SQLQueryRequest { @Accessors(fluent = true) private boolean sanitize = true; + private String cursor; + /** * Constructor of SQLQueryRequest that passes request params. */ - public SQLQueryRequest( - JSONObject jsonContent, String query, String path, Map params) { + public SQLQueryRequest(JSONObject jsonContent, String query, String path, + Map params, String cursor) { this.jsonContent = jsonContent; this.query = query; this.path = path; this.params = params; this.format = getFormat(params); this.sanitize = shouldSanitize(params); + this.cursor = cursor; } /** * Pre-check if the request can be supported by meeting ALL the following criteria: * 1.Only supported fields present in request body, ex. "filter" and "cursor" are not supported - * 2.No fetch_size or "fetch_size=0". In other word, it's not a cursor request - * 3.Response format is default or can be supported. + * 2.Response format is default or can be supported. * - * @return true if supported. + * @return true if supported. */ public boolean isSupported() { - return isOnlySupportedFieldInPayload() - && isFetchSizeZeroIfPresent() - && isSupportedFormat(); + var noCursor = !isCursor(); + var noQuery = query == null; + var noUnsupportedParams = params.isEmpty() + || (params.size() == 1 && params.containsKey(QUERY_PARAMS_FORMAT)); + var noContent = jsonContent == null || jsonContent.isEmpty(); + + return ((!noCursor && noQuery + && noUnsupportedParams && noContent) // if cursor is given, but other things + || (noCursor && !noQuery)) // or if cursor is not given, but query + && isOnlySupportedFieldInPayload() // and request has supported fields only + && isSupportedFormat(); // and request is in supported format + } + + private boolean isCursor() { + return cursor != null && !cursor.isEmpty(); } /** * Check if request is to explain rather than execute the query. - * @return true if it is a explain request + * @return true if it is an explain request */ public boolean isExplainRequest() { return path.endsWith("/_explain"); @@ -113,23 +126,23 @@ public Format format() { } private boolean isOnlySupportedFieldInPayload() { - return SUPPORTED_FIELDS.containsAll(jsonContent.keySet()); + return jsonContent == null || SUPPORTED_FIELDS.containsAll(jsonContent.keySet()); } - private boolean isFetchSizeZeroIfPresent() { - return (jsonContent.optInt("fetch_size") == 0); + public Optional getCursor() { + return Optional.ofNullable(cursor); + } + + public int getFetchSize() { + return jsonContent.optInt("fetch_size"); } private boolean isSupportedFormat() { - return Strings.isNullOrEmpty(format) || "jdbc".equalsIgnoreCase(format) - || "csv".equalsIgnoreCase(format) || "raw".equalsIgnoreCase(format); + return Stream.of("csv", "jdbc", "raw").anyMatch(format::equalsIgnoreCase); } private String getFormat(Map params) { - if (params.containsKey(QUERY_PARAMS_FORMAT)) { - return params.get(QUERY_PARAMS_FORMAT); - } - return "jdbc"; + return params.getOrDefault(QUERY_PARAMS_FORMAT, "jdbc"); } private boolean shouldSanitize(Map params) { diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java index 40d549764a..593e7b51ff 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstStatementBuilder.java @@ -26,7 +26,7 @@ public class AstStatementBuilder extends OpenSearchSQLParserBaseVisitor { - ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); - return null; - }).when(queryService).execute(any(), any()); - + public void can_execute_sql_query() { sqlService.execute( new SQLQueryRequest(new JSONObject(), "SELECT 123", QUERY, "jdbc"), - new ResponseListener() { + new ResponseListener<>() { @Override public void onResponse(QueryResponse response) { assertNotNull(response); @@ -84,13 +82,24 @@ public void onFailure(Exception e) { } @Test - public void canExecuteCsvFormatRequest() { - doAnswer(invocation -> { - ResponseListener listener = invocation.getArgument(1); - listener.onResponse(new QueryResponse(schema, Collections.emptyList())); - return null; - }).when(queryService).execute(any(), any()); + public void can_execute_cursor_query() { + sqlService.execute( + new SQLQueryRequest(new JSONObject(), null, QUERY, Map.of("format", "jdbc"), "n:cursor"), + new ResponseListener<>() { + @Override + public void onResponse(QueryResponse response) { + assertNotNull(response); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); + } + @Test + public void can_execute_csv_format_request() { sqlService.execute( new SQLQueryRequest(new JSONObject(), "SELECT 123", QUERY, "csv"), new ResponseListener() { @@ -107,7 +116,7 @@ public void onFailure(Exception e) { } @Test - public void canExplainSqlQuery() { + public void can_explain_sql_query() { doAnswer(invocation -> { ResponseListener listener = invocation.getArgument(1); listener.onResponse(new ExplainResponse(new ExplainResponseNode("Test"))); @@ -129,7 +138,25 @@ public void onFailure(Exception e) { } @Test - public void canCaptureErrorDuringExecution() { + public void cannot_explain_cursor_query() { + sqlService.explain(new SQLQueryRequest(new JSONObject(), null, EXPLAIN, + Map.of("format", "jdbc"), "n:cursor"), + new ResponseListener() { + @Override + public void onResponse(ExplainResponse response) { + fail(response.toString()); + } + + @Override + public void onFailure(Exception e) { + assertTrue(e.getMessage() + .contains("`explain` request for cursor requests is not supported.")); + } + }); + } + + @Test + public void can_capture_error_during_execution() { sqlService.execute( new SQLQueryRequest(new JSONObject(), "SELECT", QUERY, ""), new ResponseListener() { @@ -146,7 +173,7 @@ public void onFailure(Exception e) { } @Test - public void canCaptureErrorDuringExplain() { + public void can_capture_error_during_explain() { sqlService.explain( new SQLQueryRequest(new JSONObject(), "SELECT", EXPLAIN, ""), new ResponseListener() { @@ -161,5 +188,4 @@ public void onFailure(Exception e) { } }); } - } diff --git a/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java b/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java index 52a1f534e9..62bb665537 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/domain/SQLQueryRequestTest.java @@ -6,36 +6,43 @@ package org.opensearch.sql.sql.domain; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.Map; import org.json.JSONObject; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.opensearch.sql.protocol.response.format.Format; +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) public class SQLQueryRequestTest { @Test - public void shouldSupportQuery() { + public void should_support_query() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1").build(); assertTrue(request.isSupported()); } @Test - public void shouldSupportQueryWithJDBCFormat() { + public void should_support_query_with_JDBC_format() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .format("jdbc") .build(); - assertTrue(request.isSupported()); - assertEquals(request.format(), Format.JDBC); + assertAll( + () -> assertTrue(request.isSupported()), + () -> assertEquals(request.format(), Format.JDBC) + ); } @Test - public void shouldSupportQueryWithQueryFieldOnly() { + public void should_support_query_with_query_field_only() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\"}") @@ -44,16 +51,32 @@ public void shouldSupportQueryWithQueryFieldOnly() { } @Test - public void shouldSupportQueryWithParameters() { - SQLQueryRequest request = + public void should_support_query_with_parameters() { + SQLQueryRequest requestWithContent = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"parameters\":[]}") .build(); - assertTrue(request.isSupported()); + SQLQueryRequest requestWithParams = + SQLQueryRequestBuilder.request("SELECT 1") + .params(Map.of("one", "two")) + .build(); + assertAll( + () -> assertTrue(requestWithContent.isSupported()), + () -> assertTrue(requestWithParams.isSupported()) + ); + } + + @Test + public void should_support_query_without_parameters() { + SQLQueryRequest requestWithNoParams = + SQLQueryRequestBuilder.request("SELECT 1") + .params(Map.of()) + .build(); + assertTrue(requestWithNoParams.isSupported()); } @Test - public void shouldSupportQueryWithZeroFetchSize() { + public void should_support_query_with_zero_fetch_size() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 0}") @@ -62,7 +85,7 @@ public void shouldSupportQueryWithZeroFetchSize() { } @Test - public void shouldSupportQueryWithParametersAndZeroFetchSize() { + public void should_support_query_with_parameters_and_zero_fetch_size() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 0, \"parameters\":[]}") @@ -71,70 +94,155 @@ public void shouldSupportQueryWithParametersAndZeroFetchSize() { } @Test - public void shouldSupportExplain() { + public void should_support_explain() { SQLQueryRequest explainRequest = SQLQueryRequestBuilder.request("SELECT 1") .path("_plugins/_sql/_explain") .build(); - assertTrue(explainRequest.isExplainRequest()); - assertTrue(explainRequest.isSupported()); + + assertAll( + () -> assertTrue(explainRequest.isExplainRequest()), + () -> assertTrue(explainRequest.isSupported()) + ); } @Test - public void shouldNotSupportCursorRequest() { + public void should_support_cursor_request() { SQLQueryRequest fetchSizeRequest = SQLQueryRequestBuilder.request("SELECT 1") .jsonContent("{\"query\": \"SELECT 1\", \"fetch_size\": 5}") .build(); - assertFalse(fetchSizeRequest.isSupported()); SQLQueryRequest cursorRequest = + SQLQueryRequestBuilder.request(null) + .cursor("abcdefgh...") + .build(); + + assertAll( + () -> assertTrue(fetchSizeRequest.isSupported()), + () -> assertTrue(cursorRequest.isSupported()) + ); + } + + @Test + public void should_not_support_request_with_empty_cursor() { + SQLQueryRequest requestWithEmptyCursor = + SQLQueryRequestBuilder.request(null) + .cursor("") + .build(); + SQLQueryRequest requestWithNullCursor = + SQLQueryRequestBuilder.request(null) + .cursor(null) + .build(); + assertAll( + () -> assertFalse(requestWithEmptyCursor.isSupported()), + () -> assertFalse(requestWithNullCursor.isSupported()) + ); + } + + @Test + public void should_not_support_request_with_unknown_field() { + SQLQueryRequest request = + SQLQueryRequestBuilder.request("SELECT 1") + .jsonContent("{\"pewpew\": 42}") + .build(); + assertFalse(request.isSupported()); + } + + @Test + public void should_not_support_request_with_cursor_and_something_else() { + SQLQueryRequest requestWithQuery = SQLQueryRequestBuilder.request("SELECT 1") - .jsonContent("{\"cursor\": \"abcdefgh...\"}") + .cursor("n:12356") + .build(); + SQLQueryRequest requestWithParams = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of("one", "two")) + .build(); + SQLQueryRequest requestWithParamsWithFormat = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of("format", "jdbc")) .build(); - assertFalse(cursorRequest.isSupported()); + SQLQueryRequest requestWithParamsWithFormatAnd = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of("format", "jdbc", "something", "else")) + .build(); + SQLQueryRequest requestWithFetchSize = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .jsonContent("{\"fetch_size\": 5}") + .build(); + SQLQueryRequest requestWithNoParams = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .params(Map.of()) + .build(); + SQLQueryRequest requestWithNoContent = + SQLQueryRequestBuilder.request(null) + .cursor("n:12356") + .jsonContent("{}") + .build(); + assertAll( + () -> assertFalse(requestWithQuery.isSupported()), + () -> assertFalse(requestWithParams.isSupported()), + () -> assertFalse(requestWithFetchSize.isSupported()), + () -> assertTrue(requestWithNoParams.isSupported()), + () -> assertTrue(requestWithParamsWithFormat.isSupported()), + () -> assertFalse(requestWithParamsWithFormatAnd.isSupported()), + () -> assertTrue(requestWithNoContent.isSupported()) + ); } @Test - public void shouldUseJDBCFormatByDefault() { + public void should_use_JDBC_format_by_default() { SQLQueryRequest request = SQLQueryRequestBuilder.request("SELECT 1").params(ImmutableMap.of()).build(); assertEquals(request.format(), Format.JDBC); } @Test - public void shouldSupportCSVFormatAndSanitize() { + public void should_support_CSV_format_and_sanitize() { SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1") .format("csv") .build(); - assertTrue(csvRequest.isSupported()); - assertEquals(csvRequest.format(), Format.CSV); - assertTrue(csvRequest.sanitize()); + assertAll( + () -> assertTrue(csvRequest.isSupported()), + () -> assertEquals(csvRequest.format(), Format.CSV), + () -> assertTrue(csvRequest.sanitize()) + ); } @Test - public void shouldSkipSanitizeIfSetFalse() { + public void should_skip_sanitize_if_set_false() { ImmutableMap.Builder builder = ImmutableMap.builder(); Map params = builder.put("format", "csv").put("sanitize", "false").build(); SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1").params(params).build(); - assertEquals(csvRequest.format(), Format.CSV); - assertFalse(csvRequest.sanitize()); + assertAll( + () -> assertEquals(csvRequest.format(), Format.CSV), + () -> assertFalse(csvRequest.sanitize()) + ); } @Test - public void shouldNotSupportOtherFormat() { + public void should_not_support_other_format() { SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1") .format("other") .build(); - assertFalse(csvRequest.isSupported()); - assertThrows(IllegalArgumentException.class, csvRequest::format, - "response in other format is not supported."); + + assertAll( + () -> assertFalse(csvRequest.isSupported()), + () -> assertEquals("response in other format is not supported.", + assertThrows(IllegalArgumentException.class, csvRequest::format).getMessage()) + ); } @Test - public void shouldSupportRawFormat() { + public void should_support_raw_format() { SQLQueryRequest csvRequest = SQLQueryRequestBuilder.request("SELECT 1") .format("raw") @@ -150,7 +258,8 @@ private static class SQLQueryRequestBuilder { private String query; private String path = "_plugins/_sql"; private String format; - private Map params; + private String cursor; + private Map params = new HashMap<>(); static SQLQueryRequestBuilder request(String query) { SQLQueryRequestBuilder builder = new SQLQueryRequestBuilder(); @@ -178,14 +287,17 @@ SQLQueryRequestBuilder params(Map params) { return this; } + SQLQueryRequestBuilder cursor(String cursor) { + this.cursor = cursor; + return this; + } + SQLQueryRequest build() { - if (jsonContent == null) { - jsonContent = "{\"query\": \"" + query + "\"}"; - } - if (params != null) { - return new SQLQueryRequest(new JSONObject(jsonContent), query, path, params); + if (format != null) { + params.put("format", format); } - return new SQLQueryRequest(new JSONObject(jsonContent), query, path, format); + return new SQLQueryRequest(jsonContent == null ? null : new JSONObject(jsonContent), + query, path, params, cursor); } }