Skip to content

Commit

Permalink
remove generic
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Dec 18, 2023
1 parent 1c8e4ea commit e5ab729
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/ome_types/_mixins/_structured_annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Generic, Iterator, TypeVar, Union, cast, no_type_check
from typing import Any, Iterator, Union, cast, no_type_check

# for circular import reasons...
from ome_types._autogenerated.ome_2016_06.boolean_annotation import BooleanAnnotation
Expand All @@ -16,10 +16,9 @@
)
from ome_types._autogenerated.ome_2016_06.xml_annotation import XMLAnnotation

T = TypeVar("T")


class CollectionMixin(Generic[T]):
# This would ideally be a generic, but that's proven tricky with pydantic
class CollectionMixin:
"""Mixin to be used for classes that behave like collections.
Notably: ShapeUnion and StructuredAnnotations.
Expand All @@ -28,22 +27,22 @@ class CollectionMixin(Generic[T]):
"""

@no_type_check
def __iter__(self) -> Iterator[T]:
def __iter__(self) -> Iterator[Any]:
return itertools.chain(*(getattr(self, f) for f in self.model_fields))

def __len__(self) -> int:
return sum(1 for _ in self)

def append(self, item: T) -> None:
def append(self, item: Any) -> None:
"""Append an item to the appropriate field list."""
cast(list, getattr(self, self._field_name(item))).append(item)

def remove(self, item: T) -> None:
def remove(self, item: Any) -> None:
"""Remove an item from the appropriate field list."""
cast(list, getattr(self, self._field_name(item))).remove(item)

# This one is a bit hacky... perhaps deprecate and remove
def __getitem__(self, i: int) -> T:
def __getitem__(self, i: int) -> Any:
# return the ith item in the __iter__ sequence
return next(itertools.islice(self, i, None))

Expand All @@ -54,7 +53,7 @@ def __eq__(self, _value: object) -> bool:
return super().__eq__(_value)

@classmethod
def _field_name(cls, item: T) -> str:
def _field_name(cls, item: Any) -> str:
"""Return the name of the field that should contain the given item.
Must be implemented by subclasses.
Expand All @@ -81,11 +80,9 @@ def _field_name(cls, item: T) -> str:
AnnotationInstances = AnnotationType.__args__ # type: ignore


class StructuredAnnotationsMixin(CollectionMixin[AnnotationType]):
...

class StructuredAnnotationsMixin(CollectionMixin):
@classmethod
def _field_name(cls, item: T) -> str:
def _field_name(cls, item: Any) -> str:
if not isinstance(item, AnnotationInstances):
raise TypeError( # pragma: no cover
f"Expected an instance of {AnnotationInstances}, got {item!r}"
Expand Down

0 comments on commit e5ab729

Please sign in to comment.