Skip to content

Commit

Permalink
add type check for input name of resolve_datatype()
Browse files Browse the repository at this point in the history
  • Loading branch information
makoeppel committed Jun 21, 2024
1 parent db969e6 commit 9fec426
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/qonnx/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ def get_canonical_name(self):


def resolve_datatype(name):

if not isinstance(name, str):
raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'")

_special_types = {
"BINARY": IntType(1, False),
"BIPOLAR": BipolarType(),
Expand Down
39 changes: 39 additions & 0 deletions tests/core/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np

from qonnx.core.datatype import DataType
from qonnx.core.datatype import resolve_datatype


def test_datatypes():
Expand Down Expand Up @@ -97,3 +98,41 @@ def test_smallest_possible():
assert DataType.get_smallest_possible(-1) == DataType["BIPOLAR"]
assert DataType.get_smallest_possible(-3) == DataType["INT3"]
assert DataType.get_smallest_possible(-3.2) == DataType["FLOAT32"]


def test_resolve_datatype():
assert resolve_datatype("BIPOLAR")
assert resolve_datatype("BINARY")
assert resolve_datatype("TERNARY")
assert resolve_datatype("UINT2")
assert resolve_datatype("UINT3")
assert resolve_datatype("UINT4")
assert resolve_datatype("UINT8")
assert resolve_datatype("UINT16")
assert resolve_datatype("UINT32")
assert resolve_datatype("INT2")
assert resolve_datatype("INT3")
assert resolve_datatype("INT4")
assert resolve_datatype("INT8")
assert resolve_datatype("INT16")
assert resolve_datatype("INT32")
assert resolve_datatype("BINARY")
assert resolve_datatype("FLOAT32")


def test_input_type_error():
# test with invalid input to check if the TypeError works
try:
resolve_datatype(123) # This should raise a TypeError
except TypeError as e:
pass
else:
print("Test with invalid input failed: No TypeError was raised.")

# test with invalid input to check if the TypeError works
try:
resolve_datatype(1.23) # This should raise a TypeError
except TypeError as e:
pass
else:
print("Test with invalid input failed: No TypeError was raised.")

0 comments on commit 9fec426

Please sign in to comment.