-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interpreter + MX floating point types #2685
Comments
Thanks for the issue. I will get back to you soon. |
The mismatch between the StableHLO interpreter's dot_general and Python's @ (matmul) operator for low-bit matrix multiplication stems from:
It's important to note that this behavior of the StableHLO interpreter is expected. As a reference interpreter, it prioritizes clarity and correctness in demonstrating the fundamental operations of dot_general, even if it means slightly different results compared to highly optimized implementations. In any case, feel free to share if there any specific use cases you'd like to explore further using the StableHLO interpreter. TLDR; Detailed Explanation:Lets consider the evaluation using the StableHLO interpreter
will output
Comparing this with the output of
Output
|
The f8E8M0FNU is supposed to be used with a shared scale factor (1 8-bit scale for every N 8-bit element's) and as far as I can see (in the StableHLO code, MLIR code and LLVM's APFloat code) there is no scale value used anywhere. I can only assume that MLIR expects the dialect that uses this type to handle this when lowering (or interpreting) from this type into something more normal from LLVM's perspective. A little dive through the code... StableHLO's
And then
Which calls
Which constructs a Storage object using this constructor:
And So, after all of that I think that means that what is happening is that we are doing IEEE fp32 calculations in the intepreter but only storing 8-bits of the result each time which explains why the interpreter gets a result that is closer to numpy's fp32 result ( |
@samkellett First, thanks for the explanations.
That makes sense! One point that drew my attention was decomposed_matmul_mxtype = np.array(
[[
arg1_mxtype[0][0]*arg0_mxtype[0][0] + arg1_mxtype[0][1]*arg0_mxtype[1][0],
arg1_mxtype[0][0]*arg0_mxtype[0][1] + arg1_mxtype[0][1]*arg0_mxtype[1][1]
]], dtype=mxtype) matches that of StableHLO interpreter and the result of np.array(
[[
arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][0].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][0].astype(np.float32),
arg1_mxtype[0][0].astype(np.float32)*arg0_mxtype[0][1].astype(np.float32) + arg1_mxtype[0][1].astype(np.float32)*arg0_mxtype[1][1].astype(np.float32)
]], dtype=np.float32) which matches that of Clearly there is some precision loss in the former which, as you pointed out, is due to less storage bits uses (at-least in the StableHLO case). cc @reedwm |
In any case, Could you tell us a bit more about what you're hoping to achieve? We'd be happy to help you get there. |
This makes sense, but perhaps StableHLO should use FP32 precision for dots whose inputs are narrower than FP32. In practice, you typically want FP32 accumulation for numeric stability, and implementations like XLA use FP32 accumulation.
While it's true that f8E8M0FNU is intended to be used as a scale, there is no scaling logic within StableHLO itself (besides the quantized types, which aren't used here). From StableHLO's perspective, f8E8M0FNU is just another dtype (although a somewhat unusual one is that it only has powers-of-two). If the user wants to use the dtype as a scale, they need to insert the appropriate multiply and divide ops to implement such scaling. The |
Could be a version disagreement (I'm using ml_dtypes==0.5.0 and StableHLO== decomposed_matmul_mxtype = [[0.125 0.5]]
arg1_mxtype @ arg0_mxtype = [[0.1875 0.375 ]]
stablehlo_interpreter = [[2.500000e-01, 5.000000e-01]]
What I'm hoping to achieve is get some understanding of accuracy loss using this new type, and also some of the compilation pipelines available to me.
Is this a desired design, putting the scaling responsibility on the user? I expetced
I think some of this does touch on last year's discussion about how we deal with all these float variants. |
Right now, the responsibility is on the user to insert the appropriate scaling ops. MLdtypes |
I agree with respect to StableHLO itself but I do think that the interpreter should either not support the MX-types (error on seeing them) or implement it's own verison of automatic scaling rather than the current behaviour which is to produce the wrong answer. |
I've been exploring the MX-types (added in #2582), however in a basic example I seem to be getting the incorrect result.
I'm taking a simple single matmul example, represented in MLIR as:
We can run this with the stablehlo interpreter with some sample input data:
stablehlo-translate one_matmul_mxint8.mlir --interpret --args="[dense<[[0.1, 0.2], [0.3, 0.4]]> : tensor<2x2xf32>, dense<[[0.5, 0.6]]> : tensor<1x2xf32>]"
I'm using the
ml_dtypes
package as my source of truth (used in the MLIR unit tests).However, when I run (what I believe to be) an equivalent NumPy+ml_dtypes example, I get a different result:
When I return the converted
%0
and%1
tensors, they look as expected, but the result of thedot_general
operation do not match:StableHLO interpreter:
NumPy:
NumPy (using fp32 instead of mxtype):
I'm not sure if I'm missing something in the NumPy example, if I'm missing an appropriate attribute to the
DotGeneralOp
, or if there's a bug in the StableHLO interpreter. I get the same answer using anf32
orf8E8M0FNU
accumulation type.The text was updated successfully, but these errors were encountered: