Skip to content

Commit

Permalink
range agg now do query by query
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>

trying the each bucket method

Signed-off-by: bowenlan-amzn <[email protected]>

optimizing the CompositeDocIdSetIterator

Signed-off-by: bowenlan-amzn <[email protected]>

finding performance bottleneck

Signed-off-by: bowenlan-amzn <[email protected]>

sub agg hooked in

Signed-off-by: bowenlan-amzn <[email protected]>

understand the bucket ordinal

Signed-off-by: bowenlan-amzn <[email protected]>

grow seems work for doc id set builder, now verify correctness

Signed-off-by: bowenlan-amzn <[email protected]>

investigate the behavior of DocIdSetBuilder

Signed-off-by: bowenlan-amzn <[email protected]>

didn't figure out grow of size easily

Signed-off-by: bowenlan-amzn <[email protected]>

play around grow

Signed-off-by: bowenlan-amzn <[email protected]>

collect disi also

going to see the performance comparison between default and this

Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Feb 4, 2025
1 parent d24c9e4 commit 8c1a3c0
Show file tree
Hide file tree
Showing 11 changed files with 513 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,12 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
boolean optimized = filterRewriteOptimizationContext.tryOptimize(
ctx,
this::incrementBucketDocCount,
segmentMatchAll(context, ctx),
collectableSubAggregators
);
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ void setRangesConsumer(Consumer<Ranges> setRanges) {
abstract FilterRewriteOptimizationContext.DebugInfo tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
Ranges ranges,
int maxDoc
) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.aggregations.bucket.filterrewrite;

import org.apache.lucene.search.DocIdSetIterator;

import java.io.IOException;

/**
* A composite iterator over multiple DocIdSetIterators where each document
* belongs to exactly one bucket within a single segment.
*/
public class CompositeDocIdSetIterator extends DocIdSetIterator {
private final DocIdSetIterator[] iterators;

// Track active iterators to avoid scanning all
private final int[] activeIterators; // non-exhausted iterators to its bucket
private int numActiveIterators; // Number of non-exhausted iterators

private int currentDoc = -1;
private int currentBucket = -1;

public CompositeDocIdSetIterator(DocIdSetIterator[] iterators) {
this.iterators = iterators;
int numBuckets = iterators.length;
this.activeIterators = new int[numBuckets];
this.numActiveIterators = 0;

// Initialize active iterator tracking
for (int i = 0; i < numBuckets; i++) {
if (iterators[i] != null) {
activeIterators[numActiveIterators++] = i;
}
}
}

@Override
public int docID() {
return currentDoc;
}

public int getCurrentBucket() {
return currentBucket;
}

@Override
public int nextDoc() throws IOException {
return advance(currentDoc + 1);
}

@Override
public int advance(int target) throws IOException {
if (target == NO_MORE_DOCS || numActiveIterators == 0) {
currentDoc = NO_MORE_DOCS;
currentBucket = -1;
return NO_MORE_DOCS;
}

int minDoc = NO_MORE_DOCS;
int minBucket = -1;
int remainingActive = 0; // Counter for non-exhausted iterators

// Only check currently active iterators
for (int i = 0; i < numActiveIterators; i++) {
int bucket = activeIterators[i];
DocIdSetIterator iterator = iterators[bucket];

int doc = iterator.docID();
if (doc < target) {
doc = iterator.advance(target);
}

if (doc == NO_MORE_DOCS) {
// Iterator is exhausted, don't include it in active set
continue;
}

// Keep this iterator in our active set
activeIterators[remainingActive] = bucket;
remainingActive++;

if (doc < minDoc) {
minDoc = doc;
minBucket = bucket;
}
}

// Update count of active iterators
numActiveIterators = remainingActive;

currentDoc = minDoc;
currentBucket = minBucket;

return currentDoc;
}

@Override
public long cost() {
long cost = 0;
for (int i = 0; i < numActiveIterators; i++) {
DocIdSetIterator iterator = iterators[activeIterators[i]];
cost += iterator.cost();
}
return cost;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@

package org.opensearch.search.aggregations.bucket.filterrewrite;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.Rounding;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -24,6 +27,7 @@
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.search.aggregations.bucket.filterrewrite.PointTreeTraversal.multiRangesTraverse;

Expand All @@ -32,6 +36,8 @@
*/
public abstract class DateHistogramAggregatorBridge extends AggregatorBridge {

private static final Logger logger = LogManager.getLogger(Helper.loggerName);

int maxRewriteFilters;

protected boolean canOptimize(ValuesSourceConfig config) {
Expand Down Expand Up @@ -129,7 +135,8 @@ protected int getSize() {
final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
PointValues values,
BiConsumer<Long, Long> incrementDocCount,
Ranges ranges
Ranges ranges,
int maxDoc
) throws IOException {
int size = getSize();

Expand All @@ -141,7 +148,22 @@ final FilterRewriteOptimizationContext.DebugInfo tryOptimize(
incrementDocCount.accept(bucketOrd, (long) docCount);
};

return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size);
Function<Integer, Long> getBucketOrd = (activeIndex) -> {
long rangeStart = LongPoint.decodeDimension(ranges.lowers[activeIndex], 0);
rangeStart = fieldType.convertNanosToMillis(rangeStart);
return getBucketOrd(bucketOrdProducer().apply(rangeStart));
};

Supplier<DocIdSetBuilder> disBuilderSupplier = () -> {
try {
logger.trace("create DocIdSetBuilder of max doc {}", maxDoc);
return new DocIdSetBuilder(maxDoc, values, fieldType.name());
} catch (IOException e) {
throw new RuntimeException(e);
}
};

return multiRangesTraverse(values.getPointTree(), ranges, incrementFunc, size, disBuilderSupplier, getBucketOrd);
}

private static long getBucketOrd(long bucketOrd) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,20 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.index.mapper.DocCountFieldMapper;
import org.opensearch.search.aggregations.BucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

Expand Down Expand Up @@ -65,7 +75,8 @@ public FilterRewriteOptimizationContext(
private boolean canOptimize(final Object parent, final int subAggLength, SearchContext context) throws IOException {
if (context.maxAggRewriteFilters() == 0) return false;

if (parent != null || subAggLength != 0) return false;
// if (parent != null || subAggLength != 0) return false;
if (parent != null) return false;

boolean canOptimize = aggregatorBridge.canOptimize();
if (canOptimize) {
Expand Down Expand Up @@ -96,8 +107,12 @@ void setRanges(Ranges ranges) {
* @param incrementDocCount consume the doc_count results for certain ordinal
* @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment
*/
public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Long, Long> incrementDocCount, boolean segmentMatchAll)
throws IOException {
public boolean tryOptimize(
final LeafReaderContext leafCtx,
final BiConsumer<Long, Long> incrementDocCount,
boolean segmentMatchAll,
BucketCollector collectableSubAggregators
) throws IOException {
segments.incrementAndGet();
if (!canOptimize) {
return false;
Expand All @@ -123,12 +138,91 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
Ranges ranges = getRanges(leafCtx, segmentMatchAll);
if (ranges == null) return false;

consumeDebugInfo(aggregatorBridge.tryOptimize(values, incrementDocCount, ranges));
DebugInfo debugInfo = aggregatorBridge.tryOptimize(values, incrementDocCount, ranges, leafCtx.reader().maxDoc());
consumeDebugInfo(debugInfo);

optimizedSegments.incrementAndGet();
logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
logger.debug("Crossed leaf nodes: {}, inner nodes: {}", leafNodeVisited, innerNodeVisited);

// TODO refactor the tryOptimize to return a Result object which not only contains DebugInfo
// but also the DocIdSetIterator for sub aggregation
// At least 2 ways to do Iterating
// 1. List of Iterators per ranges
// 2. Composite iterator

// CompositeDocIdSetIterator iter = new CompositeDocIdSetIterator(debugInfo.iterators);
// while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
// int currentDoc = iter.docID();
// int bucket = iter.getCurrentBucket();
// sub.collect(currentDoc, bucket);
// }

// let's not use composite disi
// try rebuilding the subagg leaf collector
// for each bucket ord

LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(leafCtx);
for (int bucketOrd = 0; bucketOrd < debugInfo.builders.length; bucketOrd++) {
logger.debug("Collecting bucket {} for sub aggregation", bucketOrd);
DocIdSetBuilder builder = debugInfo.builders[bucketOrd];
if (builder == null) {
continue;
}
DocIdSetIterator iterator = debugInfo.builders[bucketOrd].build().iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
int currentDoc = iterator.docID();
sub.collect(currentDoc, bucketOrd);
}
// resetting the sub collector after processing each bucket
sub = collectableSubAggregators.getLeafCollector(leafCtx);
}

return true;
}

List<Weight> weights;

public List<Weight> getWeights() {
return weights;
}

public boolean tryGetRanges(final LeafReaderContext leafCtx, boolean segmentMatchAll, SearchContext context) throws IOException {
if (!canOptimize) {
return false;
}

if (leafCtx.reader().hasDeletions()) return false;

PointValues values = leafCtx.reader().getPointValues(aggregatorBridge.fieldType.name());
if (values == null) return false;
// only proceed if every document corresponds to exactly one point
if (values.getDocCount() != values.size()) return false;

NumericDocValues docCountValues = DocValues.getNumeric(leafCtx.reader(), DocCountFieldMapper.NAME);
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
logger.debug(
"Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization",
shardId,
leafCtx.ord
);
return false;
}

Ranges ranges = getRanges(leafCtx, segmentMatchAll);
if (ranges == null) return false;

List<Weight> weights = new ArrayList<>();
for (int i = 0; i < ranges.size; i++) {
Query query = new PointRangeQuery(aggregatorBridge.fieldType.name(), ranges.lowers[i], ranges.uppers[i], 1) {
@Override
protected String toString(int dimension, byte[] value) {
return "";
}
};
weights.add(query.rewrite(context.searcher()).createWeight(context.searcher(), ScoreMode.COMPLETE_NO_SCORES, 1));
}
this.weights = weights;
return true;
}

Expand All @@ -141,6 +235,7 @@ Ranges getRanges(LeafReaderContext leafCtx, boolean segmentMatchAll) {
return null;
}
}
logger.debug("number of ranges: {}", ranges.lowers.length);
return ranges;
}

Expand All @@ -164,6 +259,9 @@ static class DebugInfo {
private final AtomicInteger leafNodeVisited = new AtomicInteger(); // leaf node visited
private final AtomicInteger innerNodeVisited = new AtomicInteger(); // inner node visited

public DocIdSetIterator[] iterators;
public DocIdSetBuilder[] builders;

void visitLeaf() {
leafNodeVisited.incrementAndGet();
}
Expand Down
Loading

0 comments on commit 8c1a3c0

Please sign in to comment.