Skip to content

Commit

Permalink
Force callstacks to only run on untainted threads (#7914)
Browse files Browse the repository at this point in the history
* Force callstacks to only run on untainted threads

* Fix build warning

* Fix ctor

* Change name

* Fix key check
  • Loading branch information
wiktork authored Feb 4, 2025
1 parent 385367a commit 403f362
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 12 deletions.
47 changes: 45 additions & 2 deletions src/Profilers/MonitorProfiler/Communication/CommandServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ CommandServer::CommandServer(const std::shared_ptr<ILogger>& logger, ICorProfile
HRESULT CommandServer::Start(
const std::string& path,
std::function<HRESULT(const IpcMessage& message)> callback,
std::function<HRESULT(const IpcMessage& message)> validateMessageCallback)
std::function<HRESULT(const IpcMessage& message)> validateMessageCallback,
std::function<HRESULT(unsigned short commandSet, bool& unmanagedOnly)> unmanagedOnlyCallback)
{
if (_shutdown.load())
{
Expand All @@ -35,10 +36,12 @@ HRESULT CommandServer::Start(

_callback = callback;
_validateMessageCallback = validateMessageCallback;
_unmanagedOnlyCallback = unmanagedOnlyCallback;

IfFailLogRet_(_logger, _server.Bind(path));
_listeningThread = std::thread(&CommandServer::ListeningThread, this);
_clientThread = std::thread(&CommandServer::ClientProcessingThread, this);
_unmanagedOnlyThread = std::thread(&CommandServer::UnmanagedOnlyProcessingThread, this);
return S_OK;
}

Expand All @@ -48,10 +51,12 @@ void CommandServer::Shutdown()
if (_shutdown.compare_exchange_strong(shutdown, true))
{
_clientQueue.Complete();
_unmanagedOnlyQueue.Complete();
_server.Shutdown();

_listeningThread.join();
_clientThread.join();
_unmanagedOnlyThread.join();
}
}

Expand Down Expand Up @@ -110,7 +115,15 @@ void CommandServer::ListeningThread()

if (doEnqueueMessage)
{
_clientQueue.Enqueue(message);
bool unmanagedOnly = false;
if (SUCCEEDED(_unmanagedOnlyCallback(message.CommandSet, unmanagedOnly)) && unmanagedOnly)
{
_unmanagedOnlyQueue.Enqueue(message);
}
else
{
_clientQueue.Enqueue(message);
}
}
}
}
Expand All @@ -134,6 +147,36 @@ void CommandServer::ClientProcessingThread()
//We are complete, discard all messages
break;
}

// DispatchMessage in the callback serializes all callbacks.
hr = _callback(message);
if (hr != S_OK)
{
_logger->Log(LogLevel::Warning, _LS("IpcMessage callback failed: 0x%08x"), hr);
}
}
}

void CommandServer::UnmanagedOnlyProcessingThread()
{
HRESULT hr = _profilerInfo->InitializeCurrentThread();

if (FAILED(hr))
{
_logger->Log(LogLevel::Error, _LS("Unable to initialize thread: 0x%08x"), hr);
return;
}

while (true)
{
IpcMessage message;
hr = _unmanagedOnlyQueue.BlockingDequeue(message);
if (hr != S_OK)
{
// We are complete, discard all messages
break;
}

hr = _callback(message);
if (hr != S_OK)
{
Expand Down
11 changes: 10 additions & 1 deletion src/Profilers/MonitorProfiler/Communication/CommandServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,34 @@ class CommandServer final
HRESULT Start(
const std::string& path,
std::function<HRESULT (const IpcMessage& message)> callback,
std::function<HRESULT (const IpcMessage& message)> validateMessageCallback);
std::function<HRESULT (const IpcMessage& message)> validateMessageCallback,
std::function<HRESULT (unsigned short commandSet, bool& unmanagedOnly)> unmanagedOnlyCallback);
void Shutdown();

private:
void ListeningThread();
void ClientProcessingThread();
void UnmanagedOnlyProcessingThread();

std::atomic_bool _shutdown;

std::function<HRESULT(const IpcMessage& message)> _callback;
std::function<HRESULT(const IpcMessage& message)> _validateMessageCallback;
std::function<HRESULT(unsigned short commandSet, bool& unmanagedOnly)> _unmanagedOnlyCallback;

IpcCommServer _server;

// We allocates two queues and two threads to process messages.
// UnmanagedOnlyQueue is dedicated to ICorProfiler api calls that cannot be called on threads that have previously invoked managed code, such as StackSnapshot.
// Other command sets such as StartupHook call managed code and therefore interfere with StackSnapshot calls.
BlockingQueue<IpcMessage> _clientQueue;
BlockingQueue<IpcMessage> _unmanagedOnlyQueue;

std::shared_ptr<ILogger> _logger;

std::thread _listeningThread;
std::thread _clientThread;
std::thread _unmanagedOnlyThread;

ComPtr<ICorProfilerInfo12> _profilerInfo;
};
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ bool MessageCallbackManager::TryRegister(unsigned short commandSet, ManagedMessa
return TryRegister(commandSet, [pCallback](const IpcMessage& message)-> HRESULT
{
return pCallback(message.Command, message.Payload.data(), message.Payload.size());
});
}, false);
}

bool MessageCallbackManager::TryRegister(unsigned short commandSet, std::function<HRESULT (const IpcMessage& message)> callback)
bool MessageCallbackManager::TryRegister(unsigned short commandSet, std::function<HRESULT (const IpcMessage& message)> callback, bool unmanagedOnly)
{
std::lock_guard<std::mutex> dispatchLock(m_dispatchMutex);
std::lock_guard<std::mutex> lookupLock(m_lookupMutex);
Expand All @@ -30,7 +30,7 @@ bool MessageCallbackManager::TryRegister(unsigned short commandSet, std::functio
return false;
}

m_callbacks[commandSet] = callback;
m_callbacks[commandSet] = CallbackInfo(unmanagedOnly, callback);
return true;
}

Expand Down Expand Up @@ -61,9 +61,22 @@ bool MessageCallbackManager::TryGetCallback(unsigned short commandSet, std::func
auto const& it = m_callbacks.find(commandSet);
if (it != m_callbacks.end())
{
callback = it->second;
callback = it->second.Callback;
return true;
}

return false;
}

HRESULT MessageCallbackManager::UnmanagedOnly(unsigned short commandSet, bool& unmanagedOnly)
{
std::lock_guard<std::mutex> lookupLock(m_lookupMutex);

auto const& it = m_callbacks.find(commandSet);
if (it != m_callbacks.end())
{
unmanagedOnly = it->second.UnmanagedOnly;
return S_OK;
}
return E_FAIL;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,33 @@

typedef HRESULT (STDMETHODCALLTYPE *ManagedMessageCallback)(UINT16, const BYTE*, UINT64);

struct CallbackInfo
{
CallbackInfo() = default;
CallbackInfo(bool unmanagedOnly, std::function<HRESULT (const IpcMessage& message)> callback)
: UnmanagedOnly(unmanagedOnly), Callback(callback)
{
}

bool UnmanagedOnly = false;
std::function<HRESULT (const IpcMessage& message)> Callback;
};

class MessageCallbackManager
{
public:
HRESULT DispatchMessage(const IpcMessage& message);
bool IsRegistered(unsigned short commandSet);
bool TryRegister(unsigned short commandSet, std::function<HRESULT (const IpcMessage& message)> callback);

// Some callbacks, such as the profiler, must run on their own thread due to ICorProfiler API restrictions.
// Setting unmanagedOnly to true will queue the work from the command set to a separate thread.
bool TryRegister(unsigned short commandSet, std::function<HRESULT (const IpcMessage& message)> callback, bool unmanagedOnly);
bool TryRegister(unsigned short commandSet, ManagedMessageCallback pCallback);
void Unregister(unsigned short commandSet);
HRESULT UnmanagedOnly(unsigned short commandSet, bool& unmanagedOnly);
private:
bool TryGetCallback(unsigned short commandSet, std::function<HRESULT (const IpcMessage& message)>& callback);
std::unordered_map<unsigned short, std::function<HRESULT (const IpcMessage& message)>> m_callbacks;
std::unordered_map<unsigned short, CallbackInfo> m_callbacks;
//
// Ideally we would use a single std::shared_mutex instead, but we are targeting C++11 without
// an easy way to upgrade to C++17 at this time, so we use two separate mutexes instead to
Expand Down
5 changes: 3 additions & 2 deletions src/Profilers/MonitorProfiler/MainProfiler/MainProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ HRESULT MainProfiler::InitializeCommandServer()
_commandServer = std::unique_ptr<CommandServer>(new CommandServer(m_pLogger, m_pCorProfilerInfo));
tstring socketPath = sharedPath + separator + instanceId + _T(".sock");

if (!g_MessageCallbacks.TryRegister(static_cast<unsigned short>(CommandSet::Profiler), [this](const IpcMessage& message)-> HRESULT { return this->ProfilerCommandSetCallback(message); }))
if (!g_MessageCallbacks.TryRegister(static_cast<unsigned short>(CommandSet::Profiler), [this](const IpcMessage& message)-> HRESULT { return this->ProfilerCommandSetCallback(message); }, true))
{
m_pLogger->Log(LogLevel::Error, _LS("Unable to register Profiler CommandSet callback."));
return E_FAIL;
Expand All @@ -250,7 +250,8 @@ HRESULT MainProfiler::InitializeCommandServer()
hr = _commandServer->Start(
to_string(socketPath),
[this](const IpcMessage& message)-> HRESULT { return this->MessageCallback(message); },
[this](const IpcMessage& message)-> HRESULT { return this->ValidateMessage(message); });
[this](const IpcMessage& message)-> HRESULT { return this->ValidateMessage(message); },
[](unsigned short commandSet, bool& unmanagedOnly)-> HRESULT { return g_MessageCallbacks.UnmanagedOnly(commandSet, unmanagedOnly);});
if (FAILED(hr))
{
g_MessageCallbacks.Unregister(static_cast<unsigned short>(CommandSet::Profiler));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ HRESULT ProbeInstrumentation::InitBackgroundService()
m_probeManagementThread = thread(&ProbeInstrumentation::WorkerThread, this);
//
// Create a dedicated thread for managed callbacks.
// Performing the callbacks will taint the calling thread preventing it
// Performing the callbacks will prevent the calling thread
// from using certain ICorProfiler APIs marked as unsafe.
// Those functions will fail with CORPROF_E_UNSUPPORTED_CALL_SEQUENCE.
//
Expand Down

0 comments on commit 403f362

Please sign in to comment.