Skip to content

Commit

Permalink
post synapse detection training implemented; added unit tests for syn…
Browse files Browse the repository at this point in the history
…apses
  • Loading branch information
xiuliren committed Jan 5, 2022
1 parent 1305f8a commit 8f95065
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 42 deletions.
10 changes: 6 additions & 4 deletions chunkflow/chunk/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import cc3d
from cloudvolume.lib import yellow, Bbox

from chunkflow.lib.bounding_boxes import BoundingBox
from chunkflow.lib.bounding_boxes import BoundingBox, Cartesian

# from typing import Tuple
# Offset = Tuple[int, int, int]
Expand All @@ -30,11 +30,13 @@ class Chunk(NDArrayOperatorsMixin):
and `examples<https://docs.scipy.org/doc/numpy/user/basics.dispatch.html#module-numpy.doc.dispatch>`_.
:param array: the data array chunk in a big dataset
:param voxel_offset: the offset of this array chunk. 3 numbers: z, y, x
:param voxel_size: the size of each voxel, normally with unit of nm. 3 numbers: z, y, x.
:param voxel_offset (Cartesian): the offset of this array chunk. 3 numbers: z, y, x
:param voxel_size (Cartesian): the size of each voxel, normally with unit of nm. 3 numbers: z, y, x.
:return: a new chunk with array data and global offset
"""
def __init__(self, array: np.ndarray, voxel_offset: tuple = None, voxel_size: tuple = None):
def __init__(self, array: np.ndarray,
voxel_offset: Cartesian = None,
voxel_size: Cartesian = None):
assert isinstance(array, np.ndarray) or isinstance(array, Chunk)
self.array = array
if voxel_offset is None:
Expand Down
79 changes: 66 additions & 13 deletions chunkflow/lib/bounding_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from collections import UserList, namedtuple
from math import ceil
from typing import Union
from numbers import Number

from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -40,27 +42,70 @@ def __add__(self, offset: Union[Cartesian, int]):
else:
return Cartesian(*[x+o for x, o in zip(self, offset)])

def __mul__(self, m: Number) -> Cartesian:
return Cartesian(*[x*m for x in self])

def __floordiv__(self, d: int):
return Cartesian(*[x // d for x in self])

def __truediv__(self, d: Number):
return Cartesian(*[x/d for x in self])

def __mod__(self, d: int):
return Cartesian(*[x%d for x in self])

def __lt__(self, other: Cartesian) -> bool:
if self.x < other.x and self.y < other.y and self.z < other.z:
return True
else:
return False

def __le__(self, other: Cartesian) -> bool:
if self.x <= other.x and self.y <= other.y and self.z <= other.z:
return True
else:
return False

def __gt__(self, other: Cartesian) -> bool:
if self.x > other.x and self.y > other.y and self.z > other.z:
return True
else:
return False

def __ge__(self, other: Cartesian) -> bool:
if self.z >= other.z and self.y >= other.y and self.x >= other.x:
return True
else:
return False

def __ne__(self, other: Cartesian) -> bool:
if self.z != other.z and self.y != other.y and self.x != other.x:
return True
else:
return False

# def __isub__(self, other: Union[Cartesian,Number]) -> Cartesian:



@property
def vec(self):
return Vec(*self)


class BoundingBox(Bbox):
def __init__(self, min_corner: list, max_corner: list, dtype=None, voxel_size: tuple = None):
super().__init__(min_corner, max_corner, dtype=dtype)
self._voxel_size = voxel_size

@classmethod
def from_corners(cls, minpt: Cartesian, maxpt: Cartesian):
def __init__(self,
minpt: Union[list, Cartesian],
maxpt: Union[list, Cartesian],
dtype=None,
voxel_size: Cartesian = None):
if isinstance(minpt, Cartesian):
minpt = minpt.vec

if isinstance(maxpt, Cartesian):
maxpt = maxpt.vec
return cls(minpt, maxpt)
super().__init__(minpt, maxpt, dtype=dtype)
self._voxel_size = voxel_size

@classmethod
def from_bbox(cls, bbox: Bbox, voxel_size: tuple = None):
Expand Down Expand Up @@ -90,24 +135,28 @@ def from_center(cls, center: Cartesian, extent: int):
extent (int): the range to extent, like radius
"""
minpt = center - extent
maxpt = center + extent
return cls.from_corners(minpt, maxpt)
# the maxpt is not inclusive, so we need +1
maxpt = center + extent + 1
return cls(minpt, maxpt)

def __repr__(self):
return f'BoundingBox({self.minpt}, {self.maxpt}, dtype={self.dtype}, voxel_size={self.voxel_size})'

def clone(self):
bbox = Bbox(self.minpt, self.maxpt, dtype=self.dtype)
bbox = bbox.clone()
return BoundingBox.from_bbox(bbox, voxel_size=self.voxel_size)

def adjust(self, size: Union[int, tuple, list, Vec]):
def adjust(self, size: Union[Cartesian, int, tuple, list, Vec]):
if size is None:
logging.warn('adjusting bounding box size is None!')
return self

if not isinstance(size, int):
assert 3 == len(size)
self.minpt -= size
self.maxpt += size
assert len(size)==3 or len(size)==6
size = Vec(*size)
self.minpt -= size[:3]
self.maxpt += size[-3:]
return self

def union(self, bbox2):
Expand All @@ -132,6 +181,10 @@ def contains(self, point: Union[tuple, Vec, list]):
(self.maxpt >= Vec(*point)))) and np.all(
np.asarray((self.minpt <= Vec(*point))))

@property
def shape(self):
return Cartesian(*(self.maxpt - self.minpt))

@property
def voxel_size(self):
return self._voxel_size
Expand Down
146 changes: 123 additions & 23 deletions chunkflow/lib/synapses.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
from __future__ import annotations

import os
import json
from typing import List
from copy import deepcopy
from functools import cached_property
from collections import defaultdict

import numpy as np
import h5py

from .bounding_boxes import BoundingBox
from chunkflow.lib.bounding_boxes import BoundingBox, Cartesian


class Synapses():
def __init__(self, pre: np.ndarray, pre_confidence: np.ndarray = None,
post: np.ndarray = None, post_confidence: np.ndarray = None,
resolution: tuple = None) -> None:
resolution: Cartesian = None) -> None:
"""Synapses containing T-bars and post-synapses
Args:
pre (np.ndarray): T-bar points, Nx3, z,y,x, the coordinate should be physical coordinate rather than voxels.
pre_confidence (np.ndarray, optional): confidence of T-bar detection. Defaults to None.
post (np.ndarray, optional): [description]. Defaults to None.
resolution (tuple, optional): [description]. Defaults to None.
resolution (Cartesian, optional): [description]. Defaults to None.
"""
assert pre.ndim == 2
assert pre.shape[1] == 3
Expand Down Expand Up @@ -53,16 +56,16 @@ def __init__(self, pre: np.ndarray, pre_confidence: np.ndarray = None,
self.post_confidence = post_confidence

@classmethod
def from_dict(cls, synapses: dict):
def from_dict(cls, dc: dict):
"""Synapses as a dictionary
Args:
synapses (dict): the whole synapses in a dictionary
"""
order = synapses['order']
resolution = synapses['resolution']
del synapses['order']
del synapses['resolution']
order = dc['order']
resolution = dc['resolution']
del dc['order']
del dc['resolution']

pre_num = len(synapses)
pre = np.zeros((pre_num, 3), dtype=np.int32)
Expand Down Expand Up @@ -98,11 +101,11 @@ def from_dict(cls, synapses: dict):
@classmethod
def from_json(cls, fname: str, resolution: tuple = None):
with open(fname, 'r') as file:
synapses = json.load(file)
syns = json.load(file)

if resolution is not None:
synapses['resolution'] = resolution
return cls.from_dict(synapses)
syns['resolution'] = resolution
return cls.from_dict(syns)

@classmethod
def from_h5(cls, fname: str, resolution: tuple = None, c_order: bool = True):
Expand Down Expand Up @@ -137,6 +140,11 @@ def from_h5(cls, fname: str, resolution: tuple = None, c_order: bool = True):
post_confidence=post_confidence, resolution=resolution)

def to_h5(self, fname: str) -> None:
"""save to a HDF5 file
Args:
fname (str): the file name to be saved
"""
assert fname.endswith(".h5") or fname.endswith(".hdf5")
with h5py.File(fname, "w") as hf:

Expand Down Expand Up @@ -165,6 +173,27 @@ def from_file(cls, fname: str, resolution: tuple = None, c_order: bool = True):
else:
raise ValueError(f'only support JSON and HDF5 file, but got {fname}')

def __eq__(self, other: Synapses) -> bool:
"""compare two synapses.
Note that we do not compare the confidence here!
Args:
other (Synapses): the other Synapses instance
Returns:
bool: whether the pre and post are the same
"""
if np.array_equal(self.pre, other.pre):
if self.post is None and other.post is None:
return True
elif self.post is not None and other.post is not None and np.array_equal(
self.post, other.post):
return True
else:
return False
else:
return False

@property
def post_coordinates(self) -> np.ndarray:
"""the coordinate array
Expand All @@ -182,20 +211,32 @@ def pre_num(self) -> int:
def post_num(self) -> int:
return self.post.shape[0]

@property
def pre_bounding_box(self) -> BoundingBox:
bbox = BoundingBox.from_points(self.pre)
# the end point is exclusive
bbox.adjust((0,0,0, 1,1,1))
return bbox

def post_bounding_box(self) -> BoundingBox:
bbox = BoundingBox.from_points(self.post_coordinates)
# the right direction is exclusive
bbox.adjust((0,0,0, 1,1,1))
return bbox

@property
def bounding_box(self) -> BoundingBox:
bbox = self.pre_bounding_box
bbox.union(self.post_bounding_box)
return bbox

@property
def pre_with_physical_coordinate(self) -> np.ndarray:
if self.resolution is not None:
return self.pre * self.resolution
else:
return self.pre

@property
def bounding_box(self) -> BoundingBox:
bbox = BoundingBox.from_points(self.pre)
bbox_post = BoundingBox.from_points(self.post[:, 1:])
bbox.union(bbox_post)
return bbox


@property
def post_with_physical_coordinate(self):
""" post synapses with physical coordinate. Note that the first column is the index of
Expand All @@ -213,11 +254,14 @@ def post_with_physical_coordinate(self):

@cached_property
def pre_index2post_indices(self):
pi2pi = defaultdict(list)
# pi2pi = defaultdict(list)
pi2pi = []
for idx in range(self.pre_num):
# find the post synapses for this presynapse
post_indices = np.argwhere(self.post[:, 0]==idx)
pi2pi[idx].append(post_indices)
post_indices = np.nonzero(self.post[:, 0]==idx)
assert len(post_indices) == 1
post_indices = post_indices[0].tolist()
pi2pi.append(post_indices)

return pi2pi

Expand All @@ -230,4 +274,60 @@ def distances_from_pre_to_post(self):
pre = self.pre[pre_idx, :]
distances[post_idx] = np.linalg.norm(pre - post)
return distances


@property
def pre_indices_without_post(self) -> List[int]:
"""presynapse indices that do not have post synapses
Returns:
[list]: a list of presynapse indices
"""
pi2pi = self.pre_index2post_indices
pre_indices = []
for pre_index in range(self.pre_num):
post_indices = pi2pi[pre_index]
if len(post_indices) == 0:
pre_indices.append(pre_index)
return pre_indices

def remove_pre(self, indices: List[int]):
"""remove or delete presynapses according to a list of indices
Note that we need to update the post synapse as well!
Args:
indices (List[int]): the presynapse indices
"""# update the presynapse indices in post
# old presynapse index to new presynapse index
old2new = np.ones((self.pre_num,), dtype=np.int64)
old2new[indices] = 0
old2new = np.cumsum(old2new) - 1

self.pre = np.delete(self.pre, indices, axis=0)

post_indices = np.isin(self.post[:, 0], indices, assume_unique=True)
self.post = np.delete(self.post, post_indices, axis=0)
for idx in range(self.post_num):
self.post[idx, 0] = old2new[self.post[idx, 0]]

def remove_synapses_without_post(self):
"""remove synapse without post synapse target
Returns:
None: remove in place
"""
selected_pre_indices = self.pre_indices_without_post

# remove the selected presynapses
self.remove_pre(selected_pre_indices)



if __name__ == '__main__':
synapses = Synapses.from_h5(
os.path.expanduser(
'~/dropbox/40_gt/21_wasp_synapses/Sp2,6848-7448_5690-6290_7038-7638.h5'
)
)
assert len(synapses.pre_index2post_indices[0]) > 1
breakpoint()
Loading

0 comments on commit 8f95065

Please sign in to comment.