Skip to content

Commit

Permalink
wip2
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Feb 6, 2024
1 parent 6b1a3a6 commit 6cc2a0b
Showing 1 changed file with 81 additions and 56 deletions.
137 changes: 81 additions & 56 deletions python/neuroglancer/write_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ def _get_dtype_for_properties(
return dtype


def shrink_to_uniform_size(arr: np.ndarray) -> np.ndarray:
half_sizes = arr / 2
max_half_size = np.max(half_sizes)
new_sizes = np.where(half_sizes >= max_half_size, half_sizes, arr)
return new_sizes


class AnnotationWriter:
annotations: list[Annotation]
related_annotations: list[dict[int, list[Annotation]]]
Expand Down Expand Up @@ -262,7 +269,9 @@ def __init__(
shape=(self.rank,), fill_value=float("-inf"), dtype=np.float32
)
self.related_annotations = [{} for _ in self.relationships]
self.rtree = rtree.index.Index()
p = rtree.index.Property()
p.dimension = self.rank
self.rtree = rtree.index.Index(properties=p)
self.max_annotations_per_chunk = max_annotations_per_chunk

def get_chunk_index(self, coords):
Expand Down Expand Up @@ -462,9 +471,7 @@ def make_spatial_index(self):
# which spatial index each annotation belongs to
spatial_indices = np.zeros(len(self.annotations))
chunk_indices = np.zeros((len(self.annotations), self.rank), dtype=np.int32)
while (annotations_remaining > self.max_annotations_per_chunk) or (
spatial_index == 0
):
while annotations_remaining > 0:
num_chunks = np.ceil(
(self.upper_bound - self.lower_bound) / chunk_size
).astype(int)
Expand Down Expand Up @@ -500,7 +507,9 @@ def make_spatial_index(self):
if len(chunk_annotations) > 0:
# pick self.max_annotations_per_chunk annotations to place in the chunk
# randomizing the order of the annotations
np.random.Generator().shuffle(chunk_annotations)
np.random.Generator(np.random.PCG64()).shuffle(
chunk_annotations
)

chunk_annotations = chunk_annotations[
: self.max_annotations_per_chunk
Expand All @@ -512,19 +521,11 @@ def make_spatial_index(self):
annotations_remaining = len(self.annotations) - np.sum(spatial_indices > 0)

# each component of chunk_size of each successively level should be either equal to, or half of,
# the corresponding component of the prior level chunk_size, whichever results in a more spatially isotropic chunk.


def shrink_to_uniform_size(arr: np.ndarray) -> np.ndarray:
half_chunk_size = arr / 2
min_size = np.minimum(arr)
max_half_size = np.maximum(half_chunk_size)

if max_half_size > min_size:
new_sizes = np.where(half_chunk_size > min_size, half_chunk_size, arr)
else:
new_sizes = half_chunk_size
return new_sizes
# the corresponding component of the prior level chunk_size,
# whichever results in a more spatially isotropic chunk.
chunk_size = shrink_to_uniform_size(chunk_size)
spatial_indices = spatial_indices - 1
return chunk_sizes, grid_shapes, spatial_indices, chunk_indices

# query the rtree for the number of annotations in each chunk

Expand All @@ -539,54 +540,78 @@ def write(self, path: Union[str, pathlib.Path]):
"relationships": [],
"by_id": {"key": "by_id"},
}
total_ann_bytes = sum(len(a.encoded) for a in self.annotations)
bytes_per_annotation = len(self.annotations[0].encoded)
total_ann_bytes = len(self.annotations) * bytes_per_annotation
sharding_spec = choose_output_spec(len(self.annotations), total_ann_bytes)

# calculate the number of chunks in each dimension
num_chunks = np.ceil(
(self.upper_bound - self.lower_bound) / self.chunk_size
).astype(int)

# find the maximum number of annotations in any chunk
max_annotations = max(
len(annotations) for annotations in self.annotations_by_chunk.values()
)

# make directories
os.makedirs(path, exist_ok=True)
for relationship in self.relationships:
os.makedirs(os.path.join(path, f"rel_{relationship}"), exist_ok=True)
os.makedirs(os.path.join(path, "by_id"), exist_ok=True)
os.makedirs(os.path.join(path, "spatial0"), exist_ok=True)

total_chunks = len(self.annotations_by_chunk)
spatial_sharding_spec = choose_output_spec(
total_chunks, total_ann_bytes + 8 * len(self.annotations) + 8 * total_chunks
)
# initialize metadata for spatial index
metadata["spatial"] = [
{
"key": "spatial0",
"grid_shape": num_chunks.tolist(),
"chunk_size": [int(x) for x in self.chunk_size],
"limit": max_annotations,
}
]
# write annotations by spatial chunk
if spatial_sharding_spec is not None:
self._serialize_annotation_chunk_sharded(
os.path.join(path, "spatial0"),
self.annotations_by_chunk,
spatial_sharding_spec,
num_chunks.tolist(),
(
chunk_sizes,
grid_shapes,
spatial_indices,
chunk_indices,
) = self.make_spatial_index()

metadata["spatial"] = []
for i, grid_shape in enumerate(grid_shapes):
total_chunks = np.prod(grid_shape)

is_this_index = spatial_indices == i
n_anns = np.sum(is_this_index)

tot_bytes = n_anns * bytes_per_annotation

spatial_sharding_spec = choose_output_spec(
total_chunks,
total_ann_bytes + 8 * len(self.annotations) + 8 * total_chunks,
)
metadata["spatial"][0]["sharding"] = spatial_sharding_spec.to_json()
else:
for chunk_index, annotations in self.annotations_by_chunk.items():
chunk_name = "_".join([str(c) for c in chunk_index])
filepath = os.path.join(path, "spatial0", chunk_name)
with open(filepath, "wb") as f:
self._serialize_annotations(f, annotations)
# initialize metadata for spatial index
spatial_md = {
"key": f"spatial{i}",
"grid_shape": grid_shape,
"chunk_size": chunk_sizes[i],
"limit": self.max_annotations_per_chunk,
}

if spatial_sharding_spec is not None:
spatial_md["sharding"] = spatial_sharding_spec.to_json()

# organize the annotations in this chunk into a dictionary
# where the key is the chunk_index in chunk_indices
# and all the annotations in that chunk are the value

chunk_annotations = [
a for a, tf in zip(self.annotations, is_this_index) if tf
]
this_chunk_index = chunk_indices[is_this_index]
annotations_by_chunk = defaultdict(list)
# group by chunk_indices and store in a dictionary
for chunk_index, annotation in zip(this_chunk_index, chunk_annotations):
chunk_index = tuple(chunk_index)
annotations_by_chunk[chunk_index].append(annotation)

# write annotations by spatial chunk
if spatial_sharding_spec is not None:
self._serialize_annotation_chunk_sharded(
os.path.join(path, f"spatial{i}"),
annotations_by_chunk,
spatial_sharding_spec,
grid_shapes[i],
)
spatial_md["sharding"] = spatial_sharding_spec.to_json()
else:
for chunk_index, annotations in annotations_by_chunk.items():
chunk_name = "_".join([str(c) for c in chunk_index])
filepath = os.path.join(path, "spatial0", chunk_name)
with open(filepath, "wb") as f:
self._serialize_annotations(f, annotations)
metadata["spatial"].append(spatial_md)

# write annotations by id
if sharding_spec is not None:
Expand Down

0 comments on commit 6cc2a0b

Please sign in to comment.