Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Hoist heavy-cost(decimal divide) upon top-n (backport #55417) #55625

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import com.starrocks.sql.optimizer.rule.transformation.EliminateConstantCTERule;
import com.starrocks.sql.optimizer.rule.transformation.ForceCTEReuseRule;
import com.starrocks.sql.optimizer.rule.transformation.GroupByCountDistinctRewriteRule;
import com.starrocks.sql.optimizer.rule.transformation.HoistHeavyCostExprsUponTopnRule;
import com.starrocks.sql.optimizer.rule.transformation.JoinLeftAsscomRule;
import com.starrocks.sql.optimizer.rule.transformation.MaterializedViewTransparentRewriteRule;
import com.starrocks.sql.optimizer.rule.transformation.MergeProjectWithChildRule;
Expand Down Expand Up @@ -554,7 +555,9 @@ private OptExpression logicalRuleRewrite(
// Limit push must be after the column prune,
// otherwise the Node containing limit may be prune
ruleRewriteIterative(tree, rootTaskContext, RuleSetType.MERGE_LIMIT);
ruleRewriteIterative(tree, rootTaskContext, new HoistHeavyCostExprsUponTopnRule());
ruleRewriteIterative(tree, rootTaskContext, new PushDownProjectLimitRule());
ruleRewriteIterative(tree, rootTaskContext, new HoistHeavyCostExprsUponTopnRule());

ruleRewriteOnlyOnce(tree, rootTaskContext, new PushDownLimitRankingWindowRule());
rewriteGroupingSets(tree, rootTaskContext, sessionVariable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ public enum RuleType {

TF_CTE_ADD_PROJECTION,

TF_HOIST_HEAVY_COST_UPON_TOPN,

// The following are implementation rules:
IMP_OLAP_LSCAN_TO_PSCAN,
IMP_HIVE_LSCAN_TO_PSCAN,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.optimizer.rule.transformation;

import com.google.common.collect.Lists;
import com.starrocks.analysis.ArithmeticExpr;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.base.Ordering;
import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalTopNOperator;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rule.RuleType;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class HoistHeavyCostExprsUponTopnRule extends TransformationRule {
public HoistHeavyCostExprsUponTopnRule() {
super(RuleType.TF_HOIST_HEAVY_COST_UPON_TOPN,
Pattern.create(OperatorType.LOGICAL_TOPN, OperatorType.LOGICAL_PROJECT));
}

private boolean isHeavyCost(ScalarOperator op) {
if (op instanceof CallOperator) {
CallOperator call = op.cast();
if (call.getFnName().equals(ArithmeticExpr.Operator.DIVIDE.getName()) && (
call.getType().getPrimitiveType().equals(PrimitiveType.LARGEINT) ||
call.getType().getPrimitiveType().equals(PrimitiveType.DECIMAL128))) {
return true;
}
}
return op.getChildren().stream().anyMatch(this::isHeavyCost);
}

@Override
public boolean check(OptExpression input, OptimizerContext context) {
LogicalTopNOperator topnOp = input.getOp().cast();
if (topnOp.getPartitionByColumns() != null && !topnOp.getPartitionByColumns().isEmpty()) {
return false;
}
if (topnOp.getPredicate() != null || topnOp.getLimit() < 0) {
return false;
}
OptExpression child = input.inputAt(0);
Set<ColumnRefOperator> heavyCostColumnRefs = child.getRowOutputInfo().getColumnRefMap().entrySet()
.stream()
.filter(e -> isHeavyCost(e.getValue()))
.map(Map.Entry::getKey)
.collect(Collectors.toSet());

if (heavyCostColumnRefs.isEmpty()) {
return false;
}

boolean isUsedByPredicate = Optional.ofNullable(child.getOp().getPredicate())
.map(pred -> pred.getUsedColumns().containsAny(heavyCostColumnRefs))
.orElse(false);

if (isUsedByPredicate) {
return false;
}

Set<ColumnRefOperator> orderByColumnRefs = topnOp.getOrderByElements().stream()
.map(Ordering::getColumnRef)
.collect(Collectors.toSet());

ColumnRefSet heavyColumnUsedAsOrderBy = ColumnRefSet.of();
heavyColumnUsedAsOrderBy.union(orderByColumnRefs);
ColumnRefSet heavyCostColumnRefSet = ColumnRefSet.of();
heavyCostColumnRefSet.union(heavyCostColumnRefs);
heavyColumnUsedAsOrderBy.intersect(heavyCostColumnRefSet);
return heavyColumnUsedAsOrderBy.isEmpty();
}

@Override
public List<OptExpression> transform(OptExpression input, OptimizerContext context) {
LogicalTopNOperator topnOp = input.getOp().cast();
OptExpression child = input.inputAt(0);
Map<Boolean, Map<ColumnRefOperator, ScalarOperator>> columnRefMaps =
child.getRowOutputInfo().getColumnRefMap().entrySet()
.stream()
.collect(Collectors.partitioningBy(e -> isHeavyCost(e.getValue()),
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));

Map<ColumnRefOperator, ScalarOperator> heavyCostColumnRefMap = columnRefMaps.get(true);
Map<ColumnRefOperator, ScalarOperator> childColumnRefMap = columnRefMaps.get(false);

Set<ColumnRefOperator> usedColumnRefs = heavyCostColumnRefMap.values().stream()
.map(ScalarOperator::getColumnRefs)
.flatMap(Collection::stream)
.collect(Collectors.toSet());

if (!childColumnRefMap.keySet().containsAll(usedColumnRefs)) {
return Collections.emptyList();
}

Map<ColumnRefOperator, ScalarOperator> topnColumnRefMap = input.getRowOutputInfo().getColumnRefMap();
topnColumnRefMap.putAll(columnRefMaps.get(true));
LogicalProjectOperator projectOp = child.getOp().cast();

Operator newProjectOp = LogicalProjectOperator.builder().withOperator(projectOp)
.setColumnRefMap(childColumnRefMap)
.build();

OptExpression newChild = OptExpression.builder().setOp(newProjectOp).setInputs(child.getInputs()).build();
Operator upperProjectOp = LogicalProjectOperator.builder()
.setColumnRefMap(topnColumnRefMap).build();

OptExpression newTopn =
OptExpression.builder().setOp(topnOp).setInputs(Lists.newArrayList(newChild)).build();

return Collections.singletonList(OptExpression.create(upperProjectOp, newTopn));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.planner;

import com.starrocks.sql.plan.PlanTestNoneDBBase;
import com.starrocks.utframe.UtFrameUtils;
import org.junit.BeforeClass;
import org.junit.Test;

public class HoistHeavyCostExprsUponTopnTest extends PlanTestNoneDBBase {
@BeforeClass
public static void beforeClass() throws Exception {
PlanTestNoneDBBase.beforeClass();
starRocksAssert.withDatabase("test_db0").useDatabase("test_db0");
String createTableSql = "CREATE TABLE t0 (\n" +
" EventDate DATE NOT NULL,\n" +
" UserID STRING NOT NULL,\n" +
" M0 DECIMAL(20,2),\n" +
" M1 DECIMAL(20,2),\n" +
" M2 LARGEINT,\n" +
" M3 LARGEINT\n" +
") \n" +
"DUPLICATE KEY (EventDate)\n" +
"DISTRIBUTED BY HASH(UserID) BUCKETS 1\n" +
"PROPERTIES ( \"replication_num\"=\"1\");";
starRocksAssert.withTable(createTableSql);
}

@Test
public void testDecimalDivExprHoisted() throws Exception {
String sql = "select M1, M0, M0/M1\n" +
"from t0\n" +
"order by EventDate\n" +
"limit 10 offset 20";
String plan = UtFrameUtils.getVerboseFragmentPlan(connectContext, sql);
assertCContains(plan, " 3:Project\n" +
" | output columns:\n" +
" | 3 <-> [3: M0, DECIMAL128(20,2), true]\n" +
" | 4 <-> [4: M1, DECIMAL128(20,2), true]\n" +
" | 7 <-> [3: M0, DECIMAL128(20,2), true] / [4: M1, DECIMAL128(20,2), true]\n" +
" | limit: 10\n" +
" | cardinality: 1\n" +
" | \n" +
" 2:MERGING-EXCHANGE");
}

@Test
public void testDecimalDivExprNotHoisted() throws Exception {
String sql = "select M1, M0/M1\n" +
"from t0\n" +
"order by EventDate\n" +
"limit 10 offset 20";
String plan = UtFrameUtils.getVerboseFragmentPlan(connectContext, sql);
assertCContains(plan, " 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: EventDate, DATE, false]\n" +
" | 4 <-> [4: M1, DECIMAL128(20,2), true]\n" +
" | 7 <-> [3: M0, DECIMAL128(20,2), true] / [4: M1, DECIMAL128(20,2), true]\n" +
" | cardinality: 1\n" +
" | \n" +
" 0:OlapScanNode");
}

@Test
public void testLargeInt128DivHoisted() throws Exception {
String sql = "select M2, M3, M2/M3\n" +
"from t0\n" +
"order by EventDate\n" +
"limit 10 offset 20";
String plan = UtFrameUtils.getVerboseFragmentPlan(connectContext, sql);
assertCContains(plan, " 1:Project\n" +
" | output columns:\n" +
" | 1 <-> [1: EventDate, DATE, false]\n" +
" | 5 <-> [5: M2, LARGEINT, true]\n" +
" | 6 <-> [6: M3, LARGEINT, true]\n" +
" | 7 <-> cast([5: M2, LARGEINT, true] as DOUBLE) / cast([6: M3, LARGEINT, true] as DOUBLE)\n" +
" | cardinality: 1\n" +
" | \n" +
" 0:OlapScanNode");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ public void testQuery58() throws Exception {
connectContext.getSessionVariable().setCboPushDownAggregateMode(1);
String sql = getTPCDS("Q58");
String plan = getCostExplain(sql);
assertContains(plan, "|----5:EXCHANGE\n" +
assertContains(plan, " |----5:EXCHANGE\n" +
" | distribution type: BROADCAST\n" +
" | cardinality: 73049\n" +
" | probe runtime filters:\n" +
" | - filter_id = 3, probe_expr = (48: d_date)");
" | - filter_id = 3, probe_expr = (334: d_date)");
}

// @ParameterizedTest(name = "{0}")
Expand Down
Loading