Skip to content

Commit

Permalink
Fix (tests): correct QuantTensor logic in quant_rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 13, 2023
1 parent 500d473 commit 3e6920a
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions tests/brevitas/nn/test_nn_quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,23 @@ def test_quant_lstm_rnn_full(model_input, current_cases):
assert isinstance(output, torch.Tensor)

if h is not None:
if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
if return_quant_tensor and kwargs['io_quant'] is not None:
assert isinstance(h, QuantTensor)
else:
assert isinstance(h, torch.Tensor)

if c is not None:
if kwargs['signed_act'] is None or not kwargs['return_quant_tensor']:
if not kwargs['bidirectional']:
if not kwargs['return_quant_tensor'] and kwargs['num_layers'] == 1:
if not kwargs['return_quant_tensor']:
assert isinstance(c, torch.Tensor)
elif kwargs['return_quant_tensor'] and kwargs['signed_act'] is None and kwargs[
'num_layers'] == 2:
assert isinstance(c, torch.Tensor)
else:
if kwargs['num_layers'] == 2 and kwargs['signed_act'] is None:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, QuantTensor)
else:
if kwargs['num_layers'] == 2 and kwargs['signed_act'] is not None:
assert isinstance(c, QuantTensor)
else:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, torch.Tensor)
else:
assert isinstance(c, QuantTensor)

Expand Down Expand Up @@ -182,13 +179,13 @@ def test_quant_lstm_rnn(model_input, current_cases):
assert isinstance(output, torch.Tensor)

if h is not None:
if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
if return_quant_tensor and kwargs['io_quant'] is not None:
assert isinstance(h, QuantTensor)
else:
assert isinstance(h, torch.Tensor)

if c is not None:
if (return_quant_tensor or kwargs['num_layers'] == 2) and kwargs['io_quant'] is not None:
if return_quant_tensor and kwargs['io_quant'] is not None:
assert isinstance(c, QuantTensor)
else:
assert isinstance(c, torch.Tensor)
Expand Down

0 comments on commit 3e6920a

Please sign in to comment.