diff --git a/tests/unit/db_annotations/test_precomp_annotations.py b/tests/unit/db_annotations/test_precomp_annotations.py index 6c1375155..086ee134a 100644 --- a/tests/unit/db_annotations/test_precomp_annotations.py +++ b/tests/unit/db_annotations/test_precomp_annotations.py @@ -77,6 +77,38 @@ def test_round_trip(): shutil.rmtree(file_dir) # clean up when done +def test_single_level(): + temp_dir = os.path.expanduser("~/temp/test_precomp_anno") + os.makedirs(temp_dir, exist_ok=True) + file_dir = os.path.join(temp_dir, "single_level") + + lines = [ + LineAnnotation(line_id=1, start=(1640.0, 1308.0, 61.0), end=(1644.0, 1304.0, 57.0)), + LineAnnotation(line_id=2, start=(1502.0, 1709.0, 589.0), end=(1498.0, 1701.0, 589.0)), + LineAnnotation(line_id=3, start=(254.0, 68.0, 575.0), end=(258.0, 62.0, 575.0)), + LineAnnotation(line_id=4, start=(1061.0, 657.0, 507.0), end=(1063.0, 653.0, 502.0)), + LineAnnotation(line_id=5, start=(1298.0, 889.0, 315.0), end=(1295.0, 887.0, 314.0)), + ] + # Note: line 2 above, with the chunk_sizes below, will span 2 chunks, and so will + # be written out to both of them. + + index = VolumetricIndex.from_coords([0, 0, 0], [2000, 2000, 600], Vec3D(10, 10, 40)) + + sf = AnnotationLayer(file_dir, index) + assert sf.chunk_sizes == [(2000, 2000, 600)] + + chunk_sizes = [[500, 500, 300]] + sf = AnnotationLayer(file_dir, index, chunk_sizes) + os.makedirs(os.path.join(file_dir, "spatial0", "junkforcodecoverage")) + sf.clear() + sf.write_annotations([]) # (does nothing) + sf.write_annotations(lines) + sf.post_process() + + chunk_path = os.path.join(file_dir, "spatial0", "2_1_1") + assert precomp_annotations.count_lines_in_file(chunk_path) == 2 + + def test_edge_cases(): with pytest.raises(ValueError): precomp_annotations.path_join() diff --git a/zetta_utils/db_annotations/precomp_annotations.py b/zetta_utils/db_annotations/precomp_annotations.py index 35378ccd6..c54a84d72 100644 --- a/zetta_utils/db_annotations/precomp_annotations.py +++ b/zetta_utils/db_annotations/precomp_annotations.py @@ -12,6 +12,7 @@ """ import io +import itertools import json import os import struct @@ -145,7 +146,7 @@ def write_bytes(file_or_gs_path: str, data: bytes): :param file_or_gs_path: path to file to write (local or GCS path) :param data: bytes to write """ - if not "//" in file_or_gs_path: + if "//" not in file_or_gs_path: file_or_gs_path = "file://" + file_or_gs_path cf = CloudFile(file_or_gs_path) cf.put(data) @@ -184,6 +185,14 @@ def write_lines(file_or_gs_path: str, lines: Sequence[LineAnnotation], randomize write_bytes(file_or_gs_path, buffer.getvalue()) +def line_count_from_file_size(file_size: int) -> int: + """ + Provide a count (or at least a very good estimate) of the number of lines in + a line chunk file of the given size in bytes. + """ + return round((file_size - 8) / (LineAnnotation.BYTES_PER_ENTRY + 8)) + + def count_lines_in_file(file_or_gs_path: str) -> int: """ Provide a count (or at least a very good estimate) of the number of lines in @@ -192,10 +201,7 @@ def count_lines_in_file(file_or_gs_path: str) -> int: # We could open the file and read the count in the first 8 bytes. # But even faster is to just calculate it from the file length. cf = CloudFile(file_or_gs_path) - fileLen = cf.size() - if fileLen is None: - return 0 - return round((fileLen - 8) / (LineAnnotation.BYTES_PER_ENTRY + 8)) + return line_count_from_file_size(cf.size() or 0) def read_bytes(file_or_gs_path: str): @@ -205,7 +211,7 @@ def read_bytes(file_or_gs_path: str): :param file_or_gs_path: path to file to read (local or GCS path) :return: bytes read from the file """ - if not "//" in file_or_gs_path: + if "//" not in file_or_gs_path: file_or_gs_path = "file://" + file_or_gs_path cf = CloudFile(file_or_gs_path) return cf.get() @@ -470,7 +476,7 @@ def delete(self): """ # Delete all files under our path path = self.path - if not "//" in path: + if "//" not in path: path = "file://" + path cf = CloudFiles(path) cf.delete(cf.list()) @@ -558,7 +564,7 @@ def write_annotations(self, annotations: Sequence[LineAnnotation], all_levels: b # is to do one level at a time, filtering at each level, so that # the final filters don't have to wade through as much data. for x in range(grid_shape[0]): - print(f"x = {x} of {grid_shape[0]}", end="", flush=True) + logger.debug(f"x = {x} of {grid_shape[0]}") split_by_x = VolumetricIndex.from_coords( ( self.index.start[0] + x * chunk_size[0], @@ -576,11 +582,11 @@ def write_annotations(self, annotations: Sequence[LineAnnotation], all_levels: b split_data_by_x: list[LineAnnotation] = list( filter(lambda d: d.in_bounds(split_by_x), annotations) ) - print(f": {len(split_data_by_x)} lines") + logger.debug(f": {len(split_data_by_x)} lines") if not split_data_by_x: continue for y in range(grid_shape[1]): - print(f" y = {y} of {grid_shape[1]}", end="", flush=True) + logger.debug(f" y = {y} of {grid_shape[1]}") split_by_y = VolumetricIndex.from_coords( ( self.index.start[0] + x * chunk_size[0], @@ -598,7 +604,7 @@ def write_annotations(self, annotations: Sequence[LineAnnotation], all_levels: b split_data_by_y: list[LineAnnotation] = list( filter(lambda d: d.in_bounds(split_by_y), split_data_by_x) ) - print(f": {len(split_data_by_y)} lines") + logger.debug(f": {len(split_data_by_y)} lines") if not split_data_by_y: continue for z in range(grid_shape[2]): @@ -622,8 +628,6 @@ def write_annotations(self, annotations: Sequence[LineAnnotation], all_levels: b chunk_bounds = VolumetricIndex.from_coords( chunk_start, chunk_end, self.index.resolution ) - # print(f'split_by_z: {split_by_z}') - # print(f'chunk_bounds: {chunk_bounds}') assert chunk_bounds == split_by_z # pylint: disable=cell-var-from-loop chunk_data: list[LineAnnotation] = list( @@ -676,14 +680,18 @@ def find_max_size(self, spatial_level: int = -1): grid_shape = ceil(bounds_size / chunk_size) level_key = f"spatial{level}" level_dir = path_join(self.path, level_key) - result = 0 - for x in range(0, grid_shape[0]): - for y in range(0, grid_shape[1]): - for z in range(0, grid_shape[2]): - anno_file_path = path_join(level_dir, f"{x}_{y}_{z}") - line_count = count_lines_in_file(anno_file_path) - result = max(result, line_count) - return result + if "//" not in level_dir: + level_dir = "file://" + level_dir + cf = CloudFiles(level_dir) + file_paths = [ + f"{x}_{y}_{z}" + for x, y, z in itertools.product( + range(grid_shape[0]), range(grid_shape[1]), range(grid_shape[2]) + ) + ] + file_sizes = cf.size(file_paths) + max_file_size = max(x or 0 for x in file_sizes.values()) + return line_count_from_file_size(max_file_size) def read_in_bounds(self, index: VolumetricIndex, strict: bool = False): """ @@ -728,7 +736,7 @@ def post_process(self): # Just iterate over the spatial entry files, getting the line # count in each one, and keep track of the max. max_line_count = self.find_max_size(0) - print(f"Found max_line_count = {max_line_count}") + # print(f"Found max_line_count = {max_line_count}") spatial_entries = self.get_spatial_entries(max_line_count) else: # Multiple chunk sizes means we have to start at the lowest @@ -737,7 +745,7 @@ def post_process(self): # read data (from lowest level chunks) all_data = self.read_all() - # write data to all levels EXCEPT the last one + # subdivide as if writing data to all levels EXCEPT the last one levels_to_write = range(0, len(self.chunk_sizes) - 1) spatial_entries = subdivide( all_data, self.index, self.chunk_sizes, self.path, levels_to_write