-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathpointer.hpp
330 lines (291 loc) · 9.25 KB
/
pointer.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/pointer_base.hpp> // cute::iter_adaptor
#include <cute/pointer_sparse.hpp>
#include <cute/container/array_subbyte.hpp> // cute::subbyte_iterator
#include <cute/numeric/integral_constant.hpp> // cute::true_type, cute::false_type
#include <cute/numeric/numeric_types.hpp> // sizeof_bits
namespace cute
{
//
// recast_ptr<T> -- Create an iterator over values of type T.
// For most types this will simply be T*, but certain types require more care.
// Subbyte Types: uint2_t, uint4_t, etc
// Requires construction of a subbyte_iterator<T> in order to properly
// resolve each element in byte-addressed memory.
// Sparse Types: sparse_elem<int S, class T>
// A type that holds one physical element meant to represent S number of logical elements.
// Requires construction of a sparse_ptr that emulates access to the S logical elements.
//
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void* ptr)
{
if constexpr (is_sparse<NewT>::value) {
constexpr int sparsity = NewT::sparsity;
NewT* p = reinterpret_cast<NewT*>(ptr);
return make_sparse_ptr<sparsity>(p);
} else
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT>(ptr);
} else {
return reinterpret_cast<NewT*>(ptr);
}
CUTE_GCC_UNREACHABLE;
}
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(void const* ptr)
{
if constexpr (is_sparse<NewT>::value) {
constexpr int sparsity = NewT::sparsity;
NewT const* p = reinterpret_cast<NewT const*>(ptr);
return make_sparse_ptr<sparsity>(p);
} else
if constexpr (cute::is_subbyte_v<NewT>) {
return subbyte_iterator<NewT const>(ptr);
} else {
return reinterpret_cast<NewT const*>(ptr);
}
CUTE_GCC_UNREACHABLE;
}
// Disambiguate nullptr
template <class NewT>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(decltype(nullptr)) { // nullptr_t
return recast_ptr<NewT>(static_cast<NewT*>(nullptr));
}
//
// gmem_ptr
//
template <class P>
struct gmem_ptr : iter_adaptor<P, gmem_ptr<P>> {
using iter_adaptor<P, gmem_ptr<P>>::iter_adaptor;
};
template <class T, class = void>
struct is_gmem : false_type {};
template <class P> // Found the gmem
struct is_gmem<gmem_ptr<P>> : true_type {};
template <class P> // Recurse on ::iterator, if possible
struct is_gmem<P, void_t<typename P::iterator>> : is_gmem<typename P::iterator> {};
template <class P>
constexpr bool is_gmem_v = is_gmem<P>::value;
// Idempotent gmem tag on an iterator
template <class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(Iterator iter) {
if constexpr (is_gmem<Iterator>::value) {
return iter;
} else {
return gmem_ptr<Iterator>{iter};
}
CUTE_GCC_UNREACHABLE;
}
// Explicitly typed construction from a raw pointer
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(void* ptr) {
return make_gmem_ptr(recast_ptr<T>(ptr));
}
// Explicitly typed construction from a raw pointer
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(void const* ptr) {
return make_gmem_ptr(recast_ptr<T const>(ptr));
}
// nullptr_t overload for make_gmem_ptr<float>(nullptr) disambiguation
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_gmem_ptr(decltype(nullptr)) { // nullptr_t
return make_gmem_ptr(recast_ptr<T>(nullptr));
}
// The gmem tag is invariant over type-recast
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(gmem_ptr<P> const& ptr) {
return make_gmem_ptr(recast_ptr<NewT>(ptr.get()));
}
//
// smem_ptr
//
template <class P>
struct smem_ptr : iter_adaptor<P, smem_ptr<P>> {
using iter_adaptor<P, smem_ptr<P>>::iter_adaptor;
};
template <class T, class = void>
struct is_smem : false_type {};
template <class P> // Found the smem
struct is_smem<smem_ptr<P>> : true_type {};
template <class P> // Recurse on ::iterator, if possible
struct is_smem<P, void_t<typename P::iterator>> : is_smem<typename P::iterator> {};
template <class P>
constexpr bool is_smem_v = is_smem<P>::value;
// Idempotent smem tag on an iterator
template <class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(Iterator iter) {
if constexpr (is_smem<Iterator>::value) {
return iter;
} else {
return smem_ptr<Iterator>{iter};
}
CUTE_GCC_UNREACHABLE;
}
// Make a smem swizzle pointer, common operation
template <class Iterator, class Swizzle>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(Iterator ptr, Swizzle sw)
{
return make_swizzle_ptr(make_smem_ptr(ptr), sw);
}
// Explicitly typed construction from a raw pointer
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(void* ptr) {
return make_smem_ptr(recast_ptr<T>(ptr));
}
// Explicitly typed construction from a raw pointer
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(void const* ptr) {
return make_smem_ptr(recast_ptr<T const>(ptr));
}
// nullptr_t overload for make_smem_ptr<float>(nullptr) disambiguation
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(decltype(nullptr)) { // nullptr_t
return make_smem_ptr(recast_ptr<T>(nullptr));
}
// The smem tag is invariant over type-recast
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(smem_ptr<P> const& ptr) {
return make_smem_ptr(recast_ptr<NewT>(ptr.get()));
}
//
// rmem_ptr
//
template <class P>
struct rmem_ptr : iter_adaptor<P, rmem_ptr<P>> {
using iter_adaptor<P, rmem_ptr<P>>::iter_adaptor;
};
// Anything that is not gmem or smem is rmem
template <class T, class = void>
struct is_rmem : bool_constant<not (is_gmem<T>::value || is_smem<T>::value)> {};
template <class P>
struct is_rmem<rmem_ptr<P>> : true_type {};
template <class P>
constexpr bool is_rmem_v = is_rmem<P>::value;
// Idempotent rmem tag on an iterator
template <class Iterator>
CUTE_HOST_DEVICE constexpr
auto
make_rmem_ptr(Iterator iter) {
if constexpr (is_rmem<Iterator>::value) {
return iter;
} else {
return rmem_ptr<Iterator>{iter};
}
CUTE_GCC_UNREACHABLE;
}
// Explicitly typed construction from a raw pointer
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_rmem_ptr(void* ptr) {
return make_rmem_ptr(recast_ptr<T>(ptr));
}
// Explicitly typed construction from a raw pointer
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_rmem_ptr(void const* ptr) {
return make_rmem_ptr(recast_ptr<T const>(ptr));
}
// The rmem tag is invariant over type-recast
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr
auto
recast_ptr(rmem_ptr<P> const& ptr) {
return make_rmem_ptr(recast_ptr<NewT>(ptr.get()));
}
//
// Display utilities
//
template <class T>
CUTE_HOST_DEVICE void print(gmem_ptr<T> ptr)
{
printf("gmem_"); print(ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(smem_ptr<T> ptr)
{
printf("smem_"); print(ptr.get());
}
template <class T>
CUTE_HOST_DEVICE void print(rmem_ptr<T> ptr)
{
printf("rmem_"); print(ptr.get());
}
#if !defined(__CUDACC_RTC__)
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> ptr)
{
return os << "gmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr<T> ptr)
{
return os << "smem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
}
template <class T>
CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> ptr)
{
return os << "rmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]";
}
#endif // !defined(__CUDACC_RTC__)
} // end namespace cute