Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Eliminate non-required unnest computation (backport #55431) #55675

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions be/src/exec/pipeline/table_function_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ Status TableFunctionOperator::prepare(RuntimeState* state) {
if (_table_function == nullptr) {
return Status::InternalError("can't find table function " + table_function_name);
}
if (_tnode.table_function_node.__isset.fn_result_required) {
_fn_result_required = _tnode.table_function_node.fn_result_required;
} else {
_fn_result_required = true;
}
RETURN_IF_ERROR(_table_function->init(table_fn, &_table_function_state));

_table_function_state->set_is_required(_fn_result_required);
_table_function_exec_timer = ADD_TIMER(_unique_metrics, "TableFunctionExecTime");
_table_function_exec_counter = ADD_COUNTER(_unique_metrics, "TableFunctionExecCount", TUnit::UNIT);
RETURN_IF_ERROR(_table_function->prepare(_table_function_state));
Expand Down Expand Up @@ -159,8 +164,11 @@ ChunkPtr TableFunctionOperator::_build_chunk(const std::vector<ColumnPtr>& colum
for (size_t i = 0; i < _outer_slots.size(); ++i) {
chunk->append_column(columns[i], _outer_slots[i]);
}
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
chunk->append_column(columns[_outer_slots.size() + i], _fn_result_slots[i]);

if (_fn_result_required) {
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
chunk->append_column(columns[_outer_slots.size() + i], _fn_result_slots[i]);
}
}

return chunk;
Expand Down Expand Up @@ -222,8 +230,10 @@ void TableFunctionOperator::_copy_result(const std::vector<ColumnPtr>& columns,
}

// Build table function result
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
columns[_outer_slots.size() + i]->append(*(fn_result_cols[i]), start, copy_rows);
if (_fn_result_required) {
for (size_t i = 0; i < _fn_result_slots.size(); ++i) {
columns[_outer_slots.size() + i]->append(*(fn_result_cols[i]), start, copy_rows);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions be/src/exec/pipeline/table_function_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class TableFunctionOperator final : public Operator {
size_t _next_output_row_offset = 0;
// table function result
std::pair<Columns, UInt32Column::Ptr> _table_function_result;
bool _fn_result_required = true;
// table function param and return offset
TableFunctionState* _table_function_state = nullptr;

Expand Down
5 changes: 5 additions & 0 deletions be/src/exprs/table_function/table_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class TableFunctionState {

[[nodiscard]] const Status& status() const { return _status; }

void set_is_required(bool is_required) { _is_required = is_required; }

bool is_required() { return _is_required; }

private:
virtual void on_new_params(){};

Expand All @@ -79,6 +83,7 @@ class TableFunctionState {

// used to identify left join for table function
bool _is_left_join = false;
bool _is_required = true;
};

class TableFunction {
Expand Down
16 changes: 10 additions & 6 deletions be/src/exprs/table_function/unnest.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,33 @@ class Unnest final : public TableFunction {
auto offset_column = col_array->offsets_column();
auto copy_count_column = UInt32Column::create();
copy_count_column->append(0);

ColumnPtr unnested_array_elements = col_array->elements_column()->clone_empty();

uint32_t offset = 0;
for (int row_idx = 0; row_idx < arg0->size(); ++row_idx) {
if (arg0->is_null(row_idx)) {
if (state->get_is_left_join()) {
// to support unnest with null.
unnested_array_elements->append_nulls(1);
if (state->is_required()) {
unnested_array_elements->append_nulls(1);
}
offset += 1;
}
copy_count_column->append(offset);
} else {
if (offset_column->get(row_idx + 1).get_int32() == offset_column->get(row_idx).get_int32() &&
state->get_is_left_join()) {
// to support unnest with null.
unnested_array_elements->append_nulls(1);
if (state->is_required()) {
unnested_array_elements->append_nulls(1);
}
offset += 1;
} else {
auto length =
offset_column->get(row_idx + 1).get_int32() - offset_column->get(row_idx).get_int32();
unnested_array_elements->append(*(col_array->elements_column()),
offset_column->get(row_idx).get_int32(), length);
if (state->is_required()) {
unnested_array_elements->append(*(col_array->elements_column()),
offset_column->get(row_idx).get_int32(), length);
}
offset += length;
}
copy_count_column->append(offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,33 @@
import com.starrocks.thrift.TTypeDesc;

import java.util.List;
import java.util.stream.Collectors;

public class TableFunctionNode extends PlanNode {
private final TableFunction tableFunction;

//Slots of output by table function
private final List<Integer> fnResultSlots;

//External column slots of the join logic generated by the table function
private final List<Integer> outerSlots;
//Slots of table function input parameters
private final List<Integer> paramSlots;

private final boolean fnResultRequired;
public TableFunctionNode(PlanNodeId id, PlanNode child, TupleDescriptor outputTupleDesc,
TableFunction tableFunction,
List<Integer> paramSlots,
List<Integer> outerSlots,
List<Integer> fnResultSlots) {
List<Integer> fnResultSlots,
boolean fnResultRequired) {
super(id, "TableValueFunction");
this.children.add(child);
this.tableFunction = tableFunction;

this.paramSlots = paramSlots;
this.outerSlots = outerSlots;
this.fnResultSlots = fnResultSlots;
this.fnResultRequired = fnResultRequired;
this.tupleIds.add(outputTupleDesc.getId());
}

Expand All @@ -69,6 +74,7 @@ protected void toThrift(TPlanNode msg) {
msg.table_function_node.setParam_columns(paramSlots);
msg.table_function_node.setOuter_columns(outerSlots);
msg.table_function_node.setFn_result_columns(fnResultSlots);
msg.table_function_node.setFn_result_required(fnResultRequired);
}

@Override
Expand All @@ -77,6 +83,7 @@ protected String getNodeExplainString(String prefix, TExplainLevel detailLevel)
output.append(prefix).append("tableFunctionName: ").append(tableFunction.getFunctionName()).append('\n');
output.append(prefix).append("columns: ").append(tableFunction.getDefaultColumnNames()).append('\n');
output.append(prefix).append("returnTypes: ").append(tableFunction.getTableFnReturnTypes()).append('\n');

return output.toString();
}

Expand All @@ -101,4 +108,8 @@ protected void toNormalForm(TNormalPlanNode planNode, FragmentNormalizer normali
public boolean needCollectExecStats() {
return true;
}

public boolean isFnResultRequired() {
return fnResultRequired;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3340,6 +3340,12 @@ public PlanFragment visitPhysicalTableFunction(OptExpression optExpression, Exec
}
udtfOutputTuple.computeMemLayout();

ColumnRefSet fnResultsRequired = ColumnRefSet.of();
optExpression.getRowOutputInfo().getColumnRefMap().values()
.forEach(expr -> fnResultsRequired.union(expr.getUsedColumns()));
Optional.ofNullable(physicalTableFunction.getPredicate())
.ifPresent(pred -> fnResultsRequired.union(pred.getUsedColumns()));
fnResultsRequired.intersect(physicalTableFunction.getFnResultColRefs());
TableFunctionNode tableFunctionNode = new TableFunctionNode(context.getNextNodeId(),
inputFragment.getPlanRoot(),
udtfOutputTuple,
Expand All @@ -3349,8 +3355,10 @@ public PlanFragment visitPhysicalTableFunction(OptExpression optExpression, Exec
physicalTableFunction.getOuterColRefs().stream().map(ColumnRefOperator::getId)
.collect(Collectors.toList()),
physicalTableFunction.getFnResultColRefs().stream().map(ColumnRefOperator::getId)
.collect(Collectors.toList())
.collect(Collectors.toList()),
!fnResultsRequired.isEmpty() || physicalTableFunction.getOuterColRefs().isEmpty()
);

tableFunctionNode.computeStatistics(optExpression.getStatistics());
tableFunctionNode.setLimit(physicalTableFunction.getLimit());
currentExecGroup.add(tableFunctionNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@

package com.starrocks.sql.plan;

import com.starrocks.planner.TableFunctionNode;
import com.starrocks.sql.analyzer.SemanticException;
import org.junit.Assert;
import org.junit.Test;

<<<<<<< HEAD
=======
import java.util.Optional;

>>>>>>> 4be4b5c34 ([Enhancement] Eliminate non-required unnest computation (#55431))
public class TableFunctionTest extends PlanTestBase {
@Test
public void testSql0() throws Exception {
Expand Down Expand Up @@ -289,4 +295,38 @@ public void testUnnesetBitmapToArrayToUnnestBitmapRewrite() throws Exception {
PlanTestBase.assertContains(plan, "tableFunctionName: unnest_bitmap");
PlanTestBase.assertNotContains(plan, "bitmap_to_array");
}

@Test
public void testUnnesetFnResultNotRequired() throws Exception {
Object[][] testCaseList = new Object[][] {
{
"select t.* from test_all_type t, unnest(split(t1a, ','))",
false
},
{
"select t.*, unnest from test_all_type t, unnest(split(t1a, ','))",
true
},
{
"SELECT y FROM TABLE(generate_series(1, 2)) t(x), LATERAL generate_series(1, 5000) t2(y);",
true
}
};

for (Object[] tc : testCaseList) {
String sql = (String) tc[0];
Boolean isRequired = (Boolean) tc[1];
System.out.println(sql);
ExecPlan plan = getExecPlan(sql);

Optional<TableFunctionNode> optTableFuncNode = plan.getFragments()
.stream()
.flatMap(fragment -> fragment.collectNodes().stream())
.filter(planNode -> planNode instanceof TableFunctionNode)
.map(planNode -> (TableFunctionNode) planNode)
.findFirst();
Assert.assertTrue(optTableFuncNode.isPresent());
Assert.assertEquals(optTableFuncNode.get().isFnResultRequired(), isRequired);
}
}
}
1 change: 1 addition & 0 deletions gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ struct TTableFunctionNode {
2: optional list<Types.TSlotId> param_columns
3: optional list<Types.TSlotId> outer_columns
4: optional list<Types.TSlotId> fn_result_columns
5: optional bool fn_result_required
}

struct TConnectorScanNode {
Expand Down
Loading