diff --git a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp index c9f72f182955..242fbc0d4de2 100644 --- a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp @@ -258,53 +258,48 @@ class TDqsLogicalOptProposalTransformer : public TOptimizeTransformerBase { } TMaybeNode OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) { - if (TypesCtx.CostBasedOptimizer != ECostBasedOptimizerType::Disable) { - std::function log = [&](auto str) { - YQL_CLOG(INFO, ProviderDq) << str; - }; - - auto factory = MakeCBOOptimizerFactory(); - std::shared_ptr opt; - TDqCBOProviderContext pctx(TypesCtx, Config); - - switch (TypesCtx.CostBasedOptimizer) { - case ECostBasedOptimizerType::Native: - opt = factory->MakeJoinCostBasedOptimizerNative(pctx, ctx, {.MaxDPhypDPTableSize = 100000}); - break; - case ECostBasedOptimizerType::PG: - opt = factory->MakeJoinCostBasedOptimizerPG(pctx, ctx, {.Logger = log}); - break; - default: - YQL_ENSURE(false, "Unknown CBO type"); - break; - } - std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)> providerCollect = [](auto& rels, auto label, auto node, auto stats) { - Y_UNUSED(node); - auto rel = std::make_shared(TString(label), *stats); - rels.push_back(rel); - }; - - return DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, 2, *opt, providerCollect); - } else { + auto equiJoin = node.Cast(); + if (!HasDqConnectionsInEquiJoin(equiJoin)) { return node; } + if (TypesCtx.CostBasedOptimizer == ECostBasedOptimizerType::Disable) { + return node; + } + + std::function log = [&](auto str) { + YQL_CLOG(INFO, ProviderDq) << str; + }; + + auto factory = MakeCBOOptimizerFactory(); + std::shared_ptr opt; + TDqCBOProviderContext pctx(TypesCtx, Config); + + switch (TypesCtx.CostBasedOptimizer) { + case ECostBasedOptimizerType::Native: + opt = factory->MakeJoinCostBasedOptimizerNative(pctx, ctx, {.MaxDPhypDPTableSize = 100000}); + break; + case ECostBasedOptimizerType::PG: + opt = factory->MakeJoinCostBasedOptimizerPG(pctx, ctx, {.Logger = log}); + break; + case NYql::ECostBasedOptimizerType::Disable: + break; + } + std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)> providerCollect = [](auto& rels, auto label, auto node, auto stats) { + Y_UNUSED(node); + auto rel = std::make_shared(TString(label), *stats); + rels.push_back(rel); + }; + + return DqOptimizeEquiJoinWithCosts(node, ctx, TypesCtx, 2, *opt, providerCollect); } TMaybeNode RewriteEquiJoin(TExprBase node, TExprContext& ctx) { auto equiJoin = node.Cast(); - bool hasDqConnections = false; - for (size_t i = 0; i + 2 < equiJoin.ArgCount(); ++i) { - auto list = equiJoin.Arg(i).Cast().List(); - if (auto maybeExtractMembers = list.Maybe()) { - list = maybeExtractMembers.Cast().Input(); - } - if (auto maybeFlatMap = list.Maybe()) { - list = maybeFlatMap.Cast().Input(); - } - hasDqConnections |= !!list.Maybe(); + if (!HasDqConnectionsInEquiJoin(equiJoin)) { + return node; } - return hasDqConnections ? DqRewriteEquiJoin(node, Config->HashJoinMode.Get().GetOrElse(EHashJoinMode::Off), false, ctx, TypesCtx) : node; + return DqRewriteEquiJoin(node, Config->HashJoinMode.Get().GetOrElse(EHashJoinMode::Off), false, ctx, TypesCtx); } TMaybeNode ExpandWindowFunctions(TExprBase node, TExprContext& ctx) { @@ -349,6 +344,17 @@ class TDqsLogicalOptProposalTransformer : public TOptimizeTransformerBase { private: + bool HasDqConnectionsInEquiJoin(const TCoEquiJoin& equiJoin) { + for (size_t i = 0; i + 2 < equiJoin.ArgCount(); ++i) { + const auto& list = SkipCallables(equiJoin.Arg(i).Cast().List().Ref(), + {"ExtractMembers", "FlatMap", "OrderedFlatMap"}); + if (TDqConnection::Match(&list)) { + return true; + } + } + return false; + } + void EnsureNotDistinct(const TCoAggregate& aggregate) { const auto& aggregateHandlers = aggregate.Handlers();