Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CumCountWhere() and RollingCountWhere() features to UpdateBy #6566

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* {@link UpdateByOperator#initializeRolling(Context, RowSet)} (Context)} for windowed operators</li>
* <li>{@link UpdateByOperator.Context#accumulateCumulative(RowSequence, Chunk[], LongChunk, int)} for cumulative
* operators or
* {@link UpdateByOperator.Context#accumulateRolling(RowSequence, Chunk[], LongChunk, LongChunk, IntChunk, IntChunk, int)}
* {@link UpdateByOperator.Context#accumulateRolling(RowSequence, Chunk[], LongChunk, LongChunk, IntChunk, IntChunk, int, int)}
* for windowed operators</li>
* <li>{@link #finishUpdate(UpdateByOperator.Context)}</li>
* </ol>
Expand Down Expand Up @@ -99,18 +99,21 @@ protected void pop(int count) {
throw new UnsupportedOperationException("pop() must be overriden by rolling operators");
}

public abstract void accumulateCumulative(RowSequence inputKeys,
public abstract void accumulateCumulative(
RowSequence inputKeys,
Chunk<? extends Values>[] valueChunkArr,
LongChunk<? extends Values> tsChunk,
int len);

public abstract void accumulateRolling(RowSequence inputKeys,
public abstract void accumulateRolling(
RowSequence inputKeys,
Chunk<? extends Values>[] influencerValueChunkArr,
LongChunk<OrderedRowKeys> affectedPosChunk,
LongChunk<OrderedRowKeys> influencerPosChunk,
IntChunk<? extends Values> pushChunk,
IntChunk<? extends Values> popChunk,
int len);
int affectedCount,
int influencerCount);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method parameters should have javadoc.


/**
* Write the current value for this row to the output chunk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@
import io.deephaven.api.updateby.UpdateByControl;
import io.deephaven.api.updateby.UpdateByOperation;
import io.deephaven.api.updateby.spec.*;
import io.deephaven.base.verify.Require;
import io.deephaven.chunk.ChunkType;
import io.deephaven.engine.rowset.RowSetFactory;
import io.deephaven.engine.table.ColumnDefinition;
import io.deephaven.engine.table.ColumnSource;
import io.deephaven.engine.table.Table;
import io.deephaven.engine.table.TableDefinition;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.QueryCompilerRequestProcessor;
import io.deephaven.engine.table.impl.QueryTable;
import io.deephaven.engine.table.impl.select.FormulaColumn;
import io.deephaven.engine.table.impl.select.SelectColumn;
import io.deephaven.engine.table.impl.select.WhereFilter;
import io.deephaven.engine.table.impl.sources.NullValueColumnSource;
import io.deephaven.engine.table.impl.updateby.countwhere.CountWhereOperator;
import io.deephaven.engine.table.impl.updateby.delta.*;
import io.deephaven.engine.table.impl.updateby.em.*;
import io.deephaven.engine.table.impl.updateby.emstd.*;
Expand Down Expand Up @@ -414,6 +423,12 @@ public Void visit(CumProdSpec cps) {
return null;
}

@Override
public Void visit(CumCountWhereSpec spec) {
ops.add(makeCountWhereOperator(tableDef, spec));
return null;
}

@Override
public Void visit(@NotNull final DeltaSpec spec) {
Arrays.stream(pairs)
Expand Down Expand Up @@ -537,6 +552,12 @@ public Void visit(@NotNull final RollingCountSpec spec) {
return null;
}

@Override
public Void visit(@NotNull final RollingCountWhereSpec spec) {
ops.add(makeCountWhereOperator(tableDef, spec));
return null;
}

@Override
public Void visit(@NotNull final RollingFormulaSpec spec) {
final boolean isTimeBased = spec.revWindowScale().isTimeBased();
Expand Down Expand Up @@ -1240,6 +1261,122 @@ private UpdateByOperator makeRollingCountOperator(@NotNull final MatchPair pair,
}
}

/**
* This is used for Cum/Rolling CountWhere operators
*/
private UpdateByOperator makeCountWhereOperator(
@NotNull final TableDefinition tableDef,
@NotNull final UpdateBySpec spec) {

Require.eqTrue(spec instanceof CumCountWhereSpec || spec instanceof RollingCountWhereSpec,
"spec instanceof CumCountWhereSpec || spec instanceof RollingCountWhereSpec");

final boolean isCumulative = spec instanceof CumCountWhereSpec;

final WhereFilter[] whereFilters = isCumulative
? WhereFilter.fromInternal(((CumCountWhereSpec) spec).filter())
: WhereFilter.fromInternal(((RollingCountWhereSpec) spec).filter());

final List<String> inputColumnNameList = new ArrayList<>();
final Map<String, Integer> inputColumnMap = new HashMap<>();
final List<int[]> filterInputColumnIndicesList = new ArrayList<>();

// Verify all the columns in the where filters are present in the table def and valid for use.
for (final WhereFilter whereFilter : whereFilters) {
whereFilter.init(tableDef);
if (whereFilter.isRefreshing()) {
throw new UnsupportedOperationException("CountWhere does not support refreshing filters");
}

// Compute which input sources this filter will use.
final List<String> filterColumnName = whereFilter.getColumns();
final int inputColumnCount = whereFilter.getColumns().size();
final int[] inputColumnIndices = new int[inputColumnCount];
for (int ii = 0; ii < inputColumnCount; ++ii) {
final String inputColumnName = filterColumnName.get(ii);
final int inputColumnIndex = inputColumnMap.computeIfAbsent(inputColumnName, k -> {
inputColumnNameList.add(inputColumnName);
return inputColumnNameList.size() - 1;
});
inputColumnIndices[ii] = inputColumnIndex;
}
filterInputColumnIndicesList.add(inputColumnIndices);
}

// Gather the input column type info.
final String[] inputColumnNames = inputColumnNameList.toArray(String[]::new);
final ChunkType[] inputChunkTypes = new ChunkType[inputColumnNames.length];
final Class<?>[] inputColumnTypes = new Class[inputColumnNames.length];
final Class<?>[] inputComponentTypes = new Class[inputColumnNames.length];
for (int i = 0; i < inputColumnNames.length; i++) {
final ColumnDefinition<?> columnDef = tableDef.getColumn(inputColumnNames[i]);
inputColumnTypes[i] = columnDef.getDataType();
inputChunkTypes[i] = ChunkType.fromElementType(inputColumnTypes[i]);
inputComponentTypes[i] = columnDef.getComponentType();
}

// Create a dummy table we can use to initialize filters.
final Map<String, ColumnSource<?>> columnSourceMap = new LinkedHashMap<>();
for (int i = 0; i < inputColumnNames.length; i++) {
final ColumnDefinition<?> columnDef = tableDef.getColumn(inputColumnNames[i]);
final ColumnSource<?> source =
NullValueColumnSource.getInstance(columnDef.getDataType(), columnDef.getComponentType());
columnSourceMap.put(inputColumnNames[i], source);
}
final Table dummyTable = new QueryTable(RowSetFactory.empty().toTracking(), columnSourceMap);

final CountWhereOperator.CountFilter[] countFilters =
CountWhereOperator.CountFilter.createCountFilters(whereFilters, dummyTable,
filterInputColumnIndicesList);

// If any filter is a standard WhereFilter, we need a chunk source table.
final boolean chunkSourceTableRequired =
Arrays.asList(countFilters).stream().anyMatch(filter -> filter.whereFilter() != null);

// Create a new column pair with the same name for the left and right columns
final String columnName = isCumulative
? ((CumCountWhereSpec) spec).column().name()
: ((RollingCountWhereSpec) spec).column().name();
final MatchPair pair = new MatchPair(columnName, columnName);

// Create and return the operator.
if (isCumulative) {
return new CountWhereOperator(
pair,
countFilters,
inputColumnNames,
inputChunkTypes,
inputColumnTypes,
inputComponentTypes,
chunkSourceTableRequired);
} else {
final RollingCountWhereSpec rs = (RollingCountWhereSpec) spec;

final String[] affectingColumns;
if (rs.revWindowScale().timestampCol() == null) {
affectingColumns = inputColumnNames;
} else {
affectingColumns = ArrayUtils.add(inputColumnNames, rs.revWindowScale().timestampCol());
}

final long prevWindowScaleUnits = rs.revWindowScale().getTimeScaleUnits();
final long fwdWindowScaleUnits = rs.fwdWindowScale().getTimeScaleUnits();

return new CountWhereOperator(
pair,
affectingColumns,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits,
fwdWindowScaleUnits,
countFilters,
inputColumnNames,
inputChunkTypes,
inputColumnTypes,
inputComponentTypes,
chunkSourceTableRequired);
}
}

private UpdateByOperator makeRollingStdOperator(@NotNull final MatchPair pair,
@NotNull final TableDefinition tableDef,
@NotNull final RollingStdSpec rs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context,
influencePosChunk,
ctx.pushChunks[affectedChunkOffset],
ctx.popChunks[affectedChunkOffset],
affectedChunkSize);
affectedChunkSize,
influencerCount);
}

affectedChunkOffset++;
Expand Down
Loading