Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accessor Functionality for AnnData #1870

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ Types used by the former:
experimental.StorageType
```

## Extensions

```{eval-rst}
.. autosummary::
:toctree: generated/

register_anndata_namespace
```

## Errors and warnings

```{eval-rst}
Expand Down
2 changes: 2 additions & 0 deletions src/anndata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


from ._core.anndata import AnnData
from ._core.extensions import register_anndata_namespace
from ._core.merge import concat
from ._core.raw import Raw
from ._settings import settings
Expand Down Expand Up @@ -78,4 +79,5 @@ def __getattr__(attr_name: str) -> Any:
"WriteWarning",
"ImplicitModificationWarning",
"ExperimentalFeatureWarning",
"register_anndata_namespace",
]
4 changes: 3 additions & 1 deletion src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable
from os import PathLike
from typing import Any, Literal
from typing import Any, ClassVar, Literal

from ..compat import Index1D
from ..typing import ArrayDataStructureType
Expand Down Expand Up @@ -193,6 +193,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
var={"var_names", "col_names", "index"},
)

_accessors: ClassVar[set[str]] = set()

@old_positionals(
"obsm",
"varm",
Expand Down
139 changes: 139 additions & 0 deletions src/anndata/_core/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import inspect

# from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Generic, TypeVar
from warnings import warn

if TYPE_CHECKING:
from collections.abc import Callable
from anndata import AnnData

# Based off of the extension framework in Polars
# https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py

__all__ = ["register_anndata_namespace"]


def find_stacklevel() -> int:
"""
Find the first place in the stack that is not inside AnnData.

Taken from:
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
"""
import anndata as ad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why an internal-to-the-function import?

Copy link
Author

@srivarra srivarra Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason in particular, it was just what Polars used. As long as the function gets the location of the anndata/__init__.py file it'll work. Moving it outside the function should be fine since there are not any circular imports.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why resolved without change? There was a circular import?


pkg_dir = str(Path(ad.__file__).parent)

# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
try:
while frame:
fname = inspect.getfile(frame)
if fname.startswith(pkg_dir) or (
(qualname := getattr(frame.f_code, "co_qualname", None))
# ignore @singledispatch wrappers
and qualname.startswith("singledispatch.")
):
frame = frame.f_back
n += 1
else:
break
finally:
# https://docs.python.org/3/library/inspect.html
# > Though the cycle detector will catch these, destruction of the frames
# > (and local variables) can be made deterministic by removing the cycle
# > in a finally clause.
del frame
return n


NS = TypeVar("NS")


# Currently, there are no reserved namespaces internally, but if there ever are,
# this will not allow them to be overridden.
_reserved_namespaces: set[str] = set.union(*(cls._accessors for cls in (AnnData,)))


class AccessorNameSpace(Generic[NS]):
"""Establish property-like namespace object for user-defined functionality."""

def __init__(self, name: str, namespace: type[NS]) -> None:
self._accessor = name
self._ns = namespace

def __get__(self, instance: NS | None, cls: type[NS]) -> NS | type[NS]:
if instance is None:
return self._ns

ns_instance = self._ns(instance) # type: ignore[call-arg]
setattr(instance, self._accessor, ns_instance)
return ns_instance


def _create_namespace(name: str, cls: type[AnnData]) -> Callable[[type[NS]], type[NS]]:
"""Register custom namespace against the underlying AnnData class."""

def namespace(ns_class: type[NS]) -> type[NS]:
if name in _reserved_namespaces:
msg = f"cannot override reserved namespace {name!r}"
raise AttributeError(msg)

elif hasattr(cls, name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think we want to disallow this behavior in some form. I wouldn't want someone overriding X. Maybe I'm misunderstanding, but I also don't see a test for this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, just tested it out and it absolutely overrides X. What attributes should be protected, should I just do all of them with dir(AnnData)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless there's a reason to allow overriding already existing attributes, why not just throw an attribute error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this was resolved but it's still just a warning?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using Jupyter notebooks, raising an AttributeError for overriding custom namespaces can be disruptive. When you modify and re-register a namespace, the user has to frequently restart their kernel to reset the accessor namespaces, because the namespace already exists. While protecting against conflicts with core AnnData attributes (like X, obs_names, etc...) is good, being able to overwrite existing custom namespaces would be nice for iterative development in notebooks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But as things stand, won't this function simply warn if you try to override X i.e., elif hasattr(cls, name)?

warn(
f"Overriding existing custom namespace {name!r} (on {cls.__name__!r})",
UserWarning,
stacklevel=find_stacklevel(),
)

setattr(cls, name, AccessorNameSpace(name, ns_class))
cls._accessors.add(name)
return ns_class

return namespace


def register_anndata_namespace(name: str) -> Callable[[type[NS]], type[NS]]:
"""Decorator for registering custom functionality with an :class:`~anndata.AnnData` object.

Parameters
----------
name
Name under which the accessor should be registered. A warning is issued
if this name conflicts with a preexisting attribute.

Examples
--------
>>> import anndata as ad
>>> from scipy.sparse import csr_matrix
>>> import numpy as np
>>>
>>>
>>> @ad.register_anndata_namespace("transforms")
... class TransformX:
... def __init__(self, adata: ad.AnnData):
... self._adata = adata
...
... def arcsinh_cofactor(
... self, shift: float, scale: float, layer: str, inplace: bool = False
... ) -> ad.AnnData:
... self._adata.layers[layer] = (
... np.arcsinh(self._adata.X.toarray() / scale) + shift
... )
... return None if inplace else self._adata
>>>
>>> rng = np.random.default_rng(42)
>>> adata = ad.AnnData(
... X=csr_matrix(rng.poisson(1, size=(100, 2000)), dtype=np.float32),
... )
>>> adata.transforms.arcsinh_cofactor(1, 1, "arcsinh", inplace=True)
>>> adata
AnnData object with n_obs × n_vars = 100 × 2000
layers: 'arcsinh'
"""
return _create_namespace(name, AnnData)
134 changes: 134 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

import numpy as np
import pytest

import anndata as ad
from anndata._core import extensions


def test_find_stacklevel():
level = extensions.find_stacklevel()
assert isinstance(level, int)
# It should be at least 1, otherwise something is wrong.
assert level > 0


def test_accessor_namespace():
"""Test the behavior of the AccessorNameSpace descriptor.

This test verifies that:
- When accessed at the class level (i.e., without an instance), the descriptor
returns the namespace type.
- When accessed via an instance, the descriptor instantiates the namespace,
passing the instance to its constructor.
- The instantiated namespace is then cached on the instance such that subsequent
accesses of the same attribute return the cached namespace instance.
"""

# Define a dummy namespace class to be used via the descriptor.
class DummyNamespace:
def __init__(self, instance):
self.instance = instance

def foo(self):
return "foo"

class Dummy:
pass

descriptor = extensions.AccessorNameSpace("dummy", DummyNamespace)

# When accessed on the class, it should return the namespace type.
ns_class = descriptor.__get__(None, Dummy)
assert ns_class is DummyNamespace

# When accessed via an instance, it should instantiate DummyNamespace.
dummy_obj = Dummy()
ns_instance = descriptor.__get__(dummy_obj, Dummy)
assert isinstance(ns_instance, DummyNamespace)
assert ns_instance.instance is dummy_obj

# __get__ should cache the namespace instance on the object.
# Subsequent access should return the same cached instance.
assert getattr(dummy_obj, "dummy") is ns_instance


def test_register_namespace(monkeypatch):
"""Test the behavior of the register_anndata_namespace decorator.

This test verifies that:
- A new namespace can be registered successfully.
- The accessor is available on AnnData instances.
- The accessor is cached on the AnnData instance.
- An warning is raised if the namespace name is overridden.
"""

original_dummy = getattr(ad.AnnData, "dummy", None)

# Register a new namespace called 'dummy'.
@extensions.register_anndata_namespace("dummy")
class DummyNamespace:
def __init__(self, adata: ad.AnnData):
self.adata = adata

def greet(self) -> str:
return "hello"

# Create an AnnData instance with minimal data.
rng = np.random.default_rng(42)
adata = ad.AnnData(X=rng.poisson(1, size=(10, 10)))

# The accessor should now be available.
ns_instance = adata.dummy
assert ns_instance.adata is adata
assert ns_instance.greet() == "hello"

# Verify caching behavior on the AnnData instance.
assert adata.dummy is ns_instance

# Now, override the same namespace and check that a warning is emitted.
with pytest.warns(
UserWarning, match="Overriding existing custom namespace 'dummy'"
):

@extensions.register_anndata_namespace("dummy")
class DummyNamespaceOverride:
def __init__(self, adata: ad.AnnData):
self.adata = adata

def greet(self) -> str:
# Return a different string to confirm the override.
return "world"

# A new AnnData instance should now use the overridden accessor.
adata2 = ad.AnnData(X=rng.poisson(1, size=(10, 10)))
assert adata2.dummy.greet() == "world"

# Clean up by restoring any previously existing attribute.
if original_dummy is not None:
setattr(ad.AnnData, "dummy", original_dummy)
else:
if hasattr(ad.AnnData, "dummy"):
delattr(ad.AnnData, "dummy")


def test_register_reserved_namespace(monkeypatch):
"""
Check that attempting to register a namespace with a reserved name
raises an AttributeError.
"""
reserved_name = "reserved_namespace"

# Create a new reserved set that includes our test name.
new_reserved = extensions._reserved_namespaces.union({reserved_name})
monkeypatch.setattr(extensions, "_reserved_namespaces", new_reserved)

with pytest.raises(
AttributeError, match=f"cannot override reserved namespace {reserved_name!r}"
):

@extensions.register_anndata_namespace(reserved_name)
class DummyNamespace:
def __init__(self, adata: ad.AnnData):
self.adata = adata
Loading