Skip to content

Commit

Permalink
Make gen_finn_dt_tensor consider the numpy type for INT and FIXED types
Browse files Browse the repository at this point in the history
  • Loading branch information
iksnagreb committed Jun 7, 2024
1 parent cadd6b2 commit 45cfa03
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/qonnx/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,12 @@ def gen_finn_dt_tensor(finn_dt, tensor_shape):
elif finn_dt == DataType["BINARY"]:
tensor_values = np.random.randint(2, size=tensor_shape)
elif "INT" in finn_dt.name or finn_dt == DataType["TERNARY"]:
tensor_values = np.random.randint(finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape)
tensor_values = np.random.randint(
finn_dt.min(), high=finn_dt.max() + 1, size=tensor_shape, dtype=finn_dt.to_numpy_dt()
)
elif "FIXED" in finn_dt.name:
int_dt = DataType["INT" + str(finn_dt.bitwidth())]
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape)
tensor_values = np.random.randint(int_dt.min(), high=int_dt.max() + 1, size=tensor_shape, dtype=int_dt.to_numpy_dt())
tensor_values = tensor_values * finn_dt.scale_factor()
elif finn_dt == DataType["FLOAT32"]:
tensor_values = np.random.randn(*tensor_shape)
Expand Down

0 comments on commit 45cfa03

Please sign in to comment.