diff --git a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java index 5023670b..dc9cdbb1 100644 --- a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java +++ b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/SlidingWindowUtils.java @@ -224,13 +224,14 @@ public SlidingWindowPreprocessAggregateFunction( @Override public Row createAccumulator() { - Row acc = Row.withNames(); + int arity = keyFields.size() + 1 + aggDescriptors.getAggFieldDescriptors().size(); + Object[] values = new Object[arity]; for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor : aggDescriptors.getAggFieldDescriptors()) { - acc.setField( - descriptor.fieldName, descriptor.aggFuncWithoutRetract.createAccumulator()); + int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor); + values[pos] = descriptor.aggFuncWithoutRetract.createAccumulator(); } - return acc; + return Row.of(values); } @Override @@ -239,16 +240,19 @@ public Row add(Row row, Row acc) { for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor : aggDescriptors.getAggFieldDescriptors()) { Object fieldValue = row.getFieldAs(descriptor.fieldName); - Object fieldAcc = acc.getField(descriptor.fieldName); + int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor); + Object fieldAcc = acc.getFieldAs(pos); descriptor.aggFuncWithoutRetract.add(fieldAcc, fieldValue, timestamp); } - if (acc.getField(rowTimeFieldName) == null) { + if (acc.getField(keyFields.size()) == null) { acc.setField( - rowTimeFieldName, + keyFields.size(), Instant.ofEpochMilli(getWindowTime(timestamp, size, offset))); + int idx = 0; for (String key : keyFields) { - acc.setField(key, row.getField(key)); + acc.setField(idx, row.getField(key)); + idx += 1; } } @@ -264,14 +268,17 @@ public Row getResult(Row acc) { public Row merge(Row acc1, Row acc2) { for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor : aggDescriptors.getAggFieldDescriptors()) { - Object fieldAcc1 = acc1.getField(descriptor.fieldName); - Object fieldAcc2 = acc2.getField(descriptor.fieldName); + int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor); + Object fieldAcc1 = acc1.getField(pos); + Object fieldAcc2 = acc2.getField(pos); descriptor.aggFuncWithoutRetract.merge(fieldAcc1, fieldAcc2); } - if (acc1.getField(rowTimeFieldName) == null) { - acc1.setField(rowTimeFieldName, acc2.getField(rowTimeFieldName)); + if (acc1.getField(keyFields.size()) == null) { + acc1.setField(keyFields.size(), acc2.getField(keyFields.size())); + int idx = 0; for (String key : keyFields) { - acc1.setField(key, acc2.getField(key)); + acc1.setField(idx, acc2.getField(idx)); + idx += 1; } } return acc1; @@ -358,11 +365,19 @@ public static Table applySlidingWindowAggregationProcess( rowDataStream .keyBy( (KeySelector) - value -> - Row.of( - Arrays.stream(keyFieldNames) - .map(value::getField) - .toArray())) + row -> { + List values = new ArrayList<>(); + for (int i = 0; i < keyFieldNames.length; i += 1) { + Object value; + try { + value = row.getField(i); + } catch (IllegalArgumentException e) { + value = row.getField(keyFieldNames[i]); + } + values.add(value); + } + return Row.of(values.toArray(new Object[0])); + }) .process( new SlidingWindowKeyedProcessFunction( aggregationFieldsDescriptor, diff --git a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java index 9124a534..3ee66ce2 100644 --- a/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java +++ b/java/feathub-udf/flink-udf/src/main/java/com/alibaba/feathub/flink/udf/processfunction/SlidingWindowKeyedProcessFunction.java @@ -237,8 +237,17 @@ public void onTimer( break; } for (Row row : state.timestampToRows.get(rowTime)) { - descriptor.aggFunc.retractAccumulator( - accumulatorState, row.getField(descriptor.fieldName)); + Object value; + try { + int idx = + keyFieldNames.length + + 1 + + aggregationFieldsDescriptor.getAggFieldIdx(descriptor); + value = row.getField(idx); + } catch (IllegalArgumentException e) { + value = row.getField(descriptor.fieldName); + } + descriptor.aggFunc.retractAccumulator(accumulatorState, value); } } if (leftIdx < timestampList.size() && timestampList.get(leftIdx) <= timestamp) {