Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accessor Functionality for AnnData #1870

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,25 @@ Types used by the former:
experimental.StorageType
```

## Extensions

```{eval-rst}
.. autosummary::
:toctree: generated/

register_anndata_namespace

```

Types used by the former:

```{eval-rst}
.. autosummary::
:toctree: generated/

_types.ExtensionNamespace
```

## Errors and warnings

```{eval-rst}
Expand Down
3 changes: 3 additions & 0 deletions src/anndata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


from ._core.anndata import AnnData
from ._core.extensions import ExtensionNamespace, register_anndata_namespace
from ._core.merge import concat
from ._core.raw import Raw
from ._settings import settings
Expand Down Expand Up @@ -68,11 +69,13 @@ def __getattr__(attr_name: str) -> Any:
# Classes
"AnnData",
"Raw",
"ExtensionNamespace",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a typing.py file so you don't have to reexport

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this resolved?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove it from there. Should the protocol go in src/anndata/_types.py or src/anndata/typing.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@flying-sheep can answer this best, but I think _types.py actually

# Functions
"concat",
"read_zarr",
"read_h5ad",
"read",
"register_anndata_namespace",
# Warnings
"OldFormatWarning",
"WriteWarning",
Expand Down
4 changes: 3 additions & 1 deletion src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable
from os import PathLike
from typing import Any, Literal
from typing import Any, ClassVar, Literal

from ..compat import Index1D
from ..typing import ArrayDataStructureType
Expand Down Expand Up @@ -193,6 +193,8 @@ class AnnData(metaclass=utils.DeprecationMixinMeta):
var={"var_names", "col_names", "index"},
)

_accessors: ClassVar[set[str]] = set()

@old_positionals(
"obsm",
"varm",
Expand Down
269 changes: 269 additions & 0 deletions src/anndata/_core/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
from __future__ import annotations

import inspect
from pathlib import Path
from typing import (
TYPE_CHECKING,
get_type_hints,
)
from warnings import warn

from anndata import AnnData
from anndata._types import ExtensionNamespace

if TYPE_CHECKING:
from collections.abc import Callable
import anndata as ad

# Based off of the extension framework in Polars
# https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py

__all__ = ["register_anndata_namespace"]


def find_stacklevel() -> int:
"""
Find the first place in the stack that is not inside AnnData.

Taken from:
https://github.com/pola-rs/polars/blob/main/py-polars/polars/_utils/various.py#L447
"""

pkg_dir = str(Path(ad.__file__).parent)

# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
try:
while frame:
fname = inspect.getfile(frame)
if fname.startswith(pkg_dir) or (
(qualname := getattr(frame.f_code, "co_qualname", None))
# ignore @singledispatch wrappers
and qualname.startswith("singledispatch.")
):
frame = frame.f_back
n += 1
else:
break
finally:
# https://docs.python.org/3/library/inspect.html
# > Though the cycle detector will catch these, destruction of the frames
# > (and local variables) can be made deterministic by removing the cycle
# > in a finally clause.
del frame
return n


# Reserved namespaces include accessors built into AnnData (currently there are none)
# and all current attributes of AnnData
_reserved_namespaces: set[str] = set(AnnData._accessors) | set(dir(ad.AnnData))


class AccessorNameSpace(ExtensionNamespace):
"""Establish property-like namespace object for user-defined functionality."""

def __init__(self, name: str, namespace: type[ExtensionNamespace]) -> None:
self._accessor = name
self._ns = namespace

def __get__(
self, instance: ExtensionNamespace | None, cls: type[ExtensionNamespace]
) -> ExtensionNamespace | type[ExtensionNamespace]:
if instance is None:
return self._ns

ns_instance = self._ns(instance) # type: ignore[call-arg]
setattr(instance, self._accessor, ns_instance)
return ns_instance


def _check_namespace_signature(ns_class: type) -> None:
"""Validate the signature of a namespace class for AnnData extensions.

This function ensures that any class intended to be used as an extension namespace
has a properly formatted `__init__` method such that:

1. Accepts at least two parameters (self and adata)
2. Has 'adata' as the name of the second parameter
3. Has the second parameter properly type-annotated as 'AnnData' or any equivalent import alias

The function performs runtime validation of these requirements before a namespace
can be registered through the `register_anndata_namespace` decorator.

Parameters
----------
ns_class
The namespace class to validate.

Raises
------
TypeError
If the `__init__` method has fewer than 2 parameters (missing the AnnData parameter).
AttributeError
If the second parameter of `__init__` lacks a type annotation.
TypeError
If the second parameter of `__init__` is not named 'adata'.
TypeError
If the second parameter of `__init__` is not annotated as the 'AnnData' class.
TypeError
If both the name and type annotation of the second parameter are incorrect.

"""
sig = inspect.signature(ns_class.__init__)
params = list(sig.parameters.values())

# Ensure there are at least two parameters (self and adata)
if len(params) < 2:
error_msg = "Namespace initializer must accept an AnnData instance as the second parameter."
raise TypeError(error_msg)

# Get the second parameter (expected to be 'adata')
param = params[1]
if param.annotation is inspect._empty:
error_msg = "Namespace initializer's second parameter must be annotated as the 'AnnData' class."
raise AttributeError(error_msg)

name_ok = param.name == "adata"

# Resolve the annotation using get_type_hints to handle forward references and aliases.
try:
type_hints = get_type_hints(ns_class.__init__)
resolved_type = type_hints.get(param.name, param.annotation)
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the blanket except here with something more precise

resolved_type = param.annotation

Check warning on line 134 in src/anndata/_core/extensions.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/extensions.py#L133-L134

Added lines #L133 - L134 were not covered by tests

type_ok = resolved_type is ad.AnnData

match (name_ok, type_ok):
case (True, True):
return # Signature is correct.
case (False, True):
msg = f"Namespace initializer's second parameter must be named 'adata', got '{param.name}'."
raise TypeError(msg)
case (True, False):
type_repr = getattr(resolved_type, "__name__", str(resolved_type))
msg = f"Namespace initializer's second parameter must be annotated as the 'AnnData' class, got '{type_repr}'."
raise TypeError(msg)
case _:
type_repr = getattr(resolved_type, "__name__", str(resolved_type))
msg = (
f"Namespace initializer's second parameter must be named 'adata', got '{param.name}'. "
f"And must be annotated as 'AnnData', got '{type_repr}'."
)
raise TypeError(msg)


def _create_namespace(name: str, cls: type[AnnData]) -> Callable[[type], type]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more type usage that shoudl be ExtensionNamespace (there are others as well)

"""Register custom namespace against the underlying AnnData class."""

def namespace(ns_class: type) -> type:
_check_namespace_signature(ns_class) # Perform the runtime signature check
if name in _reserved_namespaces:
msg = f"cannot override reserved attribute {name!r}"
raise AttributeError(msg)
elif hasattr(cls, name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think we want to disallow this behavior in some form. I wouldn't want someone overriding X. Maybe I'm misunderstanding, but I also don't see a test for this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, just tested it out and it absolutely overrides X. What attributes should be protected, should I just do all of them with dir(AnnData)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless there's a reason to allow overriding already existing attributes, why not just throw an attribute error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this was resolved but it's still just a warning?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using Jupyter notebooks, raising an AttributeError for overriding custom namespaces can be disruptive. When you modify and re-register a namespace, the user has to frequently restart their kernel to reset the accessor namespaces, because the namespace already exists. While protecting against conflicts with core AnnData attributes (like X, obs_names, etc...) is good, being able to overwrite existing custom namespaces would be nice for iterative development in notebooks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But as things stand, won't this function simply warn if you try to override X i.e., elif hasattr(cls, name)?

warn(
f"Overriding existing custom namespace {name!r} (on {cls.__name__!r})",
UserWarning,
stacklevel=find_stacklevel(),
)
setattr(cls, name, AccessorNameSpace(name, ns_class))
cls._accessors.add(name)
return ns_class

return namespace


def register_anndata_namespace(
name: str,
) -> Callable[[type[ExtensionNamespace]], type[ExtensionNamespace]]:
"""Decorator for registering custom functionality with an :class:`~anndata.AnnData` object.

This decorator allows you to extend AnnData objects with custom methods and properties
organized under a namespace. The namespace becomes accessible as an attribute on AnnData
instances, providing a clean way to you to add domain-specific functionality without modifying
the AnnData class itself, or extending the class with additional methods as you see fit in your workflow.

Parameters
----------
name
Name under which the accessor should be registered. This will be the attribute name
used to access your namespace's functionality on AnnData objects (e.g., `adata.{name}`).
Cannot conflict with existing AnnData attributes like `obs`, `var`, `X`, etc. The list of reserved
attributes includes everything outputted by `dir(AnnData)`.

Returns
-------
A decorator that registers the decorated class as a custom namespace.

Notes
-----
Implementation requirements:

1. The decorated class must have an `__init__` method that accepts exactly one parameter
(besides `self`) named `adata` and annotated with type :class:`~anndata.AnnData`.
2. The namespace will be initialized with the AnnData object on first access and then
cached on the instance.
3. If the namespace name conflicts with an existing namespace, a warning is issued.
4. If the namespace name conflicts with a built-in AnnData attribute, an AttributeError is raised.

Examples
--------
Simple transformation namespace with two methods:

>>> import anndata as ad
>>> import numpy as np
>>>
>>> @ad.register_anndata_namespace("transform")
... class TransformX:
... def __init__(self, adata: ad.AnnData):
... self._adata = adata
...
... def log1p(
... self, layer: str = None, inplace: bool = False
... ) -> ad.AnnData | None:
... '''Log1p transform the data.'''
... data = self._adata.layers[layer] if layer else self._adata.X
... log1p_data = np.log1p(data)
...
... if layer:
... layer_name = f"{layer}_log1p" if not inplace else layer
... else:
... layer_name = "log1p"
...
... self._adata.layers[layer_name] = log1p_data
...
... if not inplace:
... return self._adata
...
... def arcsinh(
... self, layer: str = None, scale: float = 1.0, inplace: bool = False
... ) -> ad.AnnData | None:
... '''Arcsinh transform the data with optional scaling.'''
... data = self._adata.layers[layer] if layer else self._adata.X
... asinh_data = np.arcsinh(data / scale)
...
... if layer:
... layer_name = f"{layer}_arcsinh" if not inplace else layer
... else:
... layer_name = "arcsinh"
...
... self._adata.layers[layer_name] = asinh_data
...
... if not inplace:
... return self._adata
>>>
>>> # Create an AnnData object
>>> rng = np.random.default_rng(42)
>>> adata = ad.AnnData(X=rng.poisson(1, size=(100, 2000)))
>>>
>>> # Use the registered namespace
>>> adata.transform.log1p() # Transforms X and returns the AnnData object
AnnData object with n_obs × n_vars = 100 × 2000
layers: 'log1p'
>>> adata.transform.arcsinh() # Transforms X and returns the AnnData object
AnnData object with n_obs × n_vars = 100 × 2000
layers: 'log1p', 'arcsinh'
"""
return _create_namespace(name, ad.AnnData)
24 changes: 23 additions & 1 deletion src/anndata/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Protocol, TypeVar
from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable

from .compat import (
H5Array,
Expand All @@ -18,6 +18,8 @@
from collections.abc import Mapping
from typing import Any, TypeAlias

from anndata._core.anndata import AnnData

from ._io.specs.registry import DaskReader, IOSpec, Reader, Writer
from .compat import DaskArray

Expand Down Expand Up @@ -186,3 +188,23 @@
Keyword arguments to be passed to a library-level io function, like `chunks` for :doc:`zarr:index`.
"""
...


NS = TypeVar("NS", covariant=True)


@runtime_checkable
class ExtensionNamespace(Protocol[NS]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the NS generic?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thought it was necessary with the protocol, but looks like it doesn't really do anything.

"""Protocol for extension namespaces.

Enforces that the namespace initializer accepts a class with the proper `__init__` method.
Protocol's can't enforce that the `__init__` accepts the correct types. See
`_check_namespace_signature` for that. This is mainly useful for static type
checking with mypy and IDEs.
"""

def __init__(self, adata: AnnData) -> None:
"""
Used to enforce the correct signature for extension namespaces.
"""
...

Check warning on line 210 in src/anndata/_types.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_types.py#L210

Added line #L210 was not covered by tests
Loading