diff --git a/CHANGELOG-3.0.md b/CHANGELOG-3.0.md index 99b636822fb72..077d113e7d4b0 100644 --- a/CHANGELOG-3.0.md +++ b/CHANGELOG-3.0.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add filter function for AbstractQueryBuilder, BoolQueryBuilder, ConstantScoreQueryBuilder([#17409](https://github.com/opensearch-project/OpenSearch/pull/17409)) - [Star Tree] [Search] Resolving keyword & numeric bucket aggregation with metric aggregation using star-tree ([#17165](https://github.com/opensearch-project/OpenSearch/pull/17165)) - Added error handling support for the pull-based ingestion ([#17427](https://github.com/opensearch-project/OpenSearch/pull/17427)) - +- [Star Tree] [Search] Resolving numeric range aggregation with metric aggregation using star-tree ([#17273](https://github.com/opensearch-project/OpenSearch/pull/17273)) ### Dependencies - Update Apache Lucene to 10.1.0 ([#16366](https://github.com/opensearch-project/OpenSearch/pull/16366)) diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index c7303011b5800..efd37e3115b20 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -32,7 +32,9 @@ package org.opensearch.search.aggregations.bucket.range; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.util.FixedBitSet; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -43,7 +45,13 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.compositeindex.datacube.MetricStat; +import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues; +import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils; +import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator; import org.opensearch.index.fielddata.SortedNumericDoubleValues; +import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.search.DocValueFormat; import org.opensearch.search.aggregations.Aggregator; import org.opensearch.search.aggregations.AggregatorFactories; @@ -53,12 +61,17 @@ import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.LeafBucketCollectorBase; import org.opensearch.search.aggregations.NonCollectingAggregator; +import org.opensearch.search.aggregations.StarTreeBucketCollector; +import org.opensearch.search.aggregations.StarTreePreComputeCollector; import org.opensearch.search.aggregations.bucket.BucketsAggregator; import org.opensearch.search.aggregations.bucket.filterrewrite.FilterRewriteOptimizationContext; import org.opensearch.search.aggregations.bucket.filterrewrite.RangeAggregatorBridge; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.startree.StarTreeQueryHelper; +import org.opensearch.search.startree.StarTreeTraversalUtil; +import org.opensearch.search.startree.filter.DimensionFilter; import java.io.IOException; import java.util.ArrayList; @@ -70,16 +83,18 @@ import static org.opensearch.core.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.opensearch.search.aggregations.bucket.filterrewrite.AggregatorBridge.segmentMatchAll; +import static org.opensearch.search.startree.StarTreeQueryHelper.getSupportedStarTree; /** * Aggregate all docs that match given ranges. * * @opensearch.internal */ -public class RangeAggregator extends BucketsAggregator { +public class RangeAggregator extends BucketsAggregator implements StarTreePreComputeCollector { public static final ParseField RANGES_FIELD = new ParseField("ranges"); public static final ParseField KEYED_FIELD = new ParseField("keyed"); + public final String fieldName; /** * Range for the range aggregator @@ -298,6 +313,9 @@ protected Function bucketOrdProducer() { } }; filterRewriteOptimizationContext = new FilterRewriteOptimizationContext(bridge, parent, subAggregators.length, context); + this.fieldName = (valuesSource instanceof ValuesSource.Numeric.FieldData) + ? ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName() + : null; } @Override @@ -310,8 +328,13 @@ public ScoreMode scoreMode() { @Override protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { - if (segmentMatchAll(context, ctx)) { - return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false); + if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) { + return true; + } + CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); + if (supportedStarTree != null) { + preComputeWithStarTree(ctx, supportedStarTree); + return true; } return false; } @@ -333,52 +356,106 @@ public void collect(int doc, long bucket) throws IOException { } private int collect(int doc, double value, long owningBucketOrdinal, int lowBound) throws IOException { - int lo = lowBound, hi = ranges.length - 1; // all candidates are between these indexes - int mid = (lo + hi) >>> 1; - while (lo <= hi) { - if (value < ranges[mid].from) { - hi = mid - 1; - } else if (value >= maxTo[mid]) { - lo = mid + 1; - } else { - break; + MatchedRange range = new MatchedRange(ranges, lowBound, value); + for (int i = range.startLo; i <= range.endHi; ++i) { + if (ranges[i].matches(value)) { + collectBucket(sub, doc, subBucketOrdinal(owningBucketOrdinal, i)); } - mid = (lo + hi) >>> 1; } - if (lo > hi) return lo; // no potential candidate - - // binary search the lower bound - int startLo = lo, startHi = mid; - while (startLo <= startHi) { - final int startMid = (startLo + startHi) >>> 1; - if (value >= maxTo[startMid]) { - startLo = startMid + 1; - } else { - startHi = startMid - 1; - } + return range.endHi + 1; + } + }; + } + + private void preComputeWithStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { + StarTreeBucketCollector starTreeBucketCollector = getStarTreeBucketCollector(ctx, starTree, null); + FixedBitSet matchingDocsBitSet = starTreeBucketCollector.getMatchingDocsBitSet(); + + int numBits = matchingDocsBitSet.length(); + + if (numBits > 0) { + for (int bit = matchingDocsBitSet.nextSetBit(0); bit != DocIdSetIterator.NO_MORE_DOCS; bit = (bit + 1 < numBits) + ? matchingDocsBitSet.nextSetBit(bit + 1) + : DocIdSetIterator.NO_MORE_DOCS) { + starTreeBucketCollector.collectStarTreeEntry(bit, 0); + } + } + } + + @Override + public StarTreeBucketCollector getStarTreeBucketCollector( + LeafReaderContext ctx, + CompositeIndexFieldInfo starTree, + StarTreeBucketCollector parentCollector + ) throws IOException { + assert parentCollector == null; + StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree); + return new StarTreeBucketCollector( + starTreeValues, + StarTreeTraversalUtil.getStarTreeResult( + starTreeValues, + StarTreeQueryHelper.mergeDimensionFilterIfNotExists( + context.getQueryShardContext().getStarTreeQueryContext().getBaseQueryStarTreeFilter(), + fieldName, + List.of(DimensionFilter.MATCH_ALL_DEFAULT) + ), + context + ) + ) { + @Override + public void setSubCollectors() throws IOException { + for (Aggregator aggregator : subAggregators) { + this.subCollectors.add(((StarTreePreComputeCollector) aggregator).getStarTreeBucketCollector(ctx, starTree, this)); } + } + + SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator) starTreeValues + .getDimensionValuesIterator(fieldName); + + String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues( + starTree.getField(), + "_doc_count", + MetricStat.DOC_COUNT.getTypeName() + ); + + SortedNumericStarTreeValuesIterator docCountsIterator = (SortedNumericStarTreeValuesIterator) starTreeValues + .getMetricValuesIterator(metricName); - // binary search the upper bound - int endLo = mid, endHi = hi; - while (endLo <= endHi) { - final int endMid = (endLo + endHi) >>> 1; - if (value < ranges[endMid].from) { - endHi = endMid - 1; + @Override + public void collectStarTreeEntry(int starTreeEntry, long owningBucketOrd) throws IOException { + if (!valuesIterator.advanceExact(starTreeEntry)) { + return; + } + + for (int i = 0, count = valuesIterator.entryValueCount(); i < count; i++) { + long dimensionLongValue = valuesIterator.nextValue(); + double dimensionValue; + + // Only numeric & floating points are supported as of now in star-tree + // TODO: Add support for isBigInteger() when it gets supported in star-tree + if (valuesSource.isFloatingPoint()) { + dimensionValue = ((NumberFieldMapper.NumberFieldType) context.mapperService().fieldType(fieldName)).toDoubleValue( + dimensionLongValue + ); } else { - endLo = endMid + 1; + dimensionValue = dimensionLongValue; } - } - assert startLo == lowBound || value >= maxTo[startLo - 1]; - assert endHi == ranges.length - 1 || value < ranges[endHi + 1].from; + MatchedRange matchedRange = new MatchedRange(ranges, 0, dimensionValue); + if (matchedRange.startLo > matchedRange.endHi) { + continue; // No matching range + } - for (int i = startLo; i <= endHi; ++i) { - if (ranges[i].matches(value)) { - collectBucket(sub, doc, subBucketOrdinal(owningBucketOrdinal, i)); + if (docCountsIterator.advanceExact(starTreeEntry)) { + long metricValue = docCountsIterator.nextValue(); + for (int j = matchedRange.startLo; j <= matchedRange.endHi; ++j) { + if (ranges[j].matches(dimensionValue)) { + long bucketOrd = subBucketOrdinal(owningBucketOrd, j); + collectStarTreeBucket(this, metricValue, bucketOrd, starTreeEntry); + } + } } } - - return endHi + 1; } }; } @@ -421,6 +498,60 @@ public InternalAggregation buildEmptyAggregation() { return rangeFactory.create(name, buckets, format, keyed, metadata()); } + class MatchedRange { + int startLo, endHi; + + MatchedRange(RangeAggregator.Range[] ranges, int lowBound, double value) { + computeMatchingRange(ranges, lowBound, value); + } + + private void computeMatchingRange(RangeAggregator.Range[] ranges, int lowBound, double value) { + int lo = lowBound, hi = ranges.length - 1; + int mid = (lo + hi) >>> 1; + + while (lo <= hi) { + if (value < ranges[mid].from) { + hi = mid - 1; + } else if (value >= maxTo[mid]) { + lo = mid + 1; + } else { + break; + } + mid = (lo + hi) >>> 1; + } + if (lo > hi) { + this.startLo = lo; + this.endHi = lo - 1; + return; + } + + // binary search the lower bound + int startLo = lo, startHi = mid; + while (startLo <= startHi) { + int startMid = (startLo + startHi) >>> 1; + if (value >= maxTo[startMid]) { + startLo = startMid + 1; + } else { + startHi = startMid - 1; + } + } + + // binary search the upper bound + int endLo = mid, endHi = hi; + while (endLo <= endHi) { + int endMid = (endLo + endHi) >>> 1; + if (value < ranges[endMid].from) { + endHi = endMid - 1; + } else { + endLo = endMid + 1; + } + } + + this.startLo = startLo; + this.endHi = endHi; + } + } + /** * Unmapped range * @@ -456,7 +587,7 @@ public Unmapped( public InternalAggregation buildEmptyAggregation() { InternalAggregations subAggs = buildEmptySubAggregations(); List buckets = new ArrayList<>(ranges.length); - for (RangeAggregator.Range range : ranges) { + for (Range range : ranges) { buckets.add(factory.createBucket(range.key, range.from, range.to, 0, subAggs, keyed, format)); } return factory.create(name, buckets, format, keyed, metadata()); diff --git a/server/src/main/java/org/opensearch/search/startree/StarTreeQueryContext.java b/server/src/main/java/org/opensearch/search/startree/StarTreeQueryContext.java index a8f54f5793551..423646b1a07aa 100644 --- a/server/src/main/java/org/opensearch/search/startree/StarTreeQueryContext.java +++ b/server/src/main/java/org/opensearch/search/startree/StarTreeQueryContext.java @@ -21,6 +21,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.search.aggregations.AggregatorFactory; import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregatorFactory; +import org.opensearch.search.aggregations.bucket.range.RangeAggregatorFactory; import org.opensearch.search.aggregations.bucket.terms.TermsAggregatorFactory; import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory; import org.opensearch.search.internal.SearchContext; @@ -120,6 +121,10 @@ public boolean consolidateAllFilters(SearchContext context) { continue; } + // validation for range aggregation + if (validateRangeAggregationSupport(compositeMappedFieldType, aggregatorFactory)) { + continue; + } // invalid query shape return false; } @@ -184,6 +189,31 @@ private static boolean validateKeywordTermsAggregationSupport( return true; } + private static boolean validateRangeAggregationSupport( + CompositeDataCubeFieldType compositeIndexFieldInfo, + AggregatorFactory aggregatorFactory + ) { + if (!(aggregatorFactory instanceof RangeAggregatorFactory rangeAggregatorFactory)) { + return false; + } + + // Validate request field is part of dimensions + if (compositeIndexFieldInfo.getDimensions() + .stream() + .map(Dimension::getField) + .noneMatch(rangeAggregatorFactory.getField()::equals)) { + return false; + } + + // Validate all sub-factories + for (AggregatorFactory subFactory : aggregatorFactory.getSubFactories().getFactories()) { + if (!validateStarTreeMetricSupport(compositeIndexFieldInfo, subFactory)) { + return false; + } + } + return true; + } + private StarTreeFilter getStarTreeFilter( SearchContext context, QueryBuilder queryBuilder, diff --git a/server/src/test/java/org/opensearch/search/SearchServiceStarTreeTests.java b/server/src/test/java/org/opensearch/search/SearchServiceStarTreeTests.java index 95c877bfce0a8..ea62f528439ff 100644 --- a/server/src/test/java/org/opensearch/search/SearchServiceStarTreeTests.java +++ b/server/src/test/java/org/opensearch/search/SearchServiceStarTreeTests.java @@ -47,6 +47,7 @@ import org.opensearch.search.aggregations.SearchContextAggregations; import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; import org.opensearch.search.aggregations.metrics.MedianAbsoluteDeviationAggregationBuilder; @@ -71,6 +72,7 @@ import static org.opensearch.search.aggregations.AggregationBuilders.dateHistogram; import static org.opensearch.search.aggregations.AggregationBuilders.max; import static org.opensearch.search.aggregations.AggregationBuilders.medianAbsoluteDeviation; +import static org.opensearch.search.aggregations.AggregationBuilders.range; import static org.opensearch.search.aggregations.AggregationBuilders.sum; import static org.opensearch.search.aggregations.AggregationBuilders.terms; import static org.hamcrest.CoreMatchers.notNullValue; @@ -689,6 +691,113 @@ public void testQueryParsingForBucketAggregations() throws IOException { setStarTreeIndexSetting(null); } + /** + * Test query parsing for range aggregations, with/without numeric term query + */ + public void testQueryParsingForRangeAggregations() throws IOException { + FeatureFlags.initializeFeatureFlags(Settings.builder().put(FeatureFlags.STAR_TREE_INDEX, true).build()); + setStarTreeIndexSetting("true"); + + Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1) + .put(StarTreeIndexSettings.IS_COMPOSITE_INDEX_SETTING.getKey(), true) + .put(IndexMetadata.INDEX_APPEND_ONLY_ENABLED_SETTING.getKey(), true) + .build(); + CreateIndexRequestBuilder builder = client().admin() + .indices() + .prepareCreate("test") + .setSettings(settings) + .setMapping(NumericTermsAggregatorTests.getExpandedMapping(1, false)); + createIndex("test", builder); + + IndicesService indicesService = getInstanceFromNode(IndicesService.class); + IndexService indexService = indicesService.indexServiceSafe(resolveIndex("test")); + IndexShard indexShard = indexService.getShard(0); + ShardSearchRequest request = new ShardSearchRequest( + OriginalIndices.NONE, + new SearchRequest().allowPartialSearchResults(true), + indexShard.shardId(), + 1, + new AliasFilter(null, Strings.EMPTY_ARRAY), + 1.0f, + -1, + null, + null + ); + String KEYWORD_FIELD = "clientip"; + String NUMERIC_FIELD = "size"; + + MaxAggregationBuilder maxAggNoSub = max("max").field(FIELD_NAME); + SumAggregationBuilder sumAggSub = sum("sum").field(FIELD_NAME).subAggregation(maxAggNoSub); + MedianAbsoluteDeviationAggregationBuilder medianAgg = medianAbsoluteDeviation("median").field(FIELD_NAME); + + QueryBuilder baseQuery; + SearchContext searchContext = createSearchContext(indexService); + StarTreeFieldConfiguration starTreeFieldConfiguration = new StarTreeFieldConfiguration( + 1, + Collections.emptySet(), + StarTreeFieldConfiguration.StarTreeBuildMode.ON_HEAP + ); + + // Case 1: MatchAllQuery and non-nested metric aggregations is nested within range aggregation, should use star tree + RangeAggregationBuilder rangeAggregationBuilder = range("range").field(NUMERIC_FIELD).addRange(0, 10).subAggregation(maxAggNoSub); + baseQuery = new MatchAllQueryBuilder(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).query(baseQuery).aggregation(rangeAggregationBuilder); + + assertStarTreeContext( + request, + sourceBuilder, + getStarTreeQueryContext( + searchContext, + starTreeFieldConfiguration, + "startree1", + -1, + List.of(new NumericDimension(NUMERIC_FIELD), new OrdinalDimension(KEYWORD_FIELD)), + List.of(new Metric(FIELD_NAME, List.of(MetricStat.SUM, MetricStat.MAX))), + baseQuery, + sourceBuilder, + true + ), + -1 + ); + + // Case 2: NumericTermsQuery and non-nested metric aggregations is nested within range aggregation, should use star tree + rangeAggregationBuilder = range("range").field(NUMERIC_FIELD).addRange(0, 100).subAggregation(maxAggNoSub); + baseQuery = new TermQueryBuilder(FIELD_NAME, 1); + sourceBuilder = new SearchSourceBuilder().size(0).query(baseQuery).aggregation(rangeAggregationBuilder); + + assertStarTreeContext( + request, + sourceBuilder, + getStarTreeQueryContext( + searchContext, + starTreeFieldConfiguration, + "startree1", + -1, + List.of(new NumericDimension(NUMERIC_FIELD), new OrdinalDimension(KEYWORD_FIELD), new NumericDimension(FIELD_NAME)), + List.of(new Metric(FIELD_NAME, List.of(MetricStat.SUM, MetricStat.MAX))), + baseQuery, + sourceBuilder, + true + ), + -1 + ); + + // Case 3: Nested metric aggregations within range aggregation, should not use star tree + rangeAggregationBuilder = range("range").field(NUMERIC_FIELD).addRange(0, 100).subAggregation(sumAggSub); + sourceBuilder = new SearchSourceBuilder().size(0).query(new TermQueryBuilder(FIELD_NAME, 1)).aggregation(rangeAggregationBuilder); + assertStarTreeContext(request, sourceBuilder, null, -1); + + // Case 4: Unsupported aggregations within range aggregation, should not use star tree + rangeAggregationBuilder = range("range").field(NUMERIC_FIELD).addRange(0, 100).subAggregation(medianAgg); + sourceBuilder = new SearchSourceBuilder().size(0).query(new TermQueryBuilder(FIELD_NAME, 1)).aggregation(rangeAggregationBuilder); + assertStarTreeContext(request, sourceBuilder, null, -1); + + setStarTreeIndexSetting(null); + } + + private void setStarTreeIndexSetting(String value) { client().admin() .cluster() diff --git a/server/src/test/java/org/opensearch/search/aggregations/startree/RangeAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/startree/RangeAggregatorTests.java new file mode 100644 index 0000000000000..e7b1edd65682c --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/startree/RangeAggregatorTests.java @@ -0,0 +1,220 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; + +import com.carrotsearch.randomizedtesting.RandomizedTest; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene101.Lucene101Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.NumericUtils; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.CompositeIndexReader; +import org.opensearch.index.codec.composite.composite101.Composite101Codec; +import org.opensearch.index.codec.composite912.datacube.startree.StarTreeDocValuesFormatTests; +import org.opensearch.index.compositeindex.datacube.Dimension; +import org.opensearch.index.compositeindex.datacube.NumericDimension; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.AggregatorTestCase; +import org.opensearch.search.aggregations.bucket.range.InternalRange; +import org.opensearch.search.aggregations.bucket.range.RangeAggregationBuilder; +import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Random; + +import static org.opensearch.search.aggregations.AggregationBuilders.avg; +import static org.opensearch.search.aggregations.AggregationBuilders.count; +import static org.opensearch.search.aggregations.AggregationBuilders.max; +import static org.opensearch.search.aggregations.AggregationBuilders.min; +import static org.opensearch.search.aggregations.AggregationBuilders.range; +import static org.opensearch.search.aggregations.AggregationBuilders.sum; +import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; + +public class RangeAggregatorTests extends AggregatorTestCase { + final static String STATUS = "status"; + final static String SIZE = "size"; + private static final MappedFieldType STATUS_FIELD_TYPE = new NumberFieldMapper.NumberFieldType( + STATUS, + NumberFieldMapper.NumberType.LONG + ); + private static final MappedFieldType SIZE_FIELD_NAME = new NumberFieldMapper.NumberFieldType(SIZE, NumberFieldMapper.NumberType.FLOAT); + + @Before + public void setup() { + FeatureFlags.initializeFeatureFlags(Settings.builder().put(FeatureFlags.STAR_TREE_INDEX, true).build()); + } + + @After + public void teardown() throws IOException { + FeatureFlags.initializeFeatureFlags(Settings.EMPTY); + } + + protected Codec getCodec() { + final Logger testLogger = LogManager.getLogger(NumericTermsAggregatorTests.class); + MapperService mapperService; + try { + mapperService = StarTreeDocValuesFormatTests.createMapperService(NumericTermsAggregatorTests.getExpandedMapping(1, false)); + } catch (IOException e) { + throw new RuntimeException(e); + } + return new Composite101Codec(Lucene101Codec.Mode.BEST_SPEED, mapperService, testLogger); + } + + public void testRangeAggregation() throws IOException { + Directory directory = newDirectory(); + IndexWriterConfig conf = newIndexWriterConfig(null); + conf.setCodec(getCodec()); + conf.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), directory, conf); + + Random random = RandomizedTest.getRandom(); + int totalDocs = 100; + List docs = new ArrayList<>(); + long val; + + // Index 100 random documents + for (int i = 0; i < totalDocs; i++) { + Document doc = new Document(); + if (random.nextBoolean()) { + val = random.nextInt(100); // Random int between 0 and 99 for status + doc.add(new SortedNumericDocValuesField(STATUS, val)); + } + if (random.nextBoolean()) { + val = NumericUtils.doubleToSortableLong(random.nextInt(100) + 0.5f); + doc.add(new SortedNumericDocValuesField(SIZE, val)); + } + iw.addDocument(doc); + docs.add(doc); + } + + if (randomBoolean()) { + iw.forceMerge(1); + } + iw.close(); + + DirectoryReader ir = DirectoryReader.open(directory); + LeafReaderContext context = ir.leaves().get(0); + + SegmentReader reader = Lucene.segmentReader(context.reader()); + IndexSearcher indexSearcher = newSearcher(reader, false, false); + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + + List compositeIndexFields = starTreeDocValuesReader.getCompositeIndexFields(); + CompositeIndexFieldInfo starTree = compositeIndexFields.get(0); + + LinkedHashMap supportedDimensions = new LinkedHashMap<>(); + supportedDimensions.put(new NumericDimension(STATUS), STATUS_FIELD_TYPE); + supportedDimensions.put(new NumericDimension(SIZE), SIZE_FIELD_NAME); + + Query query = new MatchAllDocsQuery(); + QueryBuilder queryBuilder = null; + RangeAggregationBuilder rangeAggregationBuilder = range("range_agg").field(STATUS).addRange(10, 30).addRange(30, 50); + // no sub-aggregation + testCase(indexSearcher, query, queryBuilder, rangeAggregationBuilder, starTree, supportedDimensions); + + ValuesSourceAggregationBuilder[] aggBuilders = { + sum("_sum").field(SIZE), + max("_max").field(SIZE), + min("_min").field(SIZE), + count("_count").field(SIZE), + avg("_avg").field(SIZE) }; + + for (ValuesSourceAggregationBuilder aggregationBuilder : aggBuilders) { + query = new MatchAllDocsQuery(); + queryBuilder = null; + rangeAggregationBuilder = range("range_agg").field(STATUS).addRange(10, 30).addRange(30, 50).subAggregation(aggregationBuilder); + // sub-aggregation, no top level query + testCase(indexSearcher, query, queryBuilder, rangeAggregationBuilder, starTree, supportedDimensions); + + // Numeric-terms query with range aggregation + for (int cases = 0; cases < 100; cases++) { + // query of status field + String queryField = SIZE; + long queryValue = NumericUtils.floatToSortableInt(random.nextInt(50) + 0.5f); + query = SortedNumericDocValuesField.newSlowExactQuery(queryField, queryValue); + queryBuilder = new TermQueryBuilder(queryField, queryValue); + testCase(indexSearcher, query, queryBuilder, rangeAggregationBuilder, starTree, supportedDimensions); + } + } + + reader.close(); + directory.close(); + } + + private void testCase( + IndexSearcher indexSearcher, + Query query, + QueryBuilder queryBuilder, + RangeAggregationBuilder rangeAggregationBuilder, + CompositeIndexFieldInfo starTree, + LinkedHashMap supportedDimensions + ) throws IOException { + InternalRange starTreeAggregation = searchAndReduceStarTree( + createIndexSettings(), + indexSearcher, + query, + queryBuilder, + rangeAggregationBuilder, + starTree, + supportedDimensions, + null, + DEFAULT_MAX_BUCKETS, + false, + null, + true, + STATUS_FIELD_TYPE, + SIZE_FIELD_NAME + ); + + InternalRange defaultAggregation = searchAndReduceStarTree( + createIndexSettings(), + indexSearcher, + query, + queryBuilder, + rangeAggregationBuilder, + null, + null, + null, + DEFAULT_MAX_BUCKETS, + false, + null, + false, + STATUS_FIELD_TYPE, + SIZE_FIELD_NAME + ); + + assertEquals(defaultAggregation.getBuckets().size(), starTreeAggregation.getBuckets().size()); + assertEquals(defaultAggregation.getBuckets(), starTreeAggregation.getBuckets()); + } +}