Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoboyan091 committed Jan 20, 2025
1 parent ce46b00 commit 66716e0
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 42 deletions.
30 changes: 16 additions & 14 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "webgpu_context.h"
#include <memory>
#include <cmath>

Expand All @@ -27,8 +28,8 @@
namespace onnxruntime {
namespace webgpu {

void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() {
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
Expand Down Expand Up @@ -154,18 +155,17 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
} else {
query_type_ = TimestampQueryType::None;
}
});

if (enable_pix_capture_) {
if (enable_pix_capture) {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
Adapter(),
Device());
// set pix frame generator
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
Adapter(),
Device());
#else
ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)");
#endif // ENABLE_PIX_FOR_WEBGPU_EP
}
}
});
}

Status WebGpuContext::Wait(wgpu::Future f) {
Expand Down Expand Up @@ -666,11 +666,13 @@ void WebGpuContext::Flush() {
num_pending_dispatches_ = 0;
}

void WebGpuContext::OnRunEnd() {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
void WebGpuContext::GeneratePIXFrame() {
pix_frame_generator_->GeneratePIXFrame();
}
if (pix_frame_generator_) {
pix_frame_generator_->GeneratePIXFrame();
}
#endif // ENABLE_PIX_FOR_WEBGPU_EP
}

std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
std::mutex WebGpuContextFactory::mutex_;
Expand Down Expand Up @@ -734,7 +736,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode, config.enable_pix_capture));
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
Expand Down
14 changes: 4 additions & 10 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class WebGpuContextFactory {
// Class WebGpuContext includes all necessary resources for the context.
class WebGpuContext final {
public:
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type);
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture);

Status Wait(wgpu::Future f);

Expand Down Expand Up @@ -131,12 +131,7 @@ class WebGpuContext final {
void EndProfiling(TimePoint, profiling::Events& events, profiling::Events& cached_events);

Status Run(ComputeContext& context, const ProgramBase& program);

bool IsPixCaptureEnabled() const { return enable_pix_capture_; }

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
void GeneratePIXFrame();
#endif // ENABLE_PIX_FOR_WEBGPU_EP
void OnRunEnd();

private:
enum class TimestampQueryType {
Expand All @@ -145,8 +140,8 @@ class WebGpuContext final {
AtPasses
};

WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode, bool enable_pix_capture)
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, enable_pix_capture_(enable_pix_capture), query_type_{TimestampQueryType::None} {}
WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);

std::vector<const char*> GetEnabledAdapterToggles() const;
Expand Down Expand Up @@ -212,7 +207,6 @@ class WebGpuContext final {
const uint32_t max_num_pending_dispatches_ = 16;

// profiling
bool enable_pix_capture_;
TimestampQueryType query_type_;
wgpu::QuerySet query_set_;
wgpu::Buffer query_resolve_buffer_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,11 +857,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
context_.CollectProfilingData(profiler_->Events());
}

#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
if (context_.IsPixCaptureEnabled()) {
context_.GeneratePIXFrame();
}
#endif // ENABLE_PIX_FOR_WEBGPU_EP
context_.OnRunEnd();

return Status::OK();
}
Expand Down
27 changes: 14 additions & 13 deletions onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
validation_mode,
};

std::string enable_pix_capture_str;
if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) {
if (enable_pix_capture_str == kEnablePIXCapture_ON) {
context_config.enable_pix_capture = true;
} else if (enable_pix_capture_str == kEnablePIXCapture_OFF) {
context_config.enable_pix_capture = false;
} else {
ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str);
}
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << context_config.enable_pix_capture;

//
// STEP.3 - prepare parameters for WebGPU context initialization.
//
Expand Down Expand Up @@ -233,6 +221,19 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled);
LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode;

bool enable_pix_capture = false;
std::string enable_pix_capture_str;
if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) {
if (enable_pix_capture_str == kEnablePIXCapture_ON) {
enable_pix_capture = true;
} else if (enable_pix_capture_str == kEnablePIXCapture_OFF) {
enable_pix_capture = false;
} else {
ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str);
}
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << enable_pix_capture;

//
// STEP.4 - start initialization.
//
Expand All @@ -241,7 +242,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);

// Create WebGPU device and initialize the context.
context.Initialize(buffer_cache_config, backend_type);
context.Initialize(buffer_cache_config, backend_type, enable_pix_capture);

// Create WebGPU EP factory.
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));
Expand Down

0 comments on commit 66716e0

Please sign in to comment.