Skip to content

Commit

Permalink
Fix enforce shape in docval when argument does not have len (#122)
Browse files Browse the repository at this point in the history
* Fix enforce shape in docval when arg does not have len

* make None always match shape, add tests
  • Loading branch information
rly authored Aug 1, 2019
1 parent e5b4eee commit 1cbbce6
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 13 deletions.
32 changes: 20 additions & 12 deletions src/hdmf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,21 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
fmt_val = (argname, type(argval).__name__, __format_type(arg['type']))
type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val)
if enforce_shape and 'shape' in arg:
while not hasattr(argval, '__len__'):
valshape = get_data_shape(argval)
while valshape is None:
if argval is None:
break
if not hasattr(argval, argname):
fmt_val = (argval, argname, arg['shape'])
value_errors.append("cannot check object '%s' for shape for '%s' "
value_errors.append("cannot check shape of object '%s' for argument '%s' "
"(expected shape '%s')" % fmt_val)
continue
break
# unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked
argval = getattr(argval, argname)
if hasattr(argval, '__len__') and not __shape_okay_multi(argval, arg['shape']):
fmt_val = (argname, get_data_shape(argval), arg['shape'])
value_errors.append("incorrect shape for '%s' (got '%s, expected '%s')" % fmt_val)
valshape = get_data_shape(argval)
if valshape is not None and not __shape_okay_multi(argval, arg['shape']):
fmt_val = (argname, valshape, arg['shape'])
value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val)
ret[argname] = argval
argsi += 1
arg = next(it)
Expand All @@ -215,17 +219,21 @@ def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True,
fmt_val = (argname, type(argval).__name__, __format_type(arg['type']))
type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val)
if enforce_shape and 'shape' in arg and argval is not None:
while not hasattr(argval, '__len__'):
valshape = get_data_shape(argval)
while valshape is None:
if argval is None:
break
if not hasattr(argval, argname):
fmt_val = (argval, argname, arg['shape'])
value_errors.append("cannot check object '%s' for shape for '%s' (expected shape '%s')"
value_errors.append("cannot check shape of object '%s' for argument '%s' (expected shape '%s')"
% fmt_val)
continue
break
# unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked
argval = getattr(argval, argname)
if hasattr(argval, '__len__') and not __shape_okay_multi(argval, arg['shape']):
fmt_val = (argname, get_data_shape(argval), arg['shape'])
value_errors.append("incorrect shape for '%s' (got '%s, expected '%s')" % fmt_val)
valshape = get_data_shape(argval)
if valshape is not None and not __shape_okay_multi(argval, arg['shape']):
fmt_val = (argname, valshape, arg['shape'])
value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val)
arg = next(it)
except StopIteration:
pass
Expand Down
147 changes: 146 additions & 1 deletion tests/unit/utils_test/test_docval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest2 as unittest
from six import text_type

from hdmf.utils import docval, fmt_docval_args, get_docval
from hdmf.utils import docval, fmt_docval_args, get_docval, popargs


class MyTestClass(object):
Expand Down Expand Up @@ -49,6 +49,55 @@ def basic_add2_kw(self, **kwargs):
return kwargs


class MyChainClass(MyTestClass):

@docval({'name': 'arg1', 'type': (str, 'MyChainClass'), 'doc': 'arg1 is a string or MyChainClass'},
{'name': 'arg2', 'type': ('array_data', 'MyChainClass'),
'doc': 'arg2 is array data or MyChainClass. it defaults to None', 'default': None},
{'name': 'arg3', 'type': ('array_data', 'MyChainClass'), 'doc': 'arg3 is array data or MyChainClass',
'shape': (None, 2)},
{'name': 'arg4', 'type': ('array_data', 'MyChainClass'),
'doc': 'arg3 is array data or MyChainClass. it defaults to None.', 'shape': (None, 2), 'default': None})
def __init__(self, **kwargs):
self._arg1, self._arg2, self._arg3, self._arg4 = popargs('arg1', 'arg2', 'arg3', 'arg4', kwargs)

@property
def arg1(self):
if isinstance(self._arg1, MyChainClass):
return self._arg1.arg1
else:
return self._arg1

@property
def arg2(self):
if isinstance(self._arg2, MyChainClass):
return self._arg2.arg2
else:
return self._arg2

@property
def arg3(self):
if isinstance(self._arg3, MyChainClass):
return self._arg3.arg3
else:
return self._arg3

@arg3.setter
def arg3(self, val):
self._arg3 = val

@property
def arg4(self):
if isinstance(self._arg4, MyChainClass):
return self._arg4.arg4
else:
return self._arg4

@arg4.setter
def arg4(self, val):
self._arg4 = val


class TestDocValidator(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -372,5 +421,101 @@ def test_get_docval_none_arg(self):
get_docval(self.test_obj.__init__, 'arg3')


class TestDocValidatorChain(unittest.TestCase):

def setUp(self):
self.obj1 = MyChainClass('base', [[1, 2], [3, 4], [5, 6]], [[10, 20]])
# note that self.obj1.arg3 == [[1, 2], [3, 4], [5, 6]]

def test_type_arg(self):
"""Test that passing an object for an argument that allows a specific type works"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]])
self.assertEqual(obj2.arg1, 'base')

def test_type_arg_wrong_type(self):
"""Test that passing an object for an argument that does not match a specific type raises an error"""
err_msg = r"incorrect type for 'arg1' \(got 'object', expected 'str or MyChainClass'\)"
with self.assertRaisesRegex(TypeError, err_msg):
MyChainClass(object(), [[10, 20], [30, 40], [50, 60]], [[10, 20]])

def test_shape_valid_unpack(self):
"""Test that passing an object for an argument with required shape tests the shape of object.argument"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]])
obj3 = MyChainClass(self.obj1, obj2, [[100, 200]])
self.assertListEqual(obj3.arg3, obj2.arg3)

def test_shape_invalid_unpack(self):
"""Test that passing an object for an argument with required shape and object.argument has an invalid shape
raises an error"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]])
# change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change
# shape after an object is initialized
obj2.arg3 = [10, 20, 30]

err_msg = r"incorrect shape for 'arg3' \(got '\(3,\)', expected '\(None, 2\)'\)"
with self.assertRaisesRegex(ValueError, err_msg):
MyChainClass(self.obj1, obj2, [[100, 200]])

def test_shape_none_unpack(self):
"""Test that passing an object for an argument with required shape and object.argument is None is OK"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]])
obj2.arg3 = None
obj3 = MyChainClass(self.obj1, obj2, [[100, 200]])
self.assertIsNone(obj3.arg3)

def test_shape_other_unpack(self):
"""Test that passing an object for an argument with required shape and object.argument is an object without
an argument attribute raises an error"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]])
obj2.arg3 = object()

err_msg = r"cannot check shape of object '<object object at .*>' for argument 'arg3' " \
r"\(expected shape '\(None, 2\)'\)"
with self.assertRaisesRegex(ValueError, err_msg):
MyChainClass(self.obj1, obj2, [[100, 200]])

def test_shape_valid_unpack_default(self):
"""Test that passing an object for an argument with required shape and a default value tests the shape of
object.argument"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]])
obj3 = MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2)
self.assertListEqual(obj3.arg4, obj2.arg4)

def test_shape_invalid_unpack_default(self):
"""Test that passing an object for an argument with required shape and a default value and object.argument has
an invalid shape raises an error"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]])
# change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change
# shape after an object is initialized
obj2.arg4 = [10, 20, 30]

err_msg = r"incorrect shape for 'arg4' \(got '\(3,\)', expected '\(None, 2\)'\)"
with self.assertRaisesRegex(ValueError, err_msg):
MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2)

def test_shape_none_unpack_default(self):
"""Test that passing an object for an argument with required shape and a default value and object.argument is
an object without an argument attribute raises an error"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]])
# change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change
# shape after an object is initialized
obj2.arg4 = None
obj3 = MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2)
self.assertIsNone(obj3.arg4)

def test_shape_other_unpack_default(self):
"""Test that passing an object for an argument with required shape and a default value and object.argument is
None is OK"""
obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]])
# change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change
# shape after an object is initialized
obj2.arg4 = object()

err_msg = r"cannot check shape of object '<object object at .*>' for argument 'arg4' " \
r"\(expected shape '\(None, 2\)'\)"
with self.assertRaisesRegex(ValueError, err_msg):
MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1cbbce6

Please sign in to comment.