From c310db932d5520ab263668cc90fe47d156d636d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Mon, 13 Jan 2025 17:48:49 +0100 Subject: [PATCH] savepoint --- cmake/adjust_global_compile_flags.cmake | 2 +- .../quantization/dynamic_quantize_matmul.cc | 11 +++++--- onnxruntime/core/mlas/lib/qgemm.cpp | 6 ----- onnxruntime/core/providers/cpu/math/matmul.cc | 4 +++ onnxruntime/core/session/inference_session.cc | 26 +++++++++++++++++++ onnxruntime/core/util/math_cpu.cc | 2 ++ 6 files changed, 41 insertions(+), 10 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index dbbf685346532..d332d95d96789 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -45,7 +45,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO) # "-g3" generates DWARF format debug info. # NOTE: With debug info enabled, web assembly artifacts will be very huge (>1GB). So we offer an option to build without debug info. - set(CMAKE_CXX_FLAGS_DEBUG "-g3") + set(CMAKE_CXX_FLAGS_DEBUG "-g2") else() set(CMAKE_CXX_FLAGS_DEBUG "-g2") endif() diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index b711887690f47..fee4e6d3fa939 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -263,17 +263,22 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, //std::cout << "Calling f32Multiply\n"; // should split in parts and call ctx.ParallelFor just on the rows part +#if 0 // rowsA = M // width = K // colsB = N -#if 0 size_t rowsA = static_cast(helper.M()); if (rowsA > 1) { size_t width = static_cast(helper.K()); size_t colsB = static_cast(helper.N()); - const int8_t* b_data = static_cast(b_tensor->DataRaw()); + //std::cout << "Calling GeckoMatmulIntegerToFloat\n"; + //int threads = concurrency::ThreadPool::DegreeOfParallelism(ctx->GetOperatorThreadPool()); + //std::cout << "degree of parallelism: " << threads << "\n"; + //std::cout << "batch size: " << num_gemms << "\n"; + + GeckoMatmulIntegerToFloat(a_data, a_zp, @@ -291,7 +296,7 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, #endif MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool()); - //} + // } // /* diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 026a1215af42c..c5bbc9f93a9f2 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -61,7 +61,6 @@ Return Value: { const ptrdiff_t ThreadIdM = ThreadId / WorkBlock->ThreadCountN; const ptrdiff_t ThreadIdN = ThreadId % WorkBlock->ThreadCountN; - // // Partition the operation along the M dimension. // @@ -197,16 +196,11 @@ MlasGemmBatch( WorkBlock.ThreadCountN = 1; } TargetThreadCount = ThreadsPerGemm * BatchN; - //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << std::endl; - //std::cout << "TargetThreadCount: " << TargetThreadCount << std::endl; - //std::cout << "MaximumThreadCount: " << MaximumThreadCount << std::endl; - MlasTrySimpleParallel(ThreadPool, TargetThreadCount, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; - //std::cout << "gemm_i: " << gemm_i << " blk_i: " << blk_i << std::endl; MlasGemmQuantThreaded(&WorkBlock, &Shape, &DataParams[gemm_i], blk_i); }); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de908..b43d7344d7081 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -289,8 +289,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const { data[i].alpha = alpha_attr_; data[i].beta = 0.0f; } + +//auto start = std::chrono::steady_clock::now(); MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, M, N, K, data.data(), max_len, thread_pool); + //auto end = std::chrono::steady_clock::now(); + //std::cout << "MatMul," << std::chrono::duration_cast(end - start).count() << "," << max_len << std::endl; } return Status::OK(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 223eed248800e..0bca1b38362fa 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1695,6 +1695,8 @@ common::Status InferenceSession::Initialize() { if (session_profiler_.IsEnabled()) { tp = session_profiler_.Start(); } +//std::cout << "session Initialize" << std::endl; + //auto startInit = std::chrono::steady_clock::now(); ORT_TRY { LOGS(*session_logger_, INFO) << "Initializing session."; @@ -1720,6 +1722,9 @@ common::Status InferenceSession::Initialize() { } // Verify that there are no external initializers in the graph if external data is disabled. + //std::cout << "session Initialize loading main graph" << std::endl; + + onnxruntime::Graph& graph = model_->MainGraph(); #ifdef DISABLE_EXTERNAL_INITIALIZERS const InitializedTensorSet& initializers = graph.GetAllInitializedTensors(); @@ -1767,6 +1772,8 @@ common::Status InferenceSession::Initialize() { TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity"); session_activity_started_ = true; #endif + //std::cout << "session Initialize - creating state" << std::endl; + // now that we have all the execution providers, create the session state session_state_ = std::make_unique( @@ -1824,6 +1831,10 @@ common::Status InferenceSession::Initialize() { }(); if (!loading_ort_format) { + //std::cout << "session Initialize not using ort" << std::endl; + + + #if !defined(ORT_MINIMAL_BUILD) const auto minimal_build_opt_config_value = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigMinimalBuildOptimizations, ""); @@ -1845,6 +1856,10 @@ common::Status InferenceSession::Initialize() { *session_logger_)); #ifdef USE_DML + // std::cout << "session Initialize using DML" << std::endl; + + + const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); if (dmlExecutionProvider) { @@ -1900,10 +1915,16 @@ common::Status InferenceSession::Initialize() { #endif // apply any transformations to the main graph and any subgraphs + //auto start = std::chrono::steady_clock::now(); ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, saving_ort_format)); + //auto end = std::chrono::steady_clock::now(); + //std::cout << "Graph transformations took " << std::chrono::duration_cast(end - start).count() << " ms" << std::endl; // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. + //start = std::chrono::steady_clock::now(); ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + //end = std::chrono::steady_clock::now(); + //std::cout << "Graph resolution took " << std::chrono::duration_cast(end - start).count() << " ms" << std::endl; // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. // @@ -2052,6 +2073,9 @@ common::Status InferenceSession::Initialize() { "Loading anything other than ORT format models is not enabled in this build.")); #endif // !defined(ORT_MINIMAL_BUILD) } else { + //std::cout << "session Initialize - loading ort" << std::endl; + + ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, *session_state_, session_options_.config_options, *session_logger_)); @@ -2171,6 +2195,8 @@ common::Status InferenceSession::Initialize() { } } + //auto endInitialization = std::chrono::steady_clock::now(); + //std::cout << "session Initialize - Initialization time: " << std::chrono::duration_cast(endInitialization - startInit).count() << " ms" << std::endl; return status; } #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 983321593a92b..89fd6aac943c2 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -15,6 +15,8 @@ */ // Modifications Copyright (c) Microsoft. +#include +#include #include "core/util/math_cpuonly.h" #include "core/util/math.h" #include "core/framework/float16.h"