Skip to content

Commit

Permalink
sub agg hooked in
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <[email protected]>
  • Loading branch information
bowenlan-amzn committed Feb 3, 2025
1 parent bf9365b commit 8aef653
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,12 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
boolean optimized = filterRewriteOptimizationContext.tryOptimize(
ctx,
this::incrementBucketDocCount,
segmentMatchAll(context, ctx),
null
);
if (optimized) throw new CollectionTerminatedException();

finishLeaf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,41 @@
import org.apache.lucene.search.DocIdSetIterator;

import java.io.IOException;
import java.util.Arrays;

/**
* A composite view of multiple DocIdSetIterators from single segment
* A composite iterator over multiple DocIdSetIterators where each document
* belongs to exactly one bucket within a single segment.
*/
public class CompositeDocIdSetIterator extends DocIdSetIterator {
private final DocIdSetIterator[] iterators;
private final int[] currentDocs; // Current docId for each iterator
private final boolean[] exhausted; // Track if each iterator is exhausted
private int currentDoc = -1; // Current doc for this composite iterator
private final int numIterators;
private final int numBuckets;
private int currentDoc = -1;
private int currentBucket = -1;

/**
* Creates a composite view of multiple DocIdSetIterators
* Creates a composite view of DocIdSetIterators for a segment where
* each document belongs to exactly one bucket.
* @param ordinalToIterator Mapping of bucket ordinal to its DocIdSetIterator
* @param maxOrdinal The maximum bucket ordinal (exclusive)
*/
public CompositeDocIdSetIterator(DocIdSetIterator[] ordinalToIterator, int maxOrdinal) {
this.iterators = Arrays.copyOf(ordinalToIterator, maxOrdinal);
this.numIterators = maxOrdinal;
this.currentDocs = new int[maxOrdinal];
this.exhausted = new boolean[maxOrdinal];

// Initialize currentDocs array to -1 for all iterators
Arrays.fill(currentDocs, -1);
public CompositeDocIdSetIterator(DocIdSetIterator[] ordinalToIterator) {
this.iterators = ordinalToIterator;
this.numBuckets = ordinalToIterator.length;
}

@Override
public int docID() {
return currentDoc;
}

/**
* Returns the bucket ordinal for the current document.
* Should only be called when positioned on a valid document.
* @return bucket ordinal for the current document, or -1 if no current document
*/
public int getCurrentBucket() {
return currentBucket;
}

@Override
public int nextDoc() throws IOException {
return advance(currentDoc + 1);
Expand All @@ -52,73 +55,44 @@ public int nextDoc() throws IOException {
public int advance(int target) throws IOException {
if (target == NO_MORE_DOCS) {
currentDoc = NO_MORE_DOCS;
currentBucket = -1;
return NO_MORE_DOCS;
}

int minDoc = NO_MORE_DOCS;
int minDocBucket = -1;

// Advance all iterators that are behind target
for (int i = 0; i < numIterators; i++) {
if (iterators[i] == null) {
exhausted[i] = true;
// Find the iterator with the lowest docID >= target
for (int bucketOrd = 0; bucketOrd < numBuckets; bucketOrd++) {
DocIdSetIterator iterator = iterators[bucketOrd];
if (iterator == null) {
continue;
}

if (!exhausted[i] && currentDocs[i] < target) {
int doc = iterators[i].advance(target);
if (doc == NO_MORE_DOCS) {
exhausted[i] = true;
} else {
currentDocs[i] = doc;
minDoc = Math.min(minDoc, doc);
}
} else if (!exhausted[i]) {
minDoc = Math.min(minDoc, currentDocs[i]);
int doc = iterator.docID();
if (doc < target) {
doc = iterator.advance(target);
}

if (doc < minDoc) {
minDoc = doc;
minDocBucket = bucketOrd;
}
}

currentDoc = minDoc;
currentBucket = minDocBucket;
return currentDoc;
}

@Override
public long cost() {
long maxCost = 0;
long totalCost = 0;
for (DocIdSetIterator iterator : iterators) {
if (iterator != null) {
maxCost = Math.max(maxCost, iterator.cost());
}
}
return maxCost;
}

/**
* Checks if a specific bucket matches the current document
* @param ordinal The bucket ordinal to check
* @return true if the bucket matches the current document
*/
public boolean matches(int ordinal) {
if (ordinal >= numIterators || currentDoc == NO_MORE_DOCS) {
return false;
}
return !exhausted[ordinal] && currentDocs[ordinal] == currentDoc;
}

/**
* Gets a bit set representing all buckets that match the current document
* @return A long where each bit position represents whether the corresponding bucket matches
*/
public long getMatchingBuckets() {
if (currentDoc == NO_MORE_DOCS || numIterators > 64) {
return 0L;
}

long result = 0L;
for (int i = 0; i < numIterators; i++) {
if (matches(i)) {
result |= 1L << i;
totalCost += iterator.cost();
}
}
return result;
return totalCost;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.index.mapper.DocCountFieldMapper;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
Expand Down Expand Up @@ -96,8 +98,12 @@ void setRanges(Ranges ranges) {
* @param incrementDocCount consume the doc_count results for certain ordinal
* @param segmentMatchAll if your optimization can prepareFromSegment, you should pass in this flag to decide whether to prepareFromSegment
*/
public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Long, Long> incrementDocCount, boolean segmentMatchAll)
throws IOException {
public boolean tryOptimize(
final LeafReaderContext leafCtx,
final BiConsumer<Long, Long> incrementDocCount,
boolean segmentMatchAll,
LeafBucketCollector sub
) throws IOException {
segments.incrementAndGet();
if (!canOptimize) {
return false;
Expand All @@ -123,7 +129,8 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
Ranges ranges = getRanges(leafCtx, segmentMatchAll);
if (ranges == null) return false;

consumeDebugInfo(aggregatorBridge.tryOptimize(values, incrementDocCount, ranges, leafCtx.reader().maxDoc()));
DebugInfo debugInfo = aggregatorBridge.tryOptimize(values, incrementDocCount, ranges, leafCtx.reader().maxDoc());
consumeDebugInfo(debugInfo);

optimizedSegments.incrementAndGet();
logger.debug("Fast filter optimization applied to shard {} segment {}", shardId, leafCtx.ord);
Expand All @@ -135,6 +142,13 @@ public boolean tryOptimize(final LeafReaderContext leafCtx, final BiConsumer<Lon
// 1. List of Iterators per ranges
// 2. Composite iterator

CompositeDocIdSetIterator iter = new CompositeDocIdSetIterator(debugInfo.iterators);
while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
int currentDoc = iter.docID();
int bucket = iter.getCurrentBucket();
sub.collect(currentDoc, bucket);
}

return true;
}

Expand Down Expand Up @@ -171,6 +185,8 @@ static class DebugInfo {
private final AtomicInteger leafNodeVisited = new AtomicInteger(); // leaf node visited
private final AtomicInteger innerNodeVisited = new AtomicInteger(); // inner node visited

public DocIdSetIterator[] iterators;

void visitLeaf() {
leafNodeVisited.incrementAndGet();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.DocIdSetBuilder;
import org.opensearch.common.CheckedRunnable;
Expand Down Expand Up @@ -82,24 +83,33 @@ static FilterRewriteOptimizationContext.DebugInfo multiRangesTraverse(
}
collector.finalizePreviousRange();

DocIdSetBuilder[] builders = collector.docIdSetBuilders;
logger.debug("length of docIdSetBuilders: {}", builders.length);
int totalCount = 0;
for (int i = 0; i < builders.length; i++) {
if (builders[i] != null) {
int count = 0;
DocIdSetIterator iterator = builders[i].build().iterator();
while (iterator.nextDoc() != NO_MORE_DOCS) {
count++;
}
logger.trace(" docIdSetBuilder[{}] disi has documents: {}", i, count);
totalCount += count;
}
// DocIdSetBuilder[] builders = collector.docIdSetBuilders;
// logger.debug("length of docIdSetBuilders: {}", builders.length);
// int totalCount = 0;
// for (int i = 0; i < builders.length; i++) {
// if (builders[i] != null) {
// int count = 0;
// DocIdSetIterator iterator = builders[i].build().iterator();
// while (iterator.nextDoc() != NO_MORE_DOCS) {
// count++;
// }
// logger.trace(" docIdSetBuilder[{}] disi has documents: {}", i, count);
// totalCount += count;
// }
// }
// logger.debug("total count of documents from docIdSetBuilder: {}", totalCount);

Map<Long, DocIdSetBuilder> ordinalToBuilder = collector.bucketOrdinalToDocIdSetBuilder;
logger.debug("keys of bucketOrdinalToDocIdSetBuilder: {}", ordinalToBuilder.keySet());
int maxOrdinal = ordinalToBuilder.keySet().stream().mapToInt(Long::intValue).max().orElse(0) + 1;
DocIdSetIterator[] iterators = new DocIdSetIterator[maxOrdinal];
for (Map.Entry<Long, DocIdSetBuilder> entry : ordinalToBuilder.entrySet()) {
int ordinal = Math.toIntExact(entry.getKey());
DocIdSetBuilder builder = entry.getValue();
DocIdSet docIdSet = builder.build();
iterators[ordinal] = docIdSet.iterator();
}
logger.debug("total count of documents from docIdSetBuilder: {}", totalCount);

Map<Long, DocIdSetBuilder> map = collector.bucketOrdinalToDocIdSetBuilder;
logger.debug("keys of bucketOrdinalToDocIdSetBuilder: {}", map.keySet());
debugInfo.iterators = iterators;

return debugInfo;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,12 @@ public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBuc
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
boolean optimized = filterRewriteOptimizationContext.tryOptimize(
ctx,
this::incrementBucketDocCount,
segmentMatchAll(context, ctx),
null
);
if (optimized) throw new CollectionTerminatedException();

final SortedNumericDocValues values = valuesSource.longValues(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,15 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
boolean optimized = filterRewriteOptimizationContext.tryOptimize(
ctx,
this::incrementBucketDocCount,
segmentMatchAll(context, ctx),
sub
);
if (optimized) {
throw new CollectionTerminatedException();
}

SortedNumericDocValues values = valuesSource.longValues(ctx);
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ public ScoreMode scoreMode() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
if (segmentMatchAll(context, ctx)
&& filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false, null)) {
throw new CollectionTerminatedException();
}

Expand Down

0 comments on commit 8aef653

Please sign in to comment.