Skip to content

Commit

Permalink
Several more optimizations to AnnotationLayer.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeStrout authored and nkemnitz committed Dec 20, 2024
1 parent 58801ff commit 542a5cc
Showing 1 changed file with 141 additions and 26 deletions.
167 changes: 141 additions & 26 deletions zetta_utils/db_annotations/precomp_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import json
import os
import struct
from itertools import product
from math import ceil
from random import shuffle
from typing import IO, Literal, Optional, Sequence
Expand Down Expand Up @@ -47,6 +46,8 @@ def path_join(*paths: str):


class LineAnnotation:
BYTES_PER_ENTRY = 24 # start (3 floats), end (3 floats)

def __init__(self, line_id: int, start: Sequence[float], end: Sequence[float]):
"""
Initialize a LineAnnotation instance.
Expand Down Expand Up @@ -81,6 +82,8 @@ def write(self, output: IO[bytes]):
"""
output.write(struct.pack("<3f", *self.start))
output.write(struct.pack("<3f", *self.end))
# NOTE: if you change or add to the above, be sure to also
# change BYTES_PER_ENTRY accordingly.

@staticmethod
def read(in_stream: IO[bytes]):
Expand Down Expand Up @@ -108,8 +111,9 @@ class SpatialEntry:
chunk_size: 3-element list or tuple defining the size of a chunk in X, Y, and Z (in voxels).
grid_shape: 3-element list/tuple defining how many chunks there are in each dimension.
key: a string e.g. "spatial1" used as a prefix for the chunk file on disk.
limit: affects how the data is subsampled for display; should generally be the number
of annotations in this chunk, or 1 for no subsampling. It's confusing, but see:
limit: affects how the data is subsampled for display; should generally be the max number
of annotations in any chunk at this level, or 1 for no subsampling. It's confusing, but
see:
https://github.com/google/neuroglancer/issues/227#issuecomment-2246350747
"""

Expand Down Expand Up @@ -180,6 +184,20 @@ def write_lines(file_or_gs_path: str, lines: Sequence[LineAnnotation], randomize
write_bytes(file_or_gs_path, buffer.getvalue())


def count_lines_in_file(file_or_gs_path: str) -> int:
"""
Provide a count (or at least a very good estimate) of the number of lines in
the given line chunk file, as quickly as possible.
"""
# We could open the file and read the count in the first 8 bytes.
# But even faster is to just calculate it from the file length.
cf = CloudFile(file_or_gs_path)
fileLen = cf.size()
if fileLen is None:
return 0
return round((fileLen - 8) / (LineAnnotation.BYTES_PER_ENTRY + 8))


def read_bytes(file_or_gs_path: str):
"""
Read bytes from a local file or Google Cloud Storage.
Expand Down Expand Up @@ -532,25 +550,91 @@ def write_annotations(self, annotations: Sequence[LineAnnotation], all_levels: b
level_key = f"spatial{level}"
level_dir = path_join(self.path, level_key)
if is_local_filesystem(self.path):
os.makedirs(level_dir)
for x, y, z in product(
range(grid_shape[0]), range(grid_shape[1]), range(grid_shape[2])
):
chunk_start = self.index.start + Vec3D(x, y, z) * chunk_size
chunk_end = chunk_start + chunk_size
chunk_bounds = VolumetricIndex.from_coords(
chunk_start, chunk_end, self.index.resolution
os.makedirs(level_dir, exist_ok=True)

# Yes, there are terser ways to do this 3D iteration in Python,
# but they result in having to filter the full set of annotations
# for every chunk, which turns out to be very slow. Much faster
# is to do one level at a time, filtering at each level, so that
# the final filters don't have to wade through as much data.
for x in range(grid_shape[0]):
print(f"x = {x} of {grid_shape[0]}", end="", flush=True)
split_by_x = VolumetricIndex.from_coords(
(
self.index.start[0] + x * chunk_size[0],
self.index.start[1],
self.index.start[2],
),
(
self.index.start[0] + (x + 1) * chunk_size[0],
self.index.stop[1],
self.index.stop[2],
),
self.index.resolution,
)
# pylint: disable=cell-var-from-loop
chunk_data: list[LineAnnotation] = list(
filter(lambda d: d.in_bounds(chunk_bounds), annotations)
split_data_by_x: list[LineAnnotation] = list(
filter(lambda d: d.in_bounds(split_by_x), annotations)
)
if not chunk_data:
print(f": {len(split_data_by_x)} lines")
if not split_data_by_x:
continue
anno_file_path = path_join(level_dir, f"{x}_{y}_{z}")
chunk_data += read_lines(anno_file_path)
limit = max(limit, len(chunk_data))
write_lines(anno_file_path, chunk_data)
for y in range(grid_shape[1]):
print(f" y = {y} of {grid_shape[1]}", end="", flush=True)
split_by_y = VolumetricIndex.from_coords(
(
self.index.start[0] + x * chunk_size[0],
self.index.start[1] + y * chunk_size[1],
self.index.start[2],
),
(
self.index.start[0] + (x + 1) * chunk_size[0],
self.index.start[1] + (y + 1) * chunk_size[1],
self.index.stop[2],
),
self.index.resolution,
)
# pylint: disable=cell-var-from-loop
split_data_by_y: list[LineAnnotation] = list(
filter(lambda d: d.in_bounds(split_by_y), split_data_by_x)
)
print(f": {len(split_data_by_y)} lines")
if not split_data_by_y:
continue
for z in range(grid_shape[2]):
split_by_z = VolumetricIndex.from_coords(
(
self.index.start[0] + x * chunk_size[0],
self.index.start[1] + y * chunk_size[1],
self.index.start[2] + z * chunk_size[2],
),
(
self.index.start[0] + (x + 1) * chunk_size[0],
self.index.start[1] + (y + 1) * chunk_size[1],
self.index.start[2] + (z + 1) * chunk_size[2],
),
self.index.resolution,
)
# Sanity check: manually compute the single-chunk bounds.
# It should be equal to split_by_z.
chunk_start = self.index.start + Vec3D(x, y, z) * chunk_size
chunk_end = chunk_start + chunk_size
chunk_bounds = VolumetricIndex.from_coords(
chunk_start, chunk_end, self.index.resolution
)
# print(f'split_by_z: {split_by_z}')
# print(f'chunk_bounds: {chunk_bounds}')
assert chunk_bounds == split_by_z
# pylint: disable=cell-var-from-loop
chunk_data: list[LineAnnotation] = list(
filter(lambda d: d.in_bounds(chunk_bounds), split_data_by_y)
)
if not chunk_data:
continue
anno_file_path = path_join(level_dir, f"{x}_{y}_{z}")
chunk_data += read_lines(anno_file_path)
limit = max(limit, len(chunk_data))
write_lines(anno_file_path, chunk_data)

def read_all(self, spatial_level: int = -1, filter_duplicates: bool = True):
"""
Expand Down Expand Up @@ -582,6 +666,25 @@ def read_all(self, spatial_level: int = -1, filter_duplicates: bool = True):
result = list(result_dict.values())
return result

def find_max_size(self, spatial_level: int = -1):
"""
Find the maximum number of entries in any chunk at the given level.
"""
level = spatial_level if spatial_level >= 0 else len(self.chunk_sizes) + spatial_level
bounds_size = self.index.shape
chunk_size = Vec3D(*self.chunk_sizes[level])
grid_shape = ceil(bounds_size / chunk_size)
level_key = f"spatial{level}"
level_dir = path_join(self.path, level_key)
result = 0
for x in range(0, grid_shape[0]):
for y in range(0, grid_shape[1]):
for z in range(0, grid_shape[2]):
anno_file_path = path_join(level_dir, f"{x}_{y}_{z}")
line_count = count_lines_in_file(anno_file_path)
result = max(result, line_count)
return result

def read_in_bounds(self, index: VolumetricIndex, strict: bool = False):
"""
Return all annotations within the given bounds (index). If strict is
Expand Down Expand Up @@ -619,14 +722,26 @@ def post_process(self):
This is useful after writing out a bunch of data with
write_annotations(data, False), which writes to only the lowest-level chunks.
"""
# read data (from lowest level chunks)
all_data = self.read_all()

# write data to all levels EXCEPT the last one
levels_to_write = range(0, len(self.chunk_sizes) - 1)
spatial_entries = subdivide(
all_data, self.index, self.chunk_sizes, self.path, levels_to_write
)
if len(self.chunk_sizes) == 1:
# Special case: only one chunk size, no subdivision.
# In this case, we can cheat considerably.
# Just iterate over the spatial entry files, getting the line
# count in each one, and keep track of the max.
max_line_count = self.find_max_size(0)
print(f"Found max_line_count = {max_line_count}")
spatial_entries = self.get_spatial_entries(max_line_count)
else:
# Multiple chunk sizes means we have to start at the lowest
# level, and re-subdivide it at each higher level.

# read data (from lowest level chunks)
all_data = self.read_all()

# write data to all levels EXCEPT the last one
levels_to_write = range(0, len(self.chunk_sizes) - 1)
spatial_entries = subdivide(
all_data, self.index, self.chunk_sizes, self.path, levels_to_write
)

# rewrite the info file, with the updated spatial entries
self.write_info_file(spatial_entries)
Expand Down

0 comments on commit 542a5cc

Please sign in to comment.