-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfloat8.h
1278 lines (1058 loc) · 36.6 KB
/
float8.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines a class for using IEEE half-precision floating-point types in host or
device code.
*/
/*
Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain
existing integrations of CUTLASS require C++11 host compilers.
Until this requirement can be lifted, certain headers with this annotation are required
to be remain consistent with C++11 syntax.
C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`.
*/
#pragma once
// FP8 types are available starting CUDA 11.8+
#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
#define CUDA_FP8_ENABLED 1
#endif
#if defined(__CUDA_ARCH__)
# if (__CUDA_ARCH__ >= 900)
# if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
# define CUDA_PTX_FP8_CVT_ENABLED 1
# endif // (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))
# elif (__CUDA_ARCH__ == 890)
# if (__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))
# define CUDA_PTX_FP8_CVT_ENABLED 1
# endif // (__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))
# endif // (__CUDA_ARCH__ >= 900)
#endif // defined(__CUDA_ARCH__)
#ifdef __GNUC__
// Ignore checks on reinterpret-casts that are being used for bitcasts.
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDACC_RTC__)
#include "cutlass/floating_point_nvrtc.h"
#else
//
// Standard Library headers belong here to avoid conflicts with NVRTC.
//
#include <cmath>
#include <limits>
#include <cstdint>
#include <cstring>
#endif
#ifdef CUDA_FP8_ENABLED
#include <cuda_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include "cutlass/cutlass.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// FP8 Has 2 encodings possible : E4M3 and E5M2
//
// E4M3 : 7 | 6 5 4 3 | 2 1 0
// E5M2 : 7 | 6 5 4 3 2 | 1 0
//
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class FloatEncoding {
E4M3,
E5M2
};
template<FloatEncoding T>
struct alignas(1) float8_base {
static constexpr bool IS_E4M3 = (T == FloatEncoding::E4M3);
static constexpr bool IS_E5M2 = (T == FloatEncoding::E5M2);
// Number of Bits representing mantissa and exponents
static constexpr int FP32_NUM_BITS = 32;
static constexpr int FP32_NUM_EXPONENT_BITS = 8;
static constexpr int FP32_NUM_MANTISSA_BITS = 23;
static constexpr uint32_t FP32_NAN = 0x7fffffff;
static constexpr uint32_t FP32_INFINITY_MASK = 0x7f800000;
static constexpr int FP32_MAX_EXPONENT = 127;
static constexpr int FP32_MIN_EXPONENT = -126;
static constexpr int FP32_EXPONENT_BIAS = 127;
static constexpr int FP16_NUM_BITS = 16;
static constexpr int FP16_NUM_EXPONENT_BITS = 5;
static constexpr int FP16_NUM_MANTISSA_BITS = 10;
static constexpr uint16_t FP16_NAN = 0x7fff;
static constexpr uint16_t FP16_INFINITY_MASK = 0x7c00;
static constexpr int FP16_MAX_EXPONENT = 15;
static constexpr int FP16_MIN_EXPONENT = -14;
static constexpr int FP16_EXPONENT_BIAS = 15;
static constexpr int FP8_NUM_BITS = 8;
static constexpr int FP8_NUM_EXPONENT_BITS = IS_E4M3 ? 4 : 5;
static constexpr int FP8_NUM_MANTISSA_BITS = IS_E4M3 ? 3 : 2;
static constexpr uint8_t FP8_NAN = 0x7f; // Also F8_INF
static constexpr uint8_t FP8_INFINITY_MASK = IS_E4M3 ? 0x78 : 0x7c;
static constexpr int FP8_MAX_EXPONENT = IS_E4M3 ? 7 : 15;
static constexpr int FP8_MIN_EXPONENT = IS_E4M3 ? -6 : -14;
static constexpr int FP8_EXPONENT_BIAS = IS_E4M3 ? 7 : 15;
static constexpr uint8_t FP8_EXPONENT_MASK = (1 << FP8_NUM_EXPONENT_BITS) - 1;
static constexpr uint8_t FP8_MANTISSA_MASK = (1 << FP8_NUM_MANTISSA_BITS) - 1;
static constexpr uint8_t FP8_MAX_FLT = (IS_E4M3 ? 0x7e : 0x7b);
// 256 in float
static constexpr uint32_t FP8_SAT_VAL_FP32 = 0x43800000;
//
// Data members
//
/// Data container
uint8_t storage;
/// Ctors.
CUTLASS_HOST_DEVICE
float8_base() : storage(0) { }
/// Is finite implementation
CUTLASS_HOST_DEVICE
static bool isfinite(float flt) {
uint32_t s;
#if defined(__CUDA_ARCH__)
s = reinterpret_cast<uint32_t const &>(flt);
#else
std::memcpy(&s, &flt, sizeof(s));
#endif
return (s & 0x7f800000) < 0x7f800000;
}
/// Is NaN implementation
CUTLASS_HOST_DEVICE
static bool isnan(float flt) {
uint32_t s;
#if defined(__CUDA_ARCH__)
s = reinterpret_cast<uint32_t const &>(flt);
#else
std::memcpy(&s, &flt, sizeof(s));
#endif
return (s & 0x7fffffff) > 0x7f800000;
}
/// Is infinite implementation
CUTLASS_HOST_DEVICE
static bool isinf(float flt) {
uint32_t s;
#if defined(__CUDA_ARCH__)
s = reinterpret_cast<uint32_t const &>(flt);
#else
std::memcpy(&s, &flt, sizeof(s));
#endif
// Sign = 0 for +inf, 1 for -inf
// Exponent = all ones
// Mantissa = all zeros
return (s == 0x7f800000) || (s == 0xff800000);
}
/// FP32 -> FP8 conversion - rounds to nearest even
CUTLASS_HOST_DEVICE
static uint8_t convert_float_to_fp8(float const& flt) {
// software implementation rounds toward nearest even
uint32_t s;
#if defined(__CUDA_ARCH__)
s = reinterpret_cast<uint32_t const &>(flt);
#else
std::memcpy(&s, &flt, sizeof(s));
#endif
// Extract the bits in the FP32 type
uint8_t sign = uint8_t((s >> 24 & 0x80));
int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS;
int mantissa = s & 0x7fffff;
uint8_t u = 0;
uint8_t const kF8_NaN = 0x7f;
// NaN => NaN
if (isnan(flt)) {
return kF8_NaN;
}
// Inf => MAX_FLT (satfinite)
if (isinf(flt)) {
return sign | FP8_MAX_FLT;
}
// Special handling
if (exp == -128) {
// int8 range is from -128 to 127
// So 255(inf) - 127(bias) = 128 - will show up as -128
// satfinite
return (sign | FP8_MAX_FLT);
}
int sticky_bit = 0;
bool skip_sign = false;
bool may_be_nan = false;
if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) {
// normal fp32 to normal fp8
exp = exp + FP8_EXPONENT_BIAS;
u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS);
u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)));
} else if(exp < FP8_MIN_EXPONENT) {
// normal single-precision to subnormal float8-precision representation
int rshift = (FP8_MIN_EXPONENT - exp);
if (rshift < FP32_NUM_BITS) {
mantissa |= (1 << FP32_NUM_MANTISSA_BITS);
sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0);
mantissa = (mantissa >> rshift);
u = (uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS- FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK);
} else {
mantissa = 0;
u = 0;
}
// Exponent > FP8_MAX_EXPONENT - this is a special case done to match HW
// 0x4380_0000 to 0x43e0_0000 - maps from 256 to 448, and does not saturate / inf.
} else {
if( exp == (FP8_MAX_EXPONENT + 1) ) {
uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS));
if( mantissa_tmp < FP8_MANTISSA_MASK) {
exp = exp + FP8_EXPONENT_BIAS;
u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1));
} else {
// satfinite
return (sign | FP8_MAX_FLT);
}
} else{
// satfinite
return (sign | FP8_MAX_FLT);
}
}
// round to nearest even
int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1);
int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1);
sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0);
if ((round_bit && sticky_bit) || (round_bit && (u & 1))) {
u = uint8_t(u + 1);
if( may_be_nan ) {
skip_sign = true;
}
}
if (u > FP8_MAX_FLT) {
// satfinite
u = (sign | FP8_MAX_FLT);
}
if( ! skip_sign ) {
u |= sign;
}
return u;
}
/// Converts a fp8 value stored as a uint8_t to a float
CUTLASS_HOST_DEVICE
static float convert_fp8_to_float(uint8_t const& x) {
uint32_t constexpr kF32_NaN = 0x7fffffff;
uint8_t const &f8 = x;
uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
uint32_t mantissa = f8 & FP8_MANTISSA_MASK;
unsigned f = (sign << (FP32_NUM_BITS-1));
if (IS_E4M3 && exp == 15 && mantissa == 0x7) {
f = kF32_NaN;
}
else if (exp > 0 && (IS_E4M3 || exp < (FP8_MAX_EXPONENT + FP8_EXPONENT_BIAS + 1))) {
// normal
exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS);
f = f |
(exp << FP32_NUM_MANTISSA_BITS) |
(mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS));
} else if (exp == 0) {
if (mantissa) {
// subnormal
exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS) + 1;
while ((mantissa & (1 << FP8_NUM_MANTISSA_BITS)) == 0) {
mantissa <<= 1;
exp--;
}
mantissa &= FP8_MANTISSA_MASK;
f = f |
(exp << FP32_NUM_MANTISSA_BITS) |
(mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS));
} else {
// sign-preserving zero
}
} else {
if(mantissa == 0){
// Sign-preserving infinity
f = (f | 0x7f800000);
} else {
// Canonical NaN
f = kF32_NaN;
}
}
#if defined(__CUDA_ARCH__)
return reinterpret_cast<float const&>(f);
#else
float flt;
std::memcpy(&flt, &f, sizeof(flt));
return flt;
#endif
}
};
// Forward declaration of float_e5m2_t to define float_e4m3_t <=> float_e5m2_t
// conversions in class float_e4m3_t
struct float_e5m2_t;
///////////////////////////////////////////////////////////////
///
/// floating-point 8 type : E4M3
///
///////////////////////////////////////////////////////////////
struct alignas(1) float_e4m3_t : float8_base<FloatEncoding::E4M3> {
using Base = float8_base<FloatEncoding::E4M3>;
static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT;
//
// Static conversion operators
//
/// Constructs from an uint8_t
CUTLASS_HOST_DEVICE
static float_e4m3_t bitcast(uint8_t x) {
float_e4m3_t f;
f.storage = x;
return f;
}
/// FP32 -> FP8 conversion - rounds to nearest even
CUTLASS_HOST_DEVICE
static float_e4m3_t from_float(float const& flt) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t tmp;
float y = float();
asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt));
return *reinterpret_cast<float_e4m3_t *>(&tmp);
#else
return bitcast(Base::convert_float_to_fp8(flt));
#endif
}
/// FP16 -> E5M2 conversion - rounds to nearest even
CUTLASS_HOST_DEVICE
static float_e4m3_t from_half(half const& flt) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t tmp = 0;
uint32_t bits = reinterpret_cast<uint16_t const &>(flt);
asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits));
return *reinterpret_cast<float_e4m3_t *>(&tmp);
#else
return bitcast(Base::convert_float_to_fp8(__half2float(flt)));
#endif
}
// E4M3 -> half
CUTLASS_HOST_DEVICE
static half to_half(float_e4m3_t const& x) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t bits = x.storage;
uint32_t packed;
asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
return reinterpret_cast<half2 const &>(packed).x;
#else
return __float2half(Base::convert_fp8_to_float(x.storage));
#endif
}
// E4M3 -> Float
CUTLASS_HOST_DEVICE
static float to_float(float_e4m3_t const& x) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t bits = x.storage;
uint32_t packed;
asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
return __half2float(reinterpret_cast<half2 const &>(packed).x);
#else
return Base::convert_fp8_to_float(x.storage);
#endif
}
//
// Methods
//
/// Constructor inheritance
using Base::Base;
/// Default constructor
float_e4m3_t() = default;
#ifdef CUDA_FP8_ENABLED
/// Conversion from CUDA's FP8 type
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_fp8_e4m3 x) {
storage = x.__x;
}
#endif
/// Floating point conversion
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(float x) {
storage = from_float(x).storage;
}
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(half x) {
storage = from_half(x).storage;
}
/// Floating point conversion
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(double x): float_e4m3_t(float(x)) {
}
/// Integer conversion
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(int x): float_e4m3_t(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(unsigned x): float_e4m3_t(float(x)) {
}
/// E5M2 conversion. Defined after float_e5m2_t is defined.
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(float_e5m2_t x);
#ifdef CUDA_FP8_ENABLED
/// Assignment from CUDA's FP8 type
CUTLASS_HOST_DEVICE
float_e4m3_t & operator=(__nv_fp8_e4m3 x) {
storage = x.__x;
return *this;
}
#endif
/// Converts to float
CUTLASS_HOST_DEVICE
operator float() const {
return to_float(*this);
}
/// Converts to half
CUTLASS_HOST_DEVICE
operator half() const {
return to_half(*this);
}
/// Converts to float
CUTLASS_HOST_DEVICE
explicit operator double() const {
return double(to_float(*this));
}
/// Converts to int
CUTLASS_HOST_DEVICE
explicit operator int() const {
#if defined(__CUDA_ARCH__)
return __half2int_rn(to_half(*this));
#else
return int(to_float(*this));
#endif
}
/// Casts to bool
CUTLASS_HOST_DEVICE
explicit operator bool() const {
#if defined(__CUDA_ARCH__)
return bool(__half2int_rn(to_half(*this)));
#else
return bool(int(to_float(*this)));
#endif
}
/// Accesses raw internal state
CUTLASS_HOST_DEVICE
uint8_t& raw() {
return storage;
}
/// Accesses raw internal state
CUTLASS_HOST_DEVICE
uint8_t raw() const {
return storage;
}
/// Returns the sign bit
CUTLASS_HOST_DEVICE
bool signbit() const {
return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0);
}
/// Returns the biased exponent
CUTLASS_HOST_DEVICE
int exponent_biased() const {
return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK);
}
/// Returns the unbiased exponent
CUTLASS_HOST_DEVICE
int exponent() const {
return exponent_biased() - 15;
}
/// Returns the mantissa
CUTLASS_HOST_DEVICE
int mantissa() const {
return int(storage & Base::FP8_MANTISSA_MASK);
}
};
///////////////////////////////////////////////////////////////
///
/// floating-point 8 type : E5M2
///
///////////////////////////////////////////////////////////////
struct alignas(1) float_e5m2_t : float8_base<FloatEncoding::E5M2> {
using Base = float8_base<FloatEncoding::E5M2>;
static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT;
//
// Static conversion operators
//
/// Constructs from an uint8_t
CUTLASS_HOST_DEVICE
static float_e5m2_t bitcast(uint8_t x) {
float_e5m2_t f;
f.storage = x;
return f;
}
/// FP32 -> FP8 conversion - rounds to nearest even
CUTLASS_HOST_DEVICE
static float_e5m2_t from_float(float const& flt) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t tmp;
float y = float();
asm volatile("cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt));
return *reinterpret_cast<float_e5m2_t *>(&tmp);
#else
return bitcast(Base::convert_float_to_fp8(flt));
#endif
}
/// FP16 -> E5M2 conversion - rounds to nearest even
CUTLASS_HOST_DEVICE
static float_e5m2_t from_half(half const& flt) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t tmp = 0;
uint32_t bits = reinterpret_cast<uint16_t const &>(flt);
asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits));
return *reinterpret_cast<float_e5m2_t *>(&tmp);
#else
return bitcast(Base::convert_float_to_fp8(__half2float(flt)));
#endif
}
// E5M2 -> half
CUTLASS_HOST_DEVICE
static half to_half(float_e5m2_t const& x) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t bits = x.storage;
uint32_t packed;
asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
return reinterpret_cast<half2 const &>(packed).x;
#else
return __float2half(Base::convert_fp8_to_float(x.storage));
#endif
}
// E5M2 -> Float
CUTLASS_HOST_DEVICE
static float to_float(float_e5m2_t const& x) {
#if defined(CUDA_PTX_FP8_CVT_ENABLED)
uint16_t bits = x.storage;
uint32_t packed;
asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits));
return __half2float(reinterpret_cast<half2 const &>(packed).x);
#else
return Base::convert_fp8_to_float(x.storage);
#endif
}
//
// Methods
//
/// Constructor inheritance
using Base::Base;
/// Default constructor
float_e5m2_t() = default;
#ifdef CUDA_FP8_ENABLED
/// Conversion from CUDA's FP8 type
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_fp8_e5m2 x) {
storage = x.__x;
}
#endif
/// Floating point conversion
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(float x) {
storage = from_float(x).storage;
}
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(half x) {
storage = from_half(x).storage;
}
/// Floating point conversion
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(double x): float_e5m2_t(float(x)) {
}
/// Integer conversion
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(int x): float_e5m2_t(float(x)) {
}
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(unsigned x): float_e5m2_t(float(x)) {
}
/// E4M3 conversion
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(float_e4m3_t x);
#ifdef CUDA_FP8_ENABLED
/// Assignment from CUDA's FP8 type
CUTLASS_HOST_DEVICE
float_e5m2_t & operator=(__nv_fp8_e5m2 x) {
storage = x.__x;
return *this;
}
#endif
/// Converts to float
CUTLASS_HOST_DEVICE
operator float() const {
return to_float(*this);
}
/// Converts to half
CUTLASS_HOST_DEVICE
operator half() const {
return to_half(*this);
}
/// Converts to float
CUTLASS_HOST_DEVICE
explicit operator double() const {
return double(to_float(*this));
}
/// Converts to int
CUTLASS_HOST_DEVICE
explicit operator int() const {
#if defined(__CUDA_ARCH__)
return __half2int_rn(to_half(*this));
#else
return int(to_float(*this));
#endif
}
/// Casts to bool
CUTLASS_HOST_DEVICE
explicit operator bool() const {
#if defined(__CUDA_ARCH__)
return bool(__half2int_rn(to_half(*this)));
#else
return bool(int(to_float(*this)));
#endif
}
/// Accesses raw internal state
CUTLASS_HOST_DEVICE
uint8_t& raw() {
return storage;
}
/// Accesses raw internal state
CUTLASS_HOST_DEVICE
uint8_t raw() const {
return storage;
}
/// Returns the sign bit
CUTLASS_HOST_DEVICE
bool signbit() const {
return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0);
}
/// Returns the biased exponent
CUTLASS_HOST_DEVICE
int exponent_biased() const {
return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK);
}
/// Returns the unbiased exponent
CUTLASS_HOST_DEVICE
int exponent() const {
return exponent_biased() - 15;
}
/// Returns the mantissa
CUTLASS_HOST_DEVICE
int mantissa() const {
return int(storage & Base::FP8_MANTISSA_MASK);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Arithmetic operators
//
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool operator==(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float(lhs) == float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator!=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float(lhs) != float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float(lhs) < float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float(lhs) <= float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float(lhs) > float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float(lhs) >= float(rhs);
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator+(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float_e4m3_t(float(lhs) + float(rhs));
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator-(float_e4m3_t const& lhs) {
return float_e4m3_t(-float(lhs));
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator-(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float_e4m3_t(float(lhs) - float(rhs));
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator*(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float_e4m3_t(float(lhs) * float(rhs));
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator/(float_e4m3_t const& lhs, float_e4m3_t const& rhs) {
return float_e4m3_t(float(lhs) / float(rhs));
}
CUTLASS_HOST_DEVICE
float_e4m3_t& operator+=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
lhs = float_e4m3_t(float(lhs) + float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e4m3_t& operator-=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
lhs = float_e4m3_t(float(lhs) - float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e4m3_t& operator*=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
lhs = float_e4m3_t(float(lhs) * float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e4m3_t& operator/=(float_e4m3_t & lhs, float_e4m3_t const& rhs) {
lhs = float_e4m3_t(float(lhs) / float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e4m3_t& operator++(float_e4m3_t & lhs) {
float tmp(lhs);
++tmp;
lhs = float_e4m3_t(tmp);
return lhs;
}
CUTLASS_HOST_DEVICE
float_e4m3_t& operator--(float_e4m3_t & lhs) {
float tmp(lhs);
--tmp;
lhs = float_e4m3_t(tmp);
return lhs;
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator++(float_e4m3_t & lhs, int) {
float_e4m3_t ret(lhs);
float tmp(lhs);
tmp++;
lhs = float_e4m3_t(tmp);
return ret;
}
CUTLASS_HOST_DEVICE
float_e4m3_t operator--(float_e4m3_t & lhs, int) {
float_e4m3_t ret(lhs);
float tmp(lhs);
tmp--;
lhs = float_e4m3_t(tmp);
return ret;
}
CUTLASS_HOST_DEVICE
bool operator==(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float(lhs) == float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator!=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float(lhs) != float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float(lhs) < float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float(lhs) <= float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float(lhs) > float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float(lhs) >= float(rhs);
}
CUTLASS_HOST_DEVICE
float_e5m2_t operator+(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float_e5m2_t(float(lhs) + float(rhs));
}
CUTLASS_HOST_DEVICE
float_e5m2_t operator-(float_e5m2_t const& lhs) {
return float_e5m2_t(-float(lhs));
}
CUTLASS_HOST_DEVICE
float_e5m2_t operator-(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float_e5m2_t(float(lhs) - float(rhs));
}
CUTLASS_HOST_DEVICE
float_e5m2_t operator*(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float_e5m2_t(float(lhs) * float(rhs));
}
CUTLASS_HOST_DEVICE
float_e5m2_t operator/(float_e5m2_t const& lhs, float_e5m2_t const& rhs) {
return float_e5m2_t(float(lhs) / float(rhs));
}
CUTLASS_HOST_DEVICE
float_e5m2_t& operator+=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
lhs = float_e5m2_t(float(lhs) + float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e5m2_t& operator-=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
lhs = float_e5m2_t(float(lhs) - float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e5m2_t& operator*=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
lhs = float_e5m2_t(float(lhs) * float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e5m2_t& operator/=(float_e5m2_t & lhs, float_e5m2_t const& rhs) {
lhs = float_e5m2_t(float(lhs) / float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
float_e5m2_t& operator++(float_e5m2_t & lhs) {
float tmp(lhs);
++tmp;
lhs = float_e5m2_t(tmp);
return lhs;
}