Skip to content

Commit

Permalink
refactor(kernel): 添加一个显存的全局缓存,避免反复的 h2d 拷贝
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 23, 2024
1 parent 20788a3 commit 41036a0
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/04kernel/src/graph.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
#include "kernel/graph.h"

namespace refactor {
struct DataKey {
Arc<hardware::Device> dev;
Arc<kernel::Blob> blob;
bool operator==(const DataKey &) const = default;// since C++20
};
}// namespace refactor

template<>
struct std::hash<refactor::DataKey> {
std::size_t operator()(refactor::DataKey const &s) const noexcept {
auto hd = std::hash<decltype(s.dev)>()(s.dev),
hb = std::hash<decltype(s.blob)>()(s.blob);
return hd ^ (hb << 1);
}
};

namespace refactor::kernel {

Graph::Graph(graph_topo::GraphTopo topology,
Expand Down Expand Up @@ -31,13 +48,19 @@ namespace refactor::kernel {
_internal.edges,
32);

static std::unordered_map<DataKey, Arc<hardware::Device::Blob>> CACHE;

for (auto i : range0_(edges_.size())) {
auto const &edge = _internal.edges[i];
edges_[i].name = edge.name;
if (edge.data) {
auto blob = device->malloc(edge.size);
blob->copyFromHost(edge.data->get<void>());
edges_[i].blob = std::move(blob);
auto it = CACHE.find({device, edge.data});
if (it == CACHE.end()) {
auto blob = device->malloc(edge.size);
blob->copyFromHost(edge.data->get<void>());
std::tie(it, std::ignore) = CACHE.emplace(DataKey{device, edge.data}, std::move(blob));
}
edges_[i].blob = it->second;
} else if (edges_[i].stackOffset == SIZE_MAX - 1) {
edges_[i].blob = device->malloc(edge.size);
}
Expand Down

0 comments on commit 41036a0

Please sign in to comment.