Skip to content

Commit

Permalink
Leaves many (#273)
Browse files Browse the repository at this point in the history
* adding leaves many

* add test

* fix test

* fix test again

* move to seperate method

* cleaning up doc string

* fixing dictionary return

* fix formatting

* fixing test
  • Loading branch information
fcollman authored Dec 4, 2024
1 parent 775d4ad commit be76e20
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
43 changes: 41 additions & 2 deletions caveclient/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,43 @@ def get_tabular_change_log(self, root_ids, filtered=True) -> dict:

return changelog_dict

def get_leaves_many(self, root_ids, bounds=None, stop_layer: int = None) -> dict:
"""Get all supervoxels for a list of root IDs.
Parameters
----------
root_ids : Iterable
Root IDs to query.
bounds: np.array or None, optional
If specified, returns supervoxels within a 3x2 numpy array of bounds
``[[minx,maxx],[miny,maxy],[minz,maxz]]``. If None, finds all supervoxels.
stop_layer: int, optional
If specified, returns chunkedgraph nodes at layer `stop_layer`
default will be `stop_layer=1` (supervoxels).
Returns
-------
dict
Dict relating ids to contacts
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["leaves_many"].format_map(endpoint_mapping)
data = json.dumps({"node_ids": root_ids}, cls=BaseEncoder)
query_d = {}
if bounds is not None:
query_d["bounds"] = package_bounds(bounds)
if stop_layer is not None:
query_d["stop_layer"] = int(stop_layer)

response = self.session.post(
url,
data=data,
params=query_d,
headers={"Content-Type": "application/json"},
)
data_d = handle_response(response)
return {np.int64(k): np.int64(v) for k, v in data_d.items()}

def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray:
"""Get all supervoxels for a root ID.
Expand All @@ -458,13 +495,15 @@ def get_leaves(self, root_id, bounds=None, stop_layer: int = None) -> np.ndarray
Array of supervoxel IDs (or node ids if `stop_layer>1`).
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id
url = self._endpoints["leaves_from_root"].format_map(endpoint_mapping)

query_d = {}
if bounds is not None:
query_d["bounds"] = package_bounds(bounds)
if stop_layer is not None:
query_d["stop_layer"] = int(stop_layer)

endpoint_mapping["root_id"] = root_id
url = self._endpoints["leaves_from_root"].format_map(endpoint_mapping)
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response)["leaf_ids"])

Expand Down
1 change: 1 addition & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
"leaves_from_root": pcg_v1 + "/table/{table_id}/node/{root_id}/leaves",
"do_merge": pcg_v1 + "/table/{table_id}/merge",
"get_roots": pcg_v1 + "/table/{table_id}/roots_binary",
"leaves_many": pcg_v1 + "/table/{table_id}/node/leaves_many",
"merge_log": pcg_v1 + "/table/{table_id}/root/{root_id}/merge_log",
"change_log": pcg_v1 + "/table/{table_id}/root/{root_id}/change_log",
"tabular_change_log": pcg_v1 + "/table/{table_id}/tabular_change_log_many",
Expand Down
20 changes: 20 additions & 0 deletions tests/test_chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ def test_get_leaves(self, myclient):
)
assert np.all(svids == svids_ret)

@responses.activate
def test_get_leaves_many(self, myclient):
endpoint_mapping = self._default_endpoint_map
root_ids = [864691135217871271, 864691135217871272]
url = chunkedgraph_endpoints_v1["leaves_many"].format_map(endpoint_mapping)

sv_dict = {
"864691135217871271": [97557743795364048, 75089979126506763],
"864691135217871272": [97557743795364049, 750899791265067632],
}

data = {"node_ids": root_ids}
responses.add(
responses.POST, json=sv_dict, url=url, match=[json_params_matcher(data)]
)

svids_ret = myclient.chunkedgraph.get_leaves_many(root_ids)
for k, v in svids_ret.items():
assert np.all(v == sv_dict[str(k)])

@responses.activate
def test_get_root(self, myclient):
endpoint_mapping = self._default_endpoint_map
Expand Down

0 comments on commit be76e20

Please sign in to comment.