Skip to content

Commit

Permalink
Fix some sign conversion warnings (#1172)
Browse files Browse the repository at this point in the history
* Fix sign conversion warnings

* Fix type conversion warnings

* Fix sign conversion warnings

* Change smem_size_ to constexpr

* clang warnings

* undo cast change

* one miss change

* missing part

---------

Co-authored-by: Haicheng Wu <[email protected]>
  • Loading branch information
cyyever and hwu36 authored Nov 30, 2023
1 parent 99c4eeb commit 10b850f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 31 deletions.
5 changes: 3 additions & 2 deletions include/cutlass/fast_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,10 @@ int ceil_div(int a, int b) {
* log2_up/down codes?
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t clz(value_t x) {
CUTLASS_HOST_DEVICE int clz(value_t x) {
for (int i = 31; i >= 0; --i) {
if ((1 << i) & x) return 31 - i;
if ((1 << i) & x)
return value_t(31 - i);
}
return 32;
}
Expand Down
22 changes: 11 additions & 11 deletions include/cutlass/float8.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/***************************************************************************************************
/**************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
Expand Down Expand Up @@ -216,8 +216,8 @@ struct alignas(1) float8_base {

// Extract the bits in the FP32 type
uint8_t sign = uint8_t((s >> 24 & 0x80));
int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS);
int mantissa = s & 0x7fffff;
int32_t exp = int32_t((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS;
uint32_t mantissa = s & 0x7fffff;
uint8_t u = 0;

uint8_t const kF8_NaN = 0x7f;
Expand All @@ -233,7 +233,7 @@ struct alignas(1) float8_base {
}

// Special handling
if ( exp == -128 ) {
if (exp == -128) {
// int8 range is from -128 to 127
// So 255(inf) - 127(bias) = 128 - will show up as -128

Expand All @@ -248,8 +248,8 @@ struct alignas(1) float8_base {

if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) {
// normal fp32 to normal fp8
exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS));
u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS));
exp = exp + FP8_EXPONENT_BIAS;
u = uint8_t((uint32_t(exp) & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS);
u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)));
} else if(exp < FP8_MIN_EXPONENT) {
// normal single-precision to subnormal float8-precision representation
Expand All @@ -271,8 +271,8 @@ struct alignas(1) float8_base {
if( exp == (FP8_MAX_EXPONENT + 1) ) {
uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS));
if( mantissa_tmp < FP8_MANTISSA_MASK) {
exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS));
u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
exp = exp + FP8_EXPONENT_BIAS;
u = uint8_t(uint32_t(exp) << FP8_NUM_MANTISSA_BITS) | mantissa_tmp;
may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1));
} else {
// satfinite
Expand Down Expand Up @@ -316,9 +316,9 @@ struct alignas(1) float8_base {
uint32_t constexpr kF32_NaN = 0x7fffffff;

uint8_t const &f8 = x;
int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
int mantissa = f8 & FP8_MANTISSA_MASK;
uint32_t sign = (f8 >> (FP8_NUM_BITS - 1)) & 1;
uint32_t exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK;
uint32_t mantissa = f8 & FP8_MANTISSA_MASK;
unsigned f = (sign << (FP32_NUM_BITS-1));

if (IS_E4M3 && exp == 15 && mantissa == 0x7) {
Expand Down
15 changes: 3 additions & 12 deletions include/cutlass/gemm/device/gemm_universal_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class GemmUniversalBase {
CUTLASS_THREAD_LOCAL static int sm_occupancy_;

/// Kernel dynamic shared memory allocation requirement
CUTLASS_THREAD_LOCAL static int smem_size_;
/// Update the kernel function's shared memory configuration for the current device
static constexpr size_t smem_size_ = sizeof(typename GemmKernel::SharedStorage);

/// Initialize static thread-local members for the thread's current device,
/// if necessary.
Expand Down Expand Up @@ -143,11 +144,8 @@ class GemmUniversalBase {
return Status::kErrorInternal;
}

// Update the kernel function's shared memory configuration for the current device
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));

// If requires more than 48KB: configure for extended, dynamic shared memory
if (smem_size_ >= (48 << 10))
if constexpr (smem_size_ >= (48 << 10))
{
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
Expand Down Expand Up @@ -377,7 +375,6 @@ class GemmUniversalBase {
}
};


/////////////////////////////////////////////////////////////////////////////////////////////////
/// Static initializers
/////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -394,12 +391,6 @@ CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
template <typename GemmKernel_>
CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;

/// Kernel dynamic shared memory allocation requirement
template <typename GemmKernel_>
CUTLASS_THREAD_LOCAL int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;



/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace device
Expand Down
6 changes: 3 additions & 3 deletions include/cutlass/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ struct alignas(2) half_t {
#endif

uint16_t const &h = x.storage;
int sign = ((h >> 15) & 1);
int exp = ((h >> 10) & 0x1f);
int mantissa = (h & 0x3ff);
uint32_t sign = ((h >> 15) & 1);
uint32_t exp = ((h >> 10) & 0x1f);
uint32_t mantissa = (h & 0x3ff);
unsigned f = 0;

if (exp > 0 && exp < 31) {
Expand Down
5 changes: 2 additions & 3 deletions include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {
// software implementation rounds toward nearest even
unsigned const& s = reinterpret_cast<unsigned const &>(flt);
uint16_t sign = uint16_t((s >> 16) & 0x8000);
int16_t exp = uint16_t(((s >> 23) & 0xff) - 127);
int32_t exp = int32_t((s >> 23) & 0xff) - 127;
int mantissa = s & 0x7fffff;
uint16_t u = 0;

Expand All @@ -386,8 +386,7 @@ struct NumericConverter<half_t, float, FloatRoundStyle::round_toward_zero> {

if (exp >= -14) {
// normal fp32 to normal fp16
exp = uint16_t(exp + uint16_t(15));
u = uint16_t(((exp & 0x1f) << 10));
u = uint16_t((uint32_t(exp + 15) & 0x1f) << 10);
u = uint16_t(u | (mantissa >> 13));
} else {
// normal single-precision to subnormal half_t-precision representation
Expand Down

0 comments on commit 10b850f

Please sign in to comment.