Skip to content

Commit

Permalink
update remote tests, work on consolidated http stores
Browse files Browse the repository at this point in the history
  • Loading branch information
berombau committed Apr 12, 2024
1 parent 6e04606 commit e3e0c28
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 235 deletions.
24 changes: 15 additions & 9 deletions src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
from collections.abc import MutableMapping
from pathlib import Path
from typing import Union

import zarr
from dask.dataframe import DataFrame as DaskDataFrame # type: ignore[attr-defined]
from dask.dataframe import read_parquet
from ome_zarr.format import Format
from upath import UPath

from spatialdata._core._utils import _open_zarr_store
from spatialdata._io import SpatialDataFormatV01
from spatialdata._io._utils import (
_get_transformations_from_ngff_dict,
Expand All @@ -23,17 +23,23 @@


def _read_points(
store: Union[str, Path, MutableMapping, zarr.Group, zarr.storage.BaseStore], # type: ignore[type-arg]
path: UPath,
fmt: SpatialDataFormatV01 = CurrentPointsFormat(),
) -> DaskDataFrame:
"""Read points from a zarr store."""
assert isinstance(store, (str, Path, MutableMapping, zarr.Group, zarr.storage.BaseStore))
f = zarr.open(store, mode="r")
store = _open_zarr_store(path)
f = zarr.group(store)

path = os.path.join(f._store.path, f.path, "points.parquet")
# cache on remote file needed for parquet reader to work
# TODO: allow reading in the metadata without caching all the data
table = read_parquet("simplecache::" + path if "http" in path else path)
if isinstance(store, UPath):
path = store / "points.parquet"
read_str = "simplecache::" + path.as_posix() if "http" in path.protocol else path.as_posix()
table = read_parquet(read_str)
else:
# TODO: remove this old code path
path = os.path.join(f._store.path, f.path, "points.parquet")
# cache on remote file needed for parquet reader to work
# TODO: allow reading in the metadata without caching all the data
table = read_parquet("simplecache::" + path if "http" in path else path)
assert isinstance(table, DaskDataFrame)

transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"])
Expand Down
5 changes: 2 additions & 3 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def read_zarr(store: Union[str, Path, zarr.Group, UPath], selection: Optional[tu
if isinstance(store, (zarr.Group)):
logger.debug("No support for converting zarr.Group to UPath. Using the store object as is.")
f = store.store
f_store_path = UPath(f._path)
f_store_path = UPath(f.store.path if isinstance(f, zarr.storage.ConsolidatedMetadataStore) else f.path)
else:
f_store_path = UPath(store) if not isinstance(store, UPath) else store
f = _open_zarr_store(f_store_path)
Expand Down Expand Up @@ -92,8 +92,7 @@ def read_zarr(store: Union[str, Path, zarr.Group, UPath], selection: Optional[tu
if subgroup_name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
f_elem_store = _open_zarr_store(f_store_path / f_elem.path)
points[subgroup_name] = _read_points(f_elem_store)
points[subgroup_name] = _read_points(f_store_path / f_elem.path)
count += 1
logger.debug(f"Found {count} elements in {group}")

Expand Down
228 changes: 5 additions & 223 deletions tests/io/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,228 +1,9 @@
import os
import shlex
import shutil
import subprocess
import time
import uuid
from pathlib import Path

import fsspec
import pytest
from fsspec.implementations.local import LocalFileSystem, make_path_posix
from fsspec.registry import _registry, register_implementation
from fsspec.utils import stringify_path
import zarr
from spatialdata import SpatialData
from spatialdata.testing import assert_spatial_data_objects_are_identical
from upath import UPath
from upath.implementations.cloud import S3Path


## Mock setup from https://github.com/fsspec/universal_pathlib/blob/main/upath/tests/conftest.py
def posixify(path):
return str(path).replace("\\", "/")


class DummyTestFS(LocalFileSystem):
protocol = "mock"
root_marker = "/"

@classmethod
def _strip_protocol(cls, path):
path = stringify_path(path)
if path.startswith("mock://"):
path = path[7:]
elif path.startswith("mock:"):
path = path[5:]
return make_path_posix(path).rstrip("/") or cls.root_marker


@pytest.fixture(scope="session")
def clear_registry():
register_implementation("mock", DummyTestFS)
try:
yield
finally:
_registry.clear()


@pytest.fixture(scope="session")
def s3_server():
# writable local S3 system
if "BOTO_CONFIG" not in os.environ: # pragma: no cover
os.environ["BOTO_CONFIG"] = "/dev/null"
if "AWS_ACCESS_KEY_ID" not in os.environ: # pragma: no cover
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
if "AWS_SECRET_ACCESS_KEY" not in os.environ: # pragma: no cover
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
if "AWS_SECURITY_TOKEN" not in os.environ: # pragma: no cover
os.environ["AWS_SECURITY_TOKEN"] = "testing"
if "AWS_SESSION_TOKEN" not in os.environ: # pragma: no cover
os.environ["AWS_SESSION_TOKEN"] = "testing"
if "AWS_DEFAULT_REGION" not in os.environ: # pragma: no cover
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"
requests = pytest.importorskip("requests")

pytest.importorskip("moto")

port = 5555
endpoint_uri = f"http://127.0.0.1:{port}/"
proc = subprocess.Popen(
shlex.split(f"moto_server -p {port}"),
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
try:
timeout = 5
while timeout > 0:
try:
r = requests.get(endpoint_uri, timeout=10)
if r.ok:
break
except requests.exceptions.RequestException: # pragma: no cover
pass
timeout -= 0.1 # pragma: no cover
time.sleep(0.1) # pragma: no cover
anon = False
s3so = {
"client_kwargs": {"endpoint_url": endpoint_uri},
"use_listings_cache": True,
}
yield anon, s3so
finally:
proc.terminate()
proc.wait()


@pytest.fixture(scope="function")
def s3_fixture(s3_server):
pytest.importorskip("s3fs")
anon, s3so = s3_server
s3 = fsspec.filesystem("s3", anon=False, **s3so)
random_name = uuid.uuid4().hex
bucket_name = f"test_{random_name}"
if s3.exists(bucket_name):
for dir, _, keys in s3.walk(bucket_name):
for key in keys:
s3.rm(f"{dir}/{key}")
else:
s3.mkdir(bucket_name)
# for x in Path(local_testdir).glob("**/*"):
# target_path = f"{bucket_name}/{posixify(x.relative_to(local_testdir))}"
# if x.is_file():
# s3.upload(str(x), target_path)
s3.invalidate_cache()
yield f"s3://{bucket_name}", anon, s3so


@pytest.fixture(scope="session")
def http_server(tmp_path_factory):
http_tempdir = tmp_path_factory.mktemp("http")

requests = pytest.importorskip("requests")
pytest.importorskip("http.server")
proc = subprocess.Popen(shlex.split(f"python -m http.server --directory {http_tempdir} 8080"))
try:
url = "http://127.0.0.1:8080/folder"
path = Path(http_tempdir) / "folder"
path.mkdir()
timeout = 10
while True:
try:
r = requests.get(url, timeout=10)
if r.ok:
yield path, url
break
except requests.exceptions.RequestException as e: # noqa: E722
timeout -= 1
if timeout < 0:
raise SystemError from e
time.sleep(1)
finally:
proc.terminate()
proc.wait()


@pytest.fixture
def http_fixture(local_testdir, http_server):
http_path, http_url = http_server
shutil.rmtree(http_path)
shutil.copytree(local_testdir, http_path)
yield http_url


class TestRemote:
@pytest.fixture(scope="function")
def upath(self, s3_fixture):
path, anon, s3so = s3_fixture
return UPath(path, anon=anon, **s3so)

def test_is_S3Path(self, upath):
assert isinstance(upath, S3Path)

# # Test UPath with Moto Mocking
def test_creating_file(self, upath):
file_name = "file1"
p1 = upath / file_name
p1.touch()
contents = [p.name for p in upath.iterdir()]
assert file_name in contents

def test_images(self, upath: UPath, images: SpatialData) -> None:
tmpdir = upath / "tmp.zarr"
images.write(tmpdir)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(images, sdata)

def test_labels(self, upath: UPath, labels: SpatialData) -> None:
tmpdir = upath / "tmp.zarr"
labels.write(tmpdir)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(labels, sdata)

def test_shapes(self, upath: UPath, shapes: SpatialData) -> None:
import numpy as np

tmpdir = Path(upath) / "tmp.zarr"

# check the index is correctly written and then read
shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1)

shapes.write(tmpdir)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(shapes, sdata)

def test_points(self, upath: UPath, points: SpatialData) -> None:
import dask.dataframe as dd
import numpy as np

tmpdir = upath / "tmp.zarr"

# check the index is correctly written and then read
new_index = dd.from_array(np.arange(1, len(points["points_0"]) + 1))
points["points_0"] = points["points_0"].set_index(new_index)

points.write(tmpdir)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(points, sdata)

def _test_table(self, upath: UPath, table: SpatialData) -> None:
tmpdir = upath / "tmp.zarr"
table.write(tmpdir)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(table, sdata)

def test_single_table_single_annotation(self, upath: UPath, table_single_annotation: SpatialData) -> None:
self._test_table(upath, table_single_annotation)

def test_single_table_multiple_annotations(self, upath: UPath, table_multiple_annotations: SpatialData) -> None:
self._test_table(upath, table_multiple_annotations)

def test_full_sdata(self, upath: UPath, full_sdata: SpatialData) -> None:
tmpdir = upath / "tmp.zarr"
full_sdata.write(tmpdir)
sdata = SpatialData.read(tmpdir)
assert_spatial_data_objects_are_identical(full_sdata, sdata)

# Test actual remote datasets from https://spatialdata.scverse.org/en/latest/tutorials/notebooks/datasets/README.html
@pytest.fixture(params=["merfish", "mibitof"])
def s3_address(self, request):
Expand All @@ -232,8 +13,9 @@ def s3_address(self, request):
}
return urls[request.param]

# TODO: does not work, maybe because the reader thinks it's a HTTP webserver instead of S3 over HTTP?
# @pytest.mark.skip
# TODO: does not work, problem with opening remote parquet
@pytest.mark.xfail(reason="Problem with opening remote parquet")
def test_remote(self, s3_address):
sdata = SpatialData.read(s3_address)
root = zarr.open_consolidated(s3_address, mode="r", metadata_key="zmetadata")
sdata = SpatialData.read(root)
assert len(list(sdata.gen_elements())) > 0
Loading

0 comments on commit e3e0c28

Please sign in to comment.