diff --git a/connectomes/connectivity/flywire.py b/connectomes/connectivity/flywire.py new file mode 100644 index 0000000..03219c9 --- /dev/null +++ b/connectomes/connectivity/flywire.py @@ -0,0 +1,159 @@ +# A collection of tools to interface with various connectome backends. +# +# Copyright (C) 2021 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + + +from functools import partial + +import numpy as np +import pandas as pd + +from .base import ConnectivitySource +from ..utils.flywire import get_cave_client, retry + +NoneType = type(None) + +class FlywireConnectivitySource(ConnectivitySource): + def __init__(self, dataset = 'production'): + self.dataset = dataset + # Get the cave client + self.client = get_cave_client(dataset=self.dataset) + + def get_edges(self, sources, targets=None, min_score=10, batch_size=20): + """Fetch edges between sources and targets. + + Parameters + ---------- + sources : int | list of int | + Body ID(s) of sources. + targets : int | list of int | None + Body ID(s) of targets. + + Returns + ------- + edges : pandas.DataFrame + + """ + if targets is None: + targets = sources + + mat = self.client.materialize.version + columns = ['pre_pt_root_id', 'post_pt_root_id', 'cleft_score'] + func = partial(retry(self.client.materialize.query_table), + table=self.client.materialize.synapse_table, + materialization_version=mat, + select_columns=columns) + + edges, syn = [], [] + for i in range(0, len(sources), batch_size): + source_batch = sources[i:i+batch_size] + for k in range(0, len(targets), batch_size): + target_batch = targets[k:k+batch_size] + + this = func(filter_in_dict=dict(post_pt_root_id=target_batch, + pre_pt_root_id=source_batch)) + + # We need to drop the .attrs (which contain meta data from queries) + # Otherwise we run into issues when concatenating + this.attrs = {} + + if not this.empty: + syn.append(this) + + # Combine results from batches + if len(syn): + syn = pd.concat(syn, axis=0, ignore_index=True) + else: + edges = pd.DataFrame(np.zeros((len(sources), len(targets))), + index=sources, columns=targets) + edges.index.name = 'source' + edges.columns.name = 'target' + return edges + + + # Depending on how queries were batched, we need to drop duplicate synapses + syn.drop_duplicates('id', inplace=True) + + # Rename some of those columns + syn.rename({'post_pt_root_id': 'post', 'pre_pt_root_id': 'pre'}, + axis=1, inplace=True) + + # Next we need to run some clean-up: + # Drop below threshold connections + if min_score: + syn = syn[syn.cleft_score >= min_score] + + # Aggregate + cn = syn.groupby(['pre', 'post'], as_index=False).size() + cn.columns = ['source', 'target', 'weight'] + + # Pivot + edges = cn.pivot(index='source', columns='target', values='weight').fillna(0) + + # Index to match order and add any missing neurons + edges = edges.reindex(index=sources, columns=targets).fillna(0) + return edges + + def get_synapses(self, x, transmitters=False, batch_size=20): + """Retrieve synapse for given neurons. + + Parameters + ---------- + x : int | list of int | neu.NeuronCriteria + Body ID(s) to fetch synapses for. For more complicated + queries use neuprint-python's ``NeuronCriteria``. You can + use ``None`` to fetch all incoming edges of ``targets``. + + """ + mat = self.client.materialize.version + + columns = ['pre_pt_root_id', 'post_pt_root_id', 'cleft_score', + 'pre_pt_position', 'post_pt_position', 'id'] + + if transmitters: + columns += ['gaba', 'ach', 'glut', 'oct', 'ser', 'da'] + + func = partial(retry(self.client.materialize.query_table), + table=self.client.materialize.synapse_table, + split_positions=True, + materialization_version=mat, + select_columns=columns) + + syn = [] + for i in range(0, len(x), batch_size): + batch = x[i:i+batch_size] + syn.append(func(filter_in_dict=dict(post_pt_root_id=batch))) + syn.append(func(filter_in_dict=dict(pre_pt_root_id=batch))) + + # Drop attrs to avoid issues when concatenating + for df in syn: + df.attrs = {} + + # Combine results from batches + syn = pd.concat(syn, axis=0, ignore_index=True) + + # Depending on how queries were batched, we need to drop duplicate synapses + syn.drop_duplicates('id', inplace=True) + + # Rename some of those columns + syn.rename({'post_pt_root_id': 'post', + 'pre_pt_root_id': 'pre', + 'post_pt_position_x': 'post_x', + 'post_pt_position_y': 'post_y', + 'post_pt_position_z': 'post_z', + 'pre_pt_position_x': 'pre_x', + 'pre_pt_position_y': 'pre_y', + 'pre_pt_position_z': 'pre_z', + }, + axis=1, inplace=True) + return syn diff --git a/connectomes/datasets/__init__.py b/connectomes/datasets/__init__.py index 03ba293..711c1e4 100644 --- a/connectomes/datasets/__init__.py +++ b/connectomes/datasets/__init__.py @@ -20,11 +20,13 @@ from abc import ABC, abstractmethod from ..meshes.neu import NeuPrintMeshSource +from ..meshes.flywire import FlywireMeshSource from ..skeletons.neu import NeuPrintSkeletonSource from ..segmentation.cloudvol import CloudVolSegmentationSource from ..connectivity.neu import NeuPrintConnectivitySource +from ..connectivity.flywire import FlywireConnectivitySource from ..annotations.neu import NeuPrintAnnotationSource - +from ..utils.flywire import get_chunkedgraph_secret, set_chunkedgraph_secret @functools.lru_cache def get(dataset, *args, **kwargs): @@ -57,7 +59,7 @@ class HemiBrain(BaseDataSet): Parameters ---------- version : str - Version to use. Defaults to the currently lates (1.2.1). + Version to use. Defaults to the currently latest (1.2.1). server : str The server to use. Defaults to the public service. @@ -104,5 +106,43 @@ def check_token(self): raise ValueError(msg) +class FAFB(BaseDataSet): + """Interface with the Flywire FAFB dataset. + + Parameters + ---------- + dataset : str + Version of dataset to use. Defaults to 'production'. + + References + ---------- + Zheng, Z. et al. A complete electron microscopy volume of the brain of adult + Drosophila melanogaster. Cell 174, 730-743 (2018) + """ + def __init__(self, dataset='production'): + # Check if credentials are set + self.check_token() + self.dataset = dataset + + self.mesh = FlywireMeshSource(self.dataset) + self.connectivity = FlywireConnectivitySource(self.dataset) + self.segmentation = None # TODO + self.annotations = None # TODO + + self.reference = 'Zheng, Z. et al. (2018)' + + def __str__(self): + return f'FlyWire FAFB dataset (v{self.version})' + + def check_token(self): + """Checks whether FAFB Flywire token is set correct. + """ + get_chunkedgraph_secret() + + def set_token(self, token : str): + """Convenience method for setting the FAFB Flywire token. + """ + set_chunkedgraph_secret(token) + # Add more datasets here -DATASETS = {'hemibrain': HemiBrain} +DATASETS = {'hemibrain': HemiBrain, 'fafb': FAFB} diff --git a/connectomes/meshes/flywire.py b/connectomes/meshes/flywire.py new file mode 100644 index 0000000..7779e51 --- /dev/null +++ b/connectomes/meshes/flywire.py @@ -0,0 +1,60 @@ +# A collection of tools to interface with various connectome backends. +# +# Copyright (C) 2021 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +import numpy as np + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from .base import MeshSource +from ..utils.flywire import parse_volume, is_iterable + +class FlywireMeshSource(MeshSource): + def __init__(self, dataset='production'): + self.dataset = dataset + + def get(self, x, threads=2, omit_failures=False, **kwargs): + """Fetch meshes for given neuron id. + + Parameters + ---------- + x : int | list + Defines which meshes to fetch. Can be: + - a body ID (integers) + - lists of body IDs + threads : bool | int, optional + Whether to use threads to fetch meshes in parallel. + omit_failures : bool, optional + Determine behaviour when mesh download + fails. + """ + vol = parse_volume(self.dataset) + if is_iterable(x): + x = np.asarray(x, dtype=np.int64) + if not threads or threads == 1: + return [self.get(id_, **kwargs) for id_ in x] + else: + if not isinstance(threads, int): + raise TypeError(f'`threads` must be int or `None`, got "{type(threads)}".') + with ThreadPoolExecutor(max_workers=threads) as executor: + futures = {executor.submit(self.get, n, + omit_failures=omit_failures, + threads=None): n for n in x} + + results = [] + for f in as_completed(futures): + results.append(f.result()) + return results + x = np.int64(x) + mesh = vol.mesh.get(x, remove_duplicate_vertices=True)[x] + return mesh diff --git a/connectomes/utils/flywire.py b/connectomes/utils/flywire.py new file mode 100644 index 0000000..61d73a7 --- /dev/null +++ b/connectomes/utils/flywire.py @@ -0,0 +1,362 @@ +# A collection of tools to interface with various connectome backends. +# +# Copyright (C) 2022 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +import functools +import json +import os +import six +import pytz +import time +import requests +import warnings + +from urllib.parse import urlparse +from collections.abc import Iterable +from typing import Any +from pathlib import Path +from importlib import reload + +from caveclient import CAVEclient +import cloudvolume as cv +import datetime as dt +import numpy as np +import pandas as pd + +__all__ = ['set_chunkedgraph_secret', 'get_chunkedgraph_secret', + 'get_cave_client', 'is_iterable'] + +FLYWIRE_DATASETS = {'production': 'fly_v31', + 'sandbox': 'fly_v26'} + +CAVE_DATASETS = {'production': 'flywire_fafb_production', + 'sandbox': 'flywire_fafb_sandbox'} + +# Initialize without a volume +fw_vol = None +cave_clients = {} + +# Data stuff +fp = Path(__file__).parent +data_path = fp.parent / 'data' +area_ids = None +vol_names = None + +def is_iterable(x: Any) -> bool: + """Test if input is iterable (but not str). + Examples + -------- + >>> from navis.utils import is_iterable + >>> is_iterable(['a']) + True + >>> is_iterable('a') + False + >>> is_iterable({'a': 1}) + True + """ + if isinstance(x, Iterable) and not isinstance(x, (six.string_types, pd.DataFrame)): + return True + else: + return False + +def is_url(x): + """Check if URL is valid.""" + try: + result = urlparse(x) + return all([result.scheme, result.netloc, result.path]) + except BaseException: + return False + +def make_iterable(x, force_type=None) -> np.ndarray: + """Force input into a numpy array. + For dicts, keys will be turned into array. + Examples + -------- + >>> from navis.utils import make_iterable + >>> make_iterable(1) + array([1]) + >>> make_iterable([1]) + array([1]) + >>> make_iterable({'a': 1}) + array(['a'], dtype=' dt.timedelta(minutes=30): + force_new = True + + if datastack not in cave_clients or force_new: + cave_clients[datastack] = CAVEclient(datastack, auth_token=token) + cave_clients[datastack].birth_day = dt.datetime.now() + + return cave_clients[datastack] + + +def get_chunkedgraph_secret(domain='prod.flywire-daf.com'): + """Get chunked graph secret. + + Parameters + ---------- + domain : str + Domain to get the secret for. Only relevant for + ``cloudvolume>=3.11.0``. + + Returns + ------- + token : str + + """ + if hasattr(cv.secrets, 'cave_credentials'): + token = cv.secrets.cave_credentials(domain).get('token', None) + if not token: + raise ValueError(f'No chunkedgraph secret for domain {domain} ' + 'found.') + else: + try: + token = cv.secrets.chunkedgraph_credentials['token'] + except BaseException: + raise ValueError('No chunkedgraph secret found.', + 'Use `set_chunkedgraph_secret` method to set it.') + return token + + +def set_chunkedgraph_secret(token, filepath=None, + domain='prod.flywire-daf.com'): + """Set chunked graph secret (called "cave credentials" now). + + Parameters + ---------- + token : str + Get your token from + https://globalv1.flywire-daf.com/auth/api/v1/refresh_token + filepath : str filepath + Path to secret file. If not provided will store in default path: + ``~/.cloudvolume/secrets/{domain}-cave-secret.json`` + domain : str + The domain (incl subdomain) this secret is for. + + """ + assert isinstance(token, str), f'Token must be string, got "{type(token)}"' + + if not filepath: + filepath = f'~/.cloudvolume/secrets/{domain}-cave-secret.json' + elif not filepath.endswith('/chunkedgraph-secret.json'): + filepath = os.path.join(filepath, f'{domain}-cave-secret.json') + elif not filepath.endswith('.json'): + filepath = f'{filepath}.json' + + filepath = Path(filepath).expanduser() + + # Make sure this file (and the path!) actually exist + if not filepath.exists(): + if not filepath.parent.exists(): + filepath.parent.mkdir(parents=True) + filepath.touch() + + with open(filepath, 'w+') as f: + json.dump({'token': token}, f) + + # We need to reload cloudvolume for changes to take effect + reload(cv.secrets) + reload(cv) + + # Should also reset the volume after setting the secret + global fw_vol + fw_vol = None + + print("Token succesfully stored in ", filepath) + + +def parse_root_ids(x): + """Parse root IDs. + + Always returns an array of integers. + """ + if isinstance(x, (int, np.int64)): + ids = [x] + else: + ids = make_iterable(x, force_type=np.int64) + + # Make sure we are working with proper numerical IDs + try: + return np.asarray(ids, dtype=np.int64) + except ValueError: + raise ValueError(f'Unable to convert given root IDs to integer: {ids}') + except BaseException: + raise + + +def parse_volume(vol, **kwargs): + """Parse CloudVolume.""" + global fw_vol + if 'CloudVolume' not in str(type(vol)): + if not isinstance(vol, str): + raise ValueError(f'Unable to initialize CloudVolume from "{type(vol)}"') + + if not is_url(vol): + # We are assuming this is the dataset + # Map "production" and "sandbox" with to their correct designations + vol = FLYWIRE_DATASETS.get(vol, vol) + + # Below is supposedly the "old" api (/1.0/) + # vol = f'graphene://https://prodv1.flywire-daf.com/segmentation/1.0/{vol}' + + # This is the new url + vol = f'graphene://https://prod.flywire-daf.com/segmentation/table/{vol}' + + # This might eventually become the new url + # vol = f'graphene://https://prodv1.flywire-daf.com/segmentation_proc/table/{vol}' + + if not vol.startswith('graphene://'): + vol = f'graphene://{vol}' + + # Change default volume if necessary + if not fw_vol or getattr(fw_vol, 'path', None) != vol: + # Set and update defaults from kwargs + defaults = dict(mip=0, + fill_missing=True, + cache=False, + use_https=True, # this way google secret is not needed + progress=False) + defaults.update(kwargs) + + # Check if chunkedgraph secret exists + # This probably needs yanking! + secret = os.path.expanduser('~/.cloudvolume/secrets/chunkedgraph-secret.json') + if not os.path.isfile(secret): + # If not secrets but environment variable use this + if 'CHUNKEDGRAPH_SECRET' in os.environ and 'secrets' not in defaults: + defaults['secrets'] = {'token': os.environ['CHUNKEDGRAPH_SECRET']} + + fw_vol = cv.CloudVolume(vol, **defaults) + fw_vol.path = vol + else: + fw_vol = vol + return fw_vol + + +def retry(func, retries=5, cooldown=2): + """Retry function on HTTPError. + + This also suppresses UserWarnings (because we typically use this for stuff + like the l2 Cache). + + Parameters + ---------- + cooldown : int | float + Cooldown period in seconds between attempts. + retries : int + Number of retries before we give up. Every subsequent retry + will delay by an additional `retry`. + + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + for i in range(1, retries + 1): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + return func(*args, **kwargs) + except requests.HTTPError: + if i >= retries: + raise + except BaseException: + raise + time.sleep(cooldown * i) + return wrapper + + +def parse_bounds(x): + """Parse bounds. + + Parameters + ---------- + x : (3, 2) array | (2, 3) array | None + + Returns + ------- + bounds : (3, 2) np.array + + """ + if isinstance(x, type(None)): + return x + + x = np.asarray(x) + + if not x.ndim == 2 or x.shape not in [(3, 2), (2, 3)]: + raise ValueError('Must provide bounding box as (3, 2) or (2, 3) array, ' + f'got {x.shape}') + + if x.shape == (2, 3): + x = x.T + + return np.vstack((x.min(axis=1), x.max(axis=1))).T