diff --git a/docs/qonnx-custom-ops/quant_op.md b/docs/qonnx-custom-ops/quant_op.md
index 02d115fb..68029406 100644
--- a/docs/qonnx-custom-ops/quant_op.md
+++ b/docs/qonnx-custom-ops/quant_op.md
@@ -21,7 +21,7 @@ This operator is not part of the ONNX standard and is not currently versioned.
narrow : int (default is 0)
Defines if the value range should be interpreted as narrow, when signed=1. E.g. at 8b regular=[-128, 127] vs narrow=[-127, 127].
rounding_mode : string (default is "ROUND")
-Defines how rounding should be applied during quantization. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".
+Defines how rounding should be applied during quantization. Avaiable options are ROUND, CEIL, FLOOR, UP, DOWN, HALF_UP, HALF_DOWN. The rounding modes are described in the table bellow. The names of rounding modes can be upper case or lower case.
#### Inputs
@@ -46,6 +46,24 @@ This operator is not part of the ONNX standard and is not currently versioned.
+#### Rounding modes
+
+rounding modes
+
+| **Number \ ROUNDING_MODE** | ROUND=HALF_EVEN | CEIL | FLOOR | UP | DOWN | HALF_UP | HALF_DOWN |
+|----------------------------|-----------------|------|-------|----|------|---------|-----------|
+| 5.5 | 6 | 6 | 5 | 6 | 5 | 6 | 5 |
+| 2.5 | 2 | 3 | 2 | 3 | 2 | 3 | 2 |
+| 1.6 | 2 | 2 | 1 | 2 | 1 | 2 | 2 |
+| 1.1 | 1 | 2 | 1 | 2 | 1 | 1 | 1 |
+| 1.0 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
+| -1.0 | -1 | -1 | -1 | -1 | -1 | -1 | -1 |
+| -1.1 | -1 | -1 | -2 | -2 | -1 | -1 | -1 |
+| -1.6 | -2 | -1 | -2 | -2 | -1 | -2 | -2 |
+| -2.5 | -2 | -2 | -3 | -3 | -2 | -3 | -2 |
+| -5.5 | -6 | -5 | -6 | -6 | -5 | -6 | -5 |
+
+
#### Examples
Quant
diff --git a/src/qonnx/custom_op/general/quant.py b/src/qonnx/custom_op/general/quant.py
index b5cdf332..f81495d2 100644
--- a/src/qonnx/custom_op/general/quant.py
+++ b/src/qonnx/custom_op/general/quant.py
@@ -135,12 +135,32 @@ def resolve_rounding_mode(mode_string):
"""Resolve the rounding mode string of Quant and Trunc ops
to the corresponding numpy functions."""
normalized_mode_string = mode_string.upper()
- if normalized_mode_string == "ROUND":
+ if normalized_mode_string == "ROUND" or normalized_mode_string == "HALF_EVEN":
return np.round
elif normalized_mode_string == "CEIL":
return np.ceil
elif normalized_mode_string == "FLOOR":
return np.floor
+ elif normalized_mode_string == "UP":
+
+ def round_up(x):
+ return np.sign(x) * np.ceil(np.abs(x))
+
+ return round_up
+ elif normalized_mode_string == "DOWN":
+ return np.fix
+ elif normalized_mode_string == "HALF_UP":
+
+ def round_half_up(x):
+ return np.sign(x) * np.floor(np.abs(x) + 0.5)
+
+ return round_half_up
+ elif normalized_mode_string == "HALF_DOWN":
+
+ def round_half_down(x):
+ return np.sign(x) * np.ceil(np.abs(x) - 0.5)
+
+ return round_half_down
else:
raise ValueError(f"Could not resolve rounding mode called: {normalized_mode_string}")
diff --git a/tests/custom_op/test_runding_mode.py b/tests/custom_op/test_runding_mode.py
new file mode 100644
index 00000000..eb48d644
--- /dev/null
+++ b/tests/custom_op/test_runding_mode.py
@@ -0,0 +1,23 @@
+import pytest
+
+import numpy as np
+
+from qonnx.custom_op.general.quant import resolve_rounding_mode
+
+
+@pytest.mark.parametrize(
+ "rmode,exp",
+ [
+ ("ROUND", np.array([6, 2, 2, 1, 1, -1, -1, -2, -2, -6])),
+ ("CEIL", np.array([6, 3, 2, 2, 1, -1, -1, -1, -2, -5])),
+ ("FLOOR", np.array([5, 2, 1, 1, 1, -1, -2, -2, -3, -6])),
+ ("UP", np.array([6, 3, 2, 2, 1, -1, -2, -2, -3, -6])),
+ ("DOWN", np.array([5, 2, 1, 1, 1, -1, -1, -1, -2, -5])),
+ ("HALF_UP", np.array([6, 3, 2, 1, 1, -1, -1, -2, -3, -6])),
+ ("HALF_DOWN", np.array([5, 2, 2, 1, 1, -1, -1, -2, -2, -5])),
+ ],
+)
+def test_rounding_modes(rmode, exp):
+ test_array = np.array([5.5, 2.5, 1.6, 1.1, 1.0, -1.0, -1.1, -1.6, -2.5, -5.5])
+ rounding_fn = resolve_rounding_mode(rmode)
+ assert np.array_equal(rounding_fn(test_array), exp)