Skip to content

Commit

Permalink
more generic mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Dec 18, 2023
1 parent 2674b37 commit 4d5dff5
Showing 1 changed file with 53 additions and 39 deletions.
92 changes: 53 additions & 39 deletions src/ome_types/_mixins/_structured_annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import TYPE_CHECKING, Iterator, Union, cast
from typing import Generic, Iterator, TypeVar, Union, cast, no_type_check

# for circular import reasons...
from ome_types._autogenerated.ome_2016_06.boolean_annotation import BooleanAnnotation
Expand All @@ -16,6 +16,54 @@
)
from ome_types._autogenerated.ome_2016_06.xml_annotation import XMLAnnotation

T = TypeVar("T")


class CollectionMixin(Generic[T]):
"""Mixin to be used for classes that behave like collections.
Notably: ShapeUnion and StructuredAnnotations.
All the fields in these types list[SomeType], and they collectively behave like
a list with the union of all field types.
"""

@no_type_check
def __iter__(self) -> Iterator[T]:
return itertools.chain(*(getattr(self, f) for f in self.model_fields))

def __len__(self) -> int:
return sum(1 for _ in self)

def append(self, item: T) -> None:
"""Append an item to the appropriate field list."""
cast(list, getattr(self, self._field_name(item))).append(item)

def remove(self, item: T) -> None:
"""Remove an item from the appropriate field list."""
cast(list, getattr(self, self._field_name(item))).remove(item)

# This one is a bit hacky... perhaps deprecate and remove
def __getitem__(self, i: int) -> T:
# return the ith item in the __iter__ sequence
return next(itertools.islice(self, i, None))

# perhaps deprecate and remove
def __eq__(self, _value: object) -> bool:
if isinstance(_value, list):
return list(self) == _value
return super().__eq__(_value)

@classmethod
def _field_name(cls, item: T) -> str:
"""Return the name of the field that should contain the given item.
Must be implemented by subclasses.
"""
raise NotImplementedError() # pragma: no cover


# ------------------------ StructuredAnnotations ------------------------

AnnotationType = Union[
XMLAnnotation,
FileAnnotation,
Expand All @@ -32,49 +80,15 @@
# get_args wasn't available until Python 3.8
AnnotationInstances = AnnotationType.__args__ # type: ignore

if TYPE_CHECKING:
from ome_types.model import StructuredAnnotations


class StructuredAnnotationsMixin:
def __iter__(self) -> Iterator[AnnotationType]: # type: ignore[override]
self = cast("StructuredAnnotations", self)
return itertools.chain(
self.xml_annotations,
self.file_annotations,
self.list_annotations,
self.long_annotations,
self.double_annotations,
self.comment_annotations,
self.boolean_annotations,
self.timestamp_annotations,
self.tag_annotations,
self.term_annotations,
self.map_annotations,
)

def __getitem__(self, i: int) -> AnnotationType:
# return the ith item in the __iter__ sequence
return next(itertools.islice(self, i, None))

def __len__(self) -> int:
return sum(1 for _ in self)

def append(self, item: AnnotationType) -> None:
getattr(self, self._field_name(item)).append(item)

def remove(self, item: AnnotationType) -> None:
getattr(self, self._field_name(item)).remove(item)

def __eq__(self, _value: object) -> bool:
if isinstance(_value, list):
return list(self) == _value
return super().__eq__(_value)
class StructuredAnnotationsMixin(CollectionMixin[AnnotationType]):
...

@classmethod
def _field_name(cls, item: AnnotationType) -> str:
def _field_name(cls, item: T) -> str:
if not isinstance(item, AnnotationInstances):
raise TypeError( # pragma: no cover
f"Expected an instance of {AnnotationInstances}, got {item!r}"
)
# where 10 is the length of "Annotation"
return item.__class__.__name__[:-10].lower() + "_annotations"

0 comments on commit 4d5dff5

Please sign in to comment.