Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Compute bounds for Index scalars in lowered kernel #3850

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Feb 7, 2025

This extends #3599 by also computing the minimal dtype required by the expressions in the lowered kernel. Like in #3599, we cast from nvfuser_index_t to int32_t when passing coords to the TMA expression. However, unlike #3599 we actually verify that this is safe to do by checking the bounds of the inputs to those casts. This way we can safely use 64-bit indexing with TMA and know that we will not get silently incorrect results. Also, we will more commonly use 32-bit indexing because with TMA we often do not have extremely large values for index variables since TMA allows us to do multi-dimensional indexing.

Fixes #3601

TODO: add a few tests

Copy link

github-actions bot commented Feb 7, 2025

Review updated until commit b861294

Description

  • Added computation of bounds for index scalars in lowered kernels.

  • Introduced BoundedInt struct for handling inclusive bounds of integers.

  • Implemented ScalarBoundsCalculator class to compute bounds for scalars.

  • Ensured safe casting from nvfuser_index_t to int32_t for TMA expressions.


Changes walkthrough 📝

Relevant files
Enhancement
index_compute.cpp
Cast TMA box coordinates to int32_t                                           

csrc/index_compute.cpp

  • Included ir/builder.h.
  • Added casting of TMA box coordinates to int32_t.
  • +9/-0     
    executor.cpp
    Compute and validate index types after lowering                   

    csrc/runtime/executor.cpp

  • Included additional headers for expression evaluation and type
    handling.
  • Introduced BoundedInt struct for handling integer bounds.
  • Implemented ScalarBoundsCalculator class for computing scalar bounds.
  • Added logic to compute and validate index types after lowering.
  • +554/-7 
    matmul_utils.cpp
    Remove 64-bit indexing check for Hopper matmul                     

    csrc/scheduler/matmul_utils.cpp

    • Removed checks for 64-bit indexing with Hopper matmul.
    +0/-8     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The boundByDataType method in ScalarBoundsCalculator does not handle cases where the bounds are not initialized. The initialized flag is used to set the initial bounds, but there is no check to ensure that bounds_ is not empty before accessing its elements.

    BoundedInt ret;
    bool initialized = false;
    for (auto& [val, b] : bounds_) {
      if (val->dtype() != dtype) {
        continue;
      }
      if (!initialized) {
        ret = b;
        initialized = true;
      } else {
        ret.min = std::min(ret.min, b.min);
        ret.max = std::max(ret.max, b.max);
      }
      if (b.min < std::numeric_limits<int32_t>::min() ||
    Performance Concern

    The bitwise operations in BoundedInt do not handle negative numbers correctly. The logic assumes that the range of numbers is always positive, which may not be the case for all index computations.

    // Consider a number x=0bABCDE. If min(x)=max(x), then each of the bits A, B,
    // C, D, and E are fixed. However, if there is a range of values possible then
    // a subset of these bits could take on either 0 or 1. Suppose the range of x
    // is [0b01010, 0b01100]. Then we know that A=0, B=1, and C, D, and E can have
    // either value. Generally speaking, for numbers lying between two positive
    // integers, we know the lower-most K many bits are not fixed, where K is
    // PRECISION-(number of high bits in common). We can compute the largest K
    // between this and other, then we know that the XOR between these two values
    // can have any value for that many lower bits and all the higher bits are
    // determined by XORing the two min (or max) bounds with one another.
    //
    // [Note on twos-complement negative integers]
    // Since twos-complement negative integers can be envisioned as simply
    // stacking (without flipping) the negative values at the right side of the
    // positive values, we can apply the same algorithm regardless of signedness.
    BoundedInt operator^(const BoundedInt& other) const {
      // New interval has this many fixed bits
      int64_t fixed_bits = std::min(commonHighBits(), other.commonHighBits());
      // Mask everything below the higher fixed_bits
      int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
      int64_t new_min = (min ^ other.min) & (~low_mask); // 0b01000
      int64_t new_max = new_min + low_mask; // 0b01111
      return {new_min, new_max};
    }
    
    BoundedInt operator&(const BoundedInt& other) const {
      // New interval has this many fixed bits
      int64_t fixed_bits = std::min(commonHighBits(), other.commonHighBits());
      // Mask everything below the higher fixed_bits
      int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
      int64_t new_min = (min & other.min) & (~low_mask); // 0b01000
      int64_t new_max = new_min + low_mask; // 0b01111
      return {new_min, new_max};
    }
    
    BoundedInt operator|(const BoundedInt& other) const {
      // New interval has this many fixed bits
      int64_t fixed_bits = std::min(commonHighBits(), other.commonHighBits());
      // Mask everything below the higher fixed_bits
      int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
      int64_t new_min = (min | other.min) & (~low_mask); // 0b01000
      int64_t new_max = new_min + low_mask; // 0b01111
      return {new_min, new_max};
    }
    
    BoundedInt operator~() const {
      // New interval has this many fixed bits
      int64_t fixed_bits = commonHighBits();
      // Mask everything below the higher fixed_bits
      int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
      int64_t new_min = (~min) & (~low_mask); // 0b01000
      int64_t new_max = new_min + low_mask; // 0b01111
      return {new_min, new_max};
    }
    
    Code Complexity

    The ScalarBoundsCalculator class is quite complex and includes many operations that could be simplified or optimized. For example, the commonHighBits method could be optimized using bit manipulation techniques.

      //! Returns the number of high bits that must be common among all integers in
      //! this interval
      int64_t commonHighBits() const {
    // #if __cplusplus < 202002L
    #if true
        // XOR and view result as unsigned, so that right shift will be _logical_
        // instead of arithmetic.
        uint64_t different_bits = (*reinterpret_cast<const uint64_t*>(&max)) ^
            (*reinterpret_cast<const uint64_t*>(&min));
        // TODO: add countl_zero to csrc/C++20/ somewhere for C++17 backward
        // compatibility
        int64_t fixed_bits = 64L;
        while (different_bits != 0L) {
          different_bits >>= 1;
          fixed_bits--;
        }
        return fixed_bits;
    #else
        int64_t different_bits = b.max ^ b.min;
        return (int64_t)std::countl_zero(different_bits);
    #endif
      }
    
      // For bitwise operations, we consider the range of each bit independently.
      // Consider a number x=0bABCDE. If min(x)=max(x), then each of the bits A, B,
      // C, D, and E are fixed. However, if there is a range of values possible then
      // a subset of these bits could take on either 0 or 1. Suppose the range of x
      // is [0b01010, 0b01100]. Then we know that A=0, B=1, and C, D, and E can have
      // either value. Generally speaking, for numbers lying between two positive
      // integers, we know the lower-most K many bits are not fixed, where K is
      // PRECISION-(number of high bits in common). We can compute the largest K
      // between this and other, then we know that the XOR between these two values
      // can have any value for that many lower bits and all the higher bits are
      // determined by XORing the two min (or max) bounds with one another.
      //
      // [Note on twos-complement negative integers]
      // Since twos-complement negative integers can be envisioned as simply
      // stacking (without flipping) the negative values at the right side of the
      // positive values, we can apply the same algorithm regardless of signedness.
      BoundedInt operator^(const BoundedInt& other) const {
        // New interval has this many fixed bits
        int64_t fixed_bits = std::min(commonHighBits(), other.commonHighBits());
        // Mask everything below the higher fixed_bits
        int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
        int64_t new_min = (min ^ other.min) & (~low_mask); // 0b01000
        int64_t new_max = new_min + low_mask; // 0b01111
        return {new_min, new_max};
      }
    
      BoundedInt operator&(const BoundedInt& other) const {
        // New interval has this many fixed bits
        int64_t fixed_bits = std::min(commonHighBits(), other.commonHighBits());
        // Mask everything below the higher fixed_bits
        int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
        int64_t new_min = (min & other.min) & (~low_mask); // 0b01000
        int64_t new_max = new_min + low_mask; // 0b01111
        return {new_min, new_max};
      }
    
      BoundedInt operator|(const BoundedInt& other) const {
        // New interval has this many fixed bits
        int64_t fixed_bits = std::min(commonHighBits(), other.commonHighBits());
        // Mask everything below the higher fixed_bits
        int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
        int64_t new_min = (min | other.min) & (~low_mask); // 0b01000
        int64_t new_max = new_min + low_mask; // 0b01111
        return {new_min, new_max};
      }
    
      BoundedInt operator~() const {
        // New interval has this many fixed bits
        int64_t fixed_bits = commonHighBits();
        // Mask everything below the higher fixed_bits
        int64_t low_mask = (1 << fixed_bits) - 1; // 0b00111
        int64_t new_min = (~min) & (~low_mask); // 0b01000
        int64_t new_max = new_min + low_mask; // 0b01111
        return {new_min, new_max};
      }
    

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    Compute index type by bounding index expressions
    1 participant