forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcuda.h
194 lines (161 loc) · 5.54 KB
/
cuda.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/custom_class.h>
namespace torch {
namespace jit {
class CUDAEvent;
// This class is a wrapper around c10::cuda::CUDAStream.
// It is needed because TorchBind does not support all of the argument types
// for c10::cuda::CUDAStream. For more details, please refer to
// c10/cuda/CUDAStream.h.
class CUDAStream final : public CustomClassHolder {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CUDAStream(
c10::optional<c10::Device> device = c10::nullopt,
int64_t priority = 0) {
constexpr int64_t PRIORITY_INDEX = 0;
c10::DeviceIndex device_index =
device.has_value() ? device->index() : c10::cuda::current_device();
stream_ = std::make_unique<c10::cuda::CUDAStream>(
c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device_index));
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CUDAStream(c10::cuda::CUDAStream s) {
stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
}
bool query() {
return stream_->query();
}
c10::intrusive_ptr<CUDAEvent> recordEvent(
c10::intrusive_ptr<CUDAEvent> event);
void synchronize() {
stream_->synchronize();
}
void waitEvent(c10::intrusive_ptr<CUDAEvent> event);
void waitStream(c10::intrusive_ptr<CUDAStream> stream);
/// Get the CUDA device index that this stream is associated with.
int64_t device_index() const {
return stream_->device_index();
}
/// Get the full Device that this stream is associated with. The Device
/// is guaranteed to be a CUDA device.
c10::Device device() const {
return stream_->device();
}
/// Return the stream ID corresponding to this particular stream.
int64_t id() const {
return stream_->id();
}
/// Pack a CUDAStream to uint64_t representation.
/// The CUDAStream can be unpacked using unpack(). The format of
/// the uint64_t is unspecified and may be changed.
int64_t pack() const {
return stream_->pack();
}
private:
std::unique_ptr<c10::cuda::CUDAStream> stream_;
friend class CUDAEvent;
};
// This class is a wrapper around at::cuda::CUDAStream.
// It is needed because TorchBind does not support all of the argument types
// for at::cuda::CUDAEvent. For more details, please refer to
// aten/src/ATen/cuda/CUDAEvent.h.
class CUDAEvent final : public CustomClassHolder {
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CUDAEvent(
bool enable_timing = false,
bool blocking = false,
bool interprocess = false) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int flags = cudaEventDisableTiming;
if (enable_timing) {
flags = cudaEventDefault;
}
if (blocking) {
flags |= cudaEventBlockingSync;
}
if (interprocess) {
TORCH_CHECK(!enable_timing);
flags |= cudaEventInterprocess;
}
event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
}
double elapsedTime(c10::intrusive_ptr<CUDAEvent> end) {
return event_->elapsed_time(*end->event_);
}
std::string ipcHandle() {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
cudaIpcEventHandle_t handle;
event_->ipc_handle(&handle);
std::string str_handle((const char*)&handle, sizeof(handle));
return str_handle;
}
bool query() {
return event_->query();
}
void record(c10::intrusive_ptr<CUDAStream> stream);
void synchronize() {
event_->synchronize();
}
void wait(c10::intrusive_ptr<CUDAStream> stream);
private:
void recordInternal(CUDAStream* stream);
std::unique_ptr<at::cuda::CUDAEvent> event_;
friend class CUDAStream;
};
c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
c10::intrusive_ptr<CUDAEvent> event) {
if (!event) {
event = c10::make_intrusive<CUDAEvent>();
}
event->recordInternal(this);
return event;
}
void CUDAStream::waitEvent(c10::intrusive_ptr<CUDAEvent> event) {
event->event_->block(*stream_);
}
void CUDAStream::waitStream(c10::intrusive_ptr<CUDAStream> stream) {
auto ev = c10::make_intrusive<CUDAEvent>();
stream->recordEvent(ev);
waitEvent(ev);
}
void CUDAEvent::record(c10::intrusive_ptr<CUDAStream> stream) {
event_->record(*stream->stream_);
}
void CUDAEvent::recordInternal(CUDAStream* stream) {
event_->record(*stream->stream_);
}
void CUDAEvent::wait(c10::intrusive_ptr<CUDAStream> stream) {
event_->block(*stream->stream_);
}
TORCH_LIBRARY(cuda, m) {
auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
torch::init<c10::optional<c10::Device>, int64_t>(),
"",
{torch::arg("device") = c10::nullopt, torch::arg("priority") = 0});
auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
torch::init<bool, bool, bool>(),
"",
{torch::arg("enable_timing") = false,
torch::arg("blocking") = false,
torch::arg("interprocess") = false});
stream_class.def("query", &CUDAStream::query)
.def("record_event", &CUDAStream::recordEvent)
.def("synchronize", &CUDAStream::synchronize)
.def("wait_event", &CUDAStream::waitEvent)
.def("wait_stream", &CUDAStream::waitStream)
.def("device_index", &CUDAStream::device_index)
.def_property("device", &CUDAStream::device)
.def("pack", &CUDAStream::pack)
.def("id", &CUDAStream::id);
event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
.def("query", &CUDAEvent::query)
.def("record", &CUDAEvent::record)
.def("synchronize", &CUDAEvent::synchronize)
.def("wait", &CUDAEvent::wait);
};
} // namespace jit
} // namespace torch