Skip to content

Commit

Permalink
range aggs changes
Browse files Browse the repository at this point in the history
Signed-off-by: Sandesh Kumar <[email protected]>
  • Loading branch information
sandeshkr419 committed Mar 4, 2025
1 parent f6d6aa6 commit e2ef393
Show file tree
Hide file tree
Showing 4 changed files with 422 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG-3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Added offset management for the pull-based Ingestion ([#17354](https://github.com/opensearch-project/OpenSearch/pull/17354))
- Add filter function for AbstractQueryBuilder, BoolQueryBuilder, ConstantScoreQueryBuilder([#17409](https://github.com/opensearch-project/OpenSearch/pull/17409))
- [Star Tree] [Search] Resolving keyword & numeric bucket aggregation with metric aggregation using star-tree ([#17165](https://github.com/opensearch-project/OpenSearch/pull/17165))
- [Star Tree] [Search] Resolving numeric range aggregation with metric aggregation using star-tree ([#17273](https://github.com/opensearch-project/OpenSearch/pull/17273))

### Dependencies
- Update Apache Lucene to 10.1.0 ([#16366](https://github.com/opensearch-project/OpenSearch/pull/16366))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
package org.opensearch.search.aggregations.bucket.range;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -43,7 +45,13 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils;
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.AggregatorFactories;
Expand All @@ -53,12 +61,17 @@
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.NonCollectingAggregator;
import org.opensearch.search.aggregations.StarTreeBucketCollector;
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
import org.opensearch.search.aggregations.bucket.BucketsAggregator;
import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext;
import org.opensearch.search.aggregations.bucket.filterrewrite.RangeAggregatorBridge;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.aggregations.support.ValuesSourceConfig;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.StarTreeQueryHelper;
import org.opensearch.search.startree.StarTreeTraversalUtil;
import org.opensearch.search.startree.filter.DimensionFilter;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -70,16 +83,18 @@

import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.opensearch.search.aggregations.bucket.filterrewrite.AggregatorBridge.segmentMatchAll;
import static org.opensearch.search.startree.StarTreeQueryHelper.getSupportedStarTree;

/**
* Aggregate all docs that match given ranges.
*
* @opensearch.internal
*/
public class RangeAggregator extends BucketsAggregator {
public class RangeAggregator extends BucketsAggregator implements StarTreePreComputeCollector {

public static final ParseField RANGES_FIELD = new ParseField("ranges");
public static final ParseField KEYED_FIELD = new ParseField("keyed");
public final String fieldName;

/**
* Range for the range aggregator
Expand Down Expand Up @@ -298,6 +313,9 @@ protected Function<Object, Long> bucketOrdProducer() {
}
};
filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context);
this.fieldName = (valuesSource instanceof ValuesSource.Numeric.FieldData)
? ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName()
: null;
}

@Override
Expand All @@ -310,8 +328,13 @@ public ScoreMode scoreMode() {

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (segmentMatchAll(context, ctx)) {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
return true;
}
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
preComputeWithStarTree(ctx, supportedStarTree);
return true;
}
return false;
}
Expand All @@ -333,52 +356,106 @@ public void collect(int doc, long bucket) throws IOException {
}

private int collect(int doc, double value, long owningBucketOrdinal, int lowBound) throws IOException {
int lo = lowBound, hi = ranges.length - 1; // all candidates are between these indexes
int mid = (lo + hi) >>> 1;
while (lo <= hi) {
if (value < ranges[mid].from) {
hi = mid - 1;
} else if (value >= maxTo[mid]) {
lo = mid + 1;
} else {
break;
MatchedRange range = new MatchedRange(ranges, lowBound, value);
for (int i = range.startLo; i <= range.endHi; ++i) {
if (ranges[i].matches(value)) {
collectBucket(sub, doc, subBucketOrdinal(owningBucketOrdinal, i));
}
mid = (lo + hi) >>> 1;
}
if (lo > hi) return lo; // no potential candidate

// binary search the lower bound
int startLo = lo, startHi = mid;
while (startLo <= startHi) {
final int startMid = (startLo + startHi) >>> 1;
if (value >= maxTo[startMid]) {
startLo = startMid + 1;
} else {
startHi = startMid - 1;
}
return range.endHi + 1;
}
};
}

private void preComputeWithStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector(ctx, starTree, null);
FixedBitSet matchingDocsBitSet = starTreeBucketCollector.getMatchingDocsBitSet();

int numBits = matchingDocsBitSet.length();

if (numBits > 0) {
for (int bit = matchingDocsBitSet.nextSetBit(0); bit != DocIdSetIterator.NO_MORE_DOCS; bit = (bit + 1 < numBits)
? matchingDocsBitSet.nextSetBit(bit + 1)
: DocIdSetIterator.NO_MORE_DOCS) {
starTreeBucketCollector.collectStarTreeEntry(bit, 0);
}
}
}

@Override
public StarTreeBucketCollector getStarTreeBucketCollector(
LeafReaderContext ctx,
CompositeIndexFieldInfo starTree,
StarTreeBucketCollector parentCollector
) throws IOException {
assert parentCollector == null;
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
return new StarTreeBucketCollector(
starTreeValues,
StarTreeTraversalUtil.getStarTreeResult(
starTreeValues,
StarTreeQueryHelper.mergeDimensionFilterIfNotExists(
context.getQueryShardContext().getStarTreeQueryContext().getBaseQueryStarTreeFilter(),
fieldName,
List.of(DimensionFilter.MATCH_ALL_DEFAULT)
),
context
)
) {
@Override
public void setSubCollectors() throws IOException {
for (Aggregator aggregator : subAggregators) {
this.subCollectors.add(((StarTreePreComputeCollector) aggregator).getStarTreeBucketCollector(ctx, starTree, this));
}
}

SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues
.getDimensionValuesIterator(fieldName);

String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(
starTree.getField(),
"_doc_count",
MetricStat.DOC_COUNT.getTypeName()
);

SortedNumericStarTreeValuesIterator docCountsIterator = (SortedNumericStarTreeValuesIterator) starTreeValues
.getMetricValuesIterator(metricName);

// binary search the upper bound
int endLo = mid, endHi = hi;
while (endLo <= endHi) {
final int endMid = (endLo + endHi) >>> 1;
if (value < ranges[endMid].from) {
endHi = endMid - 1;
@Override
public void collectStarTreeEntry(int starTreeEntry, long owningBucketOrd) throws IOException {
if (!valuesIterator.advanceExact(starTreeEntry)) {
return;
}

for (int i = 0, count = valuesIterator.entryValueCount(); i < count; i++) {
long dimensionLongValue = valuesIterator.nextValue();
double dimensionValue;

// Only numeric & floating points are supported as of now in star-tree
// TODO: Add support for isBigInteger() when it gets supported in star-tree
if (valuesSource.isFloatingPoint()) {
dimensionValue = ((NumberFieldMapper.NumberFieldType) context.mapperService().fieldType(fieldName)).toDoubleValue(
dimensionLongValue
);
} else {
endLo = endMid + 1;
dimensionValue = dimensionLongValue;
}
}

assert startLo == lowBound || value >= maxTo[startLo - 1];
assert endHi == ranges.length - 1 || value < ranges[endHi + 1].from;
MatchedRange matchedRange = new MatchedRange(ranges, 0, dimensionValue);
if (matchedRange.startLo > matchedRange.endHi) {
continue; // No matching range
}

for (int i = startLo; i <= endHi; ++i) {
if (ranges[i].matches(value)) {
collectBucket(sub, doc, subBucketOrdinal(owningBucketOrdinal, i));
if (docCountsIterator.advanceExact(starTreeEntry)) {
long metricValue = docCountsIterator.nextValue();
for (int j = matchedRange.startLo; j <= matchedRange.endHi; ++j) {
if (ranges[j].matches(dimensionValue)) {
long bucketOrd = subBucketOrdinal(owningBucketOrd, j);
collectStarTreeBucket(this, metricValue, bucketOrd, starTreeEntry);
}
}
}
}

return endHi + 1;
}
};
}
Expand Down Expand Up @@ -421,6 +498,60 @@ public InternalAggregation buildEmptyAggregation() {
return rangeFactory.create(name, buckets, format, keyed, metadata());
}

class MatchedRange {
int startLo, endHi;

MatchedRange(RangeAggregator.Range[] ranges, int lowBound, double value) {
computeMatchingRange(ranges, lowBound, value);
}

private void computeMatchingRange(RangeAggregator.Range[] ranges, int lowBound, double value) {
int lo = lowBound, hi = ranges.length - 1;
int mid = (lo + hi) >>> 1;

while (lo <= hi) {
if (value < ranges[mid].from) {
hi = mid - 1;
} else if (value >= maxTo[mid]) {
lo = mid + 1;
} else {
break;
}
mid = (lo + hi) >>> 1;
}
if (lo > hi) {
this.startLo = lo;
this.endHi = lo - 1;
return;
}

// binary search the lower bound
int startLo = lo, startHi = mid;
while (startLo <= startHi) {
int startMid = (startLo + startHi) >>> 1;
if (value >= maxTo[startMid]) {
startLo = startMid + 1;
} else {
startHi = startMid - 1;
}
}

// binary search the upper bound
int endLo = mid, endHi = hi;
while (endLo <= endHi) {
int endMid = (endLo + endHi) >>> 1;
if (value < ranges[endMid].from) {
endHi = endMid - 1;
} else {
endLo = endMid + 1;
}
}

this.startLo = startLo;
this.endHi = endHi;
}
}

/**
* Unmapped range
*
Expand Down Expand Up @@ -456,7 +587,7 @@ public Unmapped(
public InternalAggregation buildEmptyAggregation() {
InternalAggregations subAggs = buildEmptySubAggregations();
List<org.opensearch.search.aggregations.bucket.range.Range.Bucket> buckets = new ArrayList<>(ranges.length);
for (RangeAggregator.Range range : ranges) {
for (Range range : ranges) {
buckets.add(factory.createBucket(range.key, range.from, range.to, 0, subAggs, keyed, format));
}
return factory.create(name, buckets, format, keyed, metadata());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregatorFactory;
import org.opensearch.search.aggregations.bucket.range.RangeAggregatorFactory;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregatorFactory;
import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory;
import org.opensearch.search.internal.SearchContext;
Expand Down Expand Up @@ -120,6 +121,10 @@ public boolean consolidateAllFilters(SearchContext context) {
continue;
}

// validation for range aggregation
if (validateRangeAggregationSupport(compositeMappedFieldType, aggregatorFactory)) {
continue;
}
// invalid query shape
return false;
}
Expand Down Expand Up @@ -184,6 +189,31 @@ private static boolean validateKeywordTermsAggregationSupport(
return true;
}

private static boolean validateRangeAggregationSupport(
CompositeDataCubeFieldType compositeIndexFieldInfo,
AggregatorFactory aggregatorFactory
) {
if (!(aggregatorFactory instanceof RangeAggregatorFactory rangeAggregatorFactory)) {
return false;
}

// Validate request field is part of dimensions
if (compositeIndexFieldInfo.getDimensions()
.stream()
.map(Dimension::getField)
.noneMatch(rangeAggregatorFactory.getField()::equals)) {
return false;
}

// Validate all sub-factories
for (AggregatorFactory subFactory : aggregatorFactory.getSubFactories().getFactories()) {
if (!validateStarTreeMetricSupport(compositeIndexFieldInfo, subFactory)) {
return false;
}
}
return true;
}

private StarTreeFilter getStarTreeFilter(
SearchContext context,
QueryBuilder queryBuilder,
Expand Down
Loading

0 comments on commit e2ef393

Please sign in to comment.