forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcopy_traits_sm90_tma.hpp
1326 lines (1160 loc) · 55.1 KB
/
copy_traits_sm90_tma.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cuda.h>
#endif
#include <cute/atom/copy_traits_sm90_tma_swizzle.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/algorithm/prefetch.hpp>
#include <cute/numeric/integral_ratio.hpp>
namespace cute
{
template <class GmemTmaBasisStrides_, class TmaGmemBasis_, class TmaSwizzle_>
struct AuxTmaParams {
using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic
GmemStrides g_stride_;
using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static
static_assert(is_static<TmaGmemBasis>::value);
using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle<B,M,S>
static_assert(is_static<TmaSwizzle>::value);
};
// Utility for unpacking TMA_LOAD arguments into a CopyOp
template <class CopyOp>
struct TMA_LOAD_Unpack
{
template <class... Args,
class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits<CopyOp, Args...> const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
auto src_coord = src.data().coord_;
if constexpr (detail::is_prefetch<CopyOp>) {
return detail::copy_explode<CopyOp>(traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
src_coord, tuple_seq<decltype(src_coord)>{});
} else {
static_assert(is_smem<TD>::value, "SM90_TMA_LOAD requires the destination be shared memory.");
void* dst_ptr = cute::raw_pointer_cast(dst.data());
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr);
#endif
return detail::copy_explode<CopyOp>(traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
make_tuple(dst_ptr), seq<0>{},
src_coord, tuple_seq<decltype(src_coord)>{});
}
}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD ///////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {};
// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM90_TMA_LOAD with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {&tma_desc_, &tma_mbar}};
}
// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {new_tma_desc, &tma_mbar}};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM90_TMA_LOAD before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM90_TMA_LOAD with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_OP>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD arguments
tuple<
TmaDescriptor const*,
uint64_t* // smem mbarrier
> const opargs_;
};
// The prefetch for SM90_TMA_LOAD with tma_desc
template <class NumBitsPerTMA, class... Args>
struct Copy_Traits<SM90_TMA_LOAD::PREFETCH, NumBitsPerTMA, Args...>
: TMA_LOAD_Unpack<SM90_TMA_LOAD::PREFETCH>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD::PREFETCH arguments
tuple<TmaDescriptor const*> const opargs_;
// Construct with any other Traits' TMA Desc
template <class... CopyArgs>
CUTE_HOST_DEVICE
Copy_Traits(Copy_Traits<CopyArgs...> const& traits)
: opargs_({&traits.tma_desc_}) {}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_LOAD_MULTICAST /////////////////////////////
//////////////////////////////////////////////////////////////////////////////
struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {};
// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar
// Use .with(tma_mbar, multicast_mask) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_MULTICAST arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask}};
}
// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask}};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_MULTICAST_OP>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_LOAD_MULTICAST arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint16_t // multicast mask
> const opargs_;
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_STORE //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
// The executable SM90_TMA_STORE with tma_desc
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_STORE arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor
void const* const desc_ptr = &(traits.tma_desc_);
void const* const src_ptr = cute::raw_pointer_cast(src.data());
auto dst_coord = dst.data().coord_;
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr);
#endif
return detail::copy_explode<SM90_TMA_STORE>(make_tuple(desc_ptr, src_ptr), seq<0,1>{},
dst_coord, tuple_seq<decltype(dst_coord)>{});
}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_REDUCE_ADD //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
// The executable SM90_TMA_REDUCE_ADD with tma_desc
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM90_TMA_REDUCE_ADD, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_TMA_REDUCE_ADD arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
template <class Coord, int... Is>
CUTE_HOST_DEVICE constexpr
void
copy_unpack_(void const* const src_ptr,
Coord const& dst_coord, seq<Is...>) const
{
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr);
#endif
SM90_TMA_REDUCE_ADD::copy(&tma_desc_,
src_ptr, get<Is>(dst_coord)...);
}
// This is the copy_unpack dispatch for this Copy_Traits
// Src needs to be a smem tensor
// Dst needs to be a gmem tensor with TmaCoordIterator .data()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_REDUCE_ADD");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor
traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq<decltype(dst.data().coord_)>{});
}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// BULK COPY //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
template <class NumBitsPerTMA, class... OpArgs>
struct Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, OpArgs...>
{
static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0,
"Bulk Copy requires copy vector size align to 16B.");
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_BULK_COPY_G2S arguments
// 0: uint64_t* bulk_load_memory_barrier
cute::tuple<OpArgs...> bulk_load_mbar_;
// Record the memory barrier for the instruction
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, uint64_t*>
with(uint64_t& bulk_mbar) const {
return {{&bulk_mbar}};
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_same<cute::tuple<OpArgs...>, cute::tuple<uint64_t*>>::value,
"Extra arguments not set. Set .with() before use.");
static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_BULK_COPY_G2S");
static_assert(is_smem<TD>::value, "Expected smem dst for SM90_BULK_COPY_G2S");
SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), get<0>(traits.bulk_load_mbar_),
raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8));
}
};
template <class NumBitsPerTMA, class... Args>
struct Copy_Traits<SM90_BULK_COPY_G2S::PREFETCH, NumBitsPerTMA, Args...>
: Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA>
{
template <class... CopyArgs>
CUTE_HOST_DEVICE
Copy_Traits(Copy_Traits<CopyArgs...> const& traits) {}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_gmem<TS>::value, "Expected gmem src for SM90_BULK_PREFETCH");
SM90_BULK_COPY_G2S::PREFETCH::copy(raw_pointer_cast(src.data()), int32_t(NumBitsPerTMA::value / 8));
}
};
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_BULK_COPY_S2G, NumBitsPerTMA>
{
static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0,
"Bulk Copy requires copy vector size align to 16B.");
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr
void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_BULK_COPY_S2G");
static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_BULK_COPY_S2G");
SM90_BULK_COPY_S2G::copy(raw_pointer_cast(src.data()), raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8));
}
};
//
// Placeholder for the bulk copy algorithm's default, auto-vectorizing behavior
//
template <class... OpArgs>
struct Copy_Traits<SM90_BULK_COPY_AUTO, OpArgs...>
{
// Logical thread id to thread idx (one-thread)
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM90_UBULK_COPY arguments
// 0: uint64_t* bulk_load_memory_barrier [if this is a BULK_LOAD_G2S]
cute::tuple<OpArgs...> opargs_;
// Record the memory barrier for the instruction
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_BULK_COPY_AUTO, uint64_t*>
with(uint64_t& bulk_mbar) const {
return {{&bulk_mbar}};
}
};
//
// MAKE_TMA_COPY and related
//
namespace detail {
// Custom version of coalesce that greedily combines modes only up to size-256
// Look at each element and the back of the stack (in order of priority)
// back(NewLayout) get<I>(OldLayout)
// s0:d0 _1:d1 => continue
// _1:d0 s1:d1 => replace_back s1:d1
// s0:d0 s1:s0*d0 => replace_back s0*s1:d0 if s0*s1 <= 256
// s0:d0 s1:d1 => append s1:d1
//
// @pre OldShape and OldStride are flat
template <int I, class OldShape, class OldStride, class NewShape, class NewStride>
CUTE_HOST_DEVICE constexpr
auto
coalesce_256_impl(OldShape const& old_shape, OldStride const& old_stride,
NewShape const& new_shape, NewStride const& new_stride)
{
if constexpr (I == rank_v<OldShape>) {
// Base case, we're done
if constexpr (is_constant<1, NewShape>::value) {
return Layout<_1,_0>{};
} else {
return Layout<NewShape,NewStride>{new_shape,new_stride};
}
} else if constexpr (is_constant<1, decltype(get<I>(old_shape))>::value) {
// shape<I>(layout) == _1, skip it and continue
return coalesce_256_impl<I+1>(old_shape, old_stride, new_shape, new_stride);
} else if constexpr (is_constant<1, NewShape>::value) {
// Replace our shape-1 with anything (Can only happen on input new_shape/new_stride)
return coalesce_256_impl<I+1>(old_shape, old_stride, get<I>(old_shape), get<I>(old_stride));
} else if constexpr (is_constant<true, decltype(back(new_shape) * back(new_stride) == get<I>(old_stride) &&
get<I>(old_shape) * back(new_shape) <= Int<256>{})>::value) {
// Merge modes because the shapes and strides match and the merge is 256 or less
return coalesce_256_impl<I+1>(old_shape, old_stride,
replace_back(new_shape, get<I>(old_shape) * back(new_shape)),
new_stride);
} else {
// Can't replace or merge, so append a new mode
return coalesce_256_impl<I+1>(old_shape, old_stride,
append(new_shape, get<I>(old_shape)),
append(new_stride, get<I>(old_stride)));
}
CUTE_GCC_UNREACHABLE;
}
// Combine all the modes that are possible to combine
// Does not respect the profile of the layout, but does preserve total size
template <class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
coalesce_256(Layout<Shape,Stride> const& layout)
{
auto flat_shape = flatten(layout.shape());
auto flat_stride = flatten(layout.stride());
return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride));
}
template <class TmaInternalType,
class GEngine, class GLayout,
class SShape, class SStride,
class VShape, class VStride>
CUTE_HOST_DEVICE constexpr
auto
construct_tma_gbasis(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
Layout<SShape,SStride> const& slayout, // The layout of SMEM
Layout<VShape,VStride> const& cta_v_map) // smem_idx to hier gmode
{
//
// TMA parameter checking
//
CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)),
"TMA requires CTA_Tile and SLayout top-level shape equivalence.");
#if 0
print("gtensor : "); print(gtensor); print("\n");
print("slayout : "); print(slayout); print("\n");
print("cta_v_map : "); print(cta_v_map); print("\n");
#endif
//
// TMA slayout manipulation
//
// Invert the smem to get the largest contiguous vector in the smem layout
// smem idx -> smem coord
auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout));
// Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode
// smem idx -> gmem mode
auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout));
#if 0
print("inv_smem_layout : "); print(inv_smem_layout); print("\n");
print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n");
#endif
//
// TMA gtensor truncation
//
// Truncate any incompatibilities -- no starting in the middle of gmodes
auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) {
[[maybe_unused]] auto v = basis_value(e);
return not is_constant<1,decltype(v)>{};
});
static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?");
// Keep only the static-1 basis modes into gmem
auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full);
#if 0
print("smem_rank : "); print(smem_rank); print("\n");
print("sidx2gmode : "); print(sidx2gmode); print("\n");
#endif
//
// TMA gtensor manipulation
//
// The smem vector is the same units as gtensor, so compose first and then recast
// tma_val_idx:gmem_strides
auto tile_gstride = recast<TmaInternalType>(gtensor.compose(sidx2gmode)).layout();
// Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType)
// tma_box_shape:gmem_strides
auto tma_gstride = coalesce_256(tile_gstride);
// Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes
auto gbasis = make_identity_layout(shape(gtensor));
auto tile_gbasis_tmp = gbasis.compose(sidx2gmode);
// Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape
// tma_box_shape:gmem_mode
auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp));
// "Coalesce" the tile basis into a compatible shape with the tma_gstride
auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride))));
// Recast the original tensor for shape/stride inspections
Tensor gtensor_T = recast<TmaInternalType>(gtensor);
// Find missing bases that don't appear in tile_gbasis
auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)),
flatten(stride(gbasis)),
[&](auto s, auto d, auto e)
{
if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) {
return cute::tuple<>{}; // If size-1 or stride-0, then don't append
} else {
using E = decltype(e);
auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; });
if constexpr (decltype(has_e)::value) {
return cute::tuple<>{}; // If d was found, then don't append
} else {
return cute::tuple<E>(e); // Else, this is missing so append
}
}
});
// Append the remaining basis modes that contribute to the TMA with size-1
auto tile_gbasis_remaining_shape = repeat<rank(tile_gbasis_remaining_stride)>(Int<1>{});
auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )),
tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride)));
// Group the trailing modes to make this max rank-5 -- TMA rank limitation
// tma_box_shape:gmem_mode
auto tma_gbasis = group<cute::min(rank(tma_gbasis_full),4),-1>(tma_gbasis_full);
#if 0
print("tile_gstride : "); print(tile_gstride); print("\n");
print("tma_gstride : "); print(tma_gstride); print("\n");
print("gbasis : "); print(gbasis); print("\n");
print("tile_gbasis : "); print(tma_gbasis_tile); print("\n");
print("tma_gbasis : "); print(tma_gbasis); print("\n");
#endif
return tma_gbasis;
}
template <class GEngine, class GLayout,
class TmaGmemBasisStride,
class ShapeT, size_t TmaRank>
CUTE_HOST_DEVICE constexpr
void
fill_tma_gmem_shape_stride(Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType
TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s)
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
{
static_assert(is_tuple<TmaGmemBasisStride>::value);
static_assert(is_same<uint32_t, ShapeT>::value || is_same<uint64_t, ShapeT>::value);
using TmaInternalType = typename GEngine::value_type;
constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value;
static_assert(TmaRank >= tma_rank);
auto gmem_shape = shape(gtensor);
auto gmem_stride = stride(gtensor);
// Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides
for_each(make_seq<tma_rank>{}, [&](auto i) {
constexpr int tma_i_rank = decltype(rank<i>(tma_gbasis_stride))::value;
if constexpr (tma_i_rank == 1) {
// Trivial contribution of this gmem mode to this tma mode
auto ej = unwrap(get<i>(tma_gbasis_stride));
gmem_prob_shape[i] = basis_get(ej, gmem_shape);
gmem_prob_stride[i] = basis_get(ej, gmem_stride);
} else {
// Apply a recurrence to each gmem mode that contributes to this tma mode
for_each(get<i>(tma_gbasis_stride), [&](auto ej) {
// Problem shape
uint64_t shape_j = basis_get(ej, gmem_shape);
// Problem stride (in bytes)
uint64_t stride_j = basis_get(ej, gmem_stride);
uint64_t old_stride = gmem_prob_stride[i];
gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j);
if (gmem_prob_stride[i] != 0) {
// Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1
gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i])
+ (shape_j-1) * (stride_j / gmem_prob_stride[i])
+ 1;
} else {
gmem_prob_shape[i] = shape_j;
}
});
}
});
}
// Overload for an existing Copy_Traits
template <class GEngine, class GLayout,
class Op, class Bits, class Aux,
class ShapeT, size_t TmaRank>
CUTE_HOST_DEVICE constexpr
void
fill_tma_gmem_shape_stride(Copy_Traits<Op,Bits,Aux> const& tma_traits,
Tensor<GEngine,GLayout> const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType
cute::array<ShapeT, TmaRank> & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t
cute::array<uint64_t, TmaRank> & gmem_prob_stride) // Tma Strides
{
return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}),
gmem_prob_shape, gmem_prob_stride);
}
// Use a sidx2gmode to read through the GMEM tensor
// and construct a TMA Descriptor for the resulting instruction
// At the same time, construct the Tma Tensor's Stride to generate
// the TMA coordinates that the instruction consumes.
//
template <class TmaInternalType,
class GEngine, class GLayout,
class TShape, class TStride,
int B, int M, int S>
CUTE_HOST_RTC
auto
make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The original GMEM Tensor
Layout<TShape,TStride> const& tma_gbasis, // TMA mode -> GMEM mode mapping
Swizzle<B,M,S> const& swizzle, // Swizzle fn on smem_idx
uint32_t num_multicast) // The number of CTAs in multicasting
{
//
// TMA desc creation
//
constexpr int tma_dim = decltype(rank(tma_gbasis))::value;
//
// TMA gmem desc info
//
// Recast the original tensor for shape/stride inspections
Tensor gtensor_T = recast<TmaInternalType>(gtensor);
void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data());
auto gmem_layout = gtensor_T.layout();
cute::array<uint64_t, 5> gmem_prob_shape = {1,1,1,1,1};
cute::array<uint64_t, 5> gmem_prob_stride = {0,0,0,0,0};
fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride);
assert((reinterpret_cast<uint64_t>(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned
assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32
assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1
assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32
// TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element).
assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem");
// convert strides to byte strides
for(uint64_t& stride : gmem_prob_stride) {
stride = (stride * sizeof_bits_v<TmaInternalType>) / 8;
}
// Assert the byte strides. Tma Descriptor uses byte strides
assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40
assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b)
//
// TMA smem desc info
//
cute::array<uint32_t, 5> smem_box_shape = {1,1,1,1,1};
cute::array<uint32_t, 5> smem_box_stride = {1,1,1,1,1};
// The smem box is simply given by the sizes of the modes in tma_gbasis
for_each(make_seq<tma_dim>{}, [&](auto i) {
smem_box_shape[i] *= size<i>(tma_gbasis);
});
// Finally, truncate the tma box by the num_multicast
for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) {
assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0);
uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]);
smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast);
multicast = new_mult;
}
assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1
assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
assert(smem_box_shape[1] >= (uint32_t(1))); // Size must be min 1
assert(smem_box_shape[1] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
assert(smem_box_shape[2] >= (uint32_t(1))); // Size must be min 1
assert(smem_box_shape[2] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
assert(smem_box_shape[3] >= (uint32_t(1))); // Size must be min 1
assert(smem_box_shape[3] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
assert(smem_box_shape[4] >= (uint32_t(1))); // Size must be min 1
assert(smem_box_shape[4] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256
assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8
assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 = 8
assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 = 8
assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 = 8
assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1
assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8
//
// Construct the descriptor
//
TmaDescriptor tma_desc{};
//
// TMA general info
//
#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)
CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType<TmaInternalType>();
CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
// TMA smem swizzle type
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle));
CUresult result = cuTensorMapEncodeTiled(
&tma_desc,
tma_format,
tma_dim,
gmem_address,
gmem_prob_shape.data(),
gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1
smem_box_shape.data(),
smem_box_stride.data(),
tma_interleave,
smem_swizzle,
tma_l2Promotion,
tma_oobFill);
if (result != CUDA_SUCCESS) {
std::cerr << "TMA Desc Addr: " << &tma_desc
<< "\nformat " << tma_format
<< "\ndim " << tma_dim
<< "\ngmem_address " << gmem_address
<< "\nglobalDim " << gmem_prob_shape
<< "\nglobalStrides " << gmem_prob_stride
<< "\nboxDim " << smem_box_shape
<< "\nelementStrides " << smem_box_stride
<< "\ninterleave " << tma_interleave
<< "\nswizzle " << smem_swizzle
<< "\nl2Promotion " << tma_l2Promotion
<< "\noobFill " << tma_oobFill << std::endl;
std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl;
assert(false);
}
#endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)
auto recast_ratio = cute::trait_ratio(sizeof_bits<typename GEngine::value_type>{},
sizeof_bits< TmaInternalType>{});
auto gbasis = make_basis_like(shape(gtensor));
// Finally, get the inverse permutation of the E<i> bases for the mocked gmem stride
auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) {
auto si = basis_get(ei, shape(gmem_layout));
auto di = basis_get(ei, stride(gmem_layout));
if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) {
return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA
} else {
auto tma_gmem_basis_stride = stride(tma_gbasis);
// Find j such that E<i> is in stride<j>(tma_gbasis)
using EI = decltype(ei);
[[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); });
if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) {
return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA
} else
if constexpr (decltype(j == Int<0>{})::value) {
auto scale = recast_ratio * basis_get(ei, stride(gtensor));
return E<j>{} * scale; // Return TMA Coord basis -- with a recast scale factor
} else
if constexpr (decltype(rank<j>(tma_gmem_basis_stride) == Int<1>{})::value) {
return E<j>{}; // Return TMA Coord basis -- known scale of Int<1>{}
} else {
int32_t scale = ceil_div(int32_t(di * sizeof_bits_v<TmaInternalType> / cute::max(gmem_prob_stride[j], uint64_t{16})), 8);
return E<j>{} * scale; // Return TMA Coord basis -- with a dynamic scale factor
}
}
});
#if 0
print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n");
#endif
using AuxParams = AuxTmaParams<decltype(gmem_tma_basis_stride),
decltype(tma_gbasis),
decltype(swizzle)>;
return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride});
}
template <class TmaInternalType,
class CopyOp,
class GEngine, class GLayout,
class SLayout,
class VShape, class VStride>
CUTE_HOST_RTC
auto
make_tma_copy_atom(CopyOp,