Skip to content

Commit

Permalink
feat: support join on or predicates rewrite to Union All
Browse files Browse the repository at this point in the history
Signed-off-by: Kould <[email protected]>
  • Loading branch information
KKould committed Feb 4, 2025
1 parent 4b2e004 commit c98645b
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ private OptExpression logicalRuleRewrite(
// apply skew join optimize after push down join on expression to child project,
// we need to compute the stats of child project(like subfield).
skewJoinOptimize(tree, rootTaskContext);
scheduler.rewriteIterative(tree, rootTaskContext, new DrivingTableSelection());
scheduler.rewriteOnce(tree, rootTaskContext, new DrivingTableSelection());
scheduler.rewriteOnce(tree, rootTaskContext, new IcebergEqualityDeleteRewriteRule());

tree = pruneSubfield(tree, rootTaskContext, requiredColumns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,24 @@
import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.BaseScalarOperatorShuttle;
import com.starrocks.sql.optimizer.rule.RuleType;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;

public class DrivingTableSelection extends TransformationRule {

Expand Down Expand Up @@ -81,23 +84,33 @@ private Optional<Long> getSourceTableId(OptExpression parent, int childIdx, OptE
return Optional.empty();
}

private Optional<Pair<Node, Integer>> findDrivingTable(OptExpression parent, int childIdx, OptExpression root,
Long drivingTableId, Integer rootChildIdx, List<Node> projections) {
Optional<Long> sourceTableId = getSourceTableId(parent, childIdx, root, projections);
if (sourceTableId.isPresent() && sourceTableId.get().equals(drivingTableId)) {
return Optional.of(Pair.create(new Node(parent, root, childIdx), rootChildIdx));
}
for (int i = 0; i < root.getInputs().size(); ++i) {
OptExpression child = root.inputAt(i);
if (childIdx == -1) {
rootChildIdx = i;
private void extractJoins(OptExpression parent, int childIdx, OptExpression root, List<Pair<Long, OptExpression>> joinTables,
List<Pair<Node, Integer>> joinWithTableIdx, List<Node> projections) {
if (root.getOp() instanceof LogicalJoinOperator joinOperator &&
(joinOperator.getJoinType().isCrossJoin() || joinOperator.getJoinType().isInnerJoin() && childIdx == -1)) {
Optional<Integer> tableIdx = Optional.empty();
for (int i = 0; i < root.getInputs().size(); i++) {
OptExpression child = root.inputAt(i);
Optional<Long> sourceTableId = getSourceTableId(root, i, child, projections);
if (sourceTableId.isPresent()) {
OptExpression tableChild = root.inputAt(i);
joinTables.add(new Pair<>(sourceTableId.get(), tableChild));
joinWithTableIdx.add(new Pair<>(new Node(parent, root, childIdx), i));

if (tableIdx.isPresent()) {
tableIdx = Optional.empty();
} else {
tableIdx = Optional.of(i);
}
}
}
Optional<Pair<Node, Integer>> newChild = findDrivingTable(root, i, child, drivingTableId, rootChildIdx, projections);
if (newChild.isPresent()) {
return newChild;
if (tableIdx.isPresent()) {
int joinIdx = tableIdx.get() == 1 ? 0 : 1;
extractJoins(root, joinIdx, root.inputAt(joinIdx), joinTables, joinWithTableIdx, projections);
}
} else if (root.getOp() instanceof LogicalProjectOperator) {
extractJoins(root, 0, root.inputAt(0), joinTables, joinWithTableIdx, projections);
}
return Optional.empty();
}

boolean isCrossJoin(OptExpression root) {
Expand Down Expand Up @@ -160,68 +173,126 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
Map<Integer, Integer> outputColumnMapping = new HashMap<>();
extractJoinOutputColumnMapping(input, outputColumnMapping, true);

Map<Long, HashSet<Long>> tableRelations = new HashMap<>();
Map<Long, HashMap<Long, Integer>> tableRelations = new HashMap<>();
List<Pair<Integer, CompoundPredicateOperator.CompoundType>> compoundTypes = new ArrayList<>();
if (rootJoinOp.getJoinType().isInnerJoin() && rootJoinOp.getOnPredicate() != null) {
binaryRelation(rootJoinOp.getOnPredicate(), columnRefFactory, tableRelations, outputColumnMapping);
if (binaryRelation(rootJoinOp.getOnPredicate(), columnRefFactory, tableRelations, outputColumnMapping, compoundTypes,
0, null)) {
return Collections.emptyList();
}
}
if (tableRelations.size() <= 1) {
return Collections.emptyList();
}
// all tables have a relationship with only the same table.
// e.g. select * from t1, t2, t3 inner join t4 on t4.c1 = t1.c1 or t4.c1 = t2.c1 or t4.c1 = t3.c1
Optional<Long> drivingTableId = Optional.empty();
for (Map.Entry<Long, HashSet<Long>> entry : tableRelations.entrySet()) {
if (entry.getValue().size() > 1) {
if (drivingTableId.isPresent()) {
Optional<Map<Long, Integer>> tableDepthMap = Optional.empty();
for (Map.Entry<Long, HashMap<Long, Integer>> entry : tableRelations.entrySet()) {
HashMap<Long, Integer> map = entry.getValue();
if (map.size() > 1) {
if (tableDepthMap.isPresent()) {
return Collections.emptyList();
}
drivingTableId = Optional.of(entry.getKey());
// driving table first
map.put(entry.getKey(), 0);
tableDepthMap = Optional.of(map);
}
}
if (drivingTableId.isPresent()) {
if (tableDepthMap.isPresent()) {
List<Pair<Long, OptExpression>> joinTables = new ArrayList<>();
List<Pair<Node, Integer>> joinWithTableIdx = new ArrayList<>();
List<Node> projections = new ArrayList<>();
Optional<Pair<Node, Integer>> pair = findDrivingTable(null, -1, input, drivingTableId.get(), null, projections);
if (pair.isEmpty()) {
return Collections.emptyList();
extractJoins(null, -1, input, joinTables, joinWithTableIdx, projections);

// JoinReorder
Map<Long, Integer> finalTableDepthMap = tableDepthMap.get();
joinTables.sort(Comparator.comparingInt((Pair<Long, OptExpression> pair) -> finalTableDepthMap.get(pair.first))
.thenComparingLong(pair -> pair.first));

for (int i = 0; i < joinTables.size(); i++) {
Pair<Node, Integer> joinPair = joinWithTableIdx.get(i);
Node join = joinPair.first;
OptExpression joinRoot = join.child;
int tableChildId = joinPair.second;
Pair<Long, OptExpression> tablePair = joinTables.get(i);
OptExpression tableChild = tablePair.second;

joinRoot.setChild(tableChildId, tableChild);
}
Node drivingTableNode = pair.get().first;
Integer rootChildIdx = pair.get().second;

if (drivingTableNode.childIndex != -1 && drivingTableNode.parent != input) {
OptExpression drivingTableRoot = drivingTableNode.child;
int drivingTableChildIndex = drivingTableNode.childIndex;
OptExpression drivingTableParent = drivingTableNode.parent;
// Projection
Collections.reverse(projections);
for (Node node : projections) {
OptExpression child = node.child;
List<OptExpression> childInputs = node.parent.inputAt(node.childIndex).getInputs();

int replaceChildIdx = rootChildIdx == 0 ? 1 : 0;
OptExpression replace = input.getInputs().get(replaceChildIdx);
if (replace == drivingTableRoot) {
return Collections.emptyList();
Map<ColumnRefOperator, ScalarOperator> newMap = new HashMap<>();
for (OptExpression projectionChild : childInputs) {
newMap.putAll(projectionChild.getRowOutputInfo().getColumnRefMap());
}
drivingTableParent.setChild(drivingTableChildIndex, replace);
ScalarOperator newOnPredicate =
innerOnPredicate.accept(new ColumnMappingRewriter(outputColumnMapping, columnRefFactory), null);
OptExpression newRoot = OptExpression.create(LogicalJoinOperator.builder().withOperator(rootJoinOp)
.setOnPredicate(newOnPredicate).build(), input.getInputs());
newRoot.setChild(replaceChildIdx, drivingTableRoot);

Collections.reverse(projections);
for (Node node : projections) {
OptExpression child = node.child;
List<OptExpression> childInputs = node.parent.inputAt(node.childIndex).getInputs();

Map<ColumnRefOperator, ScalarOperator> newMap = new HashMap<>();
for (OptExpression projectionChild : childInputs) {
newMap.putAll(projectionChild.getRowOutputInfo().getColumnRefMap());
}

LogicalProjectOperator projectionOp = (LogicalProjectOperator) child.getOp();
node.parent.setChild(node.childIndex, OptExpression.create(
LogicalProjectOperator.builder().withOperator(projectionOp)
.setColumnRefMap(newMap).build(), childInputs));
LogicalProjectOperator projectionOp = (LogicalProjectOperator) child.getOp();
node.parent.setChild(node.childIndex, OptExpression.create(
LogicalProjectOperator.builder().withOperator(projectionOp).setColumnRefMap(newMap).build(),
childInputs));
}
ScalarOperator newOnPredicate =
innerOnPredicate.accept(new ColumnMappingRewriter(outputColumnMapping, columnRefFactory), null);

compoundTypes.sort(Comparator.comparingInt((pair -> pair.first)));
for (int i = 0; i < joinWithTableIdx.size(); i++) {
Pair<Node, Integer> joinPair = joinWithTableIdx.get(i);
Node join = joinPair.first;
OptExpression joinRoot = join.child;

if (join.parent != null) {
Pair<Integer, CompoundPredicateOperator.CompoundType> compoundTypePair = compoundTypes.get(i - 1);
if (!compoundTypePair.second.equals(CompoundPredicateOperator.CompoundType.OR)) {
continue;
}
// rewrite to union all
List<ColumnRefOperator> result = new ArrayList<>();
List<List<ColumnRefOperator>> childOutputColumns = List.of(new ArrayList<>(), new ArrayList<>());

Map<ColumnRefOperator, ScalarOperator> leftMap = new HashMap<>();
Map<ColumnRefOperator, ScalarOperator> rightMap = new HashMap<>();
OptExpression leftInput = joinRoot.inputAt(0);
OptExpression rightInput = joinRoot.inputAt(1);

extractUnionInput(result, childOutputColumns, leftMap, rightMap, leftInput);
extractUnionInput(result, childOutputColumns, rightMap, leftMap, rightInput);
List<OptExpression> newInputs = new ArrayList<>();
newInputs.add(
OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(leftMap).build(), leftInput));
newInputs.add(
OptExpression.create(LogicalProjectOperator.builder().setColumnRefMap(rightMap).build(), rightInput));

join.parent.setChild(join.childIndex,
OptExpression.create(new LogicalUnionOperator(result, childOutputColumns, true), newInputs));
}
return List.of(newRoot);
}

return List.of(OptExpression.create(
LogicalJoinOperator.builder().withOperator(rootJoinOp).setOnPredicate(newOnPredicate).build(),
input.getInputs()));
}
return Collections.emptyList();
}

private void extractUnionInput(List<ColumnRefOperator> result, List<List<ColumnRefOperator>> childOutputColumns,
Map<ColumnRefOperator, ScalarOperator> leftMap,
Map<ColumnRefOperator, ScalarOperator> rightMap, OptExpression input) {
input.getRowOutputInfo().getColumnRefMap().forEach((columnOp, scalarOp) -> {
ColumnRefOperator nullableColumnOp =
new ColumnRefOperator(columnOp.getId(), columnOp.getType(), columnOp.getName(), true);
result.add(nullableColumnOp);
childOutputColumns.get(0).add(nullableColumnOp);
childOutputColumns.get(1).add(nullableColumnOp);
leftMap.put(nullableColumnOp, scalarOp);
rightMap.put(nullableColumnOp, ConstantOperator.createNull(scalarOp.getType()));
});
}

private Long getTableIdByColumnId(int columnId, ColumnRefFactory columnRefFactory, Map<Integer, Integer> columnMapping) {
Table table = columnRefFactory.getTableForColumn(columnId);

Expand All @@ -236,13 +307,18 @@ private Long getTableIdByColumnId(int columnId, ColumnRefFactory columnRefFactor
}
}

private void binaryRelation(ScalarOperator onPredicate, ColumnRefFactory columnRefFactory,
Map<Long, HashSet<Long>> tableRelations, Map<Integer, Integer> columnMapping) {
if (onPredicate instanceof CompoundPredicateOperator) {
for (ScalarOperator scalarOperator : ((CompoundPredicateOperator) onPredicate).normalizeChildren()) {
binaryRelation(scalarOperator, columnRefFactory, tableRelations, columnMapping);
private boolean binaryRelation(ScalarOperator onPredicate, ColumnRefFactory columnRefFactory,
Map<Long, HashMap<Long, Integer>> tableRelations, Map<Integer, Integer> columnMapping,
List<Pair<Integer, CompoundPredicateOperator.CompoundType>> compoundTypes,
int depth, CompoundPredicateOperator.CompoundType compoundType) {
if (onPredicate instanceof CompoundPredicateOperator compoundPredicate) {
for (ScalarOperator scalarOperator : compoundPredicate.normalizeChildren()) {
if (binaryRelation(scalarOperator, columnRefFactory, tableRelations, columnMapping, compoundTypes, depth + 1,
compoundPredicate.getCompoundType())) {
return true;
}
}
} else if (onPredicate instanceof BinaryPredicateOperator) {
} else if (onPredicate instanceof BinaryPredicateOperator && depth != 0) {
int[] columnIds = onPredicate.getUsedColumns().getColumnIds();
if (columnIds.length == 2) {
int leftIdx = columnIds[0];
Expand All @@ -251,12 +327,19 @@ private void binaryRelation(ScalarOperator onPredicate, ColumnRefFactory columnR
Long leftTableId = getTableIdByColumnId(leftIdx, columnRefFactory, columnMapping);
Long rightTableId = getTableIdByColumnId(rightIdx, columnRefFactory, columnMapping);

if (rightTableId != null && leftTableId != null) {
tableRelations.computeIfAbsent(leftTableId, k -> new HashSet<>()).add(rightTableId);
tableRelations.computeIfAbsent(rightTableId, k -> new HashSet<>()).add(leftTableId);
if (leftTableId != null && rightTableId != null) {
BiFunction<Long, Integer, Integer> function =
(tableId, tableDepth) -> tableDepth == null ? depth : Math.min(depth, tableDepth);
tableRelations.computeIfAbsent(leftTableId, k -> new HashMap<>()).compute(rightTableId, function);
tableRelations.computeIfAbsent(rightTableId, k -> new HashMap<>()).compute(leftTableId, function);
compoundTypes.add(new Pair<>(depth, compoundType));
}
} else {
// e,g, `t1.c1 = t2.c1 + t3.c1`
return columnIds.length > 2;
}
}
return false;
}

private class ColumnMappingRewriter extends BaseScalarOperatorShuttle {
Expand Down
Loading

0 comments on commit c98645b

Please sign in to comment.