diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index 0b08a194f7522..b51bea511e067 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -76,8 +76,8 @@ import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.aggregations.bucket.terms.LongKeyedBucketOrds; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.optimization.ranges.DateHistogramAggregatorBridge; -import org.opensearch.search.optimization.ranges.OptimizationContext; +import org.opensearch.search.optimization.filterrewrite.CompositeAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import org.opensearch.search.searchafter.SearchAfterBuilder; import org.opensearch.search.sort.SortAndFormats; @@ -94,6 +94,7 @@ import java.util.stream.Collectors; import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING; +import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; /** * Main aggregator that aggregates docs from multiple aggregations @@ -166,7 +167,7 @@ public final class CompositeAggregator extends BucketsAggregator { this.queue = new CompositeValuesCollectorQueue(context.bigArrays(), sources, size, rawAfterKey); this.rawAfterKey = rawAfterKey; - optimizationContext = new OptimizationContext(new DateHistogramAggregatorBridge() { + optimizationContext = new OptimizationContext(new CompositeAggregatorBridge() { private RoundingValuesSource valuesSource; private long afterKey = -1L; @@ -217,14 +218,9 @@ protected int getSize() { } @Override - protected Function bucketOrdProducer() { + protected Function bucketOrdProducer() { return (key) -> bucketOrds.add(0, getRoundingPrepared().round((long) key)); } - - @Override - protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException { - return segmentMatchAll(context, leaf); - } }); if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { optimizationContext.prepare(); @@ -563,7 +559,7 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount); + boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); finishLeaf(); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java index e4c35371fdc90..9263d3935538d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java @@ -58,8 +58,8 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.optimization.ranges.DateHistogramAggregatorBridge; -import org.opensearch.search.optimization.ranges.OptimizationContext; +import org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import java.io.IOException; import java.util.Collections; @@ -68,6 +68,8 @@ import java.util.function.Function; import java.util.function.LongToIntFunction; +import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; + /** * An aggregator for date values that attempts to return a specific number of * buckets, reconfiguring how it rounds dates to buckets on the fly as new @@ -198,14 +200,10 @@ protected Prepared getRoundingPrepared() { } @Override - protected Function bucketOrdProducer() { + protected Function bucketOrdProducer() { return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key)); } - @Override - protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException { - return segmentMatchAll(context, leaf); - } }); if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { optimizationContext.prepare(); @@ -241,7 +239,7 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount); + boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDocValues values = valuesSource.longValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index bc951d023a314..20f62f4d6e3f8 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -52,8 +52,8 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.optimization.ranges.DateHistogramAggregatorBridge; -import org.opensearch.search.optimization.ranges.OptimizationContext; +import org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; import java.io.IOException; import java.util.Collections; @@ -61,6 +61,8 @@ import java.util.function.BiConsumer; import java.util.function.Function; +import static org.opensearch.search.optimization.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; + /** * An aggregator for date values. Every date is rounded down using a configured * {@link Rounding}. @@ -144,14 +146,10 @@ protected long[] processHardBounds(long[] bounds) { } @Override - protected Function bucketOrdProducer() { + protected Function bucketOrdProducer() { return (key) -> bucketOrds.add(0, preparedRounding.round((long) key)); } - @Override - protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException { - return segmentMatchAll(context, leaf); - } }); if (optimizationContext.canOptimize(parent, subAggregators.length, context)) { optimizationContext.prepare(); @@ -172,7 +170,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount); + boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); if (optimized) throw new CollectionTerminatedException(); SortedNumericDocValues values = valuesSource.longValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index 8f250f8a43b85..c206d1e522e01 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -58,8 +58,8 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.optimization.ranges.OptimizationContext; -import org.opensearch.search.optimization.ranges.RangeAggregatorBridge; +import org.opensearch.search.optimization.filterrewrite.OptimizationContext; +import org.opensearch.search.optimization.filterrewrite.RangeAggregatorBridge; import java.io.IOException; import java.util.ArrayList; @@ -284,12 +284,12 @@ public RangeAggregator( optimizationContext = new OptimizationContext(new RangeAggregatorBridge() { @Override public boolean canOptimize() { - return canOptimize(config, RangeAggregator.this.ranges); + return canOptimize(config, ranges); } @Override public void prepare() { - buildRanges(RangeAggregator.this.ranges); + buildRanges(ranges); } @Override @@ -312,7 +312,7 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount); + boolean optimized = optimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java new file mode 100644 index 0000000000000..9e1c75e659989 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/AggregatorBridge.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.opensearch.index.mapper.MappedFieldType; + +import java.io.IOException; +import java.util.function.BiConsumer; + +/** + * This interface provides a bridge between an aggregator and the optimization context, allowing + * the aggregator to provide data and optimize the aggregation process. + * + *

The main purpose of this interface is to encapsulate the aggregator-specific optimization + * logic and provide access to the data in Aggregator that is required for optimization, while keeping the optimization + * business logic separate from the aggregator implementation. + * + *

To use this interface to optimize an aggregator, you should subclass this interface in this package + * and put any specific optimization business logic in it. Then implement this subclass in the aggregator + * to provide data that is needed for doing the optimization + * + * @opensearch.internal + */ +public abstract class AggregatorBridge { + + /** + * The optimization context associated with this aggregator bridge. + */ + OptimizationContext optimizationContext; + + /** + * The field type associated with this aggregator bridge. + */ + MappedFieldType fieldType; + + void setOptimizationContext(OptimizationContext context) { + this.optimizationContext = context; + } + + /** + * Checks whether the aggregator can be optimized. + * + * @return {@code true} if the aggregator can be optimized, {@code false} otherwise. + * The result will be saved in the optimization context. + */ + public abstract boolean canOptimize(); + + /** + * Prepares the optimization at shard level. + * For example, figure out what are the ranges from the aggregation to do the optimization later + */ + public abstract void prepare() throws IOException; + + /** + * Prepares the optimization for a specific segment and ignore whatever built at shard level + * + * @param leaf the leaf reader context for the segment + */ + public abstract void prepareFromSegment(LeafReaderContext leaf) throws IOException; + + /** + * Attempts to build aggregation results for a segment + * + * @param values the point values (index structure for numeric values) for a segment + * @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket + */ + public abstract void tryOptimize(PointValues values, BiConsumer incrementDocCount) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java new file mode 100644 index 0000000000000..1982793332605 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/CompositeAggregatorBridge.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig; +import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource; + +/** + * For composite aggregation to do optimization when it only has a single date histogram source + */ +public abstract class CompositeAggregatorBridge extends DateHistogramAggregatorBridge { + protected boolean canOptimize(CompositeValuesSourceConfig[] sourceConfigs) { + if (sourceConfigs.length != 1 || !(sourceConfigs[0].valuesSource() instanceof RoundingValuesSource)) return false; + return canOptimize(sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript(), sourceConfigs[0].fieldType()); + } + + private boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType fieldType) { + if (!missing && !hasScript) { + if (fieldType instanceof DateFieldMapper.DateFieldType) { + if (fieldType.isSearchable()) { + this.fieldType = fieldType; + return true; + } + } + } + return false; + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/ranges/DateHistogramAggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java similarity index 82% rename from server/src/main/java/org/opensearch/search/optimization/ranges/DateHistogramAggregatorBridge.java rename to server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java index 369cd9bdd02a7..da53e4aa73684 100644 --- a/server/src/main/java/org/opensearch/search/optimization/ranges/DateHistogramAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/DateHistogramAggregatorBridge.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.search.optimization.ranges; +package org.opensearch.search.optimization.filterrewrite; import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.LeafReaderContext; @@ -16,8 +16,6 @@ import org.opensearch.common.Rounding; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig; -import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource; import org.opensearch.search.aggregations.bucket.histogram.LongBounds; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; @@ -25,26 +23,15 @@ import java.io.IOException; import java.util.OptionalLong; import java.util.function.BiConsumer; +import java.util.function.Function; -import static org.opensearch.search.optimization.ranges.Helper.multiRangesTraverse; +import static org.opensearch.search.optimization.filterrewrite.TreeTraversal.multiRangesTraverse; /** * For date histogram aggregation */ public abstract class DateHistogramAggregatorBridge extends AggregatorBridge { - protected boolean canOptimize(boolean missing, boolean hasScript, MappedFieldType fieldType) { - if (!missing && !hasScript) { - if (fieldType instanceof DateFieldMapper.DateFieldType) { - if (fieldType.isSearchable()) { - this.fieldType = fieldType; - return true; - } - } - } - return false; - } - protected boolean canOptimize(ValuesSourceConfig config) { if (config.script() == null && config.missing() == null) { MappedFieldType fieldType = config.fieldType(); @@ -58,11 +45,6 @@ protected boolean canOptimize(ValuesSourceConfig config) { return false; } - protected boolean canOptimize(CompositeValuesSourceConfig[] sourceConfigs) { - if (sourceConfigs.length != 1 || !(sourceConfigs[0].valuesSource() instanceof RoundingValuesSource)) return false; - return canOptimize(sourceConfigs[0].missingBucket(), sourceConfigs[0].hasScript(), sourceConfigs[0].fieldType()); - } - protected void buildRanges(SearchContext context) throws IOException { long[] bounds = Helper.getDateHistoAggBounds(context, fieldType.name()); optimizationContext.setRanges(buildRanges(bounds)); @@ -165,7 +147,22 @@ private static long getBucketOrd(long bucketOrd) { return bucketOrd; } - protected boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException { + /** + * Provides a function to produce bucket ordinals from the lower bound of the range + */ + protected abstract Function bucketOrdProducer(); + + /** + * Checks whether the top level query matches all documents on the segment + * + *

This method creates a weight from the search context's query and checks whether the weight's + * document count matches the total number of documents in the leaf reader context. + * + * @param ctx the search context + * @param leafCtx the leaf reader context for the segment + * @return {@code true} if the segment matches all documents, {@code false} otherwise + */ + public static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException { Weight weight = ctx.query().rewrite(ctx.searcher()).createWeight(ctx.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1f); return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs(); } diff --git a/server/src/main/java/org/opensearch/search/optimization/ranges/Helper.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java similarity index 54% rename from server/src/main/java/org/opensearch/search/optimization/ranges/Helper.java rename to server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java index a06182f803683..eb57cd90b9ad9 100644 --- a/server/src/main/java/org/opensearch/search/optimization/ranges/Helper.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Helper.java @@ -6,38 +6,32 @@ * compatible open source license. */ -package org.opensearch.search.optimization.ranges; +package org.opensearch.search.optimization.filterrewrite; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; -import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.ConstantScoreQuery; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.PointRangeQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.NumericUtils; -import org.opensearch.common.CheckedRunnable; import org.opensearch.common.Rounding; import org.opensearch.common.lucene.search.function.FunctionScoreQuery; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.query.DateRangeIncludingNowQuery; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.optimization.ranges.OptimizationContext.DebugInfo; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.BiConsumer; import java.util.function.Function; import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * Utility class to help range filters rewrite optimization @@ -46,6 +40,8 @@ */ final class Helper { + private Helper() {} + static final String loggerName = Helper.class.getPackageName(); private static final Logger logger = LogManager.getLogger(loggerName); @@ -60,8 +56,6 @@ final class Helper { queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery) q).getIndexQuery()); } - private Helper() {} - /** * Recursively unwraps query into the concrete form * for applying the optimization @@ -216,183 +210,4 @@ static Ranges createRangesFromAgg( return new Ranges(lowers, uppers); } - - /** - * @param maxNumNonZeroRanges the number of non-zero ranges to collect - */ - static DebugInfo multiRangesTraverse( - final PointValues.PointTree tree, - final Ranges ranges, - final BiConsumer incrementDocCount, - final int maxNumNonZeroRanges - ) throws IOException { - DebugInfo debugInfo = new DebugInfo(); - int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); - if (activeIndex < 0) { - logger.debug("No ranges match the query, skip the fast filter optimization"); - return debugInfo; - } - RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex); - PointValues.IntersectVisitor visitor = getIntersectVisitor(collector); - try { - intersectWithRanges(visitor, tree, collector, debugInfo); - } catch (CollectionTerminatedException e) { - logger.debug("Early terminate since no more range to collect"); - } - collector.finalizePreviousRange(); - - return debugInfo; - } - - private static void intersectWithRanges( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - RangeCollectorForPointTree collector, - DebugInfo debug - ) throws IOException { - PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - - switch (r) { - case CELL_INSIDE_QUERY: - collector.countNode((int) pointTree.size()); - debug.visitInner(); - break; - case CELL_CROSSES_QUERY: - if (pointTree.moveToChild()) { - do { - intersectWithRanges(visitor, pointTree, collector, debug); - } while (pointTree.moveToSibling()); - pointTree.moveToParent(); - } else { - pointTree.visitDocValues(visitor); - debug.visitLeaf(); - } - break; - case CELL_OUTSIDE_QUERY: - } - } - - private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) { - return new PointValues.IntersectVisitor() { - @Override - public void visit(int docID) { - // this branch should be unreachable - throw new UnsupportedOperationException( - "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited" - ); - } - - @Override - public void visit(int docID, byte[] packedValue) throws IOException { - visitPoints(packedValue, collector::count); - } - - @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - visitPoints(packedValue, () -> { - for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { - collector.count(); - } - }); - } - - private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException { - if (!collector.withinUpperBound(packedValue)) { - collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(packedValue)) { - throw new CollectionTerminatedException(); - } - } - - if (collector.withinRange(packedValue)) { - collect.run(); - } - } - - @Override - public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - // try to find the first range that may collect values from this cell - if (!collector.withinUpperBound(minPackedValue)) { - collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(minPackedValue)) { - throw new CollectionTerminatedException(); - } - } - // after the loop, min < upper - // cell could be outside [min max] lower - if (!collector.withinLowerBound(maxPackedValue)) { - return PointValues.Relation.CELL_OUTSIDE_QUERY; - } - if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) { - return PointValues.Relation.CELL_INSIDE_QUERY; - } - return PointValues.Relation.CELL_CROSSES_QUERY; - } - }; - } - - private static class RangeCollectorForPointTree { - private final BiConsumer incrementRangeDocCount; - private int counter = 0; - - private final Ranges ranges; - private int activeIndex; - - private int visitedRange = 0; - private final int maxNumNonZeroRange; - - public RangeCollectorForPointTree( - BiConsumer incrementRangeDocCount, - int maxNumNonZeroRange, - Ranges ranges, - int activeIndex - ) { - this.incrementRangeDocCount = incrementRangeDocCount; - this.maxNumNonZeroRange = maxNumNonZeroRange; - this.ranges = ranges; - this.activeIndex = activeIndex; - } - - private void count() { - counter++; - } - - private void countNode(int count) { - counter += count; - } - - private void finalizePreviousRange() { - if (counter > 0) { - incrementRangeDocCount.accept(activeIndex, counter); - counter = 0; - } - } - - /** - * @return true when iterator exhausted or collect enough non-zero ranges - */ - private boolean iterateRangeEnd(byte[] value) { - // the new value may not be contiguous to the previous one - // so try to find the first next range that cross the new value - while (!withinUpperBound(value)) { - if (++activeIndex >= ranges.size) { - return true; - } - } - visitedRange++; - return visitedRange > maxNumNonZeroRange; - } - - private boolean withinLowerBound(byte[] value) { - return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]); - } - - private boolean withinUpperBound(byte[] value) { - return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]); - } - - private boolean withinRange(byte[] value) { - return withinLowerBound(value) && withinUpperBound(value); - } - } } diff --git a/server/src/main/java/org/opensearch/search/optimization/ranges/OptimizationContext.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java similarity index 90% rename from server/src/main/java/org/opensearch/search/optimization/ranges/OptimizationContext.java rename to server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java index 192d640c7cd8f..d4d5880b37ce1 100644 --- a/server/src/main/java/org/opensearch/search/optimization/ranges/OptimizationContext.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/OptimizationContext.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.search.optimization.ranges; +package org.opensearch.search.optimization.filterrewrite; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -20,7 +20,7 @@ import java.io.IOException; import java.util.function.BiConsumer; -import static org.opensearch.search.optimization.ranges.Helper.loggerName; +import static org.opensearch.search.optimization.filterrewrite.Helper.loggerName; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** @@ -100,8 +100,10 @@ Ranges getRanges() { * Usage: invoked at segment level — in getLeafCollector of aggregator * * @param incrementDocCount consume the doc_count results for certain ordinal + * @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment */ - public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer incrementDocCount) throws IOException { + public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer incrementDocCount, boolean segmentMatchAll) + throws IOException { segments++; if (!canOptimize) { return false; @@ -124,7 +126,7 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer increme multiRangesTraverse(values.getPointTree(), optimizationContext.getRanges(), incrementFunc, size) ); } + + /** + * Provides a function to produce bucket ordinals from index of the corresponding range in the range array + */ + protected abstract Function bucketOrdProducer(); } diff --git a/server/src/main/java/org/opensearch/search/optimization/ranges/Ranges.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java similarity index 96% rename from server/src/main/java/org/opensearch/search/optimization/ranges/Ranges.java rename to server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java index 453d626a2d9e6..ebf4b5c9b2b9c 100644 --- a/server/src/main/java/org/opensearch/search/optimization/ranges/Ranges.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/Ranges.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.search.optimization.ranges; +package org.opensearch.search.optimization.filterrewrite; import org.apache.lucene.util.ArrayUtil; diff --git a/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java new file mode 100644 index 0000000000000..aad833324a841 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/TreeTraversal.java @@ -0,0 +1,224 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.optimization.filterrewrite; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; +import org.opensearch.common.CheckedRunnable; + +import java.io.IOException; +import java.util.function.BiConsumer; + +import static org.opensearch.search.optimization.filterrewrite.Helper.loggerName; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Utility class for traversing a {@link PointValues.PointTree} and collecting document counts for the ranges. + * + *

The main entry point is the {@link #multiRangesTraverse(PointValues.PointTree, Ranges, + * BiConsumer, int)} method + * + *

The class uses a {@link RangeCollectorForPointTree} to keep track of the active ranges and + * determine which parts of the tree to visit. The {@link + * PointValues.IntersectVisitor} implementation is responsible for the actual visitation and + * document count collection. + */ +final class TreeTraversal { + private TreeTraversal() {} + + private static final Logger logger = LogManager.getLogger(loggerName); + + /** + * Traverses the given {@link PointValues.PointTree} and collects document counts for the intersecting ranges. + * + * @param tree the point tree to traverse + * @param ranges the set of ranges to intersect with + * @param incrementDocCount a callback to increment the document count for a range bucket + * @param maxNumNonZeroRanges the maximum number of non-zero ranges to collect + * @return a {@link OptimizationContext.DebugInfo} object containing debug information about the traversal + */ + static OptimizationContext.DebugInfo multiRangesTraverse( + final PointValues.PointTree tree, + final Ranges ranges, + final BiConsumer incrementDocCount, + final int maxNumNonZeroRanges + ) throws IOException { + OptimizationContext.DebugInfo debugInfo = new OptimizationContext.DebugInfo(); + int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); + if (activeIndex < 0) { + logger.debug("No ranges match the query, skip the fast filter optimization"); + return debugInfo; + } + RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex); + PointValues.IntersectVisitor visitor = getIntersectVisitor(collector); + try { + intersectWithRanges(visitor, tree, collector, debugInfo); + } catch (CollectionTerminatedException e) { + logger.debug("Early terminate since no more range to collect"); + } + collector.finalizePreviousRange(); + + return debugInfo; + } + + private static void intersectWithRanges( + PointValues.IntersectVisitor visitor, + PointValues.PointTree pointTree, + RangeCollectorForPointTree collector, + OptimizationContext.DebugInfo debug + ) throws IOException { + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + + switch (r) { + case CELL_INSIDE_QUERY: + collector.countNode((int) pointTree.size()); + debug.visitInner(); + break; + case CELL_CROSSES_QUERY: + if (pointTree.moveToChild()) { + do { + intersectWithRanges(visitor, pointTree, collector, debug); + } while (pointTree.moveToSibling()); + pointTree.moveToParent(); + } else { + pointTree.visitDocValues(visitor); + debug.visitLeaf(); + } + break; + case CELL_OUTSIDE_QUERY: + } + } + + private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) { + return new PointValues.IntersectVisitor() { + @Override + public void visit(int docID) { + // this branch should be unreachable + throw new UnsupportedOperationException( + "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited" + ); + } + + @Override + public void visit(int docID, byte[] packedValue) throws IOException { + visitPoints(packedValue, collector::count); + } + + @Override + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + visitPoints(packedValue, () -> { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { + collector.count(); + } + }); + } + + private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException { + if (!collector.withinUpperBound(packedValue)) { + collector.finalizePreviousRange(); + if (collector.iterateRangeEnd(packedValue)) { + throw new CollectionTerminatedException(); + } + } + + if (collector.withinRange(packedValue)) { + collect.run(); + } + } + + @Override + public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + // try to find the first range that may collect values from this cell + if (!collector.withinUpperBound(minPackedValue)) { + collector.finalizePreviousRange(); + if (collector.iterateRangeEnd(minPackedValue)) { + throw new CollectionTerminatedException(); + } + } + // after the loop, min < upper + // cell could be outside [min max] lower + if (!collector.withinLowerBound(maxPackedValue)) { + return PointValues.Relation.CELL_OUTSIDE_QUERY; + } + if (collector.withinRange(minPackedValue) && collector.withinRange(maxPackedValue)) { + return PointValues.Relation.CELL_INSIDE_QUERY; + } + return PointValues.Relation.CELL_CROSSES_QUERY; + } + }; + } + + private static class RangeCollectorForPointTree { + private final BiConsumer incrementRangeDocCount; + private int counter = 0; + + private final Ranges ranges; + private int activeIndex; + + private int visitedRange = 0; + private final int maxNumNonZeroRange; + + public RangeCollectorForPointTree( + BiConsumer incrementRangeDocCount, + int maxNumNonZeroRange, + Ranges ranges, + int activeIndex + ) { + this.incrementRangeDocCount = incrementRangeDocCount; + this.maxNumNonZeroRange = maxNumNonZeroRange; + this.ranges = ranges; + this.activeIndex = activeIndex; + } + + private void count() { + counter++; + } + + private void countNode(int count) { + counter += count; + } + + private void finalizePreviousRange() { + if (counter > 0) { + incrementRangeDocCount.accept(activeIndex, counter); + counter = 0; + } + } + + /** + * @return true when iterator exhausted or collect enough non-zero ranges + */ + private boolean iterateRangeEnd(byte[] value) { + // the new value may not be contiguous to the previous one + // so try to find the first next range that cross the new value + while (!withinUpperBound(value)) { + if (++activeIndex >= ranges.size) { + return true; + } + } + visitedRange++; + return visitedRange > maxNumNonZeroRange; + } + + private boolean withinLowerBound(byte[] value) { + return Ranges.withinLowerBound(value, ranges.lowers[activeIndex]); + } + + private boolean withinUpperBound(byte[] value) { + return Ranges.withinUpperBound(value, ranges.uppers[activeIndex]); + } + + private boolean withinRange(byte[] value) { + return withinLowerBound(value) && withinUpperBound(value); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/optimization/ranges/package-info.java b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/package-info.java similarity index 50% rename from server/src/main/java/org/opensearch/search/optimization/ranges/package-info.java rename to server/src/main/java/org/opensearch/search/optimization/filterrewrite/package-info.java index 81cf915ddafa0..7c7385bb6102d 100644 --- a/server/src/main/java/org/opensearch/search/optimization/ranges/package-info.java +++ b/server/src/main/java/org/opensearch/search/optimization/filterrewrite/package-info.java @@ -7,12 +7,13 @@ */ /** - * This package contains optimization for range-type aggregations + * This package contains filter rewrite optimization for range-type aggregations *

* The idea is to *

    *
  • figure out the "ranges" from the aggregation
  • - *
  • leverage the "range filter" to get the result of range bucket quickly
  • + *
  • leverage the ranges and bkd index to get the result of each range bucket quickly
  • *
+ * More details in https://github.com/opensearch-project/OpenSearch/pull/14464 */ -package org.opensearch.search.optimization.ranges; +package org.opensearch.search.optimization.filterrewrite; diff --git a/server/src/main/java/org/opensearch/search/optimization/ranges/AggregatorBridge.java b/server/src/main/java/org/opensearch/search/optimization/ranges/AggregatorBridge.java deleted file mode 100644 index b0e4f95c66366..0000000000000 --- a/server/src/main/java/org/opensearch/search/optimization/ranges/AggregatorBridge.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.optimization.ranges; - -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PointValues; -import org.opensearch.index.mapper.MappedFieldType; - -import java.io.IOException; -import java.util.function.BiConsumer; -import java.util.function.Function; - -/** - * This class holds aggregator-specific optimization logic and - * provides optimization necessary access to the data from Aggregator - *

- * To provide the access to data, instantiate this class inside the aggregator - * and send in data through the implemented methods - *

- * The optimization business logic other than providing data should stay in this package. - * - * @opensearch.internal - */ -public abstract class AggregatorBridge { - - OptimizationContext optimizationContext; - MappedFieldType fieldType; - - void setOptimizationContext(OptimizationContext context) { - this.optimizationContext = context; - } - - /** - * Check whether we can optimize the aggregator - * If not, don't call the other methods - * - * @return result will be saved in optimization context - */ - public abstract boolean canOptimize(); - - public abstract void prepare() throws IOException; - - public abstract void prepareFromSegment(LeafReaderContext leaf) throws IOException; - - public abstract void tryOptimize(PointValues values, BiConsumer incrementDocCount) throws IOException; - - protected abstract Function bucketOrdProducer(); - - protected boolean segmentMatchAll(LeafReaderContext leaf) throws IOException { - return false; - } -}