From cfde327ec84022a27c6238f23d9687fa066f7a8a Mon Sep 17 00:00:00 2001 From: luyang Date: Tue, 12 Nov 2024 04:56:31 +0000 Subject: [PATCH 1/4] fix device_type2sub_tsk_gph_builder_ --- oneflow/core/graph/task_graph.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 10280a8dfe5..1ece973f868 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -880,13 +880,14 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { if (device_type != DeviceType::kCPU && device_type2sub_tsk_gph_builder_.find(device_type) != device_type2sub_tsk_gph_builder_.end()) { - status = CHECK_JUST( // NOLINT + auto maybe_status = // NOLINT device_type2sub_tsk_gph_builder_ // NOLINT .at(device_type) // NOLINT ->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT - *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); // NOLINT + *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT + if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); } } else { status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, @@ -1052,6 +1053,12 @@ Maybe GlobalTaskGraph::Init() { OpGraph* op_graph = Singleton::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); + // Register the corresponding task graph builder based on the device type and store them to map + const auto* global_device_type_create_sub_tsk_gph_builder_fn = + GlobalDeviceType2CreateSubTskGphBuilderFn(); + for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { + device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); + } hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); HashMap> op_node2sorted_comp_tasks; @@ -1088,6 +1095,13 @@ Maybe BoxingTaskGraph::Init( OpGraph* op_graph = Singleton::Get(); sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this)); boxing_logger_ = CreateBoxingLogger(); + // Register the corresponding task graph builder based on the device type and store them to map + const auto* global_device_type_create_sub_tsk_gph_builder_fn = + GlobalDeviceType2CreateSubTskGphBuilderFn(); + for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) { + device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second()); + } + hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder()); const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) { From 68c353e93f2b0f456faf178f29d80ce1882a53c0 Mon Sep 17 00:00:00 2001 From: luyang Date: Tue, 12 Nov 2024 09:52:55 +0000 Subject: [PATCH 2/4] refine --- oneflow/core/graph/task_graph.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 1ece973f868..082b45ba845 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -888,7 +888,8 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) { blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT *(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); } - } else { + } + if (!status) { status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build( sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, blob_desc, src_nd_sbp, dst_nd_sbp, From c6f954774d4302388ed13f8d1d851539fff8c9be Mon Sep 17 00:00:00 2001 From: luyang Date: Mon, 18 Nov 2024 11:48:01 +0000 Subject: [PATCH 3/4] npu 2d parallel --- cmake/oneflow.cmake | 2 ++ oneflow/core/job/resource_desc.cpp | 11 ++++++----- oneflow/core/job/runtime.cpp | 5 +++-- .../core/job_rewriter/insert_nccl_logical_op_pass.cpp | 10 ++++------ oneflow/core/job_rewriter/job_completer.cpp | 4 ++-- .../nccl_logical_chain_strict_order_pass.cpp | 4 ++-- .../core/job_rewriter/nccl_logical_op_fusion_pass.cpp | 4 ++-- 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index b37535367e1..ac03505f3d5 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -547,6 +547,7 @@ if(BUILD_PYTHON) PATTERN "oneflow/core/register/register_manager.h" PATTERN "oneflow/core/register/runtime_register_desc.h" PATTERN "oneflow/core/register/tensor_slice_view.h" + PATTERN "oneflow/core/register/tensor_slice_copier.h" PATTERN "oneflow/core/ndarray/xpu_util.h" PATTERN "oneflow/core/rpc/include/base.h" PATTERN "oneflow/core/rpc/include/ctrl.h" @@ -558,6 +559,7 @@ if(BUILD_PYTHON) PATTERN "oneflow/core/operator/operator.h" PATTERN "oneflow/core/operator/operator_util.h" PATTERN "oneflow/core/operator/op_conf_util.h" + PATTERN "oneflow/core/operator/nccl_send_recv_boxing_op_util.h" PATTERN "oneflow/core/graph/compute_task_node.h" PATTERN "oneflow/core/graph/copy_task_node.h" PATTERN "oneflow/core/graph/exec_graph.h" diff --git a/oneflow/core/job/resource_desc.cpp b/oneflow/core/job/resource_desc.cpp index be9940fe571..081807de571 100644 --- a/oneflow/core/job/resource_desc.cpp +++ b/oneflow/core/job/resource_desc.cpp @@ -71,11 +71,12 @@ CollectiveBoxingConf ResourceDesc::collective_boxing_conf() const { } bool ResourceDesc::nccl_use_compute_stream() const { -#if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 - return resource_.nccl_use_compute_stream(); -#else - return false; -#endif + // #if defined(WITH_CUDA) && NCCL_VERSION_CODE > 2700 + // return resource_.nccl_use_compute_stream(); + // #else + // return false; + // #endif + return true; } void ResourceDesc::DumpCudnnConf(const JobConfigProto& job_conf) { diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index 068484c7a78..1eaeb3d67fa 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -70,9 +70,10 @@ Runtime::Runtime( Singleton::Get()->AddPlan(plan); collective_boxing_scheduler_plan_token_ = Singleton::Get()->AddPlan(plan); -#ifdef WITH_CUDA + // #ifdef WITH_CUDA + // Singleton::Get()->CreateCommFromPlan(plan); + // #endif // WITH_CUDA Singleton::Get()->CreateCommFromPlan(plan); -#endif // WITH_CUDA } std::vector source_tasks; source_tasks.reserve(plan.task().size()); diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index abdb0b596fd..1d9d54d288d 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -16,7 +16,7 @@ limitations under the License. #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/common/util.h" #include "oneflow/core/job/nd_sbp_util.h" -#ifdef WITH_CUDA +// #ifdef WITH_CUDA #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/instructions_builder.h" @@ -146,7 +146,8 @@ void FindAllConnectedSubgraphForGpuExecOrder(std::vector> CHECK(visited.insert(seed_node).second); const ParallelDesc& seed_parallel_desc = seed_node->parallel_desc(); // NOTE(chengcheng): ONLY consider GPU op and parallel num > 1. - if (seed_parallel_desc.device_type() != DeviceType::kCUDA) { continue; } + // if (seed_parallel_desc.device_type() != DeviceType::kCUDA) { continue; } + if (seed_parallel_desc.device_type() != DeviceType::kNPU) { continue; } if (seed_parallel_desc.parallel_num() <= 1) { continue; } // NOTE(chengcheng): using fastest time shape for merge acc into bw subgraph. if (!SharedPtrShapeEqual(GetOpNodeFastestTimeShape(seed_node), seed_time_shape)) { continue; } @@ -486,7 +487,6 @@ bool TryBuildNcclLogicalOpConf(OperatorConf* ret, const OpNode* src_node, const int64_t scope_symbol_id = CHECK_JUST(BuildScopeWithReducedParallelDesc( src_node->op().op_conf().scope_symbol_id(), *src_reduced_parallel_desc)); - if (src_reduced_hierarchy->NumAxes() == 1 && dst_reduced_hierarchy->NumAxes() == 1) { return TryBuildNcclBy1DHierarchy(ret, src_reduced_nd_sbp->sbp_parallel(0), dst_reduced_nd_sbp->sbp_parallel(0), lbn, scope_symbol_id, @@ -786,7 +786,6 @@ Maybe InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* } else { auto_parallel::StraightenOpGraph(op_graph, &ordered_op_nodes); } - HashMap op_node2global_order; for (int32_t global_order = 0; global_order < ordered_op_nodes.size(); global_order++) { op_node2global_order.emplace(ordered_op_nodes[global_order], global_order); @@ -844,7 +843,6 @@ Maybe InsertNcclLogicalOpPass::Apply(const OpGraph& op_graph, JobBuilder* for (auto& pair : placement2subgraphs) { PlacementNcclSubGraghsInfo& info = pair.second; - // NOTE(chengcheng): insert nccl ops for each subgraph int64_t stream_offset = 0; int64_t total_op_num = 0; @@ -883,4 +881,4 @@ REGISTER_JOB_PASS("InsertNcclLogicalOpPass", InsertNcclLogicalOpPass); } // namespace oneflow -#endif // WITH_CUDA +// #endif // WITH_CUDA diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp index dc605d8fb96..c3b638d9f22 100644 --- a/oneflow/core/job_rewriter/job_completer.cpp +++ b/oneflow/core/job_rewriter/job_completer.cpp @@ -153,7 +153,7 @@ Maybe JobCompleter::Complete(Job* job) { compile_tc->Count("[GraphCompile]" + job_name + " SystemOpFillJobNamePass", 1, true); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " DumpBlobParallelConfPass", 1, true); -#ifdef WITH_CUDA + // #ifdef WITH_CUDA if (Singleton::Get()->nccl_use_compute_stream()) { // NOTE(chengcheng): this pass need as last pass for insert correct op with nccl boxing. JUST(JobPass4Name("InsertNcclLogicalOpPass")(job, &job_pass_ctx)); @@ -169,7 +169,7 @@ Maybe JobCompleter::Complete(Job* job) { JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); compile_tc->Count("[GraphCompile]" + job_name + " DumpBlobParallelConfPass", 1, true); } -#endif // WITH_CUDA + // #endif // WITH_CUDA JUST(JobPass4Name("LogicalChainPass")(job, &job_pass_ctx)); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); diff --git a/oneflow/core/job_rewriter/nccl_logical_chain_strict_order_pass.cpp b/oneflow/core/job_rewriter/nccl_logical_chain_strict_order_pass.cpp index 5fd82c36a2b..90d5bd81908 100644 --- a/oneflow/core/job_rewriter/nccl_logical_chain_strict_order_pass.cpp +++ b/oneflow/core/job_rewriter/nccl_logical_chain_strict_order_pass.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef WITH_CUDA +// #ifdef WITH_CUDA #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/framework.h" @@ -210,4 +210,4 @@ REGISTER_JOB_PASS("NcclLogicalChainStrictOrderPass", NcclLogicalChainStrictOrder } // namespace oneflow -#endif // WITH_CUDA +// #endif // WITH_CUDA diff --git a/oneflow/core/job_rewriter/nccl_logical_op_fusion_pass.cpp b/oneflow/core/job_rewriter/nccl_logical_op_fusion_pass.cpp index 7556d91d035..96faefdebd4 100644 --- a/oneflow/core/job_rewriter/nccl_logical_op_fusion_pass.cpp +++ b/oneflow/core/job_rewriter/nccl_logical_op_fusion_pass.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef WITH_CUDA +// #ifdef WITH_CUDA #include "oneflow/core/auto_parallel/auto_memory.h" #include "oneflow/core/job/nd_sbp_util.h" #include "oneflow/core/framework/framework.h" @@ -293,4 +293,4 @@ REGISTER_JOB_PASS("NcclLogicalOpFusionPass", NcclLogicalOpFusionPass); } // namespace oneflow -#endif // WITH_CUDA +// #endif // WITH_CUDA From 7e3821fc1bfd8575c8ab48744fc0e12b0849fe47 Mon Sep 17 00:00:00 2001 From: luyang Date: Tue, 10 Dec 2024 09:11:13 +0000 Subject: [PATCH 4/4] revert Runtime CreateCommFromPlan --- oneflow/core/job/runtime.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/oneflow/core/job/runtime.cpp b/oneflow/core/job/runtime.cpp index 1eaeb3d67fa..068484c7a78 100644 --- a/oneflow/core/job/runtime.cpp +++ b/oneflow/core/job/runtime.cpp @@ -70,10 +70,9 @@ Runtime::Runtime( Singleton::Get()->AddPlan(plan); collective_boxing_scheduler_plan_token_ = Singleton::Get()->AddPlan(plan); - // #ifdef WITH_CUDA - // Singleton::Get()->CreateCommFromPlan(plan); - // #endif // WITH_CUDA +#ifdef WITH_CUDA Singleton::Get()->CreateCommFromPlan(plan); +#endif // WITH_CUDA } std::vector source_tasks; source_tasks.reserve(plan.task().size());