Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flywire FAFB dataset added #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions connectomes/connectivity/flywire.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# A collection of tools to interface with various connectome backends.
#
# Copyright (C) 2021 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.


from functools import partial

import numpy as np
import pandas as pd

from .base import ConnectivitySource
from ..utils.flywire import get_cave_client, retry

NoneType = type(None)

class FlywireConnectivitySource(ConnectivitySource):
def __init__(self, dataset = 'production'):
self.dataset = dataset
# Get the cave client
self.client = get_cave_client(dataset=self.dataset)

def get_edges(self, sources, targets=None, min_score=10, batch_size=20):
"""Fetch edges between sources and targets.

Parameters
----------
sources : int | list of int |
Body ID(s) of sources.
targets : int | list of int | None
Body ID(s) of targets.

Returns
-------
edges : pandas.DataFrame

"""
if targets is None:
targets = sources

mat = self.client.materialize.version
columns = ['pre_pt_root_id', 'post_pt_root_id', 'cleft_score']
func = partial(retry(self.client.materialize.query_table),
table=self.client.materialize.synapse_table,
materialization_version=mat,
select_columns=columns)

edges, syn = [], []
for i in range(0, len(sources), batch_size):
source_batch = sources[i:i+batch_size]
for k in range(0, len(targets), batch_size):
target_batch = targets[k:k+batch_size]

this = func(filter_in_dict=dict(post_pt_root_id=target_batch,
pre_pt_root_id=source_batch))

# We need to drop the .attrs (which contain meta data from queries)
# Otherwise we run into issues when concatenating
this.attrs = {}

if not this.empty:
syn.append(this)

# Combine results from batches
if len(syn):
syn = pd.concat(syn, axis=0, ignore_index=True)
else:
edges = pd.DataFrame(np.zeros((len(sources), len(targets))),
index=sources, columns=targets)
edges.index.name = 'source'
edges.columns.name = 'target'
return edges


# Depending on how queries were batched, we need to drop duplicate synapses
syn.drop_duplicates('id', inplace=True)

# Rename some of those columns
syn.rename({'post_pt_root_id': 'post', 'pre_pt_root_id': 'pre'},
axis=1, inplace=True)

# Next we need to run some clean-up:
# Drop below threshold connections
if min_score:
syn = syn[syn.cleft_score >= min_score]

# Aggregate
cn = syn.groupby(['pre', 'post'], as_index=False).size()
cn.columns = ['source', 'target', 'weight']

# Pivot
edges = cn.pivot(index='source', columns='target', values='weight').fillna(0)

# Index to match order and add any missing neurons
edges = edges.reindex(index=sources, columns=targets).fillna(0)
return edges

def get_synapses(self, x, transmitters=False, batch_size=20):
"""Retrieve synapse for given neurons.

Parameters
----------
x : int | list of int | neu.NeuronCriteria
Body ID(s) to fetch synapses for. For more complicated
queries use neuprint-python's ``NeuronCriteria``. You can
use ``None`` to fetch all incoming edges of ``targets``.

"""
mat = self.client.materialize.version

columns = ['pre_pt_root_id', 'post_pt_root_id', 'cleft_score',
'pre_pt_position', 'post_pt_position', 'id']

if transmitters:
columns += ['gaba', 'ach', 'glut', 'oct', 'ser', 'da']

func = partial(retry(self.client.materialize.query_table),
table=self.client.materialize.synapse_table,
split_positions=True,
materialization_version=mat,
select_columns=columns)

syn = []
for i in range(0, len(x), batch_size):
batch = x[i:i+batch_size]
syn.append(func(filter_in_dict=dict(post_pt_root_id=batch)))
syn.append(func(filter_in_dict=dict(pre_pt_root_id=batch)))

# Drop attrs to avoid issues when concatenating
for df in syn:
df.attrs = {}

# Combine results from batches
syn = pd.concat(syn, axis=0, ignore_index=True)

# Depending on how queries were batched, we need to drop duplicate synapses
syn.drop_duplicates('id', inplace=True)

# Rename some of those columns
syn.rename({'post_pt_root_id': 'post',
'pre_pt_root_id': 'pre',
'post_pt_position_x': 'post_x',
'post_pt_position_y': 'post_y',
'post_pt_position_z': 'post_z',
'pre_pt_position_x': 'pre_x',
'pre_pt_position_y': 'pre_y',
'pre_pt_position_z': 'pre_z',
},
axis=1, inplace=True)
return syn
46 changes: 43 additions & 3 deletions connectomes/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from abc import ABC, abstractmethod

from ..meshes.neu import NeuPrintMeshSource
from ..meshes.flywire import FlywireMeshSource
from ..skeletons.neu import NeuPrintSkeletonSource
from ..segmentation.cloudvol import CloudVolSegmentationSource
from ..connectivity.neu import NeuPrintConnectivitySource
from ..connectivity.flywire import FlywireConnectivitySource
from ..annotations.neu import NeuPrintAnnotationSource

from ..utils.flywire import get_chunkedgraph_secret, set_chunkedgraph_secret

@functools.lru_cache
def get(dataset, *args, **kwargs):
Expand Down Expand Up @@ -57,7 +59,7 @@ class HemiBrain(BaseDataSet):
Parameters
----------
version : str
Version to use. Defaults to the currently lates (1.2.1).
Version to use. Defaults to the currently latest (1.2.1).
server : str
The server to use. Defaults to the public service.

Expand Down Expand Up @@ -104,5 +106,43 @@ def check_token(self):
raise ValueError(msg)


class FAFB(BaseDataSet):
"""Interface with the Flywire FAFB dataset.

Parameters
----------
dataset : str
Version of dataset to use. Defaults to 'production'.

References
----------
Zheng, Z. et al. A complete electron microscopy volume of the brain of adult
Drosophila melanogaster. Cell 174, 730-743 (2018)
"""
def __init__(self, dataset='production'):
# Check if credentials are set
self.check_token()
self.dataset = dataset

self.mesh = FlywireMeshSource(self.dataset)
self.connectivity = FlywireConnectivitySource(self.dataset)
self.segmentation = None # TODO
self.annotations = None # TODO

self.reference = 'Zheng, Z. et al. (2018)'

def __str__(self):
return f'FlyWire FAFB dataset (v{self.version})'

def check_token(self):
"""Checks whether FAFB Flywire token is set correct.
"""
get_chunkedgraph_secret()

def set_token(self, token : str):
"""Convenience method for setting the FAFB Flywire token.
"""
set_chunkedgraph_secret(token)

# Add more datasets here
DATASETS = {'hemibrain': HemiBrain}
DATASETS = {'hemibrain': HemiBrain, 'fafb': FAFB}
60 changes: 60 additions & 0 deletions connectomes/meshes/flywire.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# A collection of tools to interface with various connectome backends.
#
# Copyright (C) 2021 Philipp Schlegel
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

import numpy as np

from concurrent.futures import ThreadPoolExecutor, as_completed

from .base import MeshSource
from ..utils.flywire import parse_volume, is_iterable

class FlywireMeshSource(MeshSource):
def __init__(self, dataset='production'):
self.dataset = dataset

def get(self, x, threads=2, omit_failures=False, **kwargs):
"""Fetch meshes for given neuron id.

Parameters
----------
x : int | list
Defines which meshes to fetch. Can be:
- a body ID (integers)
- lists of body IDs
threads : bool | int, optional
Whether to use threads to fetch meshes in parallel.
omit_failures : bool, optional
Determine behaviour when mesh download
fails.
"""
vol = parse_volume(self.dataset)
if is_iterable(x):
x = np.asarray(x, dtype=np.int64)
if not threads or threads == 1:
return [self.get(id_, **kwargs) for id_ in x]
else:
if not isinstance(threads, int):
raise TypeError(f'`threads` must be int or `None`, got "{type(threads)}".')
with ThreadPoolExecutor(max_workers=threads) as executor:
futures = {executor.submit(self.get, n,
omit_failures=omit_failures,
threads=None): n for n in x}

results = []
for f in as_completed(futures):
results.append(f.result())
return results
x = np.int64(x)
mesh = vol.mesh.get(x, remove_duplicate_vertices=True)[x]
return mesh
Loading