diff --git a/caveclient/endpoints.py b/caveclient/endpoints.py index 58ec9fcd..2abf6218 100644 --- a/caveclient/endpoints.py +++ b/caveclient/endpoints.py @@ -308,8 +308,12 @@ + "/{datastack_name}/precomputed/skeleton/{skvn}/info", "get_cache_contents_via_skvn_ridprefixes": skeleton_v1 + "/{datastack_name}/precomputed/skeleton/query_cache/{skeleton_version}/{root_id_prefixes}/{limit}", + # TODO: DEPRECATED: This endpoint is deprecated and will be removed in the future. + # Please use the POST endpoint in the future. "skeletons_exist_via_skvn_rids": skeleton_v1 + "/{datastack_name}/precomputed/skeleton/exists/{skeleton_version}/{root_ids}", + "skeletons_exist_via_skvn_rids_as_post": skeleton_v1 + + "/{datastack_name}/precomputed/skeleton/exists", "get_skeleton_via_rid": skeleton_v1 + "/{datastack_name}/precomputed/skeleton/{root_id}", "get_skeleton_via_skvn_rid": skeleton_v1 diff --git a/caveclient/skeletonservice.py b/caveclient/skeletonservice.py index d2b4f63e..bb6a535d 100644 --- a/caveclient/skeletonservice.py +++ b/caveclient/skeletonservice.py @@ -24,8 +24,9 @@ SERVER_KEY = "skeleton_server_address" +MAX_SKELETONS_EXISTS_QUERY_SIZE = 1000 MAX_BULK_ASYNCHRONOUS_SKELETONS = 10000 -BULK_ASYNC_SKELETONS_BATCH_SIZE = 100 +BULK_SKELETONS_BATCH_SIZE = 100 class NoL2CacheException(Exception): @@ -219,7 +220,26 @@ def decompressBytesToDict(inputBytes): inputBytesStrDict = json.loads(inputBytesStr) return inputBytesStrDict - def _build_endpoint( + def _build_skeletons_exist_endpoint( + self, + root_ids: List, + datastack_name: str, + skeleton_version: int, + post: bool = False, + ): + endpoint_mapping = self.default_url_mapping + endpoint_mapping["datastack_name"] = datastack_name + if not post: + endpoint_mapping["root_ids"] = ",".join([str(v) for v in root_ids]) + endpoint_mapping["skeleton_version"] = skeleton_version + endpoint = "skeletons_exist_via_skvn_rids" + else: + endpoint = "skeletons_exist_via_skvn_rids_as_post" + + url = self._endpoints[endpoint].format_map(endpoint_mapping) + return url + + def _build_get_skeleton_endpoint( self, root_id: int, datastack_name: str, @@ -379,6 +399,11 @@ def skeletons_exist( """ Confirm or deny that a set of root ids have H5 skeletons in the cache. """ + if self._server_version < Version("0.9.0"): + logging.warning( + "Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions." + ) + if datastack_name is None: datastack_name = self._datastack_name assert datastack_name is not None @@ -389,32 +414,54 @@ def skeletons_exist( f"Unknown skeleton version: {skeleton_version}. Valid options: {valid_skeleton_versions}" ) + if isinstance(root_ids, int): + root_ids = str(root_ids) if isinstance(root_ids, np.ndarray): root_ids = root_ids.tolist() + if not isinstance(root_ids, List): # If not a list, it can only be a string at this point + root_ids = [root_ids] - if isinstance(root_ids, int): - root_ids = str(root_ids) - elif isinstance(root_ids, List): - root_ids = ",".join([str(v) for v in root_ids]) + if len(root_ids) > MAX_SKELETONS_EXISTS_QUERY_SIZE: + logging.warning( + f"The number of root_ids exceeds the current limit of {MAX_SKELETONS_EXISTS_QUERY_SIZE}. Only the first {MAX_SKELETONS_EXISTS_QUERY_SIZE} will be processed." + ) + root_ids = root_ids[:MAX_SKELETONS_EXISTS_QUERY_SIZE] - endpoint_mapping = self.default_url_mapping - endpoint_mapping["datastack_name"] = datastack_name - endpoint_mapping["root_ids"] = root_ids + results = {} + for batch in range(0, len(root_ids), BULK_SKELETONS_BATCH_SIZE): + rids_one_batch = root_ids[batch : batch + BULK_SKELETONS_BATCH_SIZE] - endpoint_mapping["skeleton_version"] = skeleton_version - url = self._endpoints["skeletons_exist_via_skvn_rids"].format_map( - endpoint_mapping - ) + if self._server_version < Version("0.9.0"): + url = self._build_skeletons_exist_endpoint( + rids_one_batch, datastack_name, skeleton_version + ) + response = self.session.get(url) + self.raise_for_status(response, log_warning=log_warning) + else: + url = self._build_skeletons_exist_endpoint( + rids_one_batch, datastack_name, skeleton_version, True + ) + data = { + "root_ids": rids_one_batch, + "skeleton_version": skeleton_version, + } + response = self.session.post(url, json=data) + response = handle_response(response, as_json=False) - response = self.session.get(url) - self.raise_for_status(response, log_warning=log_warning) + result_json = response.json() + if isinstance(result_json, dict): + # Convert string keys to ints + results.update({int(key): value for key, value in result_json.items()}) + elif isinstance(result_json, bool): + assert len(rids_one_batch) == 1 + results[int(rids_one_batch[0])] = result_json + else: + raise ValueError(f"Unexpected response type: {type(result_json)}") - result_json = response.json() - if isinstance(result_json, bool): + if len(results) == 1: # When investigating a single root id, this returns a single bool, not a dict, list, etc. - return result_json - result_json_w_ints = {int(key): value for key, value in result_json.items()} - return result_json_w_ints + return list(results.values())[0] + return results @cached(TTLCache(maxsize=32, ttl=3600)) def get_precomputed_skeleton_info( @@ -511,7 +558,7 @@ def get_skeleton( skeleton_versions = self.get_versions() skeleton_version = sorted(skeleton_versions)[-1] - url = self._build_endpoint( + url = self._build_get_skeleton_endpoint( root_id, datastack_name, skeleton_version, endpoint_format ) @@ -685,6 +732,11 @@ def generate_bulk_skeletons_async( if not self.fc.l2cache.has_cache(): raise NoL2CacheException("SkeletonClient requires an L2Cache.") + if self._server_version < Version("0.8.0"): + logging.warning( + "Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions." + ) + if skeleton_version is None: logging.warning( "The optional nature of the 'skeleton_version' parameter will be deprecated in the future. Please specify a skeleton version." @@ -697,11 +749,6 @@ def generate_bulk_skeletons_async( raise ValueError( f"root_ids must be a list or numpy array of root_ids, not a {type(root_ids)}" ) - - if self._server_version < Version("0.8.0"): - logging.warning( - "Server version is old and only supports GET interactions for bulk async skeletons. Consider upgrading to a newer server version to enable POST interactions." - ) if len(root_ids) > MAX_BULK_ASYNCHRONOUS_SKELETONS: logging.warning( @@ -714,14 +761,15 @@ def generate_bulk_skeletons_async( # So consider reverting to the unbatched approach in the future. estimated_async_time_secs_upper_bound_sum = 0 - for batch in range(0, len(root_ids), BULK_ASYNC_SKELETONS_BATCH_SIZE): - rids_one_batch = root_ids[batch : batch + BULK_ASYNC_SKELETONS_BATCH_SIZE] + for batch in range(0, len(root_ids), BULK_SKELETONS_BATCH_SIZE): + rids_one_batch = root_ids[batch : batch + BULK_SKELETONS_BATCH_SIZE] if self._server_version < Version("0.8.0"): url = self._build_bulk_async_endpoint( rids_one_batch, datastack_name, skeleton_version ) response = self.session.get(url) + self.raise_for_status(response, log_warning=log_warning) else: url = self._build_bulk_async_endpoint( rids_one_batch, datastack_name, skeleton_version, post=True