From 66716e05c42fa24509881a63462daa2e6c6bbf29 Mon Sep 17 00:00:00 2001 From: shaoboyan Date: Thu, 16 Jan 2025 12:48:04 +0800 Subject: [PATCH] Address comments --- .../core/providers/webgpu/webgpu_context.cc | 30 ++++++++++--------- .../core/providers/webgpu/webgpu_context.h | 14 +++------ .../webgpu/webgpu_execution_provider.cc | 6 +--- .../webgpu/webgpu_provider_factory.cc | 27 +++++++++-------- 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 5c43186273868..c04a7b765538a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "webgpu_context.h" #include #include @@ -27,8 +28,8 @@ namespace onnxruntime { namespace webgpu { -void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) { - std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() { +void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) { + std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() { // Create wgpu::Adapter if (adapter_ == nullptr) { #if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN) @@ -154,18 +155,17 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi } else { query_type_ = TimestampQueryType::None; } - }); - - if (enable_pix_capture_) { + if (enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) - // set pix frame generator - pix_frame_generator_ = std::make_unique(instance_, - Adapter(), - Device()); + // set pix frame generator + pix_frame_generator_ = std::make_unique(instance_, + Adapter(), + Device()); #else ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)"); #endif // ENABLE_PIX_FOR_WEBGPU_EP - } + } + }); } Status WebGpuContext::Wait(wgpu::Future f) { @@ -666,11 +666,13 @@ void WebGpuContext::Flush() { num_pending_dispatches_ = 0; } +void WebGpuContext::OnRunEnd() { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) -void WebGpuContext::GeneratePIXFrame() { - pix_frame_generator_->GeneratePIXFrame(); -} + if (pix_frame_generator_) { + pix_frame_generator_->GeneratePIXFrame(); + } #endif // ENABLE_PIX_FOR_WEBGPU_EP +} std::unordered_map WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; @@ -734,7 +736,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co auto it = contexts_.find(context_id); if (it == contexts_.end()) { GSL_SUPPRESS(r.11) - auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, config.validation_mode, config.enable_pix_capture)); + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, config.validation_mode)); it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first; } else if (context_id != 0) { ORT_ENFORCE(it->second.context->instance_.Get() == instance && diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 0558df071a82c..9cb648eab93be 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -77,7 +77,7 @@ class WebGpuContextFactory { // Class WebGpuContext includes all necessary resources for the context. class WebGpuContext final { public: - void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type); + void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture); Status Wait(wgpu::Future f); @@ -131,12 +131,7 @@ class WebGpuContext final { void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events); Status Run(ComputeContext& context, const ProgramBase& program); - - bool IsPixCaptureEnabled() const { return enable_pix_capture_; } - -#if defined(ENABLE_PIX_FOR_WEBGPU_EP) - void GeneratePIXFrame(); -#endif // ENABLE_PIX_FOR_WEBGPU_EP + void OnRunEnd(); private: enum class TimestampQueryType { @@ -145,8 +140,8 @@ class WebGpuContext final { AtPasses }; - WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode, bool enable_pix_capture) - : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, enable_pix_capture_(enable_pix_capture), query_type_{TimestampQueryType::None} {} + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode) + : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); std::vector GetEnabledAdapterToggles() const; @@ -212,7 +207,6 @@ class WebGpuContext final { const uint32_t max_num_pending_dispatches_ = 16; // profiling - bool enable_pix_capture_; TimestampQueryType query_type_; wgpu::QuerySet query_set_; wgpu::Buffer query_resolve_buffer_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 56b3ab525c0f6..346cbd2f2ceb4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -857,11 +857,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti context_.CollectProfilingData(profiler_->Events()); } -#if defined(ENABLE_PIX_FOR_WEBGPU_EP) - if (context_.IsPixCaptureEnabled()) { - context_.GeneratePIXFrame(); - } -#endif // ENABLE_PIX_FOR_WEBGPU_EP + context_.OnRunEnd(); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 09d73951bc3b3..8980014eeb6a3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -160,18 +160,6 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( validation_mode, }; - std::string enable_pix_capture_str; - if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) { - if (enable_pix_capture_str == kEnablePIXCapture_ON) { - context_config.enable_pix_capture = true; - } else if (enable_pix_capture_str == kEnablePIXCapture_OFF) { - context_config.enable_pix_capture = false; - } else { - ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str); - } - } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << context_config.enable_pix_capture; - // // STEP.3 - prepare parameters for WebGPU context initialization. // @@ -233,6 +221,19 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode; + bool enable_pix_capture = false; + std::string enable_pix_capture_str; + if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) { + if (enable_pix_capture_str == kEnablePIXCapture_ON) { + enable_pix_capture = true; + } else if (enable_pix_capture_str == kEnablePIXCapture_OFF) { + enable_pix_capture = false; + } else { + ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << enable_pix_capture; + // // STEP.4 - start initialization. // @@ -241,7 +242,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config); // Create WebGPU device and initialize the context. - context.Initialize(buffer_cache_config, backend_type); + context.Initialize(buffer_cache_config, backend_type, enable_pix_capture); // Create WebGPU EP factory. return std::make_shared(context_id, context, std::move(webgpu_ep_config));