From 5d4487998314bd3292218ae7b76c36da4a874e4e Mon Sep 17 00:00:00 2001 From: Larry Booker Date: Mon, 18 Nov 2024 19:06:04 -0800 Subject: [PATCH] feat!: provide key columns as scalars (vs. vectors) to `RollingFormula` (#6375) ### Example: NOTE: `Sym` is a key column and is constant for each bucket. It is presented to the UDF as a string (not a vector). `intCol` / `longCol` are vectors containing the window data. ``` t_out = t.updateBy(UpdateByOperation.RollingFormula(prevTicks, postTicks, "out_val=sum(intCol) - max(longCol) + (Sym == null ? 0 : Sym.length())"), "Sym"); ``` --- .../BucketedPartitionedUpdateByManager.java | 11 +- .../impl/updateby/UpdateByBucketHelper.java | 13 +- .../table/impl/updateby/UpdateByOperator.java | 31 ++++- .../updateby/UpdateByOperatorFactory.java | 70 +++++++---- .../table/impl/updateby/UpdateByWindow.java | 13 +- .../updateby/UpdateByWindowCumulative.java | 17 ++- .../updateby/UpdateByWindowRollingBase.java | 10 +- .../updateby/UpdateByWindowRollingTicks.java | 17 ++- .../updateby/UpdateByWindowRollingTime.java | 10 +- .../impl/updateby/ZeroKeyUpdateByManager.java | 7 +- .../RollingFormulaMultiColumnOperator.java | 112 +++++++++++++----- .../impl/updateby/TestRollingFormula.java | 47 ++++++++ 12 files changed, 266 insertions(+), 92 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/BucketedPartitionedUpdateByManager.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/BucketedPartitionedUpdateByManager.java index 83461f2a75b..aad2b027925 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/BucketedPartitionedUpdateByManager.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/BucketedPartitionedUpdateByManager.java @@ -3,7 +3,6 @@ // package io.deephaven.engine.table.impl.updateby; -import io.deephaven.UncheckedDeephavenException; import io.deephaven.api.updateby.UpdateByControl; import io.deephaven.base.verify.Assert; import io.deephaven.engine.exceptions.CancellationException; @@ -83,9 +82,12 @@ class BucketedPartitionedUpdateByManager extends UpdateBy { final PartitionedTable partitioned = source.partitionedAggBy(List.of(), true, null, byColumnNames); final PartitionedTable transformed = partitioned.transform(t -> { final long firstSourceRowKey = t.getRowSet().firstRowKey(); + final Object[] bucketKeyValues = Arrays.stream(byColumnNames) + .map(colName -> t.getColumnSource(colName).get(firstSourceRowKey)) + .toArray(); final String bucketDescription = BucketedPartitionedUpdateByManager.this + "-bucket-" + - Arrays.stream(byColumnNames) - .map(bcn -> Objects.toString(t.getColumnSource(bcn).get(firstSourceRowKey))) + Arrays.stream(bucketKeyValues) + .map(Objects::toString) .collect(Collectors.joining(", ", "[", "]")); UpdateByBucketHelper bucket = new UpdateByBucketHelper( bucketDescription, @@ -94,7 +96,8 @@ class BucketedPartitionedUpdateByManager extends UpdateBy { resultSources, timestampColumnName, control, - this::onBucketFailure); + this::onBucketFailure, + bucketKeyValues); // add this to the bucket list synchronized (buckets) { buckets.offer(bucket); diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByBucketHelper.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByBucketHelper.java index 17bff00fec5..3c99b57a5b7 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByBucketHelper.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByBucketHelper.java @@ -50,6 +50,9 @@ class UpdateByBucketHelper extends IntrusiveDoublyLinkedNode.Impl timestampColumnSource; private final ModifiedColumnSet timestampColumnSet; + /** Store boxed key values for this bucket */ + private final Object[] bucketKeyValues; + /** Indicates this bucket needs to be processed (at least one window and operator are dirty) */ private boolean isDirty; /** This rowset will store row keys where the timestamp is not null (will mirror the SSA contents) */ @@ -65,8 +68,9 @@ class UpdateByBucketHelper extends IntrusiveDoublyLinkedNode.Impl> resultSources, @Nullable final String timestampColumnName, @NotNull final UpdateByControl control, - @NotNull final BiConsumer failureNotifier) { + @NotNull final BiConsumer failureNotifier, + @NotNull final Object[] bucketKeyValues) { this.description = description; this.source = source; // some columns will have multiple inputs, such as time-based and Weighted computations this.windows = windows; this.control = control; this.failureNotifier = failureNotifier; + this.bucketKeyValues = bucketKeyValues; result = new QueryTable(source.getRowSet(), resultSources); @@ -331,7 +337,8 @@ public void prepareForUpdate(final TableUpdate upstream, final boolean initialSt timestampValidRowSet, timestampsModified, control.chunkCapacityOrDefault(), - initialStep); + initialStep, + bucketKeyValues); // compute the affected/influenced operators and rowsets within this window windows[winIdx].computeAffectedRowsAndOperators(windowContexts[winIdx], upstream); diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperator.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperator.java index 95ef56f059e..ab36bbb757e 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperator.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperator.java @@ -159,19 +159,46 @@ protected UpdateByOperator( */ public abstract void initializeSources(@NotNull Table source, @Nullable RowRedirection rowRedirection); + /** + * Initialize the bucket context for a cumulative operator and pass in the bucket key values. Most operators will + * not need the key values, but those that do can override this method. + */ + public void initializeCumulativeWithKeyValues( + @NotNull final Context context, + final long firstUnmodifiedKey, + final long firstUnmodifiedTimestamp, + @NotNull final RowSet bucketRowSet, + @NotNull Object[] bucketKeyValues) { + initializeCumulative(context, firstUnmodifiedKey, firstUnmodifiedTimestamp, bucketRowSet); + } + /** * Initialize the bucket context for a cumulative operator */ - public void initializeCumulative(@NotNull final Context context, final long firstUnmodifiedKey, + public void initializeCumulative( + @NotNull final Context context, + final long firstUnmodifiedKey, final long firstUnmodifiedTimestamp, @NotNull final RowSet bucketRowSet) { context.reset(); } + /** + * Initialize the bucket context for a windowed operator and pass in the bucket key values. Most operators will not + * need the key values, but those that do can override this method. + */ + public void initializeRollingWithKeyValues( + @NotNull final Context context, + @NotNull final RowSet bucketRowSet, + @NotNull Object[] bucketKeyValues) { + initializeRolling(context, bucketRowSet); + } + /** * Initialize the bucket context for a windowed operator */ - public void initializeRolling(@NotNull final Context context, + public void initializeRolling( + @NotNull final Context context, @NotNull final RowSet bucketRowSet) { context.reset(); } diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperatorFactory.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperatorFactory.java index 2463c7413c3..671e87bdd01 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperatorFactory.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperatorFactory.java @@ -59,7 +59,7 @@ public class UpdateByOperatorFactory { private final MatchPair[] groupByColumns; @NotNull private final UpdateByControl control; - private Map> vectorColumnNameMap; + private Map> vectorColumnDefinitions; public UpdateByOperatorFactory( @NotNull final TableDefinition tableDef, @@ -1437,7 +1437,6 @@ private UpdateByOperator makeRollingFormulaOperator(@NotNull final MatchPair pai private UpdateByOperator makeRollingFormulaMultiColumnOperator( @NotNull final TableDefinition tableDef, @NotNull final RollingFormulaSpec rs) { - final long prevWindowScaleUnits = rs.revWindowScale().getTimeScaleUnits(); final long fwdWindowScaleUnits = rs.fwdWindowScale().getTimeScaleUnits(); @@ -1446,21 +1445,24 @@ private UpdateByOperator makeRollingFormulaMultiColumnOperator( // Create the colum final SelectColumn selectColumn = SelectColumn.of(Selectable.parse(rs.formula())); - // Get or create a column definition map where the definitions are vectors of the original column types. - if (vectorColumnNameMap == null) { - vectorColumnNameMap = new HashMap<>(); - columnDefinitionMap.forEach((key, value) -> { - final ColumnDefinition columnDef = ColumnDefinition.fromGenericType( - key, - VectorFactory.forElementType(value.getDataType()).vectorType(), - value.getDataType()); - vectorColumnNameMap.put(key, columnDef); - }); + // Get or create a column definition map composed of vectors of the original column types (or scalars when + // part of the group_by columns). + final Set groupByColumnSet = + Arrays.stream(groupByColumns).map(MatchPair::rightColumn).collect(Collectors.toSet()); + if (vectorColumnDefinitions == null) { + vectorColumnDefinitions = tableDef.getColumnStream().collect(Collectors.toMap( + ColumnDefinition::getName, + (final ColumnDefinition cd) -> groupByColumnSet.contains(cd.getName()) + ? cd + : ColumnDefinition.fromGenericType( + cd.getName(), + VectorFactory.forElementType(cd.getDataType()).vectorType(), + cd.getDataType()))); } - // Get the input column names and data types from the formula. - final String[] inputColumnNames = - selectColumn.initDef(vectorColumnNameMap, compilationProcessor).toArray(String[]::new); + // Get the input column names from the formula and provide them to the rolling formula operator + final String[] allInputColumns = + selectColumn.initDef(vectorColumnDefinitions, compilationProcessor).toArray(String[]::new); if (!selectColumn.getColumnArrays().isEmpty()) { throw new IllegalArgumentException("RollingFormulaMultiColumnOperator does not support column arrays (" + selectColumn.getColumnArrays() + ")"); @@ -1468,20 +1470,33 @@ private UpdateByOperator makeRollingFormulaMultiColumnOperator( if (selectColumn.hasVirtualRowVariables()) { throw new IllegalArgumentException("RollingFormula does not support virtual row variables"); } - final Class[] inputColumnTypes = new Class[inputColumnNames.length]; - final Class[] inputVectorTypes = new Class[inputColumnNames.length]; - for (int i = 0; i < inputColumnNames.length; i++) { - final ColumnDefinition columnDef = columnDefinitionMap.get(inputColumnNames[i]); - inputColumnTypes[i] = columnDef.getDataType(); - inputVectorTypes[i] = vectorColumnNameMap.get(inputColumnNames[i]).getDataType(); + final Map> partitioned = Arrays.stream(allInputColumns) + .collect(Collectors.partitioningBy(groupByColumnSet::contains)); + final String[] inputKeyColumns = partitioned.get(true).toArray(String[]::new); + final String[] inputNonKeyColumns = partitioned.get(false).toArray(String[]::new); + + final Class[] inputKeyColumnTypes = new Class[inputKeyColumns.length]; + final Class[] inputKeyComponentTypes = new Class[inputKeyColumns.length]; + for (int i = 0; i < inputKeyColumns.length; i++) { + final ColumnDefinition columnDef = columnDefinitionMap.get(inputKeyColumns[i]); + inputKeyColumnTypes[i] = columnDef.getDataType(); + inputKeyComponentTypes[i] = columnDef.getComponentType(); + } + + final Class[] inputNonKeyColumnTypes = new Class[inputNonKeyColumns.length]; + final Class[] inputNonKeyVectorTypes = new Class[inputNonKeyColumns.length]; + for (int i = 0; i < inputNonKeyColumns.length; i++) { + final ColumnDefinition columnDef = columnDefinitionMap.get(inputNonKeyColumns[i]); + inputNonKeyColumnTypes[i] = columnDef.getDataType(); + inputNonKeyVectorTypes[i] = vectorColumnDefinitions.get(inputNonKeyColumns[i]).getDataType(); } final String[] affectingColumns; if (rs.revWindowScale().timestampCol() == null) { - affectingColumns = inputColumnNames; + affectingColumns = inputNonKeyColumns; } else { - affectingColumns = ArrayUtils.add(inputColumnNames, rs.revWindowScale().timestampCol()); + affectingColumns = ArrayUtils.add(inputNonKeyColumns, rs.revWindowScale().timestampCol()); } // Create a new column pair with the same name for the left and right columns @@ -1494,9 +1509,12 @@ private UpdateByOperator makeRollingFormulaMultiColumnOperator( prevWindowScaleUnits, fwdWindowScaleUnits, selectColumn, - inputColumnNames, - inputColumnTypes, - inputVectorTypes); + inputKeyColumns, + inputKeyColumnTypes, + inputKeyComponentTypes, + inputNonKeyColumns, + inputNonKeyColumnTypes, + inputNonKeyVectorTypes); } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindow.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindow.java index 501b37fd67f..5825e6900ca 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindow.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindow.java @@ -47,7 +47,8 @@ static class UpdateByWindowBucketContext implements SafeCloseable { protected final boolean timestampsModified; /** Whether this is the creation phase of this window */ protected final boolean initialStep; - + /** Store the key values for this bucket */ + protected final Object[] bucketKeyValues; /** An array of ColumnSources for each underlying operator */ protected ColumnSource[] inputSources; @@ -71,12 +72,14 @@ static class UpdateByWindowBucketContext implements SafeCloseable { final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean initialStep) { + final boolean initialStep, + @NotNull final Object[] bucketKeyValues) { this.sourceRowSet = sourceRowSet; this.timestampColumnSource = timestampColumnSource; this.timestampSsa = timestampSsa; this.timestampValidRowSet = timestampValidRowSet; this.timestampsModified = timestampsModified; + this.bucketKeyValues = bucketKeyValues; workingChunkSize = chunkSize; this.initialStep = initialStep; @@ -91,13 +94,15 @@ public void close() { } } - abstract UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet sourceRowSet, + abstract UpdateByWindowBucketContext makeWindowContext( + final TrackingRowSet sourceRowSet, final ColumnSource timestampColumnSource, final LongSegmentedSortedArray timestampSsa, final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean isInitializeStep); + final boolean isInitializeStep, + final Object[] bucketKeyValues); UpdateByWindow(final UpdateByOperator[] operators, final int[][] operatorInputSourceSlots, diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowCumulative.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowCumulative.java index fc0dba46273..8702abf2ede 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowCumulative.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowCumulative.java @@ -55,9 +55,17 @@ UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet sourceRowSet, final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean isInitializeStep) { - return new UpdateByWindowBucketContext(sourceRowSet, timestampColumnSource, timestampSsa, timestampValidRowSet, - timestampsModified, chunkSize, isInitializeStep); + final boolean isInitializeStep, + final Object[] bucketKeyValues) { + return new UpdateByWindowBucketContext( + sourceRowSet, + timestampColumnSource, + timestampSsa, + timestampValidRowSet, + timestampsModified, + chunkSize, + isInitializeStep, + bucketKeyValues); } @Override @@ -192,7 +200,8 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context, continue; } UpdateByOperator cumOp = operators[opIdx]; - cumOp.initializeCumulative(winOpContexts[ii], rowKey, timestamp, context.sourceRowSet); + cumOp.initializeCumulativeWithKeyValues(winOpContexts[ii], rowKey, timestamp, context.sourceRowSet, + context.bucketKeyValues); } while (affectedIt.hasMore()) { diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingBase.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingBase.java index e35973bc1a0..37ccac96fbe 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingBase.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingBase.java @@ -41,14 +41,16 @@ static class UpdateByWindowRollingBucketContext extends UpdateByWindowBucketCont final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean initialStep) { + final boolean initialStep, + final Object[] bucketKeyValues) { super(sourceRowSet, timestampColumnSource, timestampSsa, timestampValidRowSet, timestampsModified, chunkSize, - initialStep); + initialStep, + bucketKeyValues); } @Override @@ -60,7 +62,7 @@ public void close() { } UpdateByWindowRollingBase(@NotNull final UpdateByOperator[] operators, - @NotNull final int[][] operatorSourceSlots, + final int[][] operatorSourceSlots, final long prevUnits, final long fwdUnits, @Nullable final String timestampColumnName) { @@ -152,7 +154,7 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context, continue; } UpdateByOperator rollingOp = operators[opIdx]; - rollingOp.initializeRolling(winOpContexts[ii], bucketRowSet); + rollingOp.initializeRollingWithKeyValues(winOpContexts[ii], bucketRowSet, context.bucketKeyValues); } int affectedChunkOffset = 0; diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTicks.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTicks.java index 53db9547e5c..417a95bec4f 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTicks.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTicks.java @@ -30,9 +30,12 @@ static class UpdateByWindowTicksBucketContext extends UpdateByWindowRollingBucke private RowSet affectedRowPositions; private RowSet influencerPositions; - UpdateByWindowTicksBucketContext(final TrackingRowSet sourceRowSet, - final int chunkSize, final boolean initialStep) { - super(sourceRowSet, null, null, null, false, chunkSize, initialStep); + UpdateByWindowTicksBucketContext( + final TrackingRowSet sourceRowSet, + final int chunkSize, + final boolean initialStep, + final Object[] bucketKeyValues) { + super(sourceRowSet, null, null, null, false, chunkSize, initialStep, bucketKeyValues); } @Override @@ -77,14 +80,16 @@ void finalizeWindowBucket(UpdateByWindowBucketContext context) { } @Override - UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet sourceRowSet, + UpdateByWindowBucketContext makeWindowContext( + final TrackingRowSet sourceRowSet, final ColumnSource timestampColumnSource, final LongSegmentedSortedArray timestampSsa, final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean isInitializeStep) { - return new UpdateByWindowTicksBucketContext(sourceRowSet, chunkSize, isInitializeStep); + final boolean isInitializeStep, + final Object[] bucketKeyValues) { + return new UpdateByWindowTicksBucketContext(sourceRowSet, chunkSize, isInitializeStep, bucketKeyValues); } private static WritableRowSet computeAffectedRowsTicks(final RowSet sourceSet, final RowSet invertedSubSet, diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTime.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTime.java index ee42b5007b3..d5f7d74f58c 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTime.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByWindowRollingTime.java @@ -39,9 +39,10 @@ public UpdateByWindowTimeBucketContext(final TrackingRowSet sourceRowSet, final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean initialStep) { + final boolean initialStep, + final Object[] bucketKeyValues) { super(sourceRowSet, timestampColumnSource, timestampSsa, timestampValidRowSet, timestampsModified, - chunkSize, initialStep); + chunkSize, initialStep, bucketKeyValues); } } @@ -72,9 +73,10 @@ public UpdateByWindowBucketContext makeWindowContext(final TrackingRowSet source final TrackingRowSet timestampValidRowSet, final boolean timestampsModified, final int chunkSize, - final boolean isInitializeStep) { + final boolean isInitializeStep, + final Object[] bucketKeyValues) { return new UpdateByWindowTimeBucketContext(sourceRowSet, timestampColumnSource, timestampSsa, - timestampValidRowSet, timestampsModified, chunkSize, isInitializeStep); + timestampValidRowSet, timestampsModified, chunkSize, isInitializeStep, bucketKeyValues); } /** diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/ZeroKeyUpdateByManager.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/ZeroKeyUpdateByManager.java index 98f43be8cf7..fd5c70a8d91 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/ZeroKeyUpdateByManager.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/ZeroKeyUpdateByManager.java @@ -3,7 +3,6 @@ // package io.deephaven.engine.table.impl.updateby; -import io.deephaven.UncheckedDeephavenException; import io.deephaven.api.updateby.UpdateByControl; import io.deephaven.engine.exceptions.CancellationException; import io.deephaven.engine.exceptions.TableInitializationException; @@ -14,6 +13,7 @@ import io.deephaven.engine.table.impl.QueryTable; import io.deephaven.engine.table.impl.TableUpdateImpl; import io.deephaven.engine.table.impl.util.RowRedirection; +import io.deephaven.util.type.ArrayTypeUtils; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -70,7 +70,8 @@ public class ZeroKeyUpdateByManager extends UpdateBy { // create an updateby bucket instance directly from the source table zeroKeyUpdateBy = new UpdateByBucketHelper(bucketDescription, source, windows, resultSources, - timestampColumnName, control, (oe, se) -> deliverUpdateError(oe, se, true)); + timestampColumnName, control, (oe, se) -> deliverUpdateError(oe, se, true), + ArrayTypeUtils.EMPTY_OBJECT_ARRAY); buckets.offer(zeroKeyUpdateBy); // make the source->result transformer @@ -88,7 +89,7 @@ public class ZeroKeyUpdateByManager extends UpdateBy { zeroKeyUpdateBy = new UpdateByBucketHelper(bucketDescription, source, windows, resultSources, timestampColumnName, control, (oe, se) -> { throw new IllegalStateException("Update failure from static zero key updateBy"); - }); + }, ArrayTypeUtils.EMPTY_OBJECT_ARRAY); result = zeroKeyUpdateBy.result; buckets.offer(zeroKeyUpdateBy); sourceListener = null; diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingformulamulticolumn/RollingFormulaMultiColumnOperator.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingformulamulticolumn/RollingFormulaMultiColumnOperator.java index a97dc9d1617..ca1effa60e3 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingformulamulticolumn/RollingFormulaMultiColumnOperator.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingformulamulticolumn/RollingFormulaMultiColumnOperator.java @@ -20,6 +20,7 @@ import io.deephaven.engine.table.impl.util.ChunkUtils; import io.deephaven.engine.table.impl.util.RowRedirection; import io.deephaven.vector.Vector; +import org.apache.commons.lang3.ArrayUtils; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -32,9 +33,14 @@ public class RollingFormulaMultiColumnOperator extends UpdateByOperator { private static final int BUFFER_INITIAL_CAPACITY = 512; private final SelectColumn selectColumn; - private final String[] inputColumnNames; - private final Class[] inputColumnTypes; - private final Class[] inputVectorTypes; + + private final String[] inputKeyColumnNames; + private final Class[] inputKeyColumnTypes; + private final Class[] inputKeyComponentTypes; + + private final String[] inputNonKeyColumnNames; + private final Class[] inputNonKeyColumnTypes; + private final Class[] inputNonKeyVectorTypes; private WritableColumnSource primitiveOutputSource; private WritableColumnSource outputSource; @@ -48,6 +54,8 @@ private class Context extends UpdateByOperator.Context { private final IntConsumer outputSetter; private final IntConsumer outputNullSetter; + @SuppressWarnings("rawtypes") + private final SingleValueColumnSource[] keyValueSources; private final RingBufferWindowConsumer[] inputConsumers; @SuppressWarnings("unused") @@ -58,13 +66,14 @@ private Context(final int affectedChunkSize, final int influencerChunkSize) { // Make a copy of the operator formula column. final SelectColumn contextSelectColumn = selectColumn.copy(); - inputConsumers = new RingBufferWindowConsumer[inputColumnNames.length]; + keyValueSources = new SingleValueColumnSource[inputKeyColumnNames.length]; + inputConsumers = new RingBufferWindowConsumer[inputNonKeyColumnNames.length]; // To perform the calculation, we will leverage SelectColumn and for its input sources we create a set of - // SingleValueColumnSources, each containing a Vector of values. This vector will contain exactly the - // values from the input columns that are appropriate for output row given the window configuration. - // The formula column is evaluated once per output row and the result written to the output column - // source. + // SingleValueColumnSources, each containing a Vector of values (or a scalar when the source is a key). + // These sources will contain exactly the values from the input columns that are appropriate for the output + // row given the window configuration and state. The formula column is evaluated once per output row and + // the result written to the output column source. // The SingleValueColumnSources is backed by RingBuffers through use of a RingBufferVectorWrapper. // The underlying RingBuffer is updated with the values from the input columns with assistance from @@ -72,10 +81,22 @@ private Context(final int affectedChunkSize, final int influencerChunkSize) { // column data chunks into the RingBuffer. final Map> inputSources = new HashMap<>(); - for (int i = 0; i < inputColumnNames.length; i++) { - final String inputColumnName = inputColumnNames[i]; - final Class inputColumnType = inputColumnTypes[i]; - final Class inputVectorType = inputVectorTypes[i]; + + for (int i = 0; i < inputKeyColumnNames.length; i++) { + final String inputColumnName = inputKeyColumnNames[i]; + final Class inputColumnType = inputKeyColumnTypes[i]; + final Class inputComponentType = inputKeyComponentTypes[i]; + + // Create a single value column source wrapping a scalar of the appropriate type for this key. + keyValueSources[i] = + SingleValueColumnSource.getSingleValueColumnSource(inputColumnType, inputComponentType); + inputSources.put(inputColumnName, keyValueSources[i]); + } + + for (int i = 0; i < inputNonKeyColumnNames.length; i++) { + final String inputColumnName = inputNonKeyColumnNames[i]; + final Class inputColumnType = inputNonKeyColumnTypes[i]; + final Class inputVectorType = inputNonKeyVectorTypes[i]; // Create and store the ring buffer for the input column. final RingBuffer ringBuffer = RingBuffer.makeRingBuffer( @@ -88,19 +109,15 @@ private Context(final int affectedChunkSize, final int influencerChunkSize) { final SingleValueColumnSource> formulaInputSource = (SingleValueColumnSource>) SingleValueColumnSource .getSingleValueColumnSource(inputVectorType, inputColumnType); - final RingBufferVectorWrapper wrapper = RingBufferVectorWrapper.makeRingBufferVectorWrapper( ringBuffer, inputColumnType); - formulaInputSource.set(wrapper); - inputSources.put(inputColumnName, formulaInputSource); - inputConsumers[i] = RingBufferWindowConsumer.create(ringBuffer); } - contextSelectColumn.initInputs(RowSetFactory.flat(1).toTracking(), inputSources); + contextSelectColumn.initInputs(RowSetFactory.flat(1).toTracking(), inputSources); final ColumnSource formulaOutputSource = ReinterpretUtils.maybeConvertToPrimitive(contextSelectColumn.getDataView()); @@ -110,8 +127,7 @@ private Context(final int affectedChunkSize, final int influencerChunkSize) { @Override protected void setValueChunks(@NotNull Chunk[] valueChunks) { - // Assign the influencer values chunks to the input consumers. - for (int i = 0; i < valueChunks.length; i++) { + for (int i = 0; i < inputConsumers.length; i++) { inputConsumers[i].setInputChunk(valueChunks[i]); } } @@ -203,6 +219,13 @@ public void close() { outputValues.close(); outputFillContext.close(); } + + private void setBucketKeyValues(final Object[] bucketKeyValues) { + for (int i = 0; i < keyValueSources.length; i++) { + // noinspection unchecked + keyValueSources[i].set(bucketKeyValues[i]); + } + } } /** @@ -215,9 +238,12 @@ public void close() { * @param reverseWindowScaleUnits The size of the reverse window in ticks (or nanoseconds when time-based) * @param forwardWindowScaleUnits The size of the forward window in ticks (or nanoseconds when time-based) * @param selectColumn The {@link SelectColumn} specifying the calculation to be performed - * @param inputColumnNames The names of the columns to be used as inputs - * @param inputColumnTypes The types of the columns to be used as inputs - * @param inputVectorTypes The vector types of the columns to be used as inputs + * @param inputKeyColumnNames The names of the key columns to be used as inputs + * @param inputKeyColumnTypes The types of the key columns to be used as inputs + * @param inputKeyComponentTypes The component types of the key columns to be used as inputs + * @param inputNonKeyColumnNames The names of the non-key columns to be used as inputs + * @param inputNonKeyColumnTypes The types of the non-key columns to be used as inputs + * @param inputNonKeyVectorTypes The vector types of the non-key columns to be used as inputs */ public RollingFormulaMultiColumnOperator( @NotNull final MatchPair pair, @@ -226,14 +252,20 @@ public RollingFormulaMultiColumnOperator( final long reverseWindowScaleUnits, final long forwardWindowScaleUnits, @NotNull final SelectColumn selectColumn, - @NotNull final String[] inputColumnNames, - @NotNull final Class[] inputColumnTypes, - @NotNull final Class[] inputVectorTypes) { + @NotNull final String[] inputKeyColumnNames, + @NotNull final Class[] inputKeyColumnTypes, + @NotNull final Class[] inputKeyComponentTypes, + @NotNull final String[] inputNonKeyColumnNames, + @NotNull final Class[] inputNonKeyColumnTypes, + @NotNull final Class[] inputNonKeyVectorTypes) { super(pair, affectingColumns, timestampColumnName, reverseWindowScaleUnits, forwardWindowScaleUnits, true); this.selectColumn = selectColumn; - this.inputColumnNames = inputColumnNames; - this.inputColumnTypes = inputColumnTypes; - this.inputVectorTypes = inputVectorTypes; + this.inputKeyColumnNames = inputKeyColumnNames; + this.inputKeyColumnTypes = inputKeyColumnTypes; + this.inputKeyComponentTypes = inputKeyComponentTypes; + this.inputNonKeyColumnNames = inputNonKeyColumnNames; + this.inputNonKeyColumnTypes = inputNonKeyColumnTypes; + this.inputNonKeyVectorTypes = inputNonKeyVectorTypes; } @Override @@ -245,9 +277,12 @@ public UpdateByOperator copy() { reverseWindowScaleUnits, forwardWindowScaleUnits, selectColumn, - inputColumnNames, - inputColumnTypes, - inputVectorTypes); + inputKeyColumnNames, + inputKeyColumnTypes, + inputKeyComponentTypes, + inputNonKeyColumnNames, + inputNonKeyColumnTypes, + inputNonKeyVectorTypes); } @Override @@ -273,6 +308,7 @@ public void initializeSources(@NotNull final Table source, @Nullable final RowRe outputChunkType = primitiveOutputSource.getChunkType(); } + // region value-setters protected static IntConsumer getChunkSetter( final WritableChunk valueChunk, final ColumnSource formulaOutputSource) { @@ -363,6 +399,7 @@ protected static IntConsumer getChunkNullSetter(final WritableChunk objectChunk.set(index, null); } } + // endregion value-setters @Override public void startTrackingPrev() { @@ -389,6 +426,17 @@ public void prepareForParallelPopulation(final RowSet changedRows) { } } + @Override + public void initializeRollingWithKeyValues( + @NotNull final UpdateByOperator.Context context, + @NotNull final RowSet bucketRowSet, + @NotNull Object[] bucketKeyValues) { + super.initializeRollingWithKeyValues(context, bucketRowSet, bucketKeyValues); + + final Context rollingContext = (Context) context; + rollingContext.setBucketKeyValues(bucketKeyValues); + } + @NotNull @Override public Map> getOutputColumns() { @@ -414,6 +462,6 @@ public void applyOutputShift(@NotNull final RowSet subRowSetToShift, final long @Override @NotNull protected String[] getInputColumnNames() { - return inputColumnNames; + return ArrayUtils.addAll(inputNonKeyColumnNames); } } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/updateby/TestRollingFormula.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/updateby/TestRollingFormula.java index 314b96faf37..bdf95beb916 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/updateby/TestRollingFormula.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/updateby/TestRollingFormula.java @@ -797,6 +797,22 @@ private void doTestStaticBucketed(boolean grouped, int prevTicks, int postTicks) TstUtils.assertTableEquals(expected, actual, TableDiff.DiffItems.DoublesExact); + // using the key column + actual = t.updateBy(UpdateByOperation.RollingFormula(prevTicks, postTicks, + "out_val=sum(intCol) - max(longCol) + (Sym == null ? 0 : Sym.length())"), "Sym"); + expected = t.updateBy(UpdateByOperation.RollingGroup(prevTicks, postTicks, "a=intCol", "b=longCol"), "Sym") + .update("out_val=sum(a) - max(b) + (Sym == null ? 0 : Sym.length())").dropColumns("a", "b"); + + TstUtils.assertTableEquals(expected, actual, TableDiff.DiffItems.DoublesExact); + + // using the byte key column + actual = t.updateBy(UpdateByOperation.RollingFormula(prevTicks, postTicks, + "out_val=byteCol == null ? -1 : byteCol + sum(intCol) - max(longCol)"), "byteCol"); + expected = + t.updateBy(UpdateByOperation.RollingGroup(prevTicks, postTicks, "a=intCol", "b=longCol"), "byteCol") + .update("out_val=byteCol == null ? -1 : byteCol + sum(a) - max(b)").dropColumns("a", "b"); + + TstUtils.assertTableEquals(expected, actual, TableDiff.DiffItems.DoublesExact); } private void doTestStaticBucketedTimed(boolean grouped, Duration prevTime, Duration postTime) { @@ -965,6 +981,23 @@ private void doTestStaticBucketedTimed(boolean grouped, Duration prevTime, Durat .update("out_val=a + b").dropColumns("a", "b"); TstUtils.assertTableEquals(expected, actual, TableDiff.DiffItems.DoublesExact); + + // using the key column + actual = t.updateBy(UpdateByOperation.RollingFormula("ts", prevTime, postTime, + "out_val=(Sym == null ? 0 : Sym.length()) + sum(intCol) - max(longCol)"), "Sym"); + expected = t.updateBy(UpdateByOperation.RollingGroup("ts", prevTime, postTime, "a=intCol", "b=longCol"), "Sym") + .update("out_val=(Sym == null ? 0 : Sym.length()) + sum(a) - max(b)").dropColumns("a", "b"); + + TstUtils.assertTableEquals(expected, actual, TableDiff.DiffItems.DoublesExact); + + // using the byte key column + actual = t.updateBy(UpdateByOperation.RollingFormula("ts", prevTime, postTime, + "out_val=byteCol == null ? -1 : byteCol + sum(intCol) - max(longCol)"), "byteCol"); + expected = + t.updateBy(UpdateByOperation.RollingGroup("ts", prevTime, postTime, "a=intCol", "b=longCol"), "byteCol") + .update("out_val=byteCol == null ? -1 : byteCol + sum(a) - max(b)").dropColumns("a", "b"); + + TstUtils.assertTableEquals(expected, actual, TableDiff.DiffItems.DoublesExact); } // endregion @@ -1172,6 +1205,13 @@ private void doTestAppendOnly(boolean bucketed, int prevTicks, int postTicks) { "Sym") : t.updateBy(UpdateByOperation.RollingFormula(prevTicks, postTicks, "sum=min(intCol) + min(longCol) * max(doubleCol)"))), + EvalNugget.from(() -> bucketed + ? t.updateBy( + UpdateByOperation.RollingFormula(prevTicks, postTicks, + "out_col=Sym == null ? 0.0 : min(intCol) + min(longCol) * max(doubleCol)"), + "Sym") + : t.updateBy(UpdateByOperation.RollingFormula(prevTicks, postTicks, + "out_col=true"))), // This is a dummy test, we care about the bucketed test }; final Random billy = new Random(0xB177B177L); @@ -1234,6 +1274,13 @@ private void doTestAppendOnlyTimed(boolean bucketed, Duration prevTime, Duration "Sym") : t.updateBy(UpdateByOperation.RollingFormula("ts", prevTime, postTime, "sum=min(intCol) + min(longCol) * max(doubleCol)"))), + EvalNugget.from(() -> bucketed + ? t.updateBy( + UpdateByOperation.RollingFormula("ts", prevTime, postTime, + "out_col=Sym == null ? null : min(intCol) + min(longCol) * max(doubleCol)"), + "Sym") + : t.updateBy(UpdateByOperation.RollingFormula("ts", prevTime, postTime, + "out_col=true"))), // This is a dummy test, we care about the bucketed test }; final Random billy = new Random(0xB177B177L);