Skip to content

Commit

Permalink
Fix ObjectMapper convert_dtype function returning wrong type (#146)
Browse files Browse the repository at this point in the history
* Fix #145

* Fix tests to test for "is" instead of ==
  • Loading branch information
rly authored Sep 26, 2019
1 parent 62d487e commit ee84b31
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/hdmf/build/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def __check_edgecases(cls, spec, value):
if spec.dtype is None or spec.dtype == 'numeric' or type(value) in cls.__no_convert:
# infer type from value
if hasattr(value, 'dtype'): # covers numpy types, AbstractDataChunkIterator
return value, value.dtype
return value, value.dtype.type
if isinstance(value, (list, tuple)):
if len(value) == 0:
msg = "cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype."
Expand Down
32 changes: 16 additions & 16 deletions tests/unit/build_tests/test_io_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def convert_higher_precision_helper(self, spec_type, value_types):
with self.subTest(dtype=dtype):
ret = ObjectMapper.convert_dtype(spec, value)
self.assertTupleEqual(ret, match)
self.assertEqual(ret[0].dtype, match[1])
self.assertIs(ret[0].dtype, match[1])

def test_keep_higher_precision(self):
"""Test that passing a data type with a precision >= specified return the given type"""
Expand Down Expand Up @@ -572,7 +572,7 @@ def keep_higher_precision_helper(self, spec_type, value_types):
with self.subTest(dtype=dtype):
ret = ObjectMapper.convert_dtype(spec, value)
self.assertTupleEqual(ret, match)
self.assertEqual(ret[0].dtype, match[1])
self.assertIs(ret[0].dtype, match[1])

def test_no_spec(self):
spec_type = None
Expand All @@ -582,49 +582,49 @@ def test_no_spec(self):
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, int)
self.assertTupleEqual(ret, match)
self.assertEqual(type(ret[0][0]), match[1])
self.assertIs(type(ret[0][0]), match[1])

value = np.uint64(4)
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.uint64)
self.assertTupleEqual(ret, match)
self.assertEqual(type(ret[0]), match[1])
self.assertIs(type(ret[0]), match[1])

value = 'hello'
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, 'utf8')
self.assertTupleEqual(ret, match)
self.assertEqual(type(ret[0]), str)
self.assertIs(type(ret[0]), str)

value = bytes('hello', encoding='utf-8')
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, 'ascii')
self.assertTupleEqual(ret, match)
self.assertEqual(type(ret[0]), bytes)
self.assertIs(type(ret[0]), bytes)

value = DataChunkIterator(data=[1, 2, 3])
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.dtype(int))
match = (value, np.dtype(int).type)
self.assertTupleEqual(ret, match)
self.assertEqual(ret[0].dtype, match[1])
self.assertIs(ret[0].dtype.type, match[1])

value = DataChunkIterator(data=[1., 2., 3.])
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.dtype(float))
match = (value, np.dtype(float).type)
self.assertTupleEqual(ret, match)
self.assertEqual(ret[0].dtype, match[1])
self.assertIs(ret[0].dtype.type, match[1])

value = H5DataIO(np.arange(30).reshape(5, 2, 3))
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.dtype(int))
match = (value, np.dtype(int).type)
self.assertTupleEqual(ret, match)
self.assertEqual(ret[0].dtype, match[1])
self.assertIs(ret[0].dtype.type, match[1])

value = H5DataIO(['foo' 'bar'])
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, 'utf8')
self.assertTupleEqual(ret, match)
self.assertEqual(type(ret[0].data[0]), str)
self.assertIs(type(ret[0].data[0]), str)

def test_numeric_spec(self):
spec_type = 'numeric'
Expand All @@ -634,13 +634,13 @@ def test_numeric_spec(self):
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.uint64)
self.assertTupleEqual(ret, match)
self.assertEqual(type(ret[0]), match[1])
self.assertIs(type(ret[0]), match[1])

value = DataChunkIterator(data=[1, 2, 3])
ret = ObjectMapper.convert_dtype(spec, value)
match = (value, np.dtype(int))
match = (value, np.dtype(int).type)
self.assertTupleEqual(ret, match)
self.assertEqual(ret[0].dtype, match[1])
self.assertIs(ret[0].dtype.type, match[1])


if __name__ == '__main__':
Expand Down

0 comments on commit ee84b31

Please sign in to comment.