diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index fcf2a40dada14..bc9e824c87186 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -94,7 +94,7 @@ import java.util.stream.Collectors; import static org.opensearch.search.aggregations.MultiBucketConsumerService.MAX_BUCKET_SETTING; -import static org.opensearch.search.aggregations.bucket.filterrewrite.DateHistogramAggregatorBridge.segmentMatchAll; +import static org.opensearch.search.aggregations.bucket.filterrewrite.AggregatorBridge.segmentMatchAll; /** * Main aggregator that aggregates docs from multiple aggregations @@ -563,14 +563,23 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t } } - @Override - protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { - finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed - return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); - } + // @Override + // protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + // finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed + // return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + // } @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + boolean optimized = filterRewriteOptimizationContext.tryOptimize( + ctx, + this::incrementBucketDocCount, + segmentMatchAll(context, ctx), + collectableSubAggregators, + sub + ); + if (optimized) throw new CollectionTerminatedException(); + finishLeaf(); boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR; diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java index 145a60373b4f3..96e78d634ad84 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/AggregatorBridge.java @@ -8,16 +8,23 @@ package org.opensearch.search.aggregations.bucket.filterrewrite; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; +import org.apache.lucene.util.DocIdSetBuilder; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.search.internal.SearchContext; import java.io.IOException; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse; /** * This interface provides a bridge between an aggregator and the optimization context, allowing @@ -35,6 +42,8 @@ */ public abstract class AggregatorBridge { + static final Logger logger = LogManager.getLogger(Helper.loggerName); + /** * The field type associated with this aggregator bridge. */ @@ -79,12 +88,46 @@ void setRangesConsumer(Consumer setRanges) { * @param incrementDocCount a consumer to increment the document count for a range bucket. The First parameter is document count, the second is the key of the bucket * @param ranges */ - abstract FilterRewriteOptimizationContext.DebugInfo tryOptimize( + abstract FilterRewriteOptimizationContext.OptimizeResult tryOptimize( PointValues values, BiConsumer incrementDocCount, - Ranges ranges + Ranges ranges, + Supplier disBuilderSupplier ) throws IOException; + static FilterRewriteOptimizationContext.OptimizeResult getResult( + PointValues values, + BiConsumer incrementDocCount, + Ranges ranges, + Supplier disBuilderSupplier, + Function getBucketOrd, + int size + ) throws IOException { + BiConsumer incrementFunc = (activeIndex, docCount) -> { + long bucketOrd = getBucketOrd.apply(activeIndex); + incrementDocCount.accept(bucketOrd, (long) docCount); + }; + + PointValues.PointTree tree = values.getPointTree(); + FilterRewriteOptimizationContext.OptimizeResult optimizeResult = new FilterRewriteOptimizationContext.OptimizeResult(); + int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); + if (activeIndex < 0) { + logger.debug("No ranges match the query, skip the fast filter optimization"); + return optimizeResult; + } + PointTreeTraversal.RangeCollectorForPointTree collector = new PointTreeTraversal.RangeCollectorForPointTree( + ranges, + incrementFunc, + size, + activeIndex, + disBuilderSupplier, + getBucketOrd, + optimizeResult + ); + + return multiRangesTraverse(tree, collector); + } + /** * Checks whether the top level query matches all documents on the segment * diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java index 50fe6a8cbf69f..18d47f0716f3e 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/DateHistogramAggregatorBridge.java @@ -11,6 +11,7 @@ import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; +import org.apache.lucene.util.DocIdSetBuilder; import org.opensearch.common.Rounding; import org.opensearch.index.mapper.DateFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -22,8 +23,7 @@ import java.util.OptionalLong; import java.util.function.BiConsumer; import java.util.function.Function; - -import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse; +import java.util.function.Supplier; /** * For date histogram aggregation @@ -127,27 +127,31 @@ private DateFieldMapper.DateFieldType getFieldType() { return (DateFieldMapper.DateFieldType) fieldType; } + /** + * Get the size of buckets to stop early + */ protected int getSize() { return Integer.MAX_VALUE; } @Override - final FilterRewriteOptimizationContext.DebugInfo tryOptimize( + final FilterRewriteOptimizationContext.OptimizeResult tryOptimize( PointValues values, BiConsumer incrementDocCount, - Ranges ranges + Ranges ranges, + Supplier disBuilderSupplier ) throws IOException { int size = getSize(); DateFieldMapper.DateFieldType fieldType = getFieldType(); - BiConsumer incrementFunc = (activeIndex, docCount) -> { + + Function getBucketOrd = (activeIndex) -> { long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0); rangeStart = fieldType.convertNanosToMillis(rangeStart); - long bucketOrd = getBucketOrd(bucketOrdProducer().apply(rangeStart)); - incrementDocCount.accept(bucketOrd, (long) docCount); + return getBucketOrd(bucketOrdProducer().apply(rangeStart)); }; - return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size); + return getResult(values, incrementDocCount, ranges, disBuilderSupplier, getBucketOrd, size); } private static long getBucketOrd(long bucketOrd) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java index 87faafe4526de..fe9976d019a0f 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteOptimizationContext.java @@ -14,12 +14,17 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.PointValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.DocIdSetBuilder; import org.opensearch.index.mapper.DocCountFieldMapper; +import org.opensearch.search.aggregations.BucketCollector; +import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.internal.SearchContext; import java.io.IOException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; +import java.util.function.Supplier; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -42,6 +47,8 @@ public final class FilterRewriteOptimizationContext { private Ranges ranges; // built at shard level + private int subAggLength; + // debug info related fields private final AtomicInteger leafNodeVisited = new AtomicInteger(); private final AtomicInteger innerNodeVisited = new AtomicInteger(); @@ -65,7 +72,8 @@ public FilterRewriteOptimizationContext( private boolean canOptimize(final Object parent, final int subAggLength, SearchContext context) throws IOException { if (context.maxAggRewriteFilters() == 0) return false; - if (parent != null || subAggLength != 0) return false; + if (parent != null) return false; + this.subAggLength = subAggLength; boolean canOptimize = aggregatorBridge.canOptimize(); if (canOptimize) { @@ -96,8 +104,13 @@ void setRanges(Ranges ranges) { * @param incrementDocCount consume the doc_count results for certain ordinal * @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment */ - public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer incrementDocCount, boolean segmentMatchAll) - throws IOException { + public boolean tryOptimize( + final LeafReaderContext leafCtx, + final BiConsumer incrementDocCount, + boolean segmentMatchAll, + BucketCollector collectableSubAggregators, + LeafBucketCollector sub + ) throws IOException { segments.incrementAndGet(); if (!canOptimize) { return false; @@ -123,12 +136,43 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer disBuilderSupplier = null; + if (subAggLength != 0) { + disBuilderSupplier = () -> { + try { + return new DocIdSetBuilder(leafCtx.reader().maxDoc(), values, aggregatorBridge.fieldType.name()); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + } + OptimizeResult optimizeResult = aggregatorBridge.tryOptimize(values, incrementDocCount, ranges, disBuilderSupplier); + consumeDebugInfo(optimizeResult); optimizedSegments.incrementAndGet(); logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord); logger.debug("Crossed leaf nodes: {}, inner nodes: {}", leafNodeVisited, innerNodeVisited); + if (subAggLength == 0) { + return true; + } + + // Handle sub aggregation + for (int bucketOrd = 0; bucketOrd < optimizeResult.builders.length; bucketOrd++) { + logger.debug("Collecting bucket {} for sub aggregation", bucketOrd); + DocIdSetBuilder builder = optimizeResult.builders[bucketOrd]; + if (builder == null) { + continue; + } + DocIdSetIterator iterator = optimizeResult.builders[bucketOrd].build().iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + int currentDoc = iterator.docID(); + sub.collect(currentDoc, bucketOrd); + } + // resetting the sub collector after processing each bucket + sub = collectableSubAggregators.getLeafCollector(leafCtx); + } + return true; } @@ -160,10 +204,12 @@ private Ranges getRangesFromSegment(LeafReaderContext leafCtx, boolean segmentMa /** * Contains debug info of BKD traversal to show in profile */ - static class DebugInfo { + static class OptimizeResult { private final AtomicInteger leafNodeVisited = new AtomicInteger(); // leaf node visited private final AtomicInteger innerNodeVisited = new AtomicInteger(); // inner node visited + public DocIdSetBuilder[] builders; + void visitLeaf() { leafNodeVisited.incrementAndGet(); } @@ -173,7 +219,7 @@ void visitInner() { } } - void consumeDebugInfo(DebugInfo debug) { + void consumeDebugInfo(OptimizeResult debug) { leafNodeVisited.addAndGet(debug.leafNodeVisited.get()); innerNodeVisited.addAndGet(debug.innerNodeVisited.get()); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java index 581ecc416f486..47271fa64c23d 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/PointTreeTraversal.java @@ -13,18 +13,22 @@ import org.apache.lucene.index.PointValues; import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.DocIdSetBuilder; import org.opensearch.common.CheckedRunnable; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.Supplier; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; /** * Utility class for traversing a {@link PointValues.PointTree} and collecting document counts for the ranges. * - *

The main entry point is the {@link #multiRangesTraverse(PointValues.PointTree, Ranges, - * BiConsumer, int)} method + *

The main entry point is the {@link #multiRangesTraverse} method * *

The class uses a {@link RangeCollectorForPointTree} to keep track of the active ranges and * determine which parts of the tree to visit. The {@link @@ -39,58 +43,49 @@ private PointTreeTraversal() {} /** * Traverses the given {@link PointValues.PointTree} and collects document counts for the intersecting ranges. * - * @param tree the point tree to traverse - * @param ranges the set of ranges to intersect with - * @param incrementDocCount a callback to increment the document count for a range bucket - * @param maxNumNonZeroRanges the maximum number of non-zero ranges to collect - * @return a {@link FilterRewriteOptimizationContext.DebugInfo} object containing debug information about the traversal + * @param tree the point tree to traverse + * @param collector + * @return a {@link FilterRewriteOptimizationContext.OptimizeResult} object containing debug information about the traversal */ - static FilterRewriteOptimizationContext.DebugInfo multiRangesTraverse( + static FilterRewriteOptimizationContext.OptimizeResult multiRangesTraverse( final PointValues.PointTree tree, - final Ranges ranges, - final BiConsumer incrementDocCount, - final int maxNumNonZeroRanges + RangeCollectorForPointTree collector ) throws IOException { - FilterRewriteOptimizationContext.DebugInfo debugInfo = new FilterRewriteOptimizationContext.DebugInfo(); - int activeIndex = ranges.firstRangeIndex(tree.getMinPackedValue(), tree.getMaxPackedValue()); - if (activeIndex < 0) { - logger.debug("No ranges match the query, skip the fast filter optimization"); - return debugInfo; - } - RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, maxNumNonZeroRanges, ranges, activeIndex); PointValues.IntersectVisitor visitor = getIntersectVisitor(collector); try { - intersectWithRanges(visitor, tree, collector, debugInfo); + intersectWithRanges(visitor, tree, collector); } catch (CollectionTerminatedException e) { logger.debug("Early terminate since no more range to collect"); } collector.finalizePreviousRange(); - - return debugInfo; + collector.finalizeDocIdSetBuildersResult(); + return collector.result; } private static void intersectWithRanges( PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, - RangeCollectorForPointTree collector, - FilterRewriteOptimizationContext.DebugInfo debug + RangeCollectorForPointTree collector ) throws IOException { PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); switch (r) { case CELL_INSIDE_QUERY: collector.countNode((int) pointTree.size()); - debug.visitInner(); + if (collector.hasSubAgg) { + pointTree.visitDocIDs(visitor); + } + collector.result.visitInner(); break; case CELL_CROSSES_QUERY: if (pointTree.moveToChild()) { do { - intersectWithRanges(visitor, pointTree, collector, debug); + intersectWithRanges(visitor, pointTree, collector); } while (pointTree.moveToSibling()); pointTree.moveToParent(); } else { pointTree.visitDocValues(visitor); - debug.visitLeaf(); + collector.result.visitLeaf(); } break; case CELL_OUTSIDE_QUERY: @@ -99,24 +94,53 @@ private static void intersectWithRanges( private static PointValues.IntersectVisitor getIntersectVisitor(RangeCollectorForPointTree collector) { return new PointValues.IntersectVisitor() { + + @Override + public void grow(int count) { + if (collector.hasSubAgg) { + collector.grow(count); + } + } + @Override public void visit(int docID) { - // this branch should be unreachable - throw new UnsupportedOperationException( - "This IntersectVisitor does not perform any actions on a " + "docID=" + docID + " node being visited" - ); + if (!collector.hasSubAgg) { + throw new UnsupportedOperationException( + "This visitor should not visit when there's no subAgg and node is fully contained by the query" + ); + } + collector.collectDocId(docID); + } + + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + if (!collector.hasSubAgg) { + throw new UnsupportedOperationException( + "This visitor should not visit when there's no subAgg and node is fully contained by the query" + ); + } + collector.collectDocIdSet(iterator); } @Override public void visit(int docID, byte[] packedValue) throws IOException { - visitPoints(packedValue, collector::count); + visitPoints(packedValue, () -> { + collector.count(); + if (collector.hasSubAgg) { + collector.collectDocId(docID); + } + }); } @Override public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { visitPoints(packedValue, () -> { + // note: iterator can only iterate once for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { collector.count(); + if (collector.hasSubAgg) { + collector.collectDocId(doc); + } } }); } @@ -124,7 +148,7 @@ public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcept private void visitPoints(byte[] packedValue, CheckedRunnable collect) throws IOException { if (!collector.withinUpperBound(packedValue)) { collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(packedValue)) { + if (collector.iterateRangeEnd(packedValue, true)) { throw new CollectionTerminatedException(); } } @@ -139,7 +163,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue // try to find the first range that may collect values from this cell if (!collector.withinUpperBound(minPackedValue)) { collector.finalizePreviousRange(); - if (collector.iterateRangeEnd(minPackedValue)) { + if (collector.iterateRangeEnd(minPackedValue, false)) { throw new CollectionTerminatedException(); } } @@ -156,7 +180,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue }; } - private static class RangeCollectorForPointTree { + static class RangeCollectorForPointTree { private final BiConsumer incrementRangeDocCount; private int counter = 0; @@ -166,46 +190,115 @@ private static class RangeCollectorForPointTree { private int visitedRange = 0; private final int maxNumNonZeroRange; + private boolean hasSubAgg = false; + private final DocIdSetBuilder[] docIdSetBuilders; + private final Supplier disBuilderSupplier; + private final Map bucketOrdinalToDocIdSetBuilder = new HashMap<>(); + private DocIdSetBuilder.BulkAdder currentAdder; + private final Function getBucketOrd; + private final FilterRewriteOptimizationContext.OptimizeResult result; + + private int lastGrowCount; + public RangeCollectorForPointTree( + Ranges ranges, BiConsumer incrementRangeDocCount, int maxNumNonZeroRange, - Ranges ranges, - int activeIndex + int activeIndex, + Supplier disBuilderSupplier, + Function getBucketOrd, + FilterRewriteOptimizationContext.OptimizeResult result ) { this.incrementRangeDocCount = incrementRangeDocCount; this.maxNumNonZeroRange = maxNumNonZeroRange; this.ranges = ranges; this.activeIndex = activeIndex; + this.docIdSetBuilders = new DocIdSetBuilder[ranges.size]; + this.disBuilderSupplier = disBuilderSupplier; + this.getBucketOrd = getBucketOrd; + if (disBuilderSupplier != null) { + hasSubAgg = true; + } + this.result = result; } - private void count() { - counter++; + private void grow(int count) { + if (docIdSetBuilders[activeIndex] == null) { + docIdSetBuilders[activeIndex] = disBuilderSupplier.get(); + } + logger.trace("grow docIdSetBuilder[{}] with count {}", activeIndex, count); + currentAdder = docIdSetBuilders[activeIndex].grow(count); + lastGrowCount = count; } private void countNode(int count) { counter += count; } + private void count() { + counter++; + } + + private void collectDocId(int docId) { + logger.trace("collect docId {}", docId); + currentAdder.add(docId); + } + + private void collectDocIdSet(DocIdSetIterator iter) throws IOException { + logger.trace("collect disi {}", iter); + currentAdder.add(iter); + } + private void finalizePreviousRange() { if (counter > 0) { incrementRangeDocCount.accept(activeIndex, counter); counter = 0; } + + if (hasSubAgg && currentAdder != null) { + long bucketOrd = getBucketOrd.apply(activeIndex); + logger.trace("finalize docIdSetBuilder[{}] with bucket ordinal {}", activeIndex, bucketOrd); + bucketOrdinalToDocIdSetBuilder.put(bucketOrd, docIdSetBuilders[activeIndex]); + currentAdder = null; + } + } + + private void finalizeDocIdSetBuildersResult() { + int maxOrdinal = bucketOrdinalToDocIdSetBuilder.keySet().stream().mapToInt(Long::intValue).max().orElse(0) + 1; + DocIdSetBuilder[] builder = new DocIdSetBuilder[maxOrdinal]; + for (Map.Entry entry : bucketOrdinalToDocIdSetBuilder.entrySet()) { + int ordinal = Math.toIntExact(entry.getKey()); + builder[ordinal] = entry.getValue(); + } + result.builders = builder; } /** + * Iterate to the first range that can include the given value + * under the assumption that ranges are not overlapping and increasing + * + * @param value the value that is outside current lower bound + * @param inLeaf whether this method is called when in the leaf node * @return true when iterator exhausted or collect enough non-zero ranges */ - private boolean iterateRangeEnd(byte[] value) { - // the new value may not be contiguous to the previous one - // so try to find the first next range that cross the new value + private boolean iterateRangeEnd(byte[] value, boolean inLeaf) { while (!withinUpperBound(value)) { if (++activeIndex >= ranges.size) { return true; } } visitedRange++; - return visitedRange > maxNumNonZeroRange; + if (visitedRange > maxNumNonZeroRange) { + return true; + } else { + // edge case: if finalizePreviousRange is called within the leaf node + // currentAdder is reset and grow would not be called immediately + // here we reuse previous grow count + if (hasSubAgg && inLeaf && currentAdder == null) { + grow(lastGrowCount); + } + return false; + } } private boolean withinLowerBound(byte[] value) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java index fc1bcd83f2c1b..b7a16f3791ba6 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/filterrewrite/RangeAggregatorBridge.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; +import org.apache.lucene.util.DocIdSetBuilder; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumericPointEncoder; import org.opensearch.search.aggregations.bucket.range.RangeAggregator; @@ -19,8 +20,7 @@ import java.io.IOException; import java.util.function.BiConsumer; import java.util.function.Function; - -import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse; +import java.util.function.Supplier; /** * For range aggregation @@ -74,18 +74,17 @@ final Ranges tryBuildRangesFromSegment(LeafReaderContext leaf) { } @Override - final FilterRewriteOptimizationContext.DebugInfo tryOptimize( + final FilterRewriteOptimizationContext.OptimizeResult tryOptimize( PointValues values, BiConsumer incrementDocCount, - Ranges ranges + Ranges ranges, + Supplier disBuilderSupplier ) throws IOException { int size = Integer.MAX_VALUE; - BiConsumer incrementFunc = (activeIndex, docCount) -> { - long bucketOrd = bucketOrdProducer().apply(activeIndex); - incrementDocCount.accept(bucketOrd, (long) docCount); - }; - return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size); + Function getBucketOrd = (activeIndex) -> bucketOrdProducer().apply(activeIndex); + + return getResult(values, incrementDocCount, ranges, disBuilderSupplier, getBucketOrd, size); } /** diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java index cbeb27e8a3e63..9cba37c099e7c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/AutoDateHistogramAggregator.java @@ -216,7 +216,7 @@ protected Prepared getRoundingPrepared() { @Override protected Function bucketOrdProducer() { - return (key) -> getBucketOrds().add(0, preparedRounding.round((long) key)); + return (key) -> getBucketOrds().add(0, preparedRounding.round(key)); } }; filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context); @@ -251,7 +251,13 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc return LeafBucketCollector.NO_OP_COLLECTOR; } - boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + boolean optimized = filterRewriteOptimizationContext.tryOptimize( + ctx, + this::incrementBucketDocCount, + segmentMatchAll(context, ctx), + collectableSubAggregators, + sub + ); if (optimized) throw new CollectionTerminatedException(); final SortedNumericDocValues values = valuesSource.longValues(ctx); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index 2294ba6f9a2b5..ab6f86df66a91 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.CollectionUtil; @@ -196,7 +197,7 @@ protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws return true; } } - return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + return false; } @Override @@ -205,6 +206,17 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol return LeafBucketCollector.NO_OP_COLLECTOR; } + boolean optimized = filterRewriteOptimizationContext.tryOptimize( + ctx, + this::incrementBucketDocCount, + segmentMatchAll(context, ctx), + collectableSubAggregators, + sub + ); + if (optimized) { + throw new CollectionTerminatedException(); + } + SortedNumericDocValues values = valuesSource.longValues(ctx); return new LeafBucketCollectorBase(sub, values) { @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index c7303011b5800..8588eb0a57d52 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -32,6 +32,7 @@ package org.opensearch.search.aggregations.bucket.range; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.ScoreMode; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; @@ -308,16 +309,20 @@ public ScoreMode scoreMode() { return super.scoreMode(); } - @Override - protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { - if (segmentMatchAll(context, ctx)) { - return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false); - } - return false; - } + // @Override + // protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + // if (segmentMatchAll(context, ctx)) { + // return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false); + // } + // return false; + // } @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + if (segmentMatchAll(context, ctx) + && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false, collectableSubAggregators, sub)) { + throw new CollectionTerminatedException(); + } final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); return new LeafBucketCollectorBase(sub, values) { diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/filterrewrite/DocIdSetBuilderTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/filterrewrite/DocIdSetBuilderTests.java new file mode 100644 index 0000000000000..2786c279c6980 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/filterrewrite/DocIdSetBuilderTests.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.filterrewrite; + +import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.DocIdSetBuilder; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +public class DocIdSetBuilderTests extends OpenSearchTestCase { + + private Directory directory; + private IndexWriter iw; + private DirectoryReader reader; + + @Override + public void setUp() throws Exception { + super.setUp(); + directory = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + iw = new IndexWriter(directory, iwc); + + // Add some test documents + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new StringField("text_field", "value_" + i, Field.Store.NO)); + doc.add(new NumericDocValuesField("numeric_field", i)); + doc.add(new BinaryDocValuesField("binary_field", new BytesRef("value_" + i))); + iw.addDocument(doc); + } + iw.commit(); + reader = DirectoryReader.open(iw); + } + + @Override + public void tearDown() throws Exception { + reader.close(); + iw.close(); + directory.close(); + super.tearDown(); + } + + public void testBasicDocIdSetBuilding() throws IOException { + LeafReaderContext context = reader.leaves().get(0); + + // Test with different maxDoc sizes + DocIdSetBuilder builder = new DocIdSetBuilder(context.reader().maxDoc()); + DocIdSetBuilder.BulkAdder adder = builder.grow(10); + + adder.add(0); + adder.add(5); + adder.add(10); + + DocIdSet docIdSet = builder.build(); + assertNotNull(docIdSet); + + DocIdSetIterator iterator = docIdSet.iterator(); + assertNotNull(iterator); + + assertEquals(0, iterator.nextDoc()); + assertEquals(5, iterator.nextDoc()); + assertEquals(10, iterator.nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } +} diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteSubAggTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteSubAggTests.java new file mode 100644 index 0000000000000..1187eeac34d84 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/filterrewrite/FilterRewriteSubAggTests.java @@ -0,0 +1,452 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.bucket.filterrewrite; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.LongField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.indices.breaker.NoneCircuitBreakerService; +import org.opensearch.index.mapper.DateFieldMapper; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregator; +import org.opensearch.search.aggregations.AggregatorTestCase; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.MultiBucketConsumerService; +import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.bucket.histogram.Histogram; +import org.opensearch.search.aggregations.bucket.histogram.InternalAutoDateHistogram; +import org.opensearch.search.aggregations.bucket.histogram.InternalDateHistogram; +import org.opensearch.search.aggregations.bucket.range.InternalRange; +import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalStats; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.internal.SearchContext; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; + +public class FilterRewriteSubAggTests extends AggregatorTestCase { + private final String longFieldName = "metric"; + private final String dateFieldName = "timestamp"; + private final Query matchAllQuery = new MatchAllDocsQuery(); + private final NumberFieldMapper.NumberFieldType longFieldType = new NumberFieldMapper.NumberFieldType( + longFieldName, + NumberFieldMapper.NumberType.LONG + ); + private final DateFieldMapper.DateFieldType dateFieldType = aggregableDateFieldType(false, true); + private final NumberFieldMapper.NumberType numberType = longFieldType.numberType(); + private final String rangeAggName = "range"; + private final String autoDateAggName = "auto"; + private final String dateAggName = "date"; + private final String statsAggName = "stats"; + private final List DEFAULT_DATA = List.of( + new TestDoc(0, Instant.parse("2020-03-01T00:00:00Z")), + new TestDoc(1, Instant.parse("2020-03-01T00:00:00Z")), + new TestDoc(1, Instant.parse("2020-03-01T00:00:01Z")), + new TestDoc(2, Instant.parse("2020-03-01T01:00:00Z")), + new TestDoc(3, Instant.parse("2020-03-01T02:00:00Z")), + new TestDoc(4, Instant.parse("2020-03-01T03:00:00Z")), + new TestDoc(5, Instant.parse("2020-03-01T04:00:00Z")), + new TestDoc(6, Instant.parse("2020-03-01T04:00:00Z")) + ); + + public void testRange() throws IOException { + RangeAggregationBuilder rangeAggregationBuilder = new RangeAggregationBuilder(rangeAggName).field(longFieldName) + .addRange(1, 2) + .addRange(2, 4) + .addRange(4, 6) + .subAggregation(new AutoDateHistogramAggregationBuilder(autoDateAggName).field(dateFieldName).setNumBuckets(3)); + + InternalRange result = executeAggregation(DEFAULT_DATA, rangeAggregationBuilder, true); + + // Verify results + List buckets = result.getBuckets(); + assertEquals(3, buckets.size()); + + InternalRange.Bucket firstBucket = buckets.get(0); + assertEquals(2, firstBucket.getDocCount()); + InternalAutoDateHistogram firstAuto = firstBucket.getAggregations().get(autoDateAggName); + assertEquals(2, firstAuto.getBuckets().size()); + + InternalRange.Bucket secondBucket = buckets.get(1); + assertEquals(2, secondBucket.getDocCount()); + InternalAutoDateHistogram secondAuto = secondBucket.getAggregations().get(autoDateAggName); + assertEquals(3, secondAuto.getBuckets().size()); + + InternalRange.Bucket thirdBucket = buckets.get(2); + assertEquals(2, thirdBucket.getDocCount()); + InternalAutoDateHistogram thirdAuto = thirdBucket.getAggregations().get(autoDateAggName); + assertEquals(3, thirdAuto.getBuckets().size()); + } + + public void testDateHisto() throws IOException { + DateHistogramAggregationBuilder dateHistogramAggregationBuilder = new DateHistogramAggregationBuilder(dateAggName).field( + dateFieldName + ).calendarInterval(DateHistogramInterval.HOUR).subAggregation(AggregationBuilders.stats(statsAggName).field(longFieldName)); + + InternalDateHistogram result = executeAggregation(DEFAULT_DATA, dateHistogramAggregationBuilder, true); + + // Verify results + List buckets = result.getBuckets(); + assertEquals(5, buckets.size()); + + InternalDateHistogram.Bucket firstBucket = buckets.get(0); + assertEquals("2020-03-01T00:00:00.000Z", firstBucket.getKeyAsString()); + assertEquals(3, firstBucket.getDocCount()); + InternalStats firstStats = firstBucket.getAggregations().get(statsAggName); + assertEquals(3, firstStats.getCount()); + assertEquals(1, firstStats.getMax(), 0); + assertEquals(0, firstStats.getMin(), 0); + + InternalDateHistogram.Bucket secondBucket = buckets.get(1); + assertEquals("2020-03-01T01:00:00.000Z", secondBucket.getKeyAsString()); + assertEquals(1, secondBucket.getDocCount()); + InternalStats secondStats = secondBucket.getAggregations().get(statsAggName); + assertEquals(1, secondStats.getCount()); + assertEquals(2, secondStats.getMax(), 0); + assertEquals(2, secondStats.getMin(), 0); + + InternalDateHistogram.Bucket thirdBucket = buckets.get(2); + assertEquals("2020-03-01T02:00:00.000Z", thirdBucket.getKeyAsString()); + assertEquals(1, thirdBucket.getDocCount()); + InternalStats thirdStats = thirdBucket.getAggregations().get(statsAggName); + assertEquals(1, thirdStats.getCount()); + assertEquals(3, thirdStats.getMax(), 0); + assertEquals(3, thirdStats.getMin(), 0); + + InternalDateHistogram.Bucket fourthBucket = buckets.get(3); + assertEquals("2020-03-01T03:00:00.000Z", fourthBucket.getKeyAsString()); + assertEquals(1, fourthBucket.getDocCount()); + InternalStats fourthStats = fourthBucket.getAggregations().get(statsAggName); + assertEquals(1, fourthStats.getCount()); + assertEquals(4, fourthStats.getMax(), 0); + assertEquals(4, fourthStats.getMin(), 0); + + InternalDateHistogram.Bucket fifthBucket = buckets.get(4); + assertEquals("2020-03-01T04:00:00.000Z", fifthBucket.getKeyAsString()); + assertEquals(2, fifthBucket.getDocCount()); + InternalStats fifthStats = fifthBucket.getAggregations().get(statsAggName); + assertEquals(2, fifthStats.getCount()); + assertEquals(6, fifthStats.getMax(), 0); + assertEquals(5, fifthStats.getMin(), 0); + } + + public void testAutoDateHisto() throws IOException { + AutoDateHistogramAggregationBuilder autoDateHistogramAggregationBuilder = new AutoDateHistogramAggregationBuilder(dateAggName) + .field(dateFieldName) + .setNumBuckets(5) + .subAggregation(AggregationBuilders.stats(statsAggName).field(longFieldName)); + + InternalAutoDateHistogram result = executeAggregation(DEFAULT_DATA, autoDateHistogramAggregationBuilder, true); + + // Verify results + List buckets = result.getBuckets(); + assertEquals(5, buckets.size()); + + Histogram.Bucket firstBucket = buckets.get(0); + assertEquals("2020-03-01T00:00:00.000Z", firstBucket.getKeyAsString()); + assertEquals(3, firstBucket.getDocCount()); + InternalStats firstStats = firstBucket.getAggregations().get(statsAggName); + assertEquals(3, firstStats.getCount()); + assertEquals(1, firstStats.getMax(), 0); + assertEquals(0, firstStats.getMin(), 0); + + Histogram.Bucket secondBucket = buckets.get(1); + assertEquals("2020-03-01T01:00:00.000Z", secondBucket.getKeyAsString()); + assertEquals(1, secondBucket.getDocCount()); + InternalStats secondStats = secondBucket.getAggregations().get(statsAggName); + assertEquals(1, secondStats.getCount()); + assertEquals(2, secondStats.getMax(), 0); + assertEquals(2, secondStats.getMin(), 0); + + Histogram.Bucket thirdBucket = buckets.get(2); + assertEquals("2020-03-01T02:00:00.000Z", thirdBucket.getKeyAsString()); + assertEquals(1, thirdBucket.getDocCount()); + InternalStats thirdStats = thirdBucket.getAggregations().get(statsAggName); + assertEquals(1, thirdStats.getCount()); + assertEquals(3, thirdStats.getMax(), 0); + assertEquals(3, thirdStats.getMin(), 0); + + Histogram.Bucket fourthBucket = buckets.get(3); + assertEquals("2020-03-01T03:00:00.000Z", fourthBucket.getKeyAsString()); + assertEquals(1, fourthBucket.getDocCount()); + InternalStats fourthStats = fourthBucket.getAggregations().get(statsAggName); + assertEquals(1, fourthStats.getCount()); + assertEquals(4, fourthStats.getMax(), 0); + assertEquals(4, fourthStats.getMin(), 0); + + Histogram.Bucket fifthBucket = buckets.get(4); + assertEquals("2020-03-01T04:00:00.000Z", fifthBucket.getKeyAsString()); + assertEquals(2, fifthBucket.getDocCount()); + InternalStats fifthStats = fifthBucket.getAggregations().get(statsAggName); + assertEquals(2, fifthStats.getCount()); + assertEquals(6, fifthStats.getMax(), 0); + assertEquals(5, fifthStats.getMin(), 0); + + } + + public void testRandom() throws IOException { + Map dataset = new HashMap<>(); + dataset.put("2017-02-01T09:02:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T09:59:59.999Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T10:00:00.001Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T13:06:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T14:04:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T14:05:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T15:59:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T16:06:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T16:48:00.000Z", randomIntBetween(100, 2000)); + dataset.put("2017-02-01T16:59:00.000Z", randomIntBetween(100, 2000)); + + Map subAggToVerify = new HashMap<>(); + List docs = new ArrayList<>(); + for (Map.Entry entry : dataset.entrySet()) { + String date = entry.getKey(); + int docCount = entry.getValue(); + // loop value times and generate TestDoc + if (!subAggToVerify.containsKey(date)) { + subAggToVerify.put(date, new SubAggToVerify()); + } + SubAggToVerify subAgg = subAggToVerify.get(date); + subAgg.count = docCount; + for (int i = 0; i < docCount; i++) { + Instant instant = Instant.parse(date); + int docValue = randomIntBetween(0, 10_000); + subAgg.min = Math.min(subAgg.min, docValue); + subAgg.max = Math.max(subAgg.max, docValue); + docs.add(new TestDoc(docValue, instant)); + } + } + + DateHistogramAggregationBuilder dateHistogramAggregationBuilder = new DateHistogramAggregationBuilder(dateAggName).field( + dateFieldName + ) + .calendarInterval(DateHistogramInterval.HOUR) + .minDocCount(1L) + .subAggregation(AggregationBuilders.stats(statsAggName).field(longFieldName)); + + InternalDateHistogram result = executeAggregation(docs, dateHistogramAggregationBuilder, true); + List buckets = result.getBuckets(); + assertEquals(6, buckets.size()); + for (InternalDateHistogram.Bucket bucket : buckets) { + String date = bucket.getKeyAsString(); + SubAggToVerify subAgg = subAggToVerify.get(date); + if (subAgg == null) continue; + InternalStats stats = bucket.getAggregations().get(statsAggName); + assertEquals(subAgg.count, stats.getCount()); + assertEquals(subAgg.max, stats.getMax(), 0); + assertEquals(subAgg.min, stats.getMin(), 0); + } + } + + public void testLeafTraversal() throws IOException { + Map dataset = new HashMap<>(); + dataset.put("2017-02-01T09:02:00.000Z", 512); + dataset.put("2017-02-01T09:59:59.999Z", 256); + dataset.put("2017-02-01T10:00:00.001Z", 256); + dataset.put("2017-02-01T13:06:00.000Z", 512); + dataset.put("2017-02-01T14:04:00.000Z", 256); + dataset.put("2017-02-01T14:05:00.000Z", 256); + dataset.put("2017-02-01T15:59:00.000Z", 768); + + Map subAggToVerify = new HashMap<>(); + List docs = new ArrayList<>(); + for (Map.Entry entry : dataset.entrySet()) { + String date = entry.getKey(); + int docCount = entry.getValue(); + // loop value times and generate TestDoc + if (!subAggToVerify.containsKey(date)) { + subAggToVerify.put(date, new SubAggToVerify()); + } + SubAggToVerify subAgg = subAggToVerify.get(date); + subAgg.count = docCount; + for (int i = 0; i < docCount; i++) { + Instant instant = Instant.parse(date); + int docValue = randomIntBetween(0, 10_000); + subAgg.min = Math.min(subAgg.min, docValue); + subAgg.max = Math.max(subAgg.max, docValue); + docs.add(new TestDoc(docValue, instant)); + } + } + + DateHistogramAggregationBuilder dateHistogramAggregationBuilder = new DateHistogramAggregationBuilder(dateAggName).field( + dateFieldName + ) + .calendarInterval(DateHistogramInterval.HOUR) + .minDocCount(1L) + .subAggregation(AggregationBuilders.stats(statsAggName).field(longFieldName)); + + InternalDateHistogram result = executeAggregation(docs, dateHistogramAggregationBuilder, false); + List buckets = result.getBuckets(); + assertEquals(5, buckets.size()); + for (InternalDateHistogram.Bucket bucket : buckets) { + String date = bucket.getKeyAsString(); + SubAggToVerify subAgg = subAggToVerify.get(date); + if (subAgg == null) continue; + InternalStats stats = bucket.getAggregations().get(statsAggName); + assertEquals(subAgg.count, stats.getCount()); + assertEquals(subAgg.max, stats.getMax(), 0); + assertEquals(subAgg.min, stats.getMin(), 0); + } + } + + private IA executeAggregation( + List docs, + AggregationBuilder aggregationBuilder, + boolean random + ) throws IOException { + try (Directory directory = setupIndex(docs, random)) { + try (DirectoryReader indexReader = DirectoryReader.open(directory)) { + return executeAggregationOnReader(indexReader, aggregationBuilder); + } + } + } + + private Directory setupIndex(List docs, boolean random) throws IOException { + Directory directory = newDirectory(); + if (!random) { + try (IndexWriter indexWriter = new IndexWriter(directory, new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()))) { + for (TestDoc doc : docs) { + indexWriter.addDocument(doc.toDocument()); + } + } + } else { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (TestDoc doc : docs) { + indexWriter.addDocument(doc.toDocument()); + } + } + } + return directory; + } + + private IA executeAggregationOnReader( + DirectoryReader indexReader, + AggregationBuilder aggregationBuilder + ) throws IOException { + IndexSearcher indexSearcher = new IndexSearcher(indexReader); + + MultiBucketConsumerService.MultiBucketConsumer bucketConsumer = createBucketConsumer(); + SearchContext searchContext = createSearchContext( + indexSearcher, + createIndexSettings(), + matchAllQuery, + bucketConsumer, + longFieldType, + dateFieldType + ); + Aggregator aggregator = createAggregator(aggregationBuilder, searchContext); + CountingAggregator countingAggregator = new CountingAggregator(new AtomicInteger(), aggregator); + + // Execute aggregation + countingAggregator.preCollection(); + indexSearcher.search(matchAllQuery, countingAggregator); + countingAggregator.postCollection(); + + // Reduce results + IA topLevel = (IA) countingAggregator.buildTopLevel(); + MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer = createReduceBucketConsumer(); + InternalAggregation.ReduceContext context = createReduceContext(countingAggregator, reduceBucketConsumer); + + IA result = (IA) topLevel.reduce(Collections.singletonList(topLevel), context); + doAssertReducedMultiBucketConsumer(result, reduceBucketConsumer); + + assertEquals("Expect not using collect to do aggregation", 0, countingAggregator.getCollectCount().get()); + + return result; + } + + private MultiBucketConsumerService.MultiBucketConsumer createBucketConsumer() { + return new MultiBucketConsumerService.MultiBucketConsumer( + DEFAULT_MAX_BUCKETS, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + } + + private MultiBucketConsumerService.MultiBucketConsumer createReduceBucketConsumer() { + return new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + } + + private InternalAggregation.ReduceContext createReduceContext( + Aggregator aggregator, + MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer + ) { + return InternalAggregation.ReduceContext.forFinalReduction( + aggregator.context().bigArrays(), + getMockScriptService(), + reduceBucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + } + + private class TestDoc { + private final long metric; + private final Instant timestamp; + + public TestDoc(long metric, Instant timestamp) { + this.metric = metric; + this.timestamp = timestamp; + } + + public ParseContext.Document toDocument() { + ParseContext.Document doc = new ParseContext.Document(); + + List fieldList = numberType.createFields(longFieldName, metric, true, true, false); + for (Field fld : fieldList) + doc.add(fld); + doc.add(new LongField(dateFieldName, dateFieldType.parse(timestamp.toString()), Field.Store.NO)); + + return doc; + } + } + + private static class SubAggToVerify { + int min; + int max; + int count; + } + + protected final DateFieldMapper.DateFieldType aggregableDateFieldType(boolean useNanosecondResolution, boolean isSearchable) { + return new DateFieldMapper.DateFieldType( + "timestamp", + isSearchable, + false, + true, + DateFieldMapper.getDefaultDateTimeFormatter(), + useNanosecondResolution ? DateFieldMapper.Resolution.NANOSECONDS : DateFieldMapper.Resolution.MILLISECONDS, + null, + Collections.emptyMap() + ); + } +}