Skip to content

Commit

Permalink
fix: correct instantiation of rollup count_where operator. (#6502)
Browse files Browse the repository at this point in the history
#6497 merged with an
error where the `rollup` `count_where` aggregation was aliased to the
regular `count_` aggregation.

This PR corrects that mistake.
  • Loading branch information
lbooker42 authored Dec 30, 2024
1 parent 0ae8509 commit 492121d
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -651,45 +651,8 @@ final void addWeightedAvgOrSumOperator(
addOperator(resultOperator, r.source, r.pair.input().name(), weightName);
});
}
}

// -----------------------------------------------------------------------------------------------------------------
// Standard Aggregations
// -----------------------------------------------------------------------------------------------------------------

/**
* Implementation class for conversion from a collection of {@link Aggregation aggregations} to an
* {@link AggregationContext} for standard aggregations. Accumulates state by visiting each aggregation.
*/
private final class NormalConverter extends Converter {
private final QueryCompilerRequestProcessor.BatchProcessor compilationProcessor;

private NormalConverter(
@NotNull final Table table,
final boolean requireStateChangeRecorder,
@NotNull final String... groupByColumnNames) {
super(table, requireStateChangeRecorder, groupByColumnNames);
this.compilationProcessor = QueryCompilerRequestProcessor.batch();
}

@Override
AggregationContext build() {
final AggregationContext resultContext = super.build();
compilationProcessor.compile();
return resultContext;
}

// -------------------------------------------------------------------------------------------------------------
// Aggregation.Visitor
// -------------------------------------------------------------------------------------------------------------

@Override
public void visit(@NotNull final Count count) {
addNoInputOperator(new CountAggregationOperator(count.column().name()));
}

@Override
public void visit(@NotNull final CountWhere countWhere) {
final void addCountWhereOperator(@NotNull CountWhere countWhere) {
final WhereFilter[] whereFilters = WhereFilter.fromInternal(countWhere.filter());

final Map<String, RecordingInternalOperator> inputColumnRecorderMap = new HashMap<>();
Expand Down Expand Up @@ -737,6 +700,47 @@ public void visit(@NotNull final CountWhere countWhere) {
addOperator(new CountWhereOperator(countWhere.column().name(), whereFilters, recorders, filterRecorders),
null, inputColumnNames);
}
}

// -----------------------------------------------------------------------------------------------------------------
// Standard Aggregations
// -----------------------------------------------------------------------------------------------------------------

/**
* Implementation class for conversion from a collection of {@link Aggregation aggregations} to an
* {@link AggregationContext} for standard aggregations. Accumulates state by visiting each aggregation.
*/
private final class NormalConverter extends Converter {
private final QueryCompilerRequestProcessor.BatchProcessor compilationProcessor;

private NormalConverter(
@NotNull final Table table,
final boolean requireStateChangeRecorder,
@NotNull final String... groupByColumnNames) {
super(table, requireStateChangeRecorder, groupByColumnNames);
this.compilationProcessor = QueryCompilerRequestProcessor.batch();
}

@Override
AggregationContext build() {
final AggregationContext resultContext = super.build();
compilationProcessor.compile();
return resultContext;
}

// -------------------------------------------------------------------------------------------------------------
// Aggregation.Visitor
// -------------------------------------------------------------------------------------------------------------

@Override
public void visit(@NotNull final Count count) {
addNoInputOperator(new CountAggregationOperator(count.column().name()));
}

@Override
public void visit(@NotNull final CountWhere countWhere) {
addCountWhereOperator(countWhere);
}

@Override
public void visit(@NotNull final FirstRowKey firstRowKey) {
Expand Down Expand Up @@ -1051,7 +1055,7 @@ public void visit(@NotNull final Count count) {

@Override
public void visit(@NotNull final CountWhere countWhere) {
addNoInputOperator(new CountAggregationOperator(countWhere.column().name()));
addCountWhereOperator(countWhere);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
//
// Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
//
package io.deephaven.engine.table.impl;

import io.deephaven.api.agg.Aggregation;
import io.deephaven.engine.table.Table;
import io.deephaven.engine.table.hierarchical.RollupTable;
import io.deephaven.engine.testutil.*;
import io.deephaven.engine.testutil.generator.*;
import io.deephaven.engine.testutil.testcase.RefreshingTableTestCase;
import io.deephaven.test.types.OutOfBandTest;
import org.junit.Test;
import org.junit.experimental.categories.Category;

import java.util.*;

import static io.deephaven.api.agg.Aggregation.*;
import static io.deephaven.engine.testutil.TstUtils.*;

@Category(OutOfBandTest.class)
public class TestRollup extends RefreshingTableTestCase {
// This is the list of supported aggregations for rollup. These are all using `intCol` as the column to aggregate
// because the re-aggregation logic is effectively the same for all column types.
private final Collection<Aggregation> aggs = List.of(
AggAbsSum("absSum=intCol"),
AggAvg("avg=intCol"),
AggCount("count"),
AggCountWhere("countWhere", "intCol > 50"),
AggCountDistinct("countDistinct=intCol"),
AggDistinct("distinct=intCol"),
AggFirst("first=intCol"),
AggLast("last=intCol"),
AggMax("max=intCol"),
AggMin("min=intCol"),
AggSortedFirst("Sym", "firstSorted=intCol"),
AggSortedLast("Sym", "lastSorted=intCol"),
AggStd("std=intCol"),
AggSum("sum=intCol"),
AggUnique("unique=intCol"),
AggVar("var=intCol"),
AggWAvg("intCol", "wavg=intCol"),
AggWSum("intCol", "wsum=intCol"));

// Companion list of columns to compare between rollup root and the zero-key equivalent
private final String[] columnsToCompare = new String[] {
"absSum",
"avg",
"count",
"countWhere",
"countDistinct",
"distinct",
"first",
"last",
"max",
"min",
"firstSorted",
"lastSorted",
"std",
"sum",
"unique",
"var",
"wavg",
"wsum"
};

@SuppressWarnings("rawtypes")
private final ColumnInfo[] columnInfo = initColumnInfos(
new String[] {"Sym", "intCol"},
new SetGenerator<>("a", "b", "c", "d"),
new IntGenerator(10, 100));

private QueryTable createTable(boolean refreshing, int size, Random random) {
return getTable(refreshing, size, random, columnInfo);
}

@Override
public void setUp() throws Exception {
super.setUp();
}

@Test
public void testRollup() {
final Random random = new Random(0);
// Create the test table
final Table testTable = createTable(false, 100_000, random);

final RollupTable rollupTable = testTable.rollup(aggs, false, "Sym");
final Table rootTable = rollupTable.getRoot();

final Table actual = rootTable.select(columnsToCompare);
final Table expected = testTable.aggBy(aggs);

// Compare the zero-key equivalent table to the rollup table root
TstUtils.assertTableEquals(actual, expected);
}

@Test
public void testRollupIncremental() {
for (int size = 10; size <= 1000; size *= 10) {
testRollupIncrementalInternal("size-" + size, size);
}
}

private void testRollupIncrementalInternal(final String ctxt, final int size) {
final Random random = new Random(0);

final QueryTable testTable = createTable(true, size * 10, random);
EvalNuggetInterface[] en = new EvalNuggetInterface[] {
new QueryTableTest.TableComparator(
testTable.rollup(aggs, false, "Sym")
.getRoot().select(columnsToCompare),
testTable.aggBy(aggs))
};

final int steps = 100;
for (int step = 0; step < steps; step++) {
if (RefreshingTableTestCase.printTableUpdates) {
System.out.println("Step = " + step);
}
simulateShiftAwareStep(ctxt + " step == " + step, size, random, testTable, columnInfo, en);
}
}
}
3 changes: 2 additions & 1 deletion py/server/tests/test_rollup_tree_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

from deephaven import read_csv, empty_table
from deephaven.agg import sum_, avg, count_, first, last, max_, min_, std, abs_sum, \
from deephaven.agg import sum_, avg, count_, count_where, first, last, max_, min_, std, abs_sum, \
var
from deephaven.filters import Filter
from deephaven.table import NodeType
Expand All @@ -18,6 +18,7 @@ def setUp(self):
self.aggs_for_rollup = [
avg(["aggAvg=var"]),
count_("aggCount"),
count_where("aggCountWhere", "var > 0"),
first(["aggFirst=var"]),
last(["aggLast=var"]),
max_(["aggMax=var"]),
Expand Down

0 comments on commit 492121d

Please sign in to comment.