Skip to content

Commit

Permalink
feat: collection dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Jan 24, 2025
1 parent 660c537 commit dead60f
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 4 deletions.
41 changes: 41 additions & 0 deletions tests/unit/training/datasets/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# pylint: disable=invalid-name

import pytest

from zetta_utils.db_annotations import annotation, collection, layer, layer_group
from zetta_utils.layer.db_layer.firestore import build_firestore_layer


@pytest.fixture(scope="session")
def annotations_db(firestore_emulator):
db = build_firestore_layer(annotation.DB_NAME, project=firestore_emulator)
annotation.ANNOTATIONS_DB = db
return annotation.ANNOTATIONS_DB


@pytest.fixture(scope="session")
def collections_db(firestore_emulator):
db = build_firestore_layer(collection.DB_NAME, project=firestore_emulator)
collection.COLLECTIONS_DB = db
collections = collection.read_collections()
collection.delete_collections(list(collections.keys()))
return collection.COLLECTIONS_DB


@pytest.fixture(scope="session")
def layer_groups_db(firestore_emulator):
db = build_firestore_layer(layer_group.DB_NAME, project=firestore_emulator)
layer_group.LAYER_GROUPS_DB = db
layer_groups = layer_group.read_layer_groups()
layer_group.delete_layer_groups(list(layer_groups.keys()))
return layer_group.LAYER_GROUPS_DB


@pytest.fixture(scope="session")
def layers_db(firestore_emulator):
db = build_firestore_layer(layer.DB_NAME, project=firestore_emulator)
layer.LAYERS_DB = db
layers = layer.read_layers()
for e in layers:
layer.delete_layer(e)
return layer.LAYERS_DB
96 changes: 96 additions & 0 deletions tests/unit/training/datasets/test_collection_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# pylint: disable=unused-argument,redefined-outer-name

import os
import pathlib

import pytest

from zetta_utils.db_annotations import annotation, collection, layer, layer_group
from zetta_utils.training.datasets.collection_dataset import build_collection_dataset

THIS_DIR = pathlib.Path(__file__).parent.resolve()
INFOS_DIR = THIS_DIR / "../../assets/infos/"
LAYER_X1_PATH = "file://" + os.path.join(INFOS_DIR, "layer_x1")
LAYER_X2_PATH = "file://" + os.path.join(INFOS_DIR, "layer_x2")


@pytest.fixture
def dummy_dataset_x0():
user = "john_doe"
collection_id = collection.add_collection("collection_x0", user, "this is a test")
layer_groups = layer_group.read_layer_groups(collection_ids=[collection_id])
annotations = annotation.read_annotations(collection_ids=[collection_id])
layer_group.delete_layer_groups(list(layer_groups.keys()))
annotation.delete_annotations(list(annotations.keys()))

layer_id0 = layer.add_layer("layer0", LAYER_X1_PATH, "this is a test")
layer_id1 = layer.add_layer("layer1", LAYER_X2_PATH, "this is a test")
layer_id2 = layer.add_layer("layer2", LAYER_X2_PATH, "this is a test")

layer_group_id = layer_group.add_layer_group(
name="layer_group_x0",
collection_id=collection_id,
user=user,
layers=[layer_id0, layer_id1, layer_id2],
comment="this is a test",
)

annotation.add_annotations(
annotation.parse_ng_annotations(
[
{
"pointA": [0, 0, 0],
"pointB": [128, 128, 128],
"type": "axis_aligned_bounding_box",
"id": "6fdfd685cc440a6106a089113869f5043cb18c2c",
}
]
),
collection_id=collection_id,
layer_group_id=layer_group_id,
)
annotation.add_annotations(
annotation.parse_ng_annotations(
[
{
"pointA": [0, 0, 0],
"pointB": [128, 128, 128],
"type": "axis_aligned_bounding_box",
"id": "6fdfd685cc440a6106a089113869f5043cb18c2c",
}
]
),
collection_id=collection_id,
layer_group_id=layer_group_id,
)
yield collection_id

layer_group.delete_layer_group(layer_group_id)
layer_groups = layer_group.read_layer_groups(collection_ids=[collection_id])
annotations = annotation.read_annotations(collection_ids=[collection_id])
layer_group.delete_layer_groups(list(layer_groups.keys()))
annotation.delete_annotations(list(annotations.keys()))
collection.delete_collection(collection_id=collection_id)


def test_simple(
firestore_emulator,
annotations_db,
collections_db,
layer_groups_db,
layers_db,
dummy_dataset_x0,
):
dset = build_collection_dataset(
collection_name="collection_x0",
resolution=[8, 8, 8],
chunk_size=[1, 1, 1],
chunk_stride=[1, 1, 1],
layer_rename_map={"layer0": "layer00"},
per_layer_read_procs={},
)
assert len(dset) == 4096 * 2
sample = dset[0]
assert "layer00" in sample
assert "layer1" in sample
assert "layer2" in sample
4 changes: 2 additions & 2 deletions zetta_utils/layer/layer_set/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=missing-docstring
from __future__ import annotations

from typing import TypeVar
from typing import Mapping, TypeVar

import attrs

Expand All @@ -16,7 +16,7 @@
class LayerSetBackend(
Backend[IndexT, dict[str, DataT], dict[str, DataWriteT]]
): # pylint: disable=too-few-public-methods
layers: dict[str, Layer[IndexT, DataT, DataWriteT]]
layers: Mapping[str, Layer[IndexT, DataT, DataWriteT]]

def read(self, idx: IndexT) -> dict[str, DataT]:
return {k: v.read_with_procs(idx) for k, v in self.layers.items()}
Expand Down
4 changes: 2 additions & 2 deletions zetta_utils/layer/layer_set/build.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=missing-docstring
from __future__ import annotations

from typing import Iterable
from typing import Iterable, Mapping

from typeguard import typechecked

Expand All @@ -14,7 +14,7 @@
@builder.register("build_layer_set")
@typechecked
def build_layer_set(
layers: dict[str, Layer],
layers: Mapping[str, Layer],
readonly: bool = False,
index_procs: Iterable[IndexProcessor] = (),
read_procs: Iterable[LayerSetDataProcT] = (),
Expand Down
2 changes: 2 additions & 0 deletions zetta_utils/training/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Traning datasets."""

from . import joint_dataset, layer_dataset, sample_indexers
from .joint_dataset import JointDataset
from .layer_dataset import LayerDataset
from .sample_indexers import RandomIndexer, VolumetricStridedIndexer
from .collection_dataset import build_collection_dataset
91 changes: 91 additions & 0 deletions zetta_utils/training/datasets/collection_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import json
import os
from typing import Sequence

import fsspec
from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation
from typeguard import typechecked

from zetta_utils import builder, db_annotations
from zetta_utils.geometry.bbox import BBox3D
from zetta_utils.layer.layer_set.build import build_layer_set
from zetta_utils.layer.tools_base import DataProcessor
from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer
from zetta_utils.layer.volumetric.layer import VolumetricLayer
from zetta_utils.training.datasets.joint_dataset import JointDataset
from zetta_utils.training.datasets.layer_dataset import LayerDataset
from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import (
VolumetricStridedIndexer,
)


def _get_z_resolution(layers: dict[str, VolumetricLayer]) -> float:
z_resolutions = {}
for layer_name, layer in layers.items():
info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info")
with fsspec.open(info_path) as f:
info = json.loads(f.read())
z_resolutions[layer_name] = {e["resolution"][-1] for e in info["scales"]}
result = min(set.intersection(*z_resolutions.values()))
return result


@builder.register("build_collection_dataset")
@typechecked
def build_collection_dataset(
collection_name: str,
resolution: Sequence[float],
chunk_size: Sequence[int],
chunk_stride: Sequence[int],
layer_rename_map: dict[str, str],
per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None,
shared_read_procs: Sequence[DataProcessor] = tuple(),
tags: list[str] | None = None,
flexible_z: bool = True,
) -> JointDataset:
datasets = {}
annotations = db_annotations.read_annotations(
collection_ids=[collection_name], tags=tags, union=False
)
# layer group->layer_name->layer
layer_group_map: dict[str, dict[str, VolumetricLayer]] = {}

per_layer_read_procs_dict = {}
if per_layer_read_procs is not None:
per_layer_read_procs_dict = per_layer_read_procs

for i, annotation in enumerate(annotations.values()):
if annotation.layer_group not in layer_group_map:
layer_group = db_annotations.read_layer_group(annotation.layer_group)
db_layers = db_annotations.read_layers(layer_ids=layer_group.layers)
layers = {}
for layer in db_layers:
name = layer.name
if name in layer_rename_map:
name = layer_rename_map[name]
read_procs = per_layer_read_procs_dict.get(name, [])
layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs)
layer_group_map[annotation.layer_group] = layers
else:
layers = layer_group_map[annotation.layer_group]

z_resolution = resolution[-1]
if flexible_z:
z_resolution = _get_z_resolution(layers)

this_resolution = [resolution[0], resolution[1], z_resolution]
if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation):
bbox = BBox3D.from_ng_bbox(annotation.ng_annotation, (1, 1, 1))

datasets[str(i)] = LayerDataset(
layer=build_layer_set(layers=layers, read_procs=shared_read_procs),
sample_indexer=VolumetricStridedIndexer(
resolution=this_resolution,
chunk_size=chunk_size,
stride=chunk_stride,
mode="shrink",
bbox=bbox,
),
)
dset = JointDataset(mode="vertical", datasets=datasets)
return dset

0 comments on commit dead60f

Please sign in to comment.