diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 2ccde69c42..2a1ae9711c 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -1225,8 +1225,9 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } } + auto masters = GetMasters(node, nodes_inline, nodes_set); // node can be inline. - if (CanbeInline(node, consumers, reducer, nodes_in_order.front(), group, nodes_set, this->shape_dict_)) { + if (CanbeInline(node, consumers, reducer, masters, group, nodes_set, this->shape_dict_)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); ir::ComputeInlineChecker checker(ir_sch, block); if (!checker.Check()) { @@ -1326,7 +1327,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map); + SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group); VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index e193b44970..01650c3446 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -165,11 +165,10 @@ bool IsConstOp(const framework::Node* node) { } std::vector GetInputShape(const Node* node, const absl::flat_hash_map& shape_dict) { - auto producers = GetProducers(node); - CHECK(producers.size()); + auto input_data = GetInputNodeData(node); + CHECK(input_data.size()); - auto producer_data = GetNodeData(producers.front()); - return shape_dict.at(producer_data->id()); + return shape_dict.at(input_data.front()->id()); } std::vector GetOutputShape(const Node* node, const absl::flat_hash_map& shape_dict) { @@ -636,7 +635,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, bool CanbeInline(Node* node, const std::vector consumers, const Node* reducer, - const Node* laster, + const std::unordered_set masters, const GroupPtr& group, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict) { @@ -678,10 +677,14 @@ bool CanbeInline(Node* node, return false; } else { auto node_shape = GetOutputShape(node, shape_dict); - auto last_shape = GetOutputShape(laster, shape_dict); - if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != - std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { - return true; + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + + for (auto master : masters) { + auto master_shape = GetOutputShape(master, shape_dict); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + if (node_size != master_size) { + return true; + } } return false; @@ -1313,7 +1316,7 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); if (!group->output_nodes.count(node)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); - ir_sch.SetBuffer(block, "local", true); + ir_sch.SetBuffer(block, "local"); } if (op_pattern_dict[node->op()] == framework::kReduction) { @@ -1370,11 +1373,14 @@ std::unordered_map GetNodeDataSet(const std::unordered_s return node_data_set; } -Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set) { +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set) { // find consumer std::unordered_set visited; std::queue candidates; candidates.push(node); + std::unordered_set masters; while (!candidates.empty()) { auto candidate = candidates.front(); @@ -1389,19 +1395,20 @@ Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const candidates.push(consumer); visited.insert(consumer); } else { - return consumer; + masters.insert(consumer); } } } - return nullptr; + return masters; } void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { + const std::unordered_map& tensor_map, + const GroupPtr& group) { auto exprs_inorder = ir_sch.GetAllBlocks(); auto node_data_set = GetNodeDataSet(nodes_set); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); @@ -1438,34 +1445,35 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, auto node = node_data->source_node.get(); auto node_shape = shape_dict.at(node_data->id()); - auto master = GetMaster(node, nodes_inline, nodes_set); - if (!master) { + auto masters = GetMasters(node, nodes_inline, nodes_set); + if (masters.empty()) { continue; } - auto master_data = GetNodeData(master); - auto master_shape = shape_dict.at(master_data->id()); - if (op_pattern_dict[master->op()] == framework::kReduction) { - master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); - } + bool do_set_buffer_to_shared = false; + for (auto master : masters) { + auto master_data = GetNodeData(master); + auto master_shape = shape_dict.at(master_data->id()); + if (op_pattern_dict[master->op()] == framework::kReduction) { + master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + } - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); - if (node_size == master_size) { - continue; + if (node_size != master_size) { + if (check_sync_mark(idx, master_data->id())) { + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SyncThreads(loops.back(), false); + sync_mark.insert(master_data->id()); + } + do_set_buffer_to_shared = true; + } } - - { + if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) { auto block = ir_sch.GetBlock(node_data->id()); ir_sch.SetBuffer(block, "shared", true); } - - if (check_sync_mark(idx, master_data->id())) { - auto loops = ir_sch.GetLoops(master_data->id()); - ir_sch.SyncThreads(loops.back(), false); - sync_mark.insert(master_data->id()); - } } } diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index 01a33ae876..db92b74c68 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -60,7 +60,7 @@ Node* FindNearestReducer(const Node* node, const std::unordered_set& node bool CanbeInline(Node* node, const std::vector consumers, const Node* reducer, - const Node* laster, + const std::unordered_set masters, const GroupPtr& group, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict); @@ -72,6 +72,10 @@ Node* GetMasterToComputeAt(Node* node, const std::unordered_map& virtual_consumers, const absl::flat_hash_map& shape_dict); +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set); + void LoopAssignReduce(ir::IRSchedule& ir_sch, const Node* node, const Node* reducer, @@ -90,7 +94,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map); + const std::unordered_map& tensor_map, + const GroupPtr& group); } // namespace framework } // namespace hlir diff --git a/cinn/hlir/pass/fusion_helper_base.h b/cinn/hlir/pass/fusion_helper_base.h index 94ef3460aa..7658cc0792 100644 --- a/cinn/hlir/pass/fusion_helper_base.h +++ b/cinn/hlir/pass/fusion_helper_base.h @@ -112,6 +112,17 @@ class FusionHelperBase { return producer_node; } + std::vector GetConsumerNode(const Node* node) const { + std::vector consumer_nodes; + auto node_data = GetNodeData(node); + for (auto& link : node_data->outlinks()) { + auto consumer = link->sink()->safe_as(); + CHECK(consumer); + consumer_nodes.push_back(consumer); + } + return consumer_nodes; + } + bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes) const { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 88f54dc566..0121f8f056 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -617,10 +617,6 @@ class FusionMergePassHelper : public FusionHelperBase { void RecomputeWithCostModel(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind == framework::kReduction) { - CHECK_EQ(fusionable_consumers.size(), 1) << "Find more than one consumer can fuse to " << producer->group_id; - } - // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; @@ -818,14 +814,23 @@ class FusionMergePassHelper : public FusionHelperBase { auto& consumers = input_consumers.second; std::unordered_set updated_consumers; for (auto& consumer : consumers) { - // if group is sub group - if (consumer->belong_groups.size()) { - // inset belong group to consumers. - for (auto& belong_group : consumer->belong_groups) { - updated_consumers.insert(belong_group); + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(belong_group); + } else { + fused_groups.push(belong_group); + } + } } - } else { - updated_consumers.insert(consumer); } } consumers = updated_consumers; @@ -976,7 +981,7 @@ class FusionMergePassHelper : public FusionHelperBase { relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. {OpPatternKind::kElementWise, reduce_fuse_elementwise}, // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, // reduce and injective op must be horizontal relation. {OpPatternKind::kInjective, horizontal_with_injective}, // reduce and reduce must be horizontal relation. diff --git a/cinn/hlir/pass/fusion_merge_pass_test.cc b/cinn/hlir/pass/fusion_merge_pass_test.cc index e834da510c..544f86019c 100755 --- a/cinn/hlir/pass/fusion_merge_pass_test.cc +++ b/cinn/hlir/pass/fusion_merge_pass_test.cc @@ -401,7 +401,7 @@ TEST(FusionMergePass, Reduce_Test_2) { auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 4); + CHECK_EQ(graph->fusion_groups.size(), 3); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); CHECK_EQ(graph->fusion_groups.size(), 2); } diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index 696a55f1cf..82bbabd20f 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -285,11 +285,109 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { return elementwise_fuse_reduce(helper, first, second); } -CONDITION_FUNC(reduce_fuse_reduce) { - // check reduce horizontal with reduce. - if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) { - return false; +CONDITION_FUNC(reduce_fuse_broadcast) { + // if same shape with horizontal relation + if (is_same_size(helper, first, second)) { + return true; } + + // Traversing all reducers in all producers requires two types of conditions to be met. + // The first type is the condition that the reducer itself needs to meet, + // and the second type is the condition that the relationship between each reducer and its consumers with type of + // Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as + // before reduce. + for (auto& node_in_master : first->master_nodes) { + if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) { + continue; + } + Node* reducer = node_in_master; + // First type conditions + // Get some reduce infomation + auto reducer_input_shape = helper->GetNodeInputShape(reducer); + auto reducer_output_shape = helper->GetNodeDataShape(reducer); + auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(reducer->attrs.attr_store.at("keep_dim")); + for (auto& axis : reduce_axes) { + if (axis == -1) { + axis = reducer_input_shape.size() - 1; + } + } + // Check if the reduce axes are continuous + int reduce_size = reducer_input_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= reducer_input_shape[idx - 1]; + } + // Check if the reduce size exceeds the hardware limit + if (helper->target_ == common::DefaultNVGPUTarget() && reduce_size > helper->target_.max_num_threads()) { + return false; + } + + // Second type conditions + // Find directly or indirectly consumers with type of Broadcast in the second group + auto find_broadcasters_in_descendants = [&](const Node* producer) -> std::unordered_set { + std::queue candidates; + std::unordered_set visited_set; + std::unordered_set broadcasters; + candidates.push(producer); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : helper->GetConsumerNode(candidate)) { + if (helper->GetOpKind(consumer) == OpPatternKind::kBroadcast && + second->NodeSet().find(consumer) != second->NodeSet().end()) { + broadcasters.insert(consumer); + } else if (!visited_set.count(consumer)) { + visited_set.insert(consumer); + candidates.push(consumer); + } + } + } + + return broadcasters; + }; + + // Check if each broadcast node meets the conditions + std::unordered_set broadcasters_in_consumers = find_broadcasters_in_descendants(reducer); + for (auto broadcaster : broadcasters_in_consumers) { + auto broadcaster_output_shape = absl::get>(broadcaster->attrs.attr_store.at("out_shape")); + auto broadcast_axes = absl::get>(broadcaster->attrs.attr_store.at("broadcast_axes")); + for (auto& axis : broadcast_axes) { + if (axis == -1) { + axis = broadcaster_output_shape.size() - 1; + } + } + + if (reducer_input_shape != broadcaster_output_shape) { + return false; + } + + if (keep_dim) { + continue; + } else { + // if reducer_output_shape = [1] + if (reducer_output_shape.size() == 1 && reducer_output_shape[0] == 1) { + continue; + } + // check union [reduce_axes, broadcast_axes] = reducer_input_shape + for (int idx = 0; idx < reducer_input_shape.size(); ++idx) { + if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + return false; + } + } + } + } + } + + return true; +} + +CONDITION_FUNC(reduce_fuse_reduce) { if (!limit_args(helper, first, second)) { return false; } diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 026f2c6195..021e66e9d3 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -267,7 +267,7 @@ class OpFusionPassHelper : public FusionHelperBase { // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce. - {framework::kElementWise, without_last_dimension_in_reduce}, + {framework::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape and without last dimension in reduce. {framework::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr.