From 414ea72935d42f84360d6951001951dcec16fc35 Mon Sep 17 00:00:00 2001 From: Casey Schneider-Mizell Date: Mon, 11 Mar 2024 22:48:46 -1000 Subject: [PATCH] multithreaded materialization initialization --- caveclient/materializationengine.py | 60 ++++++++++++++++--------- caveclient/tools/table_manager.py | 68 +++++++++++++++++++++-------- requirements.txt | 3 +- 3 files changed, 91 insertions(+), 40 deletions(-) diff --git a/caveclient/materializationengine.py b/caveclient/materializationengine.py index b2ebbfd3..e29a90e8 100644 --- a/caveclient/materializationengine.py +++ b/caveclient/materializationengine.py @@ -11,6 +11,7 @@ import pandas as pd import pyarrow as pa import pytz +from concurrent.futures import ThreadPoolExecutor from cachetools import TTLCache, cached from IPython.display import HTML @@ -247,25 +248,28 @@ def __init__( over_client=over_client, ) self._datastack_name = datastack_name - if version is None: - version = self.most_recent_version() self._version = version - if cg_client is None: - if self.fc is not None: - self.cg_client = self.fc.chunkedgraph - else: - self.cg_client = cg_client + self._cg_client = cg_client self.synapse_table = synapse_table self.desired_resolution = desired_resolution - self._tables = None - self._views = None @property def datastack_name(self): return self._datastack_name + + @property + def cg_client(self): + if self._cg_client is None: + if self.fc is not None: + self._cg_client = self.fc.chunkedgraph + else: + raise ValueError("No chunkedgraph client specified") + return self._cg_client @property def version(self): + if self._version is None: + self._version = self.most_recent_version() return self._version @property @@ -328,18 +332,6 @@ def get_versions(self, datastack_name=None, expired=False): self.raise_for_status(response) return response.json() - @property - def tables(self): - if self._tables is None: - self._tables = TableManager(self.fc) - return self._tables - - @property - def views(self): - if self._views is None: - self._views = ViewManager(self.fc) - return self._views - def get_tables(self, datastack_name=None, version=None): """Gets a list of table names for a datastack @@ -1878,6 +1870,32 @@ def _assemble_attributes( class MaterializationClientV3(MaterializationClientV2): def __init__(self, *args, **kwargs): super(MaterializationClientV3, self).__init__(*args, **kwargs) + metadata = [] + with ThreadPoolExecutor(max_workers=4) as executor: + metadata.append( + executor.submit( + self.get_tables_metadata, + ) + ) + metadata.append( + executor.submit( + self.fc.schema.schema_definition_all + ) + ) + metadata.append( + executor.submit( + self.get_views + ) + ) + metadata.append( + executor.submit( + self.get_view_schemas + ) + ) + tables = TableManager(self.fc, metadata[0].result(), metadata[1].result()) + self.tables = tables + views = ViewManager(self.fc, metadata[2].result(), metadata[3].result()) + self.views = views @cached(cache=TTLCache(maxsize=100, ttl=60 * 60 * 12)) def get_tables_metadata( diff --git a/caveclient/tools/table_manager.py b/caveclient/tools/table_manager.py index 46d03b02..c374040e 100644 --- a/caveclient/tools/table_manager.py +++ b/caveclient/tools/table_manager.py @@ -57,8 +57,9 @@ def combine_names(tableA, namesA, tableB, namesB, suffixes): return final_namesA + final_namesB, table_map, rename_map -def get_all_table_metadata(client): - meta = client.materialize.get_tables_metadata() +def get_all_table_metadata(client, meta=None): + if meta is None: + meta = client.materialize.get_tables_metadata() tables = [] for m in meta: if m.get("annotation_table"): @@ -116,14 +117,21 @@ def _schema_key(schema_name, client, **kwargs): key = keys.hashkey(schema_name, str(allow_types)) return key -def populate_schema_cache(client): - try: - schema_definitions = client.schema.schema_definition_all() - except: - schema_definitions = {sn:None for sn in client.schema.get_schemas()} +def populate_schema_cache(client, schema_definitions=None): + if schema_definitions is None: + try: + schema_definitions = client.schema.schema_definition_all() + except: + schema_definitions = {sn:None for sn in client.schema.get_schemas()} for schema_name, schema_definition in schema_definitions.items(): get_col_info(schema_name, client, schema_definition=schema_definition) +def populate_table_cache(client, metadata=None): + if metadata is None: + metadata = get_all_table_metadata(client) + for tn, meta in metadata.items(): + table_metadata(tn, client, meta=meta) + @cached(cache=_schema_cache, key=_schema_key) def get_col_info( schema_name, @@ -286,18 +294,18 @@ def get_table_info( _metadata_cache = TTLCache(maxsize=128, ttl=86_400) - -def _metadata_key(tn, client): +def _metadata_key(tn, client, **kwargs): key = keys.hashkey(tn) return key @cached(cache=_metadata_cache, key=_metadata_key) -def table_metadata(table_name, client): +def table_metadata(table_name, client, meta=None): "Caches getting table metadata" with warnings.catch_warnings(): warnings.simplefilter(action="ignore") - meta = client.materialize.get_table_metadata(table_name) + if meta is None: + meta = client.materialize.get_table_metadata(table_name) if "schema" not in meta: meta["schema"] = meta.get("schema_type") return meta @@ -545,6 +553,29 @@ def query( desired_resolution=None, get_counts=False, ): + """Query views through the table interface + + Parameters + ---------- + select_columns : list[str], optional + Specification of columns to return, by default None + offset : int, optional + Integer offset from the beginning of the table to return, by default None. + Used when tables are too large to return in one query. + limit : int, optional + Maximum number of rows to return, by default None + split_positions : bool, optional + If true, returns each point coordinate as a separate column, by default False + materialization_version : int, optional + Query a specified materialization version, by default None + metadata : bool, optional + If true includes query and table metadata in the .attrs property of the returned dataframe, by default True + desired_resolution : list[int], optional + Sets the 3d point resolution in nm, by default None. + If default, uses the values in the table directly. + get_counts : bool, optional + Only return number of rows in the query, by default False + """ logger.warning( "The `client.materialize.views` interface is experimental and might experience breaking changes before the feature is stabilized." ) @@ -612,11 +643,12 @@ def make_query_filter_view(view_name, meta, schema, client): class TableManager(object): """Use schema definitions to generate query filters for each table.""" - def __init__(self, client): + def __init__(self, client, metadata=None, schema=None): self._client = client - self._table_metadata = get_all_table_metadata(self._client) + self._table_metadata = get_all_table_metadata(self._client, meta=metadata) self._tables = sorted(list(self._table_metadata.keys())) - populate_schema_cache(client) + populate_schema_cache(client, schema_definitions=schema) + populate_table_cache(client, metadata=self._table_metadata) for tn in self._tables: setattr(self, tn, make_query_filter(tn, self._table_metadata[tn], client)) @@ -628,9 +660,12 @@ def __repr__(self): class ViewManager(object): - def __init__(self, client): + def __init__(self, client, view_metadata=None, view_schema=None): self._client = client - self._view_metadata, view_schema = get_all_view_metadata(self._client) + if view_metadata is None or view_schema is None: + view_metadata, view_schema = get_all_view_metadata(self._client) + else: + self._view_metadata = view_metadata self._views = sorted(list(self._view_metadata.keys())) for vn in self._views: setattr( @@ -647,4 +682,3 @@ def __getitem__(self, key): def __repr__(self): return str(self._views) - \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1b154679..489c93d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ cachetools>=4.2.1 ipython networkx jsonschema -attrs>=21.3.0 -cachetools>=4 \ No newline at end of file +attrs>=21.3.0 \ No newline at end of file