diff --git a/python/tests/op_mappers/test_roll_op.py b/python/tests/op_mappers/test_roll_op.py index 55a731b5cf..efdf90555f 100644 --- a/python/tests/op_mappers/test_roll_op.py +++ b/python/tests/op_mappers/test_roll_op.py @@ -102,6 +102,15 @@ def init_input_data(self): self.shifts = [121, 122] +class TestRollCase7(TestRollOp): + def init_input_data(self): + self.feed_data = { + 'x': self.random([10, 2, 3], 'float32'), + } + self.axis = [0, 1, 1, -2, 2, 2, 1, -1] + self.shifts = [1, 2, 3, 4, 5, 6, 7, 8] + + class TestRollAxesEmpty(TestRollOp): def set_op_attrs(self): return {"shifts": self.shifts, "axis": []}