From 9c727d9410298bd966a94431ea5a0207102e4779 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Mon, 9 Dec 2024 09:49:10 -0800 Subject: [PATCH] handling invalid uint64 values --- caveclient/chunkedgraph.py | 16 +++++++++++++--- tests/test_chunkedgraph.py | 5 +++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/caveclient/chunkedgraph.py b/caveclient/chunkedgraph.py index cc86bb47..c9ea73b2 100644 --- a/caveclient/chunkedgraph.py +++ b/caveclient/chunkedgraph.py @@ -1323,8 +1323,16 @@ def is_valid_nodes( OverflowError If any root_id is too large or negative to be represented as np.uint64. """ + node_ids = np.array(node_ids, dtype=np.int64) + invalid_mask = (node_ids == -1) | (node_ids == 0) + valid_mask = ~invalid_mask - node_ids = root_id_int_list_check(node_ids, make_unique=False) + valid_node_ids = node_ids[valid_mask] + + if valid_node_ids.size == 0: + return np.full(node_ids.shape, False, dtype=bool) + + valid_node_ids = root_id_int_list_check(valid_node_ids, make_unique=False) endpoint_mapping = self.default_url_mapping url = self._endpoints["valid_nodes"].format_map(endpoint_mapping) @@ -1349,15 +1357,17 @@ def is_valid_nodes( ) ) - data = {"node_ids": node_ids} + data = {"node_ids": valid_node_ids} r = handle_response( self.session.get( url, data=json.dumps(data, cls=BaseEncoder), params=query_d ) ) valid_ids = np.array(r["valid_roots"], np.uint64) + result = np.full(node_ids.shape, False, dtype=bool) + result[valid_mask] = np.isin(valid_node_ids, valid_ids) - return np.isin(node_ids, valid_ids) + return result @_check_version_compatibility( kwarg_use_constraints={ diff --git a/tests/test_chunkedgraph.py b/tests/test_chunkedgraph.py index 277126ae..2c63cd7f 100644 --- a/tests/test_chunkedgraph.py +++ b/tests/test_chunkedgraph.py @@ -1037,5 +1037,6 @@ def test_is_valid_nodes(self, myclient): json=return_data, match=[json_params_matcher(data)], ) - with pytest.raises(OverflowError): - out = myclient.chunkedgraph.is_valid_nodes(query_nodes) + + out = myclient.chunkedgraph.is_valid_nodes(query_nodes) + assert not np.any(out)