-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
660c537
commit dead60f
Showing
6 changed files
with
234 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# pylint: disable=invalid-name | ||
|
||
import pytest | ||
|
||
from zetta_utils.db_annotations import annotation, collection, layer, layer_group | ||
from zetta_utils.layer.db_layer.firestore import build_firestore_layer | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def annotations_db(firestore_emulator): | ||
db = build_firestore_layer(annotation.DB_NAME, project=firestore_emulator) | ||
annotation.ANNOTATIONS_DB = db | ||
return annotation.ANNOTATIONS_DB | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def collections_db(firestore_emulator): | ||
db = build_firestore_layer(collection.DB_NAME, project=firestore_emulator) | ||
collection.COLLECTIONS_DB = db | ||
collections = collection.read_collections() | ||
collection.delete_collections(list(collections.keys())) | ||
return collection.COLLECTIONS_DB | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def layer_groups_db(firestore_emulator): | ||
db = build_firestore_layer(layer_group.DB_NAME, project=firestore_emulator) | ||
layer_group.LAYER_GROUPS_DB = db | ||
layer_groups = layer_group.read_layer_groups() | ||
layer_group.delete_layer_groups(list(layer_groups.keys())) | ||
return layer_group.LAYER_GROUPS_DB | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def layers_db(firestore_emulator): | ||
db = build_firestore_layer(layer.DB_NAME, project=firestore_emulator) | ||
layer.LAYERS_DB = db | ||
layers = layer.read_layers() | ||
for e in layers: | ||
layer.delete_layer(e) | ||
return layer.LAYERS_DB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# pylint: disable=unused-argument,redefined-outer-name | ||
|
||
import os | ||
import pathlib | ||
|
||
import pytest | ||
|
||
from zetta_utils.db_annotations import annotation, collection, layer, layer_group | ||
from zetta_utils.training.datasets.collection_dataset import build_collection_dataset | ||
|
||
THIS_DIR = pathlib.Path(__file__).parent.resolve() | ||
INFOS_DIR = THIS_DIR / "../../assets/infos/" | ||
LAYER_X1_PATH = "file://" + os.path.join(INFOS_DIR, "layer_x1") | ||
LAYER_X2_PATH = "file://" + os.path.join(INFOS_DIR, "layer_x2") | ||
|
||
|
||
@pytest.fixture | ||
def dummy_dataset_x0(): | ||
user = "john_doe" | ||
collection_id = collection.add_collection("collection_x0", user, "this is a test") | ||
layer_groups = layer_group.read_layer_groups(collection_ids=[collection_id]) | ||
annotations = annotation.read_annotations(collection_ids=[collection_id]) | ||
layer_group.delete_layer_groups(list(layer_groups.keys())) | ||
annotation.delete_annotations(list(annotations.keys())) | ||
|
||
layer_id0 = layer.add_layer("layer0", LAYER_X1_PATH, "this is a test") | ||
layer_id1 = layer.add_layer("layer1", LAYER_X2_PATH, "this is a test") | ||
layer_id2 = layer.add_layer("layer2", LAYER_X2_PATH, "this is a test") | ||
|
||
layer_group_id = layer_group.add_layer_group( | ||
name="layer_group_x0", | ||
collection_id=collection_id, | ||
user=user, | ||
layers=[layer_id0, layer_id1, layer_id2], | ||
comment="this is a test", | ||
) | ||
|
||
annotation.add_annotations( | ||
annotation.parse_ng_annotations( | ||
[ | ||
{ | ||
"pointA": [0, 0, 0], | ||
"pointB": [128, 128, 128], | ||
"type": "axis_aligned_bounding_box", | ||
"id": "6fdfd685cc440a6106a089113869f5043cb18c2c", | ||
} | ||
] | ||
), | ||
collection_id=collection_id, | ||
layer_group_id=layer_group_id, | ||
) | ||
annotation.add_annotations( | ||
annotation.parse_ng_annotations( | ||
[ | ||
{ | ||
"pointA": [0, 0, 0], | ||
"pointB": [128, 128, 128], | ||
"type": "axis_aligned_bounding_box", | ||
"id": "6fdfd685cc440a6106a089113869f5043cb18c2c", | ||
} | ||
] | ||
), | ||
collection_id=collection_id, | ||
layer_group_id=layer_group_id, | ||
) | ||
yield collection_id | ||
|
||
layer_group.delete_layer_group(layer_group_id) | ||
layer_groups = layer_group.read_layer_groups(collection_ids=[collection_id]) | ||
annotations = annotation.read_annotations(collection_ids=[collection_id]) | ||
layer_group.delete_layer_groups(list(layer_groups.keys())) | ||
annotation.delete_annotations(list(annotations.keys())) | ||
collection.delete_collection(collection_id=collection_id) | ||
|
||
|
||
def test_simple( | ||
firestore_emulator, | ||
annotations_db, | ||
collections_db, | ||
layer_groups_db, | ||
layers_db, | ||
dummy_dataset_x0, | ||
): | ||
dset = build_collection_dataset( | ||
collection_name="collection_x0", | ||
resolution=[8, 8, 8], | ||
chunk_size=[1, 1, 1], | ||
chunk_stride=[1, 1, 1], | ||
layer_rename_map={"layer0": "layer00"}, | ||
per_layer_read_procs={}, | ||
) | ||
assert len(dset) == 4096 * 2 | ||
sample = dset[0] | ||
assert "layer00" in sample | ||
assert "layer1" in sample | ||
assert "layer2" in sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
"""Traning datasets.""" | ||
|
||
from . import joint_dataset, layer_dataset, sample_indexers | ||
from .joint_dataset import JointDataset | ||
from .layer_dataset import LayerDataset | ||
from .sample_indexers import RandomIndexer, VolumetricStridedIndexer | ||
from .collection_dataset import build_collection_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import json | ||
import os | ||
from typing import Sequence | ||
|
||
import fsspec | ||
from neuroglancer.viewer_state import AxisAlignedBoundingBoxAnnotation | ||
from typeguard import typechecked | ||
|
||
from zetta_utils import builder, db_annotations | ||
from zetta_utils.geometry.bbox import BBox3D | ||
from zetta_utils.layer.layer_set.build import build_layer_set | ||
from zetta_utils.layer.tools_base import DataProcessor | ||
from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer | ||
from zetta_utils.layer.volumetric.layer import VolumetricLayer | ||
from zetta_utils.training.datasets.joint_dataset import JointDataset | ||
from zetta_utils.training.datasets.layer_dataset import LayerDataset | ||
from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import ( | ||
VolumetricStridedIndexer, | ||
) | ||
|
||
|
||
def _get_z_resolution(layers: dict[str, VolumetricLayer]) -> float: | ||
z_resolutions = {} | ||
for layer_name, layer in layers.items(): | ||
info_path = os.path.join(layer.backend.name.strip("precomputed://"), "info") | ||
with fsspec.open(info_path) as f: | ||
info = json.loads(f.read()) | ||
z_resolutions[layer_name] = {e["resolution"][-1] for e in info["scales"]} | ||
result = min(set.intersection(*z_resolutions.values())) | ||
return result | ||
|
||
|
||
@builder.register("build_collection_dataset") | ||
@typechecked | ||
def build_collection_dataset( | ||
collection_name: str, | ||
resolution: Sequence[float], | ||
chunk_size: Sequence[int], | ||
chunk_stride: Sequence[int], | ||
layer_rename_map: dict[str, str], | ||
per_layer_read_procs: dict[str, Sequence[DataProcessor]] | None = None, | ||
shared_read_procs: Sequence[DataProcessor] = tuple(), | ||
tags: list[str] | None = None, | ||
flexible_z: bool = True, | ||
) -> JointDataset: | ||
datasets = {} | ||
annotations = db_annotations.read_annotations( | ||
collection_ids=[collection_name], tags=tags, union=False | ||
) | ||
# layer group->layer_name->layer | ||
layer_group_map: dict[str, dict[str, VolumetricLayer]] = {} | ||
|
||
per_layer_read_procs_dict = {} | ||
if per_layer_read_procs is not None: | ||
per_layer_read_procs_dict = per_layer_read_procs | ||
|
||
for i, annotation in enumerate(annotations.values()): | ||
if annotation.layer_group not in layer_group_map: | ||
layer_group = db_annotations.read_layer_group(annotation.layer_group) | ||
db_layers = db_annotations.read_layers(layer_ids=layer_group.layers) | ||
layers = {} | ||
for layer in db_layers: | ||
name = layer.name | ||
if name in layer_rename_map: | ||
name = layer_rename_map[name] | ||
read_procs = per_layer_read_procs_dict.get(name, []) | ||
layers[name] = build_cv_layer(path=layer.source, read_procs=read_procs) | ||
layer_group_map[annotation.layer_group] = layers | ||
else: | ||
layers = layer_group_map[annotation.layer_group] | ||
|
||
z_resolution = resolution[-1] | ||
if flexible_z: | ||
z_resolution = _get_z_resolution(layers) | ||
|
||
this_resolution = [resolution[0], resolution[1], z_resolution] | ||
if isinstance(annotation.ng_annotation, AxisAlignedBoundingBoxAnnotation): | ||
bbox = BBox3D.from_ng_bbox(annotation.ng_annotation, (1, 1, 1)) | ||
|
||
datasets[str(i)] = LayerDataset( | ||
layer=build_layer_set(layers=layers, read_procs=shared_read_procs), | ||
sample_indexer=VolumetricStridedIndexer( | ||
resolution=this_resolution, | ||
chunk_size=chunk_size, | ||
stride=chunk_stride, | ||
mode="shrink", | ||
bbox=bbox, | ||
), | ||
) | ||
dset = JointDataset(mode="vertical", datasets=datasets) | ||
return dset |