diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/Column.java b/fe/fe-core/src/main/java/com/starrocks/catalog/Column.java index 5252bca2b9203f..23c72b6e303016 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/Column.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/Column.java @@ -232,7 +232,7 @@ public Column(Column column) { this.name = column.getName(); this.columnId = column.getColumnId(); this.type = column.type; - this.type.setAggStateDesc(this.aggStateDesc); + this.type.setAggStateDesc(column.aggStateDesc); this.aggregationType = column.getAggregationType(); this.isAggregationTypeImplicit = column.isAggregationTypeImplicit(); this.isKey = column.isKey(); diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUnionCombinator.java b/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUnionCombinator.java index f6eb6968c57c08..daf1045bc85cae 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUnionCombinator.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUnionCombinator.java @@ -50,10 +50,9 @@ public AggStateUnionCombinator(AggStateUnionCombinator other) { public static Optional of(AggregateFunction aggFunc) { try { - Type intermediateType = aggFunc.getIntermediateTypeOrReturnType(); + Type intermediateType = aggFunc.getIntermediateTypeOrReturnType().clone(); FunctionName functionName = new FunctionName(aggFunc.functionName() + FunctionSet.AGG_STATE_UNION_SUFFIX); - AggStateUnionCombinator aggStateUnionFunc = - new AggStateUnionCombinator(functionName, intermediateType); + AggStateUnionCombinator aggStateUnionFunc = new AggStateUnionCombinator(functionName, intermediateType); aggStateUnionFunc.setBinaryType(TFunctionBinaryType.BUILTIN); aggStateUnionFunc.setPolymorphic(aggFunc.isPolymorphic()); AggStateDesc aggStateDesc; @@ -63,6 +62,8 @@ public static Optional of(AggregateFunction aggFunc) { aggStateDesc = new AggStateDesc(aggFunc); } aggStateUnionFunc.setAggStateDesc(aggStateDesc); + // set agg state desc for the function's result type so can be used as the later agg state functions. + intermediateType.setAggStateDesc(aggStateDesc); // use agg state desc's nullable as `agg_state` function's nullable aggStateUnionFunc.setIsNullable(aggStateDesc.getResultNullable()); LOG.info("Register agg state function: {}", aggStateUnionFunc.functionName()); diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUtils.java b/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUtils.java index 1f0b88fd297613..02b91c1b09ccd9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUtils.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateUtils.java @@ -25,6 +25,7 @@ import com.starrocks.catalog.Type; import com.starrocks.qe.ConnectContext; import com.starrocks.sql.analyzer.FunctionAnalyzer; +import com.starrocks.sql.analyzer.SemanticException; import com.starrocks.sql.parser.NodePosition; import java.util.List; @@ -165,13 +166,13 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session, result = AggStateCombinator.of(aggFunc); } } else if (func instanceof AggStateUnionCombinator) { - AggregateFunction argFn = getAggStateFunction(session, argumentTypes, pos); + AggregateFunction argFn = getAggStateFunction(session, func, argumentTypes, pos); if (argFn == null) { return null; } result = AggStateUnionCombinator.of(argFn); } else if (func instanceof AggStateMergeCombinator) { - AggregateFunction argFn = getAggStateFunction(session, argumentTypes, pos); + AggregateFunction argFn = getAggStateFunction(session, func, argumentTypes, pos); if (argFn == null) { return null; } @@ -185,14 +186,16 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session, } private static AggregateFunction getAggStateFunction(ConnectContext session, + Function inputFunc, Type[] argumentTypes, NodePosition pos) { Preconditions.checkArgument(argumentTypes.length == 1, "AggState's AggFunc should have only one argument"); Type arg0Type = argumentTypes[0]; - Preconditions.checkArgument(arg0Type.getAggStateDesc() != null, - String.format("AggState's agg state desc is null")); - + if (arg0Type.getAggStateDesc() == null) { + throw new SemanticException(String.format("AggState's AggFunc should have AggStateDesc: %s", + inputFunc), pos); + } AggStateDesc aggStateDesc = arg0Type.getAggStateDesc(); List argTypes = aggStateDesc.getArgTypes(); String argFnName = aggStateDesc.getFunctionName(); diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java index 509c688d974a40..173efc4e0ebe46 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java @@ -5609,6 +5609,25 @@ public void testCreateMVWithAggStateRewrite() throws Exception { " REFRESH DEFERRED MANUAL " + "as \n" + "SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1, k2;"); + { + String query = "select k1, k2, avg_merge(v1) from (" + + "SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1,k2) t " + + "group by k1, k2;"; + String plan = UtFrameUtils.getFragmentPlan(connectContext, query); + PlanTestBase.assertContains(plan, "test_mv1"); + } + { + String query = "select k1, k2, avg_merge(v1) from (" + + " SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1,k2 " + + " UNION ALL" + + " SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1,k2 " + + ") t " + + "group by k1, k2;"; + String plan = UtFrameUtils.getFragmentPlan(connectContext, query); + PlanTestBase.assertContains(plan, "test_mv1"); + } + + { String query = "SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1, k2;"; String plan = UtFrameUtils.getFragmentPlan(connectContext, query);