From 542a5cc351d7b39eb5c4fe0e710fb5d6b9ca98c9 Mon Sep 17 00:00:00 2001 From: JoeStrout Date: Fri, 13 Dec 2024 15:26:10 +0000 Subject: [PATCH] Several more optimizations to AnnotationLayer. --- .../db_annotations/precomp_annotations.py | 167 +++++++++++++++--- 1 file changed, 141 insertions(+), 26 deletions(-) diff --git a/zetta_utils/db_annotations/precomp_annotations.py b/zetta_utils/db_annotations/precomp_annotations.py index 9bbc5c1b1..35378ccd6 100644 --- a/zetta_utils/db_annotations/precomp_annotations.py +++ b/zetta_utils/db_annotations/precomp_annotations.py @@ -15,7 +15,6 @@ import json import os import struct -from itertools import product from math import ceil from random import shuffle from typing import IO, Literal, Optional, Sequence @@ -47,6 +46,8 @@ def path_join(*paths: str): class LineAnnotation: + BYTES_PER_ENTRY = 24 # start (3 floats), end (3 floats) + def __init__(self, line_id: int, start: Sequence[float], end: Sequence[float]): """ Initialize a LineAnnotation instance. @@ -81,6 +82,8 @@ def write(self, output: IO[bytes]): """ output.write(struct.pack("<3f", *self.start)) output.write(struct.pack("<3f", *self.end)) + # NOTE: if you change or add to the above, be sure to also + # change BYTES_PER_ENTRY accordingly. @staticmethod def read(in_stream: IO[bytes]): @@ -108,8 +111,9 @@ class SpatialEntry: chunk_size: 3-element list or tuple defining the size of a chunk in X, Y, and Z (in voxels). grid_shape: 3-element list/tuple defining how many chunks there are in each dimension. key: a string e.g. "spatial1" used as a prefix for the chunk file on disk. - limit: affects how the data is subsampled for display; should generally be the number - of annotations in this chunk, or 1 for no subsampling. It's confusing, but see: + limit: affects how the data is subsampled for display; should generally be the max number + of annotations in any chunk at this level, or 1 for no subsampling. It's confusing, but + see: https://github.com/google/neuroglancer/issues/227#issuecomment-2246350747 """ @@ -180,6 +184,20 @@ def write_lines(file_or_gs_path: str, lines: Sequence[LineAnnotation], randomize write_bytes(file_or_gs_path, buffer.getvalue()) +def count_lines_in_file(file_or_gs_path: str) -> int: + """ + Provide a count (or at least a very good estimate) of the number of lines in + the given line chunk file, as quickly as possible. + """ + # We could open the file and read the count in the first 8 bytes. + # But even faster is to just calculate it from the file length. + cf = CloudFile(file_or_gs_path) + fileLen = cf.size() + if fileLen is None: + return 0 + return round((fileLen - 8) / (LineAnnotation.BYTES_PER_ENTRY + 8)) + + def read_bytes(file_or_gs_path: str): """ Read bytes from a local file or Google Cloud Storage. @@ -532,25 +550,91 @@ def write_annotations(self, annotations: Sequence[LineAnnotation], all_levels: b level_key = f"spatial{level}" level_dir = path_join(self.path, level_key) if is_local_filesystem(self.path): - os.makedirs(level_dir) - for x, y, z in product( - range(grid_shape[0]), range(grid_shape[1]), range(grid_shape[2]) - ): - chunk_start = self.index.start + Vec3D(x, y, z) * chunk_size - chunk_end = chunk_start + chunk_size - chunk_bounds = VolumetricIndex.from_coords( - chunk_start, chunk_end, self.index.resolution + os.makedirs(level_dir, exist_ok=True) + + # Yes, there are terser ways to do this 3D iteration in Python, + # but they result in having to filter the full set of annotations + # for every chunk, which turns out to be very slow. Much faster + # is to do one level at a time, filtering at each level, so that + # the final filters don't have to wade through as much data. + for x in range(grid_shape[0]): + print(f"x = {x} of {grid_shape[0]}", end="", flush=True) + split_by_x = VolumetricIndex.from_coords( + ( + self.index.start[0] + x * chunk_size[0], + self.index.start[1], + self.index.start[2], + ), + ( + self.index.start[0] + (x + 1) * chunk_size[0], + self.index.stop[1], + self.index.stop[2], + ), + self.index.resolution, ) # pylint: disable=cell-var-from-loop - chunk_data: list[LineAnnotation] = list( - filter(lambda d: d.in_bounds(chunk_bounds), annotations) + split_data_by_x: list[LineAnnotation] = list( + filter(lambda d: d.in_bounds(split_by_x), annotations) ) - if not chunk_data: + print(f": {len(split_data_by_x)} lines") + if not split_data_by_x: continue - anno_file_path = path_join(level_dir, f"{x}_{y}_{z}") - chunk_data += read_lines(anno_file_path) - limit = max(limit, len(chunk_data)) - write_lines(anno_file_path, chunk_data) + for y in range(grid_shape[1]): + print(f" y = {y} of {grid_shape[1]}", end="", flush=True) + split_by_y = VolumetricIndex.from_coords( + ( + self.index.start[0] + x * chunk_size[0], + self.index.start[1] + y * chunk_size[1], + self.index.start[2], + ), + ( + self.index.start[0] + (x + 1) * chunk_size[0], + self.index.start[1] + (y + 1) * chunk_size[1], + self.index.stop[2], + ), + self.index.resolution, + ) + # pylint: disable=cell-var-from-loop + split_data_by_y: list[LineAnnotation] = list( + filter(lambda d: d.in_bounds(split_by_y), split_data_by_x) + ) + print(f": {len(split_data_by_y)} lines") + if not split_data_by_y: + continue + for z in range(grid_shape[2]): + split_by_z = VolumetricIndex.from_coords( + ( + self.index.start[0] + x * chunk_size[0], + self.index.start[1] + y * chunk_size[1], + self.index.start[2] + z * chunk_size[2], + ), + ( + self.index.start[0] + (x + 1) * chunk_size[0], + self.index.start[1] + (y + 1) * chunk_size[1], + self.index.start[2] + (z + 1) * chunk_size[2], + ), + self.index.resolution, + ) + # Sanity check: manually compute the single-chunk bounds. + # It should be equal to split_by_z. + chunk_start = self.index.start + Vec3D(x, y, z) * chunk_size + chunk_end = chunk_start + chunk_size + chunk_bounds = VolumetricIndex.from_coords( + chunk_start, chunk_end, self.index.resolution + ) + # print(f'split_by_z: {split_by_z}') + # print(f'chunk_bounds: {chunk_bounds}') + assert chunk_bounds == split_by_z + # pylint: disable=cell-var-from-loop + chunk_data: list[LineAnnotation] = list( + filter(lambda d: d.in_bounds(chunk_bounds), split_data_by_y) + ) + if not chunk_data: + continue + anno_file_path = path_join(level_dir, f"{x}_{y}_{z}") + chunk_data += read_lines(anno_file_path) + limit = max(limit, len(chunk_data)) + write_lines(anno_file_path, chunk_data) def read_all(self, spatial_level: int = -1, filter_duplicates: bool = True): """ @@ -582,6 +666,25 @@ def read_all(self, spatial_level: int = -1, filter_duplicates: bool = True): result = list(result_dict.values()) return result + def find_max_size(self, spatial_level: int = -1): + """ + Find the maximum number of entries in any chunk at the given level. + """ + level = spatial_level if spatial_level >= 0 else len(self.chunk_sizes) + spatial_level + bounds_size = self.index.shape + chunk_size = Vec3D(*self.chunk_sizes[level]) + grid_shape = ceil(bounds_size / chunk_size) + level_key = f"spatial{level}" + level_dir = path_join(self.path, level_key) + result = 0 + for x in range(0, grid_shape[0]): + for y in range(0, grid_shape[1]): + for z in range(0, grid_shape[2]): + anno_file_path = path_join(level_dir, f"{x}_{y}_{z}") + line_count = count_lines_in_file(anno_file_path) + result = max(result, line_count) + return result + def read_in_bounds(self, index: VolumetricIndex, strict: bool = False): """ Return all annotations within the given bounds (index). If strict is @@ -619,14 +722,26 @@ def post_process(self): This is useful after writing out a bunch of data with write_annotations(data, False), which writes to only the lowest-level chunks. """ - # read data (from lowest level chunks) - all_data = self.read_all() - - # write data to all levels EXCEPT the last one - levels_to_write = range(0, len(self.chunk_sizes) - 1) - spatial_entries = subdivide( - all_data, self.index, self.chunk_sizes, self.path, levels_to_write - ) + if len(self.chunk_sizes) == 1: + # Special case: only one chunk size, no subdivision. + # In this case, we can cheat considerably. + # Just iterate over the spatial entry files, getting the line + # count in each one, and keep track of the max. + max_line_count = self.find_max_size(0) + print(f"Found max_line_count = {max_line_count}") + spatial_entries = self.get_spatial_entries(max_line_count) + else: + # Multiple chunk sizes means we have to start at the lowest + # level, and re-subdivide it at each higher level. + + # read data (from lowest level chunks) + all_data = self.read_all() + + # write data to all levels EXCEPT the last one + levels_to_write = range(0, len(self.chunk_sizes) - 1) + spatial_entries = subdivide( + all_data, self.index, self.chunk_sizes, self.path, levels_to_write + ) # rewrite the info file, with the updated spatial entries self.write_info_file(spatial_entries)