-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accessor Functionality for AnnData #1870
base: main
Are you sure you want to change the base?
Changes from 5 commits
790b211
105b155
a87a8e9
59d545b
8fa883d
f673c2b
9f6dc2a
00cff10
0121106
c85f38b
db2844b
16b49bc
657aa07
035df3d
f658fdd
ec4c9b7
db71c52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
|
||
# from collections.abc import Callable | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Generic, TypeVar | ||
from warnings import warn | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable | ||
from anndata import AnnData | ||
|
||
# Based off of the extension framework in Polars | ||
# https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py | ||
|
||
__all__ = ["register_anndata_namespace"] | ||
|
||
|
||
def find_stacklevel() -> int: | ||
""" | ||
Find the first place in the stack that is not inside AnnData. | ||
|
||
Taken from: | ||
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51 | ||
""" | ||
import anndata as ad | ||
|
||
pkg_dir = str(Path(ad.__file__).parent) | ||
|
||
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow | ||
frame = inspect.currentframe() | ||
n = 0 | ||
try: | ||
while frame: | ||
fname = inspect.getfile(frame) | ||
if fname.startswith(pkg_dir) or ( | ||
(qualname := getattr(frame.f_code, "co_qualname", None)) | ||
# ignore @singledispatch wrappers | ||
and qualname.startswith("singledispatch.") | ||
): | ||
frame = frame.f_back | ||
n += 1 | ||
else: | ||
break | ||
finally: | ||
# https://docs.python.org/3/library/inspect.html | ||
# > Though the cycle detector will catch these, destruction of the frames | ||
# > (and local variables) can be made deterministic by removing the cycle | ||
# > in a finally clause. | ||
del frame | ||
return n | ||
|
||
|
||
NS = TypeVar("NS") | ||
|
||
|
||
# Currently, there are no reserved namespaces internally, but if there ever are, | ||
# this will not allow them to be overridden. | ||
_reserved_namespaces: set[str] = set.union(*(cls._accessors for cls in (AnnData,))) | ||
srivarra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class AccessorNameSpace(Generic[NS]): | ||
"""Establish property-like namespace object for user-defined functionality.""" | ||
|
||
def __init__(self, name: str, namespace: type[NS]) -> None: | ||
self._accessor = name | ||
self._ns = namespace | ||
|
||
def __get__(self, instance: NS | None, cls: type[NS]) -> NS | type[NS]: | ||
if instance is None: | ||
return self._ns | ||
|
||
ns_instance = self._ns(instance) # type: ignore[call-arg] | ||
setattr(instance, self._accessor, ns_instance) | ||
return ns_instance | ||
|
||
|
||
def _create_namespace(name: str, cls: type[AnnData]) -> Callable[[type[NS]], type[NS]]: | ||
"""Register custom namespace against the underlying AnnData class.""" | ||
|
||
def namespace(ns_class: type[NS]) -> type[NS]: | ||
srivarra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if name in _reserved_namespaces: | ||
msg = f"cannot override reserved namespace {name!r}" | ||
raise AttributeError(msg) | ||
|
||
elif hasattr(cls, name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would think we want to disallow this behavior in some form. I wouldn't want someone overriding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah, just tested it out and it absolutely overrides X. What attributes should be protected, should I just do all of them with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless there's a reason to allow overriding already existing attributes, why not just throw an attribute error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this was resolved but it's still just a warning? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When using Jupyter notebooks, raising an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But as things stand, won't this function simply warn if you try to override |
||
warn( | ||
f"Overriding existing custom namespace {name!r} (on {cls.__name__!r})", | ||
UserWarning, | ||
stacklevel=find_stacklevel(), | ||
) | ||
|
||
setattr(cls, name, AccessorNameSpace(name, ns_class)) | ||
cls._accessors.add(name) | ||
return ns_class | ||
|
||
return namespace | ||
|
||
|
||
def register_anndata_namespace(name: str) -> Callable[[type[NS]], type[NS]]: | ||
"""Decorator for registering custom functionality with an :class:`~anndata.AnnData` object. | ||
|
||
Parameters | ||
---------- | ||
name | ||
Name under which the accessor should be registered. A warning is issued | ||
if this name conflicts with a preexisting attribute. | ||
|
||
Examples | ||
-------- | ||
>>> import anndata as ad | ||
>>> from scipy.sparse import csr_matrix | ||
>>> import numpy as np | ||
>>> | ||
>>> | ||
>>> @ad.register_anndata_namespace("transforms") | ||
... class TransformX: | ||
... def __init__(self, adata: ad.AnnData): | ||
... self._adata = adata | ||
... | ||
... def arcsinh_cofactor( | ||
... self, shift: float, scale: float, layer: str, inplace: bool = False | ||
... ) -> ad.AnnData: | ||
... self._adata.layers[layer] = ( | ||
... np.arcsinh(self._adata.X.toarray() / scale) + shift | ||
... ) | ||
... return None if inplace else self._adata | ||
>>> | ||
>>> rng = np.random.default_rng(42) | ||
>>> adata = ad.AnnData( | ||
... X=csr_matrix(rng.poisson(1, size=(100, 2000)), dtype=np.float32), | ||
... ) | ||
>>> adata.transforms.arcsinh_cofactor(1, 1, "arcsinh", inplace=True) | ||
>>> adata | ||
AnnData object with n_obs × n_vars = 100 × 2000 | ||
layers: 'arcsinh' | ||
""" | ||
return _create_namespace(name, AnnData) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import anndata as ad | ||
from anndata._core import extensions | ||
|
||
|
||
def test_find_stacklevel(): | ||
level = extensions.find_stacklevel() | ||
assert isinstance(level, int) | ||
# It should be at least 1, otherwise something is wrong. | ||
assert level > 0 | ||
|
||
|
||
def test_accessor_namespace(): | ||
"""Test the behavior of the AccessorNameSpace descriptor. | ||
|
||
This test verifies that: | ||
- When accessed at the class level (i.e., without an instance), the descriptor | ||
returns the namespace type. | ||
- When accessed via an instance, the descriptor instantiates the namespace, | ||
passing the instance to its constructor. | ||
- The instantiated namespace is then cached on the instance such that subsequent | ||
accesses of the same attribute return the cached namespace instance. | ||
""" | ||
|
||
# Define a dummy namespace class to be used via the descriptor. | ||
class DummyNamespace: | ||
def __init__(self, instance): | ||
self.instance = instance | ||
|
||
def foo(self): | ||
return "foo" | ||
|
||
class Dummy: | ||
pass | ||
|
||
descriptor = extensions.AccessorNameSpace("dummy", DummyNamespace) | ||
|
||
# When accessed on the class, it should return the namespace type. | ||
ns_class = descriptor.__get__(None, Dummy) | ||
assert ns_class is DummyNamespace | ||
|
||
# When accessed via an instance, it should instantiate DummyNamespace. | ||
dummy_obj = Dummy() | ||
ns_instance = descriptor.__get__(dummy_obj, Dummy) | ||
assert isinstance(ns_instance, DummyNamespace) | ||
assert ns_instance.instance is dummy_obj | ||
|
||
# __get__ should cache the namespace instance on the object. | ||
# Subsequent access should return the same cached instance. | ||
assert getattr(dummy_obj, "dummy") is ns_instance | ||
|
||
|
||
def test_register_namespace(monkeypatch): | ||
"""Test the behavior of the register_anndata_namespace decorator. | ||
|
||
This test verifies that: | ||
- A new namespace can be registered successfully. | ||
- The accessor is available on AnnData instances. | ||
- The accessor is cached on the AnnData instance. | ||
- An warning is raised if the namespace name is overridden. | ||
""" | ||
|
||
original_dummy = getattr(ad.AnnData, "dummy", None) | ||
|
||
# Register a new namespace called 'dummy'. | ||
@extensions.register_anndata_namespace("dummy") | ||
class DummyNamespace: | ||
def __init__(self, adata: ad.AnnData): | ||
self.adata = adata | ||
|
||
def greet(self) -> str: | ||
return "hello" | ||
|
||
# Create an AnnData instance with minimal data. | ||
rng = np.random.default_rng(42) | ||
adata = ad.AnnData(X=rng.poisson(1, size=(10, 10))) | ||
|
||
# The accessor should now be available. | ||
ns_instance = adata.dummy | ||
assert ns_instance.adata is adata | ||
assert ns_instance.greet() == "hello" | ||
|
||
# Verify caching behavior on the AnnData instance. | ||
assert adata.dummy is ns_instance | ||
|
||
# Now, override the same namespace and check that a warning is emitted. | ||
with pytest.warns( | ||
UserWarning, match="Overriding existing custom namespace 'dummy'" | ||
): | ||
|
||
@extensions.register_anndata_namespace("dummy") | ||
srivarra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class DummyNamespaceOverride: | ||
def __init__(self, adata: ad.AnnData): | ||
self.adata = adata | ||
|
||
def greet(self) -> str: | ||
# Return a different string to confirm the override. | ||
return "world" | ||
|
||
# A new AnnData instance should now use the overridden accessor. | ||
adata2 = ad.AnnData(X=rng.poisson(1, size=(10, 10))) | ||
assert adata2.dummy.greet() == "world" | ||
|
||
# Clean up by restoring any previously existing attribute. | ||
if original_dummy is not None: | ||
setattr(ad.AnnData, "dummy", original_dummy) | ||
else: | ||
if hasattr(ad.AnnData, "dummy"): | ||
delattr(ad.AnnData, "dummy") | ||
|
||
|
||
def test_register_reserved_namespace(monkeypatch): | ||
""" | ||
Check that attempting to register a namespace with a reserved name | ||
raises an AttributeError. | ||
""" | ||
reserved_name = "reserved_namespace" | ||
|
||
# Create a new reserved set that includes our test name. | ||
new_reserved = extensions._reserved_namespaces.union({reserved_name}) | ||
monkeypatch.setattr(extensions, "_reserved_namespaces", new_reserved) | ||
|
||
with pytest.raises( | ||
AttributeError, match=f"cannot override reserved namespace {reserved_name!r}" | ||
): | ||
|
||
@extensions.register_anndata_namespace(reserved_name) | ||
class DummyNamespace: | ||
def __init__(self, adata: ad.AnnData): | ||
self.adata = adata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why an internal-to-the-function import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason in particular, it was just what Polars used. As long as the function gets the location of the
anndata/__init__.py
file it'll work. Moving it outside the function should be fine since there are not any circular imports.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why resolved without change? There was a circular import?