Skip to content

Commit

Permalink
merge in upstream branch_3.3.4_08fd66 with hll fix from StarRocks#52540
Browse files Browse the repository at this point in the history
  • Loading branch information
ctbrennan committed Nov 5, 2024
2 parents edf06cd + 00491fa commit 76facf7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
2 changes: 1 addition & 1 deletion fe/fe-core/src/main/java/com/starrocks/catalog/Column.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public Column(Column column) {
this.name = column.getName();
this.columnId = column.getColumnId();
this.type = column.type;
this.type.setAggStateDesc(this.aggStateDesc);
this.type.setAggStateDesc(column.aggStateDesc);
this.aggregationType = column.getAggregationType();
this.isAggregationTypeImplicit = column.isAggregationTypeImplicit();
this.isKey = column.isKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ public AggStateUnionCombinator(AggStateUnionCombinator other) {

public static Optional<AggStateUnionCombinator> of(AggregateFunction aggFunc) {
try {
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType();
Type intermediateType = aggFunc.getIntermediateTypeOrReturnType().clone();
FunctionName functionName = new FunctionName(aggFunc.functionName() + FunctionSet.AGG_STATE_UNION_SUFFIX);
AggStateUnionCombinator aggStateUnionFunc =
new AggStateUnionCombinator(functionName, intermediateType);
AggStateUnionCombinator aggStateUnionFunc = new AggStateUnionCombinator(functionName, intermediateType);
aggStateUnionFunc.setBinaryType(TFunctionBinaryType.BUILTIN);
aggStateUnionFunc.setPolymorphic(aggFunc.isPolymorphic());
AggStateDesc aggStateDesc;
Expand All @@ -63,6 +62,8 @@ public static Optional<AggStateUnionCombinator> of(AggregateFunction aggFunc) {
aggStateDesc = new AggStateDesc(aggFunc);
}
aggStateUnionFunc.setAggStateDesc(aggStateDesc);
// set agg state desc for the function's result type so can be used as the later agg state functions.
intermediateType.setAggStateDesc(aggStateDesc);
// use agg state desc's nullable as `agg_state` function's nullable
aggStateUnionFunc.setIsNullable(aggStateDesc.getResultNullable());
LOG.info("Register agg state function: {}", aggStateUnionFunc.functionName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.starrocks.catalog.Type;
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.analyzer.FunctionAnalyzer;
import com.starrocks.sql.analyzer.SemanticException;
import com.starrocks.sql.parser.NodePosition;

import java.util.List;
Expand Down Expand Up @@ -165,13 +166,13 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session,
result = AggStateCombinator.of(aggFunc);
}
} else if (func instanceof AggStateUnionCombinator) {
AggregateFunction argFn = getAggStateFunction(session, argumentTypes, pos);
AggregateFunction argFn = getAggStateFunction(session, func, argumentTypes, pos);
if (argFn == null) {
return null;
}
result = AggStateUnionCombinator.of(argFn);
} else if (func instanceof AggStateMergeCombinator) {
AggregateFunction argFn = getAggStateFunction(session, argumentTypes, pos);
AggregateFunction argFn = getAggStateFunction(session, func, argumentTypes, pos);
if (argFn == null) {
return null;
}
Expand All @@ -185,14 +186,16 @@ public static Function getAnalyzedCombinatorFunction(ConnectContext session,
}

private static AggregateFunction getAggStateFunction(ConnectContext session,
Function inputFunc,
Type[] argumentTypes,
NodePosition pos) {
Preconditions.checkArgument(argumentTypes.length == 1,
"AggState's AggFunc should have only one argument");
Type arg0Type = argumentTypes[0];
Preconditions.checkArgument(arg0Type.getAggStateDesc() != null,
String.format("AggState's agg state desc is null"));

if (arg0Type.getAggStateDesc() == null) {
throw new SemanticException(String.format("AggState's AggFunc should have AggStateDesc: %s",
inputFunc), pos);
}
AggStateDesc aggStateDesc = arg0Type.getAggStateDesc();
List<Type> argTypes = aggStateDesc.getArgTypes();
String argFnName = aggStateDesc.getFunctionName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5609,6 +5609,25 @@ public void testCreateMVWithAggStateRewrite() throws Exception {
" REFRESH DEFERRED MANUAL " +
"as \n" +
"SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1, k2;");
{
String query = "select k1, k2, avg_merge(v1) from (" +
"SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1,k2) t " +
"group by k1, k2;";
String plan = UtFrameUtils.getFragmentPlan(connectContext, query);
PlanTestBase.assertContains(plan, "test_mv1");
}
{
String query = "select k1, k2, avg_merge(v1) from (" +
" SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1,k2 " +
" UNION ALL" +
" SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1,k2 " +
") t " +
"group by k1, k2;";
String plan = UtFrameUtils.getFragmentPlan(connectContext, query);
PlanTestBase.assertContains(plan, "test_mv1");
}


{
String query = "SELECT k1, k2, avg_union(avg_state(k3 * 4)) as v1 from s1 where k1 != 'a' group by k1, k2;";
String plan = UtFrameUtils.getFragmentPlan(connectContext, query);
Expand Down

0 comments on commit 76facf7

Please sign in to comment.