Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multischema endpoint #158

Merged
merged 8 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions caveclient/emannotationschemas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from .base import ClientBase, _api_endpoints, handle_response
from .endpoints import schema_api_versions, schema_endpoints_common
from .auth import AuthClient
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
over_client=over_client,
)

def get_schemas(self):
def get_schemas(self) -> list[str]:
"""Get the available schema types

Returns
Expand All @@ -77,8 +78,8 @@ def get_schemas(self):
url = self._endpoints["schema"].format_map(endpoint_mapping)
response = self.session.get(url)
return handle_response(response)

def schema_definition(self, schema_type):
def schema_definition(self, schema_type: str) -> dict[str]:
"""Get the definition of a specified schema_type

Parameters
Expand All @@ -97,6 +98,38 @@ def schema_definition(self, schema_type):
response = self.session.get(url)
return handle_response(response)

def schema_definition_multi(self, schema_types: list[str]) -> dict:
"""Get the definition of multiple schema_types

Parameters
----------
schema_types : list
List of schema names

Returns
-------
dict
Dictionary of schema definitions. Keys are schema names, values are definitions.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["schema_definition_multi"].format_map(endpoint_mapping)
data={'schema_names': ','.join(schema_types)}
response = self.session.post(url, params=data)
return handle_response(response)

def schema_definition_all(self) -> dict[str]:
"""Get the definition of all schema_types

Returns
-------
dict
Dictionary of schema definitions. Keys are schema names, values are definitions.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["schema_definition_all"].format_map(endpoint_mapping)
response = self.session.get(url)
return handle_response(response)


client_mapping = {
1: SchemaClientLegacy,
Expand Down
2 changes: 2 additions & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@
schema_endpoints_v2 = {
"schema": schema_v2 + "/type",
"schema_definition": schema_v2 + "/type/{schema_type}",
"schema_definition_multi": schema_v2 + "/types",
"schema_definition_all": schema_v2 + "/types_all",
}

schema_api_versions = {1: schema_endpoints_v1, 2: schema_endpoints_v2}
Expand Down
67 changes: 46 additions & 21 deletions caveclient/materializationengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
import pyarrow as pa
import pytz
from concurrent.futures import ThreadPoolExecutor
from cachetools import TTLCache, cached
from IPython.display import HTML

Expand Down Expand Up @@ -247,25 +248,28 @@ def __init__(
over_client=over_client,
)
self._datastack_name = datastack_name
if version is None:
version = self.most_recent_version()
self._version = version
if cg_client is None:
if self.fc is not None:
self.cg_client = self.fc.chunkedgraph
else:
self.cg_client = cg_client
self._cg_client = cg_client
self.synapse_table = synapse_table
self.desired_resolution = desired_resolution
self._tables = None
self._views = None

@property
def datastack_name(self):
return self._datastack_name

@property
def cg_client(self):
if self._cg_client is None:
if self.fc is not None:
self._cg_client = self.fc.chunkedgraph
else:
raise ValueError("No chunkedgraph client specified")
return self._cg_client

@property
def version(self):
if self._version is None:
self._version = self.most_recent_version()
return self._version

@property
Expand Down Expand Up @@ -328,18 +332,6 @@ def get_versions(self, datastack_name=None, expired=False):
self.raise_for_status(response)
return response.json()

@property
def tables(self):
if self._tables is None:
self._tables = TableManager(self.fc)
return self._tables

@property
def views(self):
if self._views is None:
self._views = ViewManager(self.fc)
return self._views

def get_tables(self, datastack_name=None, version=None):
"""Gets a list of table names for a datastack

Expand Down Expand Up @@ -1878,6 +1870,39 @@ def _assemble_attributes(
class MaterializationClientV3(MaterializationClientV2):
def __init__(self, *args, **kwargs):
super(MaterializationClientV3, self).__init__(*args, **kwargs)
metadata = []
with ThreadPoolExecutor(max_workers=4) as executor:
metadata.append(
executor.submit(
self.get_tables_metadata,
)
)
metadata.append(
executor.submit(
self.fc.schema.schema_definition_all
)
)
metadata.append(
executor.submit(
self.get_views
)
)
metadata.append(
executor.submit(
self.get_view_schemas
)
)
if self.fc is not None:
tables = TableManager(self.fc, metadata[0].result(), metadata[1].result())
else:
tables = None
self.tables = tables

if self.fc is not None:
views = ViewManager(self.fc, metadata[2].result(), metadata[3].result())
else:
views = None
self.views = views

@cached(cache=TTLCache(maxsize=100, ttl=60 * 60 * 12))
def get_tables_metadata(
Expand Down
90 changes: 79 additions & 11 deletions caveclient/tools/table_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def combine_names(tableA, namesA, tableB, namesB, suffixes):
return final_namesA + final_namesB, table_map, rename_map


def get_all_table_metadata(client):
meta = client.materialize.get_tables_metadata()
def get_all_table_metadata(client, meta=None):
if meta is None:
meta = client.materialize.get_tables_metadata()
tables = []
for m in meta:
if m.get("annotation_table"):
Expand Down Expand Up @@ -116,6 +117,20 @@ def _schema_key(schema_name, client, **kwargs):
key = keys.hashkey(schema_name, str(allow_types))
return key

def populate_schema_cache(client, schema_definitions=None):
if schema_definitions is None:
try:
schema_definitions = client.schema.schema_definition_all()
except:
schema_definitions = {sn:None for sn in client.schema.get_schemas()}
for schema_name, schema_definition in schema_definitions.items():
get_col_info(schema_name, client, schema_definition=schema_definition)

def populate_table_cache(client, metadata=None):
if metadata is None:
metadata = get_all_table_metadata(client)
for tn, meta in metadata.items():
table_metadata(tn, client, meta=meta)

@cached(cache=_schema_cache, key=_schema_key)
def get_col_info(
Expand All @@ -126,8 +141,12 @@ def get_col_info(
allow_types=ALLOW_COLUMN_TYPES,
add_fields=["id"],
omit_fields=[],
schema_definition=None,
):
schema = client.schema.schema_definition(schema_name)
if schema_definition is None:
schema = client.schema.schema_definition(schema_name)
else:
schema = schema_definition.copy()
sp_name = f"#/definitions/{spatial_point}"
unbd_sp_name = f"#/definitions/{unbound_spatial_point}"
n_sp = 0
Expand Down Expand Up @@ -275,18 +294,18 @@ def get_table_info(

_metadata_cache = TTLCache(maxsize=128, ttl=86_400)


def _metadata_key(tn, client):
def _metadata_key(tn, client, **kwargs):
key = keys.hashkey(tn)
return key


@cached(cache=_metadata_cache, key=_metadata_key)
def table_metadata(table_name, client):
def table_metadata(table_name, client, meta=None):
"Caches getting table metadata"
with warnings.catch_warnings():
warnings.simplefilter(action="ignore")
meta = client.materialize.get_table_metadata(table_name)
if meta is None:
meta = client.materialize.get_table_metadata(table_name)
if "schema" not in meta:
meta["schema"] = meta.get("schema_type")
return meta
Expand Down Expand Up @@ -534,6 +553,29 @@ def query(
desired_resolution=None,
get_counts=False,
):
"""Query views through the table interface

Parameters
----------
select_columns : list[str], optional
Specification of columns to return, by default None
offset : int, optional
Integer offset from the beginning of the table to return, by default None.
Used when tables are too large to return in one query.
limit : int, optional
Maximum number of rows to return, by default None
split_positions : bool, optional
If true, returns each point coordinate as a separate column, by default False
materialization_version : int, optional
Query a specified materialization version, by default None
metadata : bool, optional
If true includes query and table metadata in the .attrs property of the returned dataframe, by default True
desired_resolution : list[int], optional
Sets the 3d point resolution in nm, by default None.
If default, uses the values in the table directly.
get_counts : bool, optional
Only return number of rows in the query, by default False
"""
logger.warning(
"The `client.materialize.views` interface is experimental and might experience breaking changes before the feature is stabilized."
)
Expand Down Expand Up @@ -601,24 +643,39 @@ def make_query_filter_view(view_name, meta, schema, client):
class TableManager(object):
"""Use schema definitions to generate query filters for each table."""

def __init__(self, client):
def __init__(self, client, metadata=None, schema=None):
self._client = client
self._table_metadata = get_all_table_metadata(self._client)
self._table_metadata = get_all_table_metadata(self._client, meta=metadata)
self._tables = sorted(list(self._table_metadata.keys()))
populate_schema_cache(client, schema_definitions=schema)
populate_table_cache(client, metadata=self._table_metadata)
for tn in self._tables:
setattr(self, tn, make_query_filter(tn, self._table_metadata[tn], client))

def __getitem__(self, key):
return getattr(self, key)

def __contains__(self, key):
return key in self._tables

def __repr__(self):
return str(self._tables)

@property
def table_names(self):
return self._tables

def __len__(self):
return len(self._tables)


class ViewManager(object):
def __init__(self, client):
def __init__(self, client, view_metadata=None, view_schema=None):
self._client = client
self._view_metadata, view_schema = get_all_view_metadata(self._client)
if view_metadata is None or view_schema is None:
view_metadata, view_schema = get_all_view_metadata(self._client)
else:
self._view_metadata = view_metadata
self._views = sorted(list(self._view_metadata.keys()))
for vn in self._views:
setattr(
Expand All @@ -632,5 +689,16 @@ def __init__(self, client):
def __getitem__(self, key):
return getattr(self, key)

def __contains__(self, key):
return key in self._views

def __repr__(self):
return str(self._views)

@property
def table_names(self):
return self._views

def __len__(self):
return len(self._views)

3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ cachetools>=4.2.1
ipython
networkx
jsonschema
attrs>=21.3.0
cachetools>=4
attrs>=21.3.0
Loading