diff --git a/onnx/numpy_helper.py b/onnx/numpy_helper.py index 3e6c391ed1f..be32595e3d5 100644 --- a/onnx/numpy_helper.py +++ b/onnx/numpy_helper.py @@ -234,8 +234,8 @@ def unpacked_float4e2m1_to_float32(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.f """ # x is stored in 4 LSB of int sign = np.where(np.bitwise_and(x, 0x08), -1, 1) - mantissa = x & 0x01 - exponent = (x & 0x06) >> 1 + mantissa = (x & 0x01).astype(np.float32) + exponent = ((x & 0x06) >> 1).astype(np.float32) val = np.where( exponent == 0, diff --git a/onnx/test/test_backend_reference.py b/onnx/test/test_backend_reference.py index 71ce7651ebf..e60a806b39b 100644 --- a/onnx/test/test_backend_reference.py +++ b/onnx/test/test_backend_reference.py @@ -141,10 +141,13 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): "|test_cast_no_saturate_FLOAT16_to_FLOAT8" "|test_cast_BFLOAT16_to_FLOAT" "|test_castlike_BFLOAT16_to_FLOAT" + "|test_cast_FLOAT_to_FLOAT4" + "|test_cast_FLOAT16_to_FLOAT4" "|test_quantizelinear_e4m3" "|test_quantizelinear_e5m2" "|test_quantizelinear_uint4" "|test_quantizelinear_int4" + "|test_quantizelinear_float4e2m1" ")" )