From be76e20b4eda4672b2970ce9215456d4ce945311 Mon Sep 17 00:00:00 2001 From: Forrest Collman Date: Tue, 3 Dec 2024 18:57:02 -0800 Subject: [PATCH] Leaves many (#273) * adding leaves many * add test * fix test * fix test again * move to seperate method * cleaning up doc string * fixing dictionary return * fix formatting * fixing test --- caveclient/chunkedgraph.py | 43 ++++++++++++++++++++++++++++++++++++-- caveclient/endpoints.py | 1 + tests/test_chunkedgraph.py | 20 ++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/caveclient/chunkedgraph.py b/caveclient/chunkedgraph.py index 8f6b7030..2e62eb7f 100644 --- a/caveclient/chunkedgraph.py +++ b/caveclient/chunkedgraph.py @@ -438,6 +438,43 @@ def get_tabular_change_log(self, root_ids, filtered=True) -> dict: return changelog_dict + def get_leaves_many(self, root_ids, bounds=None, stop_layer: int = None) -> dict: + """Get all supervoxels for a list of root IDs. + + Parameters + ---------- + root_ids : Iterable + Root IDs to query. + bounds: np.array or None, optional + If specified, returns supervoxels within a 3x2 numpy array of bounds + ``[[minx,maxx],[miny,maxy],[minz,maxz]]``. If None, finds all supervoxels. + stop_layer: int, optional + If specified, returns chunkedgraph nodes at layer `stop_layer` + default will be `stop_layer=1` (supervoxels). + + Returns + ------- + dict + Dict relating ids to contacts + """ + endpoint_mapping = self.default_url_mapping + url = self._endpoints["leaves_many"].format_map(endpoint_mapping) + data = json.dumps({"node_ids": root_ids}, cls=BaseEncoder) + query_d = {} + if bounds is not None: + query_d["bounds"] = package_bounds(bounds) + if stop_layer is not None: + query_d["stop_layer"] = int(stop_layer) + + response = self.session.post( + url, + data=data, + params=query_d, + headers={"Content-Type": "application/json"}, + ) + data_d = handle_response(response) + return {np.int64(k): np.int64(v) for k, v in data_d.items()} + def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray: """Get all supervoxels for a root ID. @@ -458,13 +495,15 @@ def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray Array of supervoxel IDs (or node ids if `stop_layer>1`). """ endpoint_mapping = self.default_url_mapping - endpoint_mapping["root_id"] = root_id - url = self._endpoints["leaves_from_root"].format_map(endpoint_mapping) + query_d = {} if bounds is not None: query_d["bounds"] = package_bounds(bounds) if stop_layer is not None: query_d["stop_layer"] = int(stop_layer) + + endpoint_mapping["root_id"] = root_id + url = self._endpoints["leaves_from_root"].format_map(endpoint_mapping) response = self.session.get(url, params=query_d) return np.int64(handle_response(response)["leaf_ids"]) diff --git a/caveclient/endpoints.py b/caveclient/endpoints.py index 4bceba99..ca23bd65 100644 --- a/caveclient/endpoints.py +++ b/caveclient/endpoints.py @@ -158,6 +158,7 @@ "leaves_from_root": pcg_v1 + "/table/{table_id}/node/{root_id}/leaves", "do_merge": pcg_v1 + "/table/{table_id}/merge", "get_roots": pcg_v1 + "/table/{table_id}/roots_binary", + "leaves_many": pcg_v1 + "/table/{table_id}/node/leaves_many", "merge_log": pcg_v1 + "/table/{table_id}/root/{root_id}/merge_log", "change_log": pcg_v1 + "/table/{table_id}/root/{root_id}/change_log", "tabular_change_log": pcg_v1 + "/table/{table_id}/tabular_change_log_many", diff --git a/tests/test_chunkedgraph.py b/tests/test_chunkedgraph.py index f21e71d5..92a7bc8c 100644 --- a/tests/test_chunkedgraph.py +++ b/tests/test_chunkedgraph.py @@ -118,6 +118,26 @@ def test_get_leaves(self, myclient): ) assert np.all(svids == svids_ret) + @responses.activate + def test_get_leaves_many(self, myclient): + endpoint_mapping = self._default_endpoint_map + root_ids = [864691135217871271, 864691135217871272] + url = chunkedgraph_endpoints_v1["leaves_many"].format_map(endpoint_mapping) + + sv_dict = { + "864691135217871271": [97557743795364048, 75089979126506763], + "864691135217871272": [97557743795364049, 750899791265067632], + } + + data = {"node_ids": root_ids} + responses.add( + responses.POST, json=sv_dict, url=url, match=[json_params_matcher(data)] + ) + + svids_ret = myclient.chunkedgraph.get_leaves_many(root_ids) + for k, v in svids_ret.items(): + assert np.all(v == sv_dict[str(k)]) + @responses.activate def test_get_root(self, myclient): endpoint_mapping = self._default_endpoint_map