diff --git a/caveclient/skeletonservice.py b/caveclient/skeletonservice.py index 483a82de..8473043f 100644 --- a/caveclient/skeletonservice.py +++ b/caveclient/skeletonservice.py @@ -8,6 +8,7 @@ from io import BytesIO, StringIO from typing import List, Literal, Optional, Union +import numpy as np import pandas as pd from cachetools import TTLCache, cached from packaging.version import Version @@ -22,6 +23,9 @@ SERVER_KEY = "skeleton_server_address" +MAX_BULK_ASYNCHRONOUS_SKELETONS = 10000 +BULK_ASYNC_SKELETONS_BATCH_SIZE = 100 + class NoL2CacheException(Exception): def __init__(self, value=""): @@ -375,6 +379,9 @@ def skeletons_exist( f"Unknown skeleton version: {skeleton_version}. Valid options: {valid_skeleton_versions}" ) + if isinstance(root_ids, np.ndarray): + root_ids = root_ids.tolist() + if isinstance(root_ids, int): root_ids = str(root_ids) elif isinstance(root_ids, List): @@ -669,20 +676,57 @@ def generate_bulk_skeletons_async( ) skeleton_version = -1 - url = self._build_bulk_async_endpoint( - root_ids, datastack_name, skeleton_version - ) - response = self.session.get(url) - self.raise_for_status(response, log_warning=log_warning) + if isinstance(root_ids, np.ndarray): + root_ids = root_ids.tolist() + if not isinstance(root_ids, list): + raise ValueError( + f"root_ids must be a list or numpy array of root_ids, not a {type(root_ids)}" + ) - estimated_async_time_secs_upper_bound = float(response.text) + if len(root_ids) > MAX_BULK_ASYNCHRONOUS_SKELETONS: + logging.warning( + f"The number of root_ids exceeds the current limit of {MAX_BULK_ASYNCHRONOUS_SKELETONS}. Only the first {MAX_BULK_ASYNCHRONOUS_SKELETONS} will be processed." + ) + root_ids = root_ids[:MAX_BULK_ASYNCHRONOUS_SKELETONS] - if verbose_level >= 1: - logging.info( - f"Queued asynchronous skeleton generation for root_ids: {root_ids}" + estimated_async_time_secs_upper_bound_sum = 0 + for batch in range(0, len(root_ids), BULK_ASYNC_SKELETONS_BATCH_SIZE): + rids_one_batch = root_ids[batch : batch + BULK_ASYNC_SKELETONS_BATCH_SIZE] + + url = self._build_bulk_async_endpoint( + rids_one_batch, datastack_name, skeleton_version ) - logging.info( - f"Upper estimate to generate {len(root_ids)} skeletons: {estimated_async_time_secs_upper_bound} seconds" + response = self.session.get(url) + self.raise_for_status(response, log_warning=log_warning) + + estimated_async_time_secs_upper_bound = float(response.text) + estimated_async_time_secs_upper_bound_sum += ( + estimated_async_time_secs_upper_bound + ) + + if verbose_level >= 1: + logging.info( + f"Queued asynchronous skeleton generation for one batch of root_ids: {rids_one_batch}" + ) + logging.info( + f"Upper estimate to generate one batch of {len(rids_one_batch)} skeletons: {estimated_async_time_secs_upper_bound} seconds" + ) + + if estimated_async_time_secs_upper_bound_sum < 60: + estimate_time_str = ( + f"{estimated_async_time_secs_upper_bound_sum:.0f} seconds" + ) + elif estimated_async_time_secs_upper_bound_sum < 3600: + estimate_time_str = ( + f"{(estimated_async_time_secs_upper_bound_sum / 60):.1f} minutes" + ) + # With a 10000 skeleton limit, the maximum time about 12 hours, so we don't need to check for more than that. + # elif estimated_async_time_secs_upper_bound_sum < 86400: + else: + estimate_time_str = ( + f"{(estimated_async_time_secs_upper_bound_sum / 3600):.1f} hours" ) + # else: + # estimate_time_str = f"{(estimated_async_time_secs_upper_bound_sum / 86400):.2f} days" - return f"Upper estimate to generate {len(root_ids)} skeletons: {estimated_async_time_secs_upper_bound} seconds" + return f"Upper estimate to generate all {len(root_ids)} skeletons: {estimate_time_str}"