Skip to content

Commit

Permalink
multithreaded materialization initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ceesem committed Mar 12, 2024
1 parent da717c3 commit 414ea72
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
60 changes: 39 additions & 21 deletions caveclient/materializationengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
import pyarrow as pa
import pytz
from concurrent.futures import ThreadPoolExecutor
from cachetools import TTLCache, cached
from IPython.display import HTML

Expand Down Expand Up @@ -247,25 +248,28 @@ def __init__(
over_client=over_client,
)
self._datastack_name = datastack_name
if version is None:
version = self.most_recent_version()
self._version = version
if cg_client is None:
if self.fc is not None:
self.cg_client = self.fc.chunkedgraph
else:
self.cg_client = cg_client
self._cg_client = cg_client
self.synapse_table = synapse_table
self.desired_resolution = desired_resolution
self._tables = None
self._views = None

@property
def datastack_name(self):
return self._datastack_name

@property
def cg_client(self):
if self._cg_client is None:
if self.fc is not None:
self._cg_client = self.fc.chunkedgraph
else:
raise ValueError("No chunkedgraph client specified")
return self._cg_client

@property
def version(self):
if self._version is None:
self._version = self.most_recent_version()
return self._version

@property
Expand Down Expand Up @@ -328,18 +332,6 @@ def get_versions(self, datastack_name=None, expired=False):
self.raise_for_status(response)
return response.json()

@property
def tables(self):
if self._tables is None:
self._tables = TableManager(self.fc)
return self._tables

@property
def views(self):
if self._views is None:
self._views = ViewManager(self.fc)
return self._views

def get_tables(self, datastack_name=None, version=None):
"""Gets a list of table names for a datastack
Expand Down Expand Up @@ -1878,6 +1870,32 @@ def _assemble_attributes(
class MaterializationClientV3(MaterializationClientV2):
def __init__(self, *args, **kwargs):
super(MaterializationClientV3, self).__init__(*args, **kwargs)
metadata = []
with ThreadPoolExecutor(max_workers=4) as executor:
metadata.append(
executor.submit(
self.get_tables_metadata,
)
)
metadata.append(
executor.submit(
self.fc.schema.schema_definition_all
)
)
metadata.append(
executor.submit(
self.get_views
)
)
metadata.append(
executor.submit(
self.get_view_schemas
)
)
tables = TableManager(self.fc, metadata[0].result(), metadata[1].result())
self.tables = tables
views = ViewManager(self.fc, metadata[2].result(), metadata[3].result())
self.views = views

@cached(cache=TTLCache(maxsize=100, ttl=60 * 60 * 12))
def get_tables_metadata(
Expand Down
68 changes: 51 additions & 17 deletions caveclient/tools/table_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def combine_names(tableA, namesA, tableB, namesB, suffixes):
return final_namesA + final_namesB, table_map, rename_map


def get_all_table_metadata(client):
meta = client.materialize.get_tables_metadata()
def get_all_table_metadata(client, meta=None):
if meta is None:
meta = client.materialize.get_tables_metadata()
tables = []
for m in meta:
if m.get("annotation_table"):
Expand Down Expand Up @@ -116,14 +117,21 @@ def _schema_key(schema_name, client, **kwargs):
key = keys.hashkey(schema_name, str(allow_types))
return key

def populate_schema_cache(client):
try:
schema_definitions = client.schema.schema_definition_all()
except:
schema_definitions = {sn:None for sn in client.schema.get_schemas()}
def populate_schema_cache(client, schema_definitions=None):
if schema_definitions is None:
try:
schema_definitions = client.schema.schema_definition_all()
except:
schema_definitions = {sn:None for sn in client.schema.get_schemas()}
for schema_name, schema_definition in schema_definitions.items():
get_col_info(schema_name, client, schema_definition=schema_definition)

def populate_table_cache(client, metadata=None):
if metadata is None:
metadata = get_all_table_metadata(client)
for tn, meta in metadata.items():
table_metadata(tn, client, meta=meta)

@cached(cache=_schema_cache, key=_schema_key)
def get_col_info(
schema_name,
Expand Down Expand Up @@ -286,18 +294,18 @@ def get_table_info(

_metadata_cache = TTLCache(maxsize=128, ttl=86_400)


def _metadata_key(tn, client):
def _metadata_key(tn, client, **kwargs):
key = keys.hashkey(tn)
return key


@cached(cache=_metadata_cache, key=_metadata_key)
def table_metadata(table_name, client):
def table_metadata(table_name, client, meta=None):
"Caches getting table metadata"
with warnings.catch_warnings():
warnings.simplefilter(action="ignore")
meta = client.materialize.get_table_metadata(table_name)
if meta is None:
meta = client.materialize.get_table_metadata(table_name)
if "schema" not in meta:
meta["schema"] = meta.get("schema_type")
return meta
Expand Down Expand Up @@ -545,6 +553,29 @@ def query(
desired_resolution=None,
get_counts=False,
):
"""Query views through the table interface
Parameters
----------
select_columns : list[str], optional
Specification of columns to return, by default None
offset : int, optional
Integer offset from the beginning of the table to return, by default None.
Used when tables are too large to return in one query.
limit : int, optional
Maximum number of rows to return, by default None
split_positions : bool, optional
If true, returns each point coordinate as a separate column, by default False
materialization_version : int, optional
Query a specified materialization version, by default None
metadata : bool, optional
If true includes query and table metadata in the .attrs property of the returned dataframe, by default True
desired_resolution : list[int], optional
Sets the 3d point resolution in nm, by default None.
If default, uses the values in the table directly.
get_counts : bool, optional
Only return number of rows in the query, by default False
"""
logger.warning(
"The `client.materialize.views` interface is experimental and might experience breaking changes before the feature is stabilized."
)
Expand Down Expand Up @@ -612,11 +643,12 @@ def make_query_filter_view(view_name, meta, schema, client):
class TableManager(object):
"""Use schema definitions to generate query filters for each table."""

def __init__(self, client):
def __init__(self, client, metadata=None, schema=None):
self._client = client
self._table_metadata = get_all_table_metadata(self._client)
self._table_metadata = get_all_table_metadata(self._client, meta=metadata)
self._tables = sorted(list(self._table_metadata.keys()))
populate_schema_cache(client)
populate_schema_cache(client, schema_definitions=schema)
populate_table_cache(client, metadata=self._table_metadata)
for tn in self._tables:
setattr(self, tn, make_query_filter(tn, self._table_metadata[tn], client))

Expand All @@ -628,9 +660,12 @@ def __repr__(self):


class ViewManager(object):
def __init__(self, client):
def __init__(self, client, view_metadata=None, view_schema=None):
self._client = client
self._view_metadata, view_schema = get_all_view_metadata(self._client)
if view_metadata is None or view_schema is None:
view_metadata, view_schema = get_all_view_metadata(self._client)
else:
self._view_metadata = view_metadata
self._views = sorted(list(self._view_metadata.keys()))
for vn in self._views:
setattr(
Expand All @@ -647,4 +682,3 @@ def __getitem__(self, key):
def __repr__(self):
return str(self._views)


3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ cachetools>=4.2.1
ipython
networkx
jsonschema
attrs>=21.3.0
cachetools>=4
attrs>=21.3.0

0 comments on commit 414ea72

Please sign in to comment.