Skip to content

Commit

Permalink
db_layer batching (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
akhileshh authored May 21, 2024
1 parent 414d044 commit 56d65e2
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 64 deletions.
51 changes: 50 additions & 1 deletion tests/unit/layer/db_layer/datastore/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=redefined-outer-name

import math
import pickle
from random import randint
from typing import cast

import pytest
Expand Down Expand Up @@ -59,7 +61,7 @@ def test_query_and_keys(datastore_emulator) -> None:
layer = build_datastore_layer(datastore_emulator, datastore_emulator)
_write_some_data(layer)
col_filter = {"col1": ["val0"]}
result = layer.query(column_filter=col_filter)
result = layer.query(column_filter=col_filter, return_columns=("col0", "col1"))
assert "key0" in result and len(result) == 1

col_filter = {"col1": ["val1"]}
Expand All @@ -80,6 +82,53 @@ def test_delete_and_clear(datastore_emulator) -> None:
assert len(layer.keys()) == 0


def _test_batches(layer: DBLayer, batch_size: int, return_columns: tuple[str, ...] = ()):
ROW_COUNT = len(layer.keys())
batches = []
batch_count = int(math.ceil(len(layer) / batch_size))
for i in range(batch_count):
batches.append(layer.get_batch(i, batch_size, return_columns=return_columns))

batch_0 = layer.get_batch(0, batch_size) # test with no return_columns
assert len(batches[0]) == len(batch_0)

batch_keys = [k for b in batches for k in b.keys()]
batch_sizes = [len(b) for b in batches]

assert sum(batch_sizes) == ROW_COUNT
assert len(batch_keys) == ROW_COUNT
assert len(set(batch_keys)) == ROW_COUNT

# test batch_size > len(layer), must return all rows
batch = layer.get_batch(0, ROW_COUNT + 1, return_columns=return_columns)
assert len(batch) == ROW_COUNT

# test out of bounds error
with pytest.raises(IndexError):
layer.get_batch(batch_count, batch_size)


def test_batching(datastore_emulator, mocker) -> None:
layer = build_datastore_layer(datastore_emulator, datastore_emulator)

ROW_COUNT = 150
COLS = ("col_a", "col_b")
mocker.patch(
"zetta_utils.layer.db_layer.datastore.backend.DatastoreBackend.__len__",
return_value=ROW_COUNT,
)
rows: DBDataT = [
dict(zip(COLS, [randint(0, 1000), randint(0, 1000)])) for _ in range(1, ROW_COUNT + 1)
]
row_keys = list(str(x) for x in range(1, ROW_COUNT + 1)) # cannot use 0 as key in Datastore.
layer[(row_keys, COLS)] = rows

assert len(layer.keys()) == ROW_COUNT

_test_batches(layer, 75, return_columns=COLS) # test even batch size, total divisible by batch
_test_batches(layer, 45, return_columns=COLS) # test odd batch size


def test_with_changes(datastore_emulator) -> None:
backend = DatastoreBackend(datastore_emulator, project=datastore_emulator)
backend2 = backend.with_changes(namespace=backend.namespace, project=backend.project)
Expand Down
13 changes: 8 additions & 5 deletions zetta_utils/db_annotations/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
PointAnnotation,
)

from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.layer.db_layer import DBRowDataT
from zetta_utils.layer.db_layer.datastore import DatastoreBackend, build_datastore_layer
from zetta_utils.parsing.ngl_state import AnnotationKeys

from . import constants
Expand All @@ -34,9 +34,12 @@
)

DB_NAME = "annotations"
DB_BACKEND = DatastoreBackend(DB_NAME, project=constants.PROJECT, database=constants.DATABASE)
DB_BACKEND.exclude_from_indexes = NON_INDEXED_COLS
ANNOTATIONS_DB = build_db_layer(DB_BACKEND)
ANNOTATIONS_DB = build_datastore_layer(
DB_NAME,
project=constants.PROJECT,
database=constants.DATABASE,
exclude_from_indexes=NON_INDEXED_COLS,
)


def read_annotation(annotation_id: str) -> DBRowDataT:
Expand Down
13 changes: 8 additions & 5 deletions zetta_utils/db_annotations/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
import time
import uuid

from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.layer.db_layer import DBRowDataT
from zetta_utils.layer.db_layer.datastore import build_datastore_layer

from . import constants

DB_NAME = "collections"
INDEXED_COLS = ("name", "created_by", "created_at", "modified_by", "modified_at")
NON_INDEXED_COLS = ("comment",)

DB_BACKEND = DatastoreBackend(DB_NAME, project=constants.PROJECT, database=constants.DATABASE)
DB_BACKEND.exclude_from_indexes = NON_INDEXED_COLS
COLLECTIONS_DB = build_db_layer(DB_BACKEND)
COLLECTIONS_DB = build_datastore_layer(
DB_NAME,
project=constants.PROJECT,
database=constants.DATABASE,
exclude_from_indexes=NON_INDEXED_COLS,
)


def read_collection(collection_id: str) -> DBRowDataT:
Expand Down
13 changes: 8 additions & 5 deletions zetta_utils/db_annotations/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

import uuid

from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.layer.db_layer import DBRowDataT
from zetta_utils.layer.db_layer.datastore import build_datastore_layer

from . import constants

DB_NAME = "layers"
INDEXED_COLS = ("name", "source")
NON_INDEXED_COLS = ("comment",)

DB_BACKEND = DatastoreBackend(DB_NAME, project=constants.PROJECT, database=constants.DATABASE)
DB_BACKEND.exclude_from_indexes = NON_INDEXED_COLS
LAYERS_DB = build_db_layer(DB_BACKEND)
LAYERS_DB = build_datastore_layer(
DB_NAME,
project=constants.PROJECT,
database=constants.DATABASE,
exclude_from_indexes=NON_INDEXED_COLS,
)


def read_layer(layer_id: str) -> DBRowDataT:
Expand Down
13 changes: 8 additions & 5 deletions zetta_utils/db_annotations/layer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

import uuid

from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.layer.db_layer import DBRowDataT
from zetta_utils.layer.db_layer.datastore import build_datastore_layer

from . import constants

DB_NAME = "layer_groups"
INDEXED_COLS = ("name", "layers", "collection", "created_by", "modified_by")
NON_INDEXED_COLS = ("comment",)

DB_BACKEND = DatastoreBackend(DB_NAME, project=constants.PROJECT, database=constants.DATABASE)
DB_BACKEND.exclude_from_indexes = NON_INDEXED_COLS
LAYER_GROUPS_DB = build_db_layer(DB_BACKEND)
LAYER_GROUPS_DB = build_datastore_layer(
DB_NAME,
project=constants.PROJECT,
database=constants.DATABASE,
exclude_from_indexes=NON_INDEXED_COLS,
)


def read_layer_group(layer_group_id: str) -> DBRowDataT:
Expand Down
4 changes: 1 addition & 3 deletions zetta_utils/layer/db_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .index import (
DBIndex,
)
from .index import DBIndex
from .backend import DBDataT, DBBackend, DBArrayValueT, DBValueT, DBRowDataT

from .layer import DBLayer, UserDBIndex, ColIndex, DBDataProcT
Expand Down
7 changes: 7 additions & 0 deletions zetta_utils/layer/db_layer/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,12 @@ def keys(self, column_filter: dict[str, list] | None = None) -> list[str]:
def query(
self,
column_filter: dict[str, list] | None = None,
return_columns: tuple[str, ...] = (),
) -> dict[str, DBRowDataT]:
...

@abstractmethod
def get_batch(
self, batch_number: int, avg_rows_per_batch: int, return_columns: tuple[str, ...] = ()
) -> dict[str, DBRowDataT]:
...
Loading

0 comments on commit 56d65e2

Please sign in to comment.