Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skeleton dev #290

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,12 @@
+ "/{datastack_name}/precomputed/skeleton/{skvn}/info",
"get_cache_contents_via_skvn_ridprefixes": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/query_cache/{skeleton_version}/{root_id_prefixes}/{limit}",
# TODO: DEPRECATED: This endpoint is deprecated and will be removed in the future.
# Please use the POST endpoint in the future.
"skeletons_exist_via_skvn_rids": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/exists/{skeleton_version}/{root_ids}",
"skeletons_exist_via_skvn_rids_as_post": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/exists",
"get_skeleton_via_rid": skeleton_v1
+ "/{datastack_name}/precomputed/skeleton/{root_id}",
"get_skeleton_via_skvn_rid": skeleton_v1
Expand Down
104 changes: 76 additions & 28 deletions caveclient/skeletonservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

SERVER_KEY = "skeleton_server_address"

MAX_SKELETONS_EXISTS_QUERY_SIZE = 1000
MAX_BULK_ASYNCHRONOUS_SKELETONS = 10000
BULK_ASYNC_SKELETONS_BATCH_SIZE = 100
BULK_SKELETONS_BATCH_SIZE = 100


class NoL2CacheException(Exception):
Expand Down Expand Up @@ -219,7 +220,26 @@ def decompressBytesToDict(inputBytes):
inputBytesStrDict = json.loads(inputBytesStr)
return inputBytesStrDict

def _build_endpoint(
def _build_skeletons_exist_endpoint(
self,
root_ids: List,
datastack_name: str,
skeleton_version: int,
post: bool = False,
):
endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
if not post:
endpoint_mapping["root_ids"] = ",".join([str(v) for v in root_ids])
endpoint_mapping["skeleton_version"] = skeleton_version
endpoint = "skeletons_exist_via_skvn_rids"
else:
endpoint = "skeletons_exist_via_skvn_rids_as_post"

url = self._endpoints[endpoint].format_map(endpoint_mapping)
return url

def _build_get_skeleton_endpoint(
self,
root_id: int,
datastack_name: str,
Expand Down Expand Up @@ -379,6 +399,11 @@ def skeletons_exist(
"""
Confirm or deny that a set of root ids have H5 skeletons in the cache.
"""
if self._server_version < Version("0.9.0"):
logging.warning(
"Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions."
)

if datastack_name is None:
datastack_name = self._datastack_name
assert datastack_name is not None
Expand All @@ -389,32 +414,54 @@ def skeletons_exist(
f"Unknown skeleton version: {skeleton_version}. Valid options: {valid_skeleton_versions}"
)

if isinstance(root_ids, int):
root_ids = str(root_ids)
if isinstance(root_ids, np.ndarray):
root_ids = root_ids.tolist()
if not isinstance(root_ids, List): # If not a list, it can only be a string at this point
root_ids = [root_ids]

if isinstance(root_ids, int):
root_ids = str(root_ids)
elif isinstance(root_ids, List):
root_ids = ",".join([str(v) for v in root_ids])
if len(root_ids) > MAX_SKELETONS_EXISTS_QUERY_SIZE:
logging.warning(
f"The number of root_ids exceeds the current limit of {MAX_SKELETONS_EXISTS_QUERY_SIZE}. Only the first {MAX_SKELETONS_EXISTS_QUERY_SIZE} will be processed."
)
root_ids = root_ids[:MAX_SKELETONS_EXISTS_QUERY_SIZE]

endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
endpoint_mapping["root_ids"] = root_ids
results = {}
for batch in range(0, len(root_ids), BULK_SKELETONS_BATCH_SIZE):
rids_one_batch = root_ids[batch : batch + BULK_SKELETONS_BATCH_SIZE]

endpoint_mapping["skeleton_version"] = skeleton_version
url = self._endpoints["skeletons_exist_via_skvn_rids"].format_map(
endpoint_mapping
)
if self._server_version < Version("0.9.0"):
url = self._build_skeletons_exist_endpoint(
rids_one_batch, datastack_name, skeleton_version
)
response = self.session.get(url)
self.raise_for_status(response, log_warning=log_warning)
else:
url = self._build_skeletons_exist_endpoint(
rids_one_batch, datastack_name, skeleton_version, True
)
data = {
"root_ids": rids_one_batch,
"skeleton_version": skeleton_version,
}
response = self.session.post(url, json=data)
response = handle_response(response, as_json=False)

response = self.session.get(url)
self.raise_for_status(response, log_warning=log_warning)
result_json = response.json()
if isinstance(result_json, dict):
# Convert string keys to ints
results.update({int(key): value for key, value in result_json.items()})
elif isinstance(result_json, bool):
assert len(rids_one_batch) == 1
results[int(rids_one_batch[0])] = result_json
else:
raise ValueError(f"Unexpected response type: {type(result_json)}")

result_json = response.json()
if isinstance(result_json, bool):
if len(results) == 1:
# When investigating a single root id, this returns a single bool, not a dict, list, etc.
return result_json
result_json_w_ints = {int(key): value for key, value in result_json.items()}
return result_json_w_ints
return list(results.values())[0]
return results

@cached(TTLCache(maxsize=32, ttl=3600))
def get_precomputed_skeleton_info(
Expand Down Expand Up @@ -511,7 +558,7 @@ def get_skeleton(
skeleton_versions = self.get_versions()
skeleton_version = sorted(skeleton_versions)[-1]

url = self._build_endpoint(
url = self._build_get_skeleton_endpoint(
root_id, datastack_name, skeleton_version, endpoint_format
)

Expand Down Expand Up @@ -685,6 +732,11 @@ def generate_bulk_skeletons_async(
if not self.fc.l2cache.has_cache():
raise NoL2CacheException("SkeletonClient requires an L2Cache.")

if self._server_version < Version("0.8.0"):
logging.warning(
"Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions."
)

if skeleton_version is None:
logging.warning(
"The optional nature of the 'skeleton_version' parameter will be deprecated in the future. Please specify a skeleton version."
Expand All @@ -697,11 +749,6 @@ def generate_bulk_skeletons_async(
raise ValueError(
f"root_ids must be a list or numpy array of root_ids, not a {type(root_ids)}"
)

if self._server_version < Version("0.8.0"):
logging.warning(
"Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions."
)

if len(root_ids) > MAX_BULK_ASYNCHRONOUS_SKELETONS:
logging.warning(
Expand All @@ -714,14 +761,15 @@ def generate_bulk_skeletons_async(
# So consider reverting to the unbatched approach in the future.

estimated_async_time_secs_upper_bound_sum = 0
for batch in range(0, len(root_ids), BULK_ASYNC_SKELETONS_BATCH_SIZE):
rids_one_batch = root_ids[batch : batch + BULK_ASYNC_SKELETONS_BATCH_SIZE]
for batch in range(0, len(root_ids), BULK_SKELETONS_BATCH_SIZE):
rids_one_batch = root_ids[batch : batch + BULK_SKELETONS_BATCH_SIZE]

if self._server_version < Version("0.8.0"):
url = self._build_bulk_async_endpoint(
rids_one_batch, datastack_name, skeleton_version
)
response = self.session.get(url)
self.raise_for_status(response, log_warning=log_warning)
else:
url = self._build_bulk_async_endpoint(
rids_one_batch, datastack_name, skeleton_version, post=True
Expand Down
Loading