Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpreter + MX floating point types #2685

Open
Wheest opened this issue Jan 8, 2025 · 9 comments
Open

Interpreter + MX floating point types #2685

Wheest opened this issue Jan 8, 2025 · 9 comments
Assignees

Comments

@Wheest
Copy link

Wheest commented Jan 8, 2025

I've been exploring the MX-types (added in #2582), however in a basic example I seem to be getting the incorrect result.

I'm taking a simple single matmul example, represented in MLIR as:

module @IrToHlo.6 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> {
    %0 = stablehlo.convert %arg0 : (tensor<2x2xf32>) -> tensor<2x2xf8E8M0FNU>
    %1 = stablehlo.convert %arg1 : (tensor<1x2xf32>) -> tensor<1x2xf8E8M0FNU>
    %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {accumulation_type = f8E8M0FNU} : (tensor<1x2xf8E8M0FNU>, tensor<2x2xf8E8M0FNU>) -> tensor<1x2xf8E8M0FNU>
    %3 = stablehlo.convert %2 : (tensor<1x2xf8E8M0FNU>) -> tensor<1x2xf32>
    return %3 : tensor<1x2xf32>
  }
}

We can run this with the stablehlo interpreter with some sample input data:

stablehlo-translate one_matmul_mxint8.mlir --interpret --args="[dense<[[0.1, 0.2], [0.3, 0.4]]> : tensor<2x2xf32>, dense<[[0.5, 0.6]]> : tensor<1x2xf32>]"

I'm using the ml_dtypes package as my source of truth (used in the MLIR unit tests).

However, when I run (what I believe to be) an equivalent NumPy+ml_dtypes example, I get a different result:

import numpy as np
from ml_dtypes import float8_e8m0fnu as mxtype


def stablehlo_emulation(arg0, arg1):
    """Equivalent to stablehlo program:

      ```mlir
    func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> {
      %0 = stablehlo.convert %arg0 : (tensor<2x2xf32>) -> tensor<2x2xf8E8M0FNU>
      %1 = stablehlo.convert %arg1 : (tensor<1x2xf32>) -> tensor<1x2xf8E8M0FNU>
      %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {accumulation_type = f8E8M0FNU} : (tensor<1x2xf8E8M0FNU>, tensor<2x2xf8E8M0FNU>) -> tensor<1x2xf8E8M0FNU>
      %3 = stablehlo.convert %2 : (tensor<1x2xf8E8M0FNU>) -> tensor<1x2xf32>
      return %3 : tensor<1x2xf32>
    }
    ```"""
    arg0_mxtype = arg0.astype(mxtype)
    arg1_mxtype = arg1.astype(mxtype)
    result_mxtype = arg1_mxtype @ arg0_mxtype
    result_fp32 = result_mxtype.astype(np.float32)
    return result_fp32


arg0 = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)  # tensor<2x2xf32>
arg1 = np.array([[0.5, 0.6]], dtype=np.float32)  # tensor<1x2xf32>
output = stablehlo_emulation(arg0, arg1)

print("Input tensor arg0 (fp32):")
print(arg0)
print("Input tensor arg1 (fp32):")
print(arg1)
print("Output tensor (fp32):")
print(output)
print("Non-mx type output:")
print(arg1 @ arg0)

When I return the converted %0 and %1 tensors, they look as expected, but the result of the dot_general operation do not match:

StableHLO interpreter:

tensor<1x2xf32> {
  [
    [2.500000e-01, 5.000000e-01]
  ]
}

NumPy:

[[0.1875 0.375]]

NumPy (using fp32 instead of mxtype):

[[0.23 0.34]]

I'm not sure if I'm missing something in the NumPy example, if I'm missing an appropriate attribute to the DotGeneralOp, or if there's a bug in the StableHLO interpreter. I get the same answer using an f32 or f8E8M0FNU accumulation type.

@sdasgup3 sdasgup3 self-assigned this Jan 8, 2025
@sdasgup3
Copy link
Member

sdasgup3 commented Jan 8, 2025

Thanks for the issue. I will get back to you soon.

@sdasgup3
Copy link
Member

sdasgup3 commented Jan 9, 2025

The mismatch between the StableHLO interpreter's dot_general and Python's @ (matmul) operator for low-bit matrix multiplication stems from:

  • Rounding Errors: The StableHLO interpreter uses a reference implementation for dot_general that closely follows the mathematical definition, performing each multiplication and addition step-by-step at low-bit type (f8E8M0FNU in this case).
    Even small rounding errors in such low-bit type, for several additions and multiplications, can accumulate and lead to a final result that differs from higher-precision calculations (demonstrated below at (3)).
  • Optimized Implementations: Python's @ operator, seems to be backed by optimized libraries or hardware-specific instructions, employs optimizations that minimize rounding errors.

It's important to note that this behavior of the StableHLO interpreter is expected. As a reference interpreter, it prioritizes clarity and correctness in demonstrating the fundamental operations of dot_general, even if it means slightly different results compared to highly optimized implementations.

In any case, feel free to share if there any specific use cases you'd like to explore further using the StableHLO interpreter.

TLDR; Detailed Explanation:

Lets consider the evaluation using the StableHLO interpreter

stablehlo_program_text = """
  func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>) -> (tensor<2x2xf8E8M0FNU>, tensor<1x2xf8E8M0FNU>, tensor<1x2xf8E8M0FNU>, tensor<1x2xf32>) {
    %0 = stablehlo.convert %arg0 : (tensor<2x2xf32>) -> tensor<2x2xf8E8M0FNU>
    %1 = stablehlo.convert %arg1 : (tensor<1x2xf32>) -> tensor<1x2xf8E8M0FNU>
    %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {accumulation_type = f8E8M0FNU} : (tensor<1x2xf8E8M0FNU>, tensor<2x2xf8E8M0FNU>) -> tensor<1x2xf8E8M0FNU>
    %3 = stablehlo.convert %2 : (tensor<1x2xf8E8M0FNU>) -> tensor<1x2xf32>
    return %0, %1, %2, %3 : tensor<2x2xf8E8M0FNU>, tensor<1x2xf8E8M0FNU>, tensor<1x2xf8E8M0FNU>, tensor<1x2xf32>
  }
"""

with ir.Context() as ctx:
  stablehlo_dialect.register_dialect(ctx)
  arg0 = ir.DenseElementsAttr.get(np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32))
  arg1 = ir.DenseElementsAttr.get(np.array([[0.5, 0.6]], dtype=np.float32))
  results = stablehlo_dialect.eval_module(ir.Module.parse(stablehlo_program_text, ctx), [arg0, arg1])
  print(results)

will output

[
  DenseFPElementsAttr(dense<[[1.250000e-01, 2.500000e-01], [2.500000e-01, 5.000000e-01]]> : tensor<2x2xf8E8M0FNU>), # arg0 in mxtype
  DenseFPElementsAttr(dense<5.000000e-01> : tensor<1x2xf8E8M0FNU>), # arg1 in mxtype
  DenseFPElementsAttr(dense<[[2.500000e-01, 5.000000e-01]]> : tensor<1x2xf8E8M0FNU>),  # result in mxtype
  DenseFPElementsAttr(dense<[[2.500000e-01, 5.000000e-01]]> : tensor<1x2xf32>)] # result in f32

Comparing this with the output of

def stablehlo_emulation(arg0, arg1):
    arg0_mxtype = arg0.astype(mxtype)
    arg1_mxtype = arg1.astype(mxtype)
    print("Input tensor arg0 (mx type):")
    print(arg0_mxtype)
    print("Input tensor arg1 (mx type):")
    print(arg1_mxtype)
    
    result_mxtype = arg1_mxtype @ arg0_mxtype
    print("Output tensor (mx type):")
    print(result_mxtype)

    decomposed_matmul_mxtype = np.array(
        [[
            arg1_mxtype[0][0]*arg0_mxtype[0][0] + arg1_mxtype[0][1]*arg0_mxtype[1][0],
            arg1_mxtype[0][0]*arg0_mxtype[0][1] + arg1_mxtype[0][1]*arg0_mxtype[1][1]
        ]], dtype=mxtype)
    print("decomposed_matmul (mx type):")
    print(decomposed_matmul_mxtype)

    decomposed_matmul_with_f32_accum = np.array(
        [[
            arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][0].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][0].astype(np.float32),
            arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][1].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][1].astype(np.float32)
        ]], dtype=np.float32)
    print("decomposed_matmul (with f32 accum):")
    print(decomposed_matmul_with_f32_accum)
    return result_mxtype


arg0 = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)  # tensor<2x2xf32>
arg1 = np.array([[0.5, 0.6]], dtype=np.float32)  # tensor<1x2xf32>
output = stablehlo_emulation(arg0, arg1)

Output

Input tensor arg0 (mx type):
[[0.125 0.25]
 [0.25 0.5]]
Input tensor arg1 (mx type):
[[0.5 0.5]]
Output tensor (mx type):
[[0.1875 0.375 ]]
decomposed_matmul (mx type):
[[0.25 0.5]]
decomposed_matmul (with f32 accum):
[[0.1875 0.375 ]]
  1. The mxtype'd inputs arguments for both stablehlo.dot and @ are exactly the same.
  2. The stablehlo interpreter output is same as decomposed_matmul (the result of a sum of products at mxtype). This is exactly how the reference implementation for dot_general looks like in StableHLO. Even small rounding errors in such low-bit type, for addition and multiplication, can accumulate and lead to noticeable differences in the final result.
  3. Performing the same decomposed matmul using higher accumulation type ( float32 ), decomposed_matmul_with_f32_accum, matches the @ result at mxtype suggesting that the @ operator might utilize optimized libraries or hardware-specific instructions for matrix multiplication which minimized the rounding errors and hence does not match the naive decomposition.

@samkellett
Copy link

samkellett commented Jan 9, 2025

The f8E8M0FNU is supposed to be used with a shared scale factor (1 8-bit scale for every N 8-bit element's) and as far as I can see (in the StableHLO code, MLIR code and LLVM's APFloat code) there is no scale value used anywhere. I can only assume that MLIR expects the dialect that uses this type to handle this when lowering (or interpreting) from this type into something more normal from LLVM's perspective.

A little dive through the code...

StableHLO's Tensor::makeTensor constructs a tensor of these elements by just taking 8-bits from the value in the DenseAttr:

    auto floatValues = llvm::map_to_vector(
        attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
          return value.bitcastToAPInt().getZExtValue();
        });

And then Tensor::get recreates an APFloat:

    return Element(elementType,
                   APFloat(floatTy.getFloatSemantics(),
                           APInt(floatTy.getWidth(), *elementData)));

Which calls

  APFloat(const fltSemantics &Semantics, const APInt &I) : U(Semantics, I) {}

Which constructs a Storage object using this constructor:

    Storage(const fltSemantics &Semantics, ArgTypes &&... Args) {
      if (usesLayout<IEEEFloat>(Semantics)) {
        new (&IEEE) IEEEFloat(Semantics, std::forward<ArgTypes>(Args)...);
        return;
      }
      if (usesLayout<DoubleAPFloat>(Semantics)) {
        new (&Double) DoubleAPFloat(Semantics, std::forward<ArgTypes>(Args)...);
        return;
      }
      llvm_unreachable("Unexpected semantics");
    }

And usesLayout<IEEEFloat> returns true if Semantics is not the Double Semantics... therefore the APFloat created is an IEEE Float (32-bit?) that is constructed using the 8-bits stored in the APInt object that StableHLO's Element object owns.

So, after all of that I think that means that what is happening is that we are doing IEEE fp32 calculations in the intepreter but only storing 8-bits of the result each time which explains why the interpreter gets a result that is closer to numpy's fp32 result (0.23 0.34) but worse.

@sdasgup3
Copy link
Member

sdasgup3 commented Jan 10, 2025

@samkellett First, thanks for the explanations.

after all of that I think that means that what is happening is that we are doing IEEE fp32 calculations in the intepreter but only storing 8-bits of the result

That makes sense!

One point that drew my attention was
the result of

decomposed_matmul_mxtype = np.array(
        [[
            arg1_mxtype[0][0]*arg0_mxtype[0][0] + arg1_mxtype[0][1]*arg0_mxtype[1][0],
            arg1_mxtype[0][0]*arg0_mxtype[0][1] + arg1_mxtype[0][1]*arg0_mxtype[1][1]
        ]], dtype=mxtype)

matches that of StableHLO interpreter and the result of

np.array(
        [[
            arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][0].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][0].astype(np.float32),
            arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][1].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][1].astype(np.float32)
        ]], dtype=np.float32)

which matches that of arg1_mxtype @ arg0_mxtype.

Clearly there is some precision loss in the former which, as you pointed out, is due to less storage bits uses (at-least in the StableHLO case).

cc @reedwm

@sdasgup3
Copy link
Member

@Wheest

In any case, Could you tell us a bit more about what you're hoping to achieve? We'd be happy to help you get there.

@sdasgup3 sdasgup3 assigned reedwm and unassigned reedwm Jan 10, 2025
@reedwm
Copy link
Member

reedwm commented Jan 10, 2025

Rounding Errors: The StableHLO interpreter uses a reference implementation for dot_general that closely follows the mathematical definition, performing each multiplication and addition step-by-step at low-bit type (f8E8M0FNU in this case).

This makes sense, but perhaps StableHLO should use FP32 precision for dots whose inputs are narrower than FP32. In practice, you typically want FP32 accumulation for numeric stability, and implementations like XLA use FP32 accumulation.

The f8E8M0FNU is supposed to be used with a shared scale factor (1 8-bit scale for every N 8-bit element's) and as far as I can see (in the StableHLO code, MLIR code and LLVM's APFloat code) there is no scale value used anywhere.

While it's true that f8E8M0FNU is intended to be used as a scale, there is no scaling logic within StableHLO itself (besides the quantized types, which aren't used here). From StableHLO's perspective, f8E8M0FNU is just another dtype (although a somewhat unusual one is that it only has powers-of-two). If the user wants to use the dtype as a scale, they need to insert the appropriate multiply and divide ops to implement such scaling.

The usesLayout<IEEEFloat> part in APFloat just means that it's not using the special DoubleAPFloat format which encodes a number as a pair of floating-point values (DoubleAPFloat is unrelated to C's double type, which is still represented with an IEEEFloat). Despite its name, IEEEFloat does not imply the type is an IEEE format, but can represent a variety of floating-point formats like FP32, FP64, and the various forms of FP8.

@Wheest
Copy link
Author

Wheest commented Jan 13, 2025

One point that drew my attention was the result of

decomposed_matmul_mxtype = np.array(
        [[
            arg1_mxtype[0][0]*arg0_mxtype[0][0] + arg1_mxtype[0][1]*arg0_mxtype[1][0],
            arg1_mxtype[0][0]*arg0_mxtype[0][1] + arg1_mxtype[0][1]*arg0_mxtype[1][1]
        ]], dtype=mxtype)

matches that of StableHLO interpreter and the result of

np.array(
        [[
            arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][0].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][0].astype(np.float32),
            arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][1].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][1].astype(np.float32)
        ]], dtype=np.float32)

which matches that of arg1_mxtype @ arg0_mxtype.

Could be a version disagreement (I'm using ml_dtypes==0.5.0 and StableHLO==830b9787c58), but I'm not getting those results.

decomposed_matmul_mxtype = [[0.125 0.5]]
arg1_mxtype @ arg0_mxtype = [[0.1875 0.375 ]]
stablehlo_interpreter = [[2.500000e-01, 5.000000e-01]]

In any case, Could you tell us a bit more about what you're hoping to achieve? We'd be happy to help you get there.

What I'm hoping to achieve is get some understanding of accuracy loss using this new type, and also some of the compilation pipelines available to me.

there is no scaling logic within StableHLO itself (besides the quantized types, which aren't used here). From StableHLO's perspective, f8E8M0FNU is just another dtype (although a somewhat unusual one is that it only has powers-of-two). If the user wants to use the dtype as a scale, they need to insert the appropriate multiply and divide ops to implement such scaling.

Is this a desired design, putting the scaling responsibility on the user? I expetced stablehlo.convert to have similar behavior to MLdtypes .astype(mxtype).
There is ongoing discussion regarding if UniformQuantizeOp/UniformDequantizeOp should be merged into convert, do the microscaling types need similar ops, or should it be the user's responsibility to insert the appropriate ops?

Despite its name, IEEEFloat does not imply the type is an IEEE format, but can represent a variety of floating-point formats like FP32, FP64, and the various forms of FP8.

I think some of this does touch on last year's discussion about how we deal with all these float variants.

@reedwm
Copy link
Member

reedwm commented Jan 13, 2025

there is no scaling logic within StableHLO itself (besides the quantized types, which aren't used here). From StableHLO's perspective, f8E8M0FNU is just another dtype (although a somewhat unusual one is that it only has powers-of-two). If the user wants to use the dtype as a scale, they need to insert the appropriate multiply and divide ops to implement such scaling.

Is this a desired design, putting the scaling responsibility on the user? I expetced stablehlo.convert to have similar behavior to MLdtypes .astype(mxtype). There is ongoing discussion regarding if UniformQuantizeOp/UniformDequantizeOp should be merged into convert, do the microscaling types need similar ops, or should it be the user's responsibility to insert the appropriate ops?

Right now, the responsibility is on the user to insert the appropriate scaling ops. MLdtypes .astype(mxtype) also behaviors similarly: It does no automatic scaling, and mxdtypes are just treated as a normal dtype. I think this design is sensible: frameworks like JAX can provide higher-level APIs that do automatic scaling, and the types are still very new so I don't think we should introduce hard-coded scaling algorithms yet.

@samkellett
Copy link

samkellett commented Jan 14, 2025

Right now, the responsibility is on the user to insert the appropriate scaling ops. MLdtypes .astype(mxtype) also behaviors similarly: It does no automatic scaling, and mxdtypes are just treated as a normal dtype. I think this design is sensible: frameworks like JAX can provide higher-level APIs that do automatic scaling, and the types are still very new so I don't think we should introduce hard-coded scaling algorithms yet.

I agree with respect to StableHLO itself but I do think that the interpreter should either not support the MX-types (error on seeing them) or implement it's own verison of automatic scaling rather than the current behaviour which is to produce the wrong answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants