Skip to content

Commit

Permalink
update cast, castlike, Q/DQ
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Yao <[email protected]>
  • Loading branch information
yuanyao-nv committed Aug 24, 2024
1 parent c057d17 commit 46dc2f7
Show file tree
Hide file tree
Showing 139 changed files with 1,057 additions and 48 deletions.
276 changes: 271 additions & 5 deletions docs/Changelog.md

Large diffs are not rendered by default.

162 changes: 144 additions & 18 deletions docs/Operators.md

Large diffs are not rendered by default.

126 changes: 124 additions & 2 deletions docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,10 @@ test_cases = [
("INT4", "FLOAT"),
("INT4", "FLOAT16"),
("INT4", "INT8"),
("FLOAT4E2M1", "FLOAT"),
("FLOAT4E2M1", "FLOAT16"),
("FLOAT", "FLOAT4E2M1"),
("FLOAT16", "FLOAT4E2M1"),
]

vect_float32_to_float8e4m3 = np.vectorize(float32_to_float8e4m3)
Expand Down Expand Up @@ -2566,7 +2570,57 @@ for from_type, to_type in test_cases:
output_type_proto = onnx.helper.make_tensor_type_proto(
getattr(TensorProto, to_type), input_shape
)
elif from_type == "FLOAT4E2M1" or to_type == "FLOAT4E2M1":
np_fp32 = np.array(
[
"0.48",
"0.25",
"1.05",
"-3.5",
"-8",
"9",
"1000000",
"1e-7",
"NaN",
"INF",
"+INF",
"-INF",
"-4",
"0.01",
"-0.0",
],
dtype=np.float32,
)
input_shape = (3, 5)
if from_type == "FLOAT":
input_values = np_fp32
input = make_tensor(
"x", TensorProto.FLOAT, input_shape, input_values.tolist()
)
elif from_type == "FLOAT16":
input_values = np_fp32.astype(np.float16).astype(np.float32)
input = make_tensor(
"x", TensorProto.FLOAT16, input_shape, input_values.tolist()
)
elif from_type == "FLOAT4E2M1":
input = make_tensor(
"x", TensorProto.FLOAT4E2M1, input_shape, np_fp32.tolist()
)
else:
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)

if to_type not in ("FLOAT", "FLOAT16", "FLOAT4E2M1"):
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)
expected = unpacked_float4e2m1_to_float32(
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
)
output = make_tensor(
"y", getattr(TensorProto, to_type), input_shape, expected.tolist()
)
elif from_type != "STRING":
input = np.random.random_sample(shape).astype(
helper.tensor_dtype_to_np_dtype(getattr(TensorProto, from_type))
Expand Down Expand Up @@ -5317,7 +5371,7 @@ expect(node, inputs=[x], outputs=[y], name="test_depthtospace_example")


### DequantizeLinear
There are 11 test cases, listed as following:
There are 12 test cases, listed as following:
<details>
<summary>axis</summary>

Expand Down Expand Up @@ -5554,6 +5608,32 @@ expect(
)
```

</details>
<details>
<summary>float4e2m1</summary>

```python
node = onnx.helper.make_node(
"DequantizeLinear",
inputs=["x", "x_scale", "x_zero_point"],
outputs=["y"],
axis=0,
)

# scalar zero point and scale
x = make_tensor("x", TensorProto.FLOAT4E2M1, [5], [0, 1, -1, 1.5, -4])
x_scale = np.float32(2)
x_zero_point = make_tensor("x_zero_point", TensorProto.FLOAT4E2M1, (1,), [0])
y = np.array([0, 2, -2, 3, -8], dtype=np.float32)

expect(
node,
inputs=[x, x_scale, x_zero_point],
outputs=[y],
name="test_dequantizelinear_float4e2m1",
)
```

</details>
<details>
<summary>int16</summary>
Expand Down Expand Up @@ -13937,7 +14017,7 @@ for quant_type_name in ["uint8", "int8"]:


### QuantizeLinear
There are 10 test cases, listed as following:
There are 11 test cases, listed as following:
<details>
<summary>axis</summary>

Expand Down Expand Up @@ -14151,6 +14231,48 @@ expect(
)
```

</details>
<details>
<summary>float4e2m1</summary>

```python
node = onnx.helper.make_node(
"QuantizeLinear",
inputs=["x", "y_scale", "y_zero_point"],
outputs=["y"],
axis=0,
)

x = np.array(
[
[0.0, 2.5, 4.8, 8.6],
[-30, -20, 6, 9],
[-0.0, -2.5, -4.8, -8.6],
]
).astype(np.float32)

y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
y_zero_point = make_tensor(
"y_zero_point",
TensorProto.FLOAT4E2M1,
y_scale.shape,
np.zeros_like(y_scale),
)
y = make_tensor(
"y",
TensorProto.FLOAT4E2M1,
x.shape,
[0, 1, 2, 4, -6, -6, 2, 3, 0, -0.5, -1, -2],
)

expect(
node,
inputs=[x, y_scale, y_zero_point],
outputs=[y],
name="test_quantizelinear_float4e2m1",
)
```

</details>
<details>
<summary>int16</summary>
Expand Down
60 changes: 59 additions & 1 deletion onnx/backend/test/case/node/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
make_tensor,
tensor_dtype_to_field,
)
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32
from onnx.numpy_helper import (
float8e4m3_to_float32,
float8e5m2_to_float32,
unpacked_float4e2m1_to_float32,
)


class Cast(Base):
Expand Down Expand Up @@ -62,6 +66,10 @@ def export() -> None:
("INT4", "FLOAT"),
("INT4", "FLOAT16"),
("INT4", "INT8"),
("FLOAT4E2M1", "FLOAT"),
("FLOAT4E2M1", "FLOAT16"),
("FLOAT", "FLOAT4E2M1"),
("FLOAT16", "FLOAT4E2M1"),
]

vect_float32_to_float8e4m3 = np.vectorize(float32_to_float8e4m3)
Expand Down Expand Up @@ -278,7 +286,57 @@ def export() -> None:
output_type_proto = onnx.helper.make_tensor_type_proto(
getattr(TensorProto, to_type), input_shape
)
elif from_type == "FLOAT4E2M1" or to_type == "FLOAT4E2M1":
np_fp32 = np.array(
[
"0.48",
"0.25",
"1.05",
"-3.5",
"-8",
"9",
"1000000",
"1e-7",
"NaN",
"INF",
"+INF",
"-INF",
"-4",
"0.01",
"-0.0",
],
dtype=np.float32,
)
input_shape = (3, 5)
if from_type == "FLOAT":
input_values = np_fp32
input = make_tensor(
"x", TensorProto.FLOAT, input_shape, input_values.tolist()
)
elif from_type == "FLOAT16":
input_values = np_fp32.astype(np.float16).astype(np.float32)
input = make_tensor(
"x", TensorProto.FLOAT16, input_shape, input_values.tolist()
)
elif from_type == "FLOAT4E2M1":
input = make_tensor(
"x", TensorProto.FLOAT4E2M1, input_shape, np_fp32.tolist()
)
else:
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)

if to_type not in ("FLOAT", "FLOAT16", "FLOAT4E2M1"):
raise ValueError(
f"Conversion from {from_type} to {to_type} is not tested."
)
expected = unpacked_float4e2m1_to_float32(
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
)
output = make_tensor(
"y", getattr(TensorProto, to_type), input_shape, expected.tolist()
)
elif from_type != "STRING":
input = np.random.random_sample(shape).astype(
helper.tensor_dtype_to_np_dtype(getattr(TensorProto, from_type))
Expand Down
22 changes: 22 additions & 0 deletions onnx/backend/test/case/node/dequantizelinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,28 @@ def export_int4() -> None:
name="test_dequantizelinear_int4",
)

@staticmethod
def export_float4e2m1() -> None:
node = onnx.helper.make_node(
"DequantizeLinear",
inputs=["x", "x_scale", "x_zero_point"],
outputs=["y"],
axis=0,
)

# scalar zero point and scale
x = make_tensor("x", TensorProto.FLOAT4E2M1, [5], [0, 1, -1, 1.5, -4])
x_scale = np.float32(2)
x_zero_point = make_tensor("x_zero_point", TensorProto.FLOAT4E2M1, (1,), [0])
y = np.array([0, 2, -2, 3, -8], dtype=np.float32)

expect(
node,
inputs=[x, x_scale, x_zero_point],
outputs=[y],
name="test_dequantizelinear_float4e2m1",
)

@staticmethod
def export_blocked() -> None:
node = onnx.helper.make_node(
Expand Down
38 changes: 38 additions & 0 deletions onnx/backend/test/case/node/quantizelinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,44 @@ def export_int4() -> None:
name="test_quantizelinear_int4",
)

@staticmethod
def export_float4e2m1() -> None:
node = onnx.helper.make_node(
"QuantizeLinear",
inputs=["x", "y_scale", "y_zero_point"],
outputs=["y"],
axis=0,
)

x = np.array(
[
[0.0, 2.5, 4.8, 8.6],
[-30, -20, 6, 9],
[-0.0, -2.5, -4.8, -8.6],
]
).astype(np.float32)

y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
y_zero_point = make_tensor(
"y_zero_point",
TensorProto.FLOAT4E2M1,
y_scale.shape,
np.zeros_like(y_scale),
)
y = make_tensor(
"y",
TensorProto.FLOAT4E2M1,
x.shape,
[0, 1, 2, 4, -6, -6, 2, 3, 0, -0.5, -1, -2],
)

expect(
node,
inputs=[x, y_scale, y_zero_point],
outputs=[y],
name="test_quantizelinear_float4e2m1",
)

@staticmethod
def export_blocked_asymmetric() -> None:
node = onnx.helper.make_node(
Expand Down
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_DOUBLE_to_FLOAT/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

*'�o�h�x�������������������B��Bx
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
�w�By
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_FLOAT16_to_INT4/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
�w�Bx
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
�w�Bx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_FLOAT_to_DOUBLE/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
�w�By
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_FLOAT_to_INT4/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_FLOAT_to_STRING/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_FLOAT_to_UINT4/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_INT4_to_FLOAT/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_INT4_to_FLOAT16/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_INT4_to_INT8/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_STRING_to_FLOAT/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_cast_UINT4_to_UINT8/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_dequantizelinear/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* :Bx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_quantizelinear/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_quantizelinear_axis/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_quantizelinear_e5m2/model.onnx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* d�T��By
Binary file modified onnx/backend/test/data/node/test_quantizelinear_int16/model.onnx
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_quantizelinear_int4/model.onnx
Binary file not shown.
Binary file not shown.
Binary file modified onnx/backend/test/data/node/test_quantizelinear_uint4/model.onnx
Binary file not shown.
11 changes: 9 additions & 2 deletions onnx/defs/operator_sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -1291,11 +1291,18 @@ class OpSet_Onnx_ver22 {
};

// Iterate over schema from ai.onnx version 23
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Cast);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, CastLike);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, DequantizeLinear);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, QuantizeLinear);

class OpSet_Onnx_ver23 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
// TODO: Remove after introducing the first schema to opset 23
(void)fn;
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Cast)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, CastLike)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, DequantizeLinear)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, QuantizeLinear)>());
}
};

Expand Down
Loading

0 comments on commit 46dc2f7

Please sign in to comment.