Skip to content

Commit

Permalink
[Enhancement] Eliminate non-required unnest computation (#55431)
Browse files Browse the repository at this point in the history
Signed-off-by: satanson <[email protected]>
(cherry picked from commit 4be4b5c)

# Conflicts:
#	be/src/exprs/table_function/table_function.h
#	be/src/exprs/table_function/unnest.h
#	fe/fe-core/src/main/java/com/starrocks/planner/TableFunctionNode.java
#	fe/fe-core/src/test/java/com/starrocks/sql/plan/TableFunctionTest.java
  • Loading branch information
satanson authored and mergify[bot] committed Feb 7, 2025
1 parent 09211ba commit caf8aca
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 8 deletions.
20 changes: 15 additions & 5 deletions be/src/exec/pipeline/table_function_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ Status TableFunctionOperator::prepare(RuntimeState* state) {
if (_table_function == nullptr) {
return Status::InternalError("can't find table function " + table_function_name);
}
if (_tnode.table_function_node.__isset.fn_result_required) {
_fn_result_required = _tnode.table_function_node.fn_result_required;
} else {
_fn_result_required = true;
}
RETURN_IF_ERROR(_table_function->init(table_fn, &_table_function_state));

_table_function_state->set_is_required(_fn_result_required);
_table_function_exec_timer = ADD_TIMER(_unique_metrics, "TableFunctionExecTime");
_table_function_exec_counter = ADD_COUNTER(_unique_metrics, "TableFunctionExecCount", TUnit::UNIT);
RETURN_IF_ERROR(_table_function->prepare(_table_function_state));
Expand Down Expand Up @@ -159,8 +164,11 @@ ChunkPtr TableFunctionOperator::_build_chunk(const std::vector<ColumnPtr>& colum
for (size_t i = 0; i < _outer_slots.size(); ++i) {
chunk->append_column(columns[i], _outer_slots[i]);
}
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
chunk->append_column(columns[_outer_slots.size() + i], _fn_result_slots[i]);

if (_fn_result_required) {
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
chunk->append_column(columns[_outer_slots.size() + i], _fn_result_slots[i]);
}
}

return chunk;
Expand Down Expand Up @@ -222,8 +230,10 @@ void TableFunctionOperator::_copy_result(const std::vector<ColumnPtr>& columns,
}

// Build table function result
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
columns[_outer_slots.size() + i]->append(*(fn_result_cols[i]), start, copy_rows);
if (_fn_result_required) {
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
columns[_outer_slots.size() + i]->append(*(fn_result_cols[i]), start, copy_rows);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions be/src/exec/pipeline/table_function_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class TableFunctionOperator final : public Operator {
size_t _next_output_row_offset = 0;
// table function result
std::pair<Columns, UInt32Column::Ptr> _table_function_result;
bool _fn_result_required = true;
// table function param and return offset
TableFunctionState* _table_function_state = nullptr;

Expand Down
11 changes: 11 additions & 0 deletions be/src/exprs/table_function/table_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class TableFunctionState {

const Status& status() const { return _status; }

void set_is_required(bool is_required) { _is_required = is_required; }

bool is_required() { return _is_required; }

private:
virtual void on_new_params(){};

Expand All @@ -72,6 +76,13 @@ class TableFunctionState {
int64_t _offset = 0;

Status _status;
<<<<<<< HEAD
=======

// used to identify left join for table function
bool _is_left_join = false;
bool _is_required = true;
>>>>>>> 4be4b5c34e ([Enhancement] Eliminate non-required unnest computation (#55431))
};

class TableFunction {
Expand Down
35 changes: 35 additions & 0 deletions be/src/exprs/table_function/unnest.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Unnest final : public TableFunction {
auto* nullable_array_column = down_cast<NullableColumn*>(arg0);

auto offset_column = col_array->offsets_column();
<<<<<<< HEAD
auto compacted_offset_column = UInt32Column::create();
compacted_offset_column->append_datum(Datum(0));

Expand All @@ -61,6 +62,40 @@ class Unnest final : public TableFunction {
compacted_array_elements->append(
*(col_array->elements_column()), offset_column->get(row_idx).get_int32(),
offset_column->get(row_idx + 1).get_int32() - offset_column->get(row_idx).get_int32());
=======
auto copy_count_column = UInt32Column::create();
copy_count_column->append(0);
ColumnPtr unnested_array_elements = col_array->elements_column()->clone_empty();
uint32_t offset = 0;
for (int row_idx = 0; row_idx < arg0->size(); ++row_idx) {
if (arg0->is_null(row_idx)) {
if (state->get_is_left_join()) {
// to support unnest with null.
if (state->is_required()) {
unnested_array_elements->append_nulls(1);
}
offset += 1;
}
copy_count_column->append(offset);
} else {
if (offset_column->get(row_idx + 1).get_int32() == offset_column->get(row_idx).get_int32() &&
state->get_is_left_join()) {
// to support unnest with null.
if (state->is_required()) {
unnested_array_elements->append_nulls(1);
}
offset += 1;
} else {
auto length =
offset_column->get(row_idx + 1).get_int32() - offset_column->get(row_idx).get_int32();
if (state->is_required()) {
unnested_array_elements->append(*(col_array->elements_column()),
offset_column->get(row_idx).get_int32(), length);
}
offset += length;
}
copy_count_column->append(offset);
>>>>>>> 4be4b5c34e ([Enhancement] Eliminate non-required unnest computation (#55431))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,33 @@
import com.starrocks.thrift.TTypeDesc;

import java.util.List;
import java.util.stream.Collectors;

public class TableFunctionNode extends PlanNode {
private final TableFunction tableFunction;

//Slots of output by table function
private final List<Integer> fnResultSlots;

//External column slots of the join logic generated by the table function
private final List<Integer> outerSlots;
//Slots of table function input parameters
private final List<Integer> paramSlots;

private final boolean fnResultRequired;
public TableFunctionNode(PlanNodeId id, PlanNode child, TupleDescriptor outputTupleDesc,
TableFunction tableFunction,
List<Integer> paramSlots,
List<Integer> outerSlots,
List<Integer> fnResultSlots) {
List<Integer> fnResultSlots,
boolean fnResultRequired) {
super(id, "TableValueFunction");
this.children.add(child);
this.tableFunction = tableFunction;

this.paramSlots = paramSlots;
this.outerSlots = outerSlots;
this.fnResultSlots = fnResultSlots;
this.fnResultRequired = fnResultRequired;
this.tupleIds.add(outputTupleDesc.getId());
}

Expand All @@ -69,6 +74,7 @@ protected void toThrift(TPlanNode msg) {
msg.table_function_node.setParam_columns(paramSlots);
msg.table_function_node.setOuter_columns(outerSlots);
msg.table_function_node.setFn_result_columns(fnResultSlots);
msg.table_function_node.setFn_result_required(fnResultRequired);
}

@Override
Expand All @@ -77,6 +83,7 @@ protected String getNodeExplainString(String prefix, TExplainLevel detailLevel)
output.append(prefix).append("tableFunctionName: ").append(tableFunction.getFunctionName()).append('\n');
output.append(prefix).append("columns: ").append(tableFunction.getDefaultColumnNames()).append('\n');
output.append(prefix).append("returnTypes: ").append(tableFunction.getTableFnReturnTypes()).append('\n');

return output.toString();
}

Expand All @@ -101,4 +108,16 @@ protected void toNormalForm(TNormalPlanNode planNode, FragmentNormalizer normali
planNode.setNode_type(TPlanNodeType.TABLE_FUNCTION_NODE);
normalizeConjuncts(normalizer, planNode, conjuncts);
}
<<<<<<< HEAD
=======

@Override
public boolean needCollectExecStats() {
return true;
}

public boolean isFnResultRequired() {
return fnResultRequired;
}
>>>>>>> 4be4b5c34e ([Enhancement] Eliminate non-required unnest computation (#55431))
}
Original file line number Diff line number Diff line change
Expand Up @@ -2777,6 +2777,12 @@ public PlanFragment visitPhysicalTableFunction(OptExpression optExpression, Exec
}
udtfOutputTuple.computeMemLayout();

ColumnRefSet fnResultsRequired = ColumnRefSet.of();
optExpression.getRowOutputInfo().getColumnRefMap().values()
.forEach(expr -> fnResultsRequired.union(expr.getUsedColumns()));
Optional.ofNullable(physicalTableFunction.getPredicate())
.ifPresent(pred -> fnResultsRequired.union(pred.getUsedColumns()));
fnResultsRequired.intersect(physicalTableFunction.getFnResultColRefs());
TableFunctionNode tableFunctionNode = new TableFunctionNode(context.getNextNodeId(),
inputFragment.getPlanRoot(),
udtfOutputTuple,
Expand All @@ -2786,8 +2792,10 @@ public PlanFragment visitPhysicalTableFunction(OptExpression optExpression, Exec
physicalTableFunction.getOuterColRefs().stream().map(ColumnRefOperator::getId)
.collect(Collectors.toList()),
physicalTableFunction.getFnResultColRefs().stream().map(ColumnRefOperator::getId)
.collect(Collectors.toList())
.collect(Collectors.toList()),
!fnResultsRequired.isEmpty() || physicalTableFunction.getOuterColRefs().isEmpty()
);

tableFunctionNode.computeStatistics(optExpression.getStatistics());
tableFunctionNode.setLimit(physicalTableFunction.getLimit());
inputFragment.setPlanRoot(tableFunctionNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@

package com.starrocks.sql.plan;

import com.starrocks.planner.TableFunctionNode;
import com.starrocks.sql.analyzer.SemanticException;
import org.junit.Assert;
import org.junit.Test;

<<<<<<< HEAD
=======
import java.util.Optional;

>>>>>>> 4be4b5c34e ([Enhancement] Eliminate non-required unnest computation (#55431))
public class TableFunctionTest extends PlanTestBase {
@Test
public void testSql0() throws Exception {
Expand Down Expand Up @@ -293,4 +299,38 @@ public void testUnnesetBitmapToArrayToUnnestBitmapRewrite() throws Exception {
PlanTestBase.assertContains(plan, "tableFunctionName: unnest_bitmap");
PlanTestBase.assertNotContains(plan, "bitmap_to_array");
}

@Test
public void testUnnesetFnResultNotRequired() throws Exception {
Object[][] testCaseList = new Object[][] {
{
"select t.* from test_all_type t, unnest(split(t1a, ','))",
false
},
{
"select t.*, unnest from test_all_type t, unnest(split(t1a, ','))",
true
},
{
"SELECT y FROM TABLE(generate_series(1, 2)) t(x), LATERAL generate_series(1, 5000) t2(y);",
true
}
};

for (Object[] tc : testCaseList) {
String sql = (String) tc[0];
Boolean isRequired = (Boolean) tc[1];
System.out.println(sql);
ExecPlan plan = getExecPlan(sql);

Optional<TableFunctionNode> optTableFuncNode = plan.getFragments()
.stream()
.flatMap(fragment -> fragment.collectNodes().stream())
.filter(planNode -> planNode instanceof TableFunctionNode)
.map(planNode -> (TableFunctionNode) planNode)
.findFirst();
Assert.assertTrue(optTableFuncNode.isPresent());
Assert.assertEquals(optTableFuncNode.get().isFnResultRequired(), isRequired);
}
}
}
1 change: 1 addition & 0 deletions gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ struct TTableFunctionNode {
2: optional list<Types.TSlotId> param_columns
3: optional list<Types.TSlotId> outer_columns
4: optional list<Types.TSlotId> fn_result_columns
5: optional bool fn_result_required
}

struct TConnectorScanNode {
Expand Down

0 comments on commit caf8aca

Please sign in to comment.