diff --git a/manim/utils/iterables.py b/manim/utils/iterables.py index 4b20cfe704..9a688561ac 100644 --- a/manim/utils/iterables.py +++ b/manim/utils/iterables.py @@ -19,12 +19,32 @@ ] import itertools as it -from typing import Any, Callable, Collection, Generator, Iterable, Reversible, Sequence +from typing import ( + TYPE_CHECKING, + Callable, + Collection, + Generator, + Hashable, + Iterable, + Reversible, + Sequence, + TypeVar, + overload, +) import numpy as np +T = TypeVar("T") +U = TypeVar("U") +F = TypeVar("F", np.float_, np.int_) +H = TypeVar("H", bound=Hashable) -def adjacent_n_tuples(objects: Sequence, n: int) -> zip: + +if TYPE_CHECKING: + import numpy.typing as npt + + +def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, ...]]: """Returns the Sequence objects cyclically split into n length tuples. See Also @@ -33,18 +53,17 @@ def adjacent_n_tuples(objects: Sequence, n: int) -> zip: Examples -------- - Normal usage:: + .. code-block:: pycon - list(adjacent_n_tuples([1, 2, 3, 4], 2)) - # returns [(1, 2), (2, 3), (3, 4), (4, 1)] - - list(adjacent_n_tuples([1, 2, 3, 4], 3)) - # returns [(1, 2, 3), (2, 3, 4), (3, 4, 1), (4, 1, 2)] + >>> list(adjacent_n_tuples([1, 2, 3, 4], 2)) + [(1, 2), (2, 3), (3, 4), (4, 1)] + >>> list(adjacent_n_tuples([1, 2, 3, 4], 3)) + [(1, 2, 3), (2, 3, 4), (3, 4, 1), (4, 1, 2)] """ return zip(*([*objects[k:], *objects[:k]] for k in range(n))) -def adjacent_pairs(objects: Sequence) -> zip: +def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, ...]]: """Alias for ``adjacent_n_tuples(objects, 2)``. See Also @@ -53,15 +72,15 @@ def adjacent_pairs(objects: Sequence) -> zip: Examples -------- - Normal usage:: + .. code-block:: pycon - list(adjacent_pairs([1, 2, 3, 4])) - # returns [(1, 2), (2, 3), (3, 4), (4, 1)] + >>> list(adjacent_pairs([1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4), (4, 1)] """ return adjacent_n_tuples(objects, 2) -def all_elements_are_instances(iterable: Iterable, Class) -> bool: +def all_elements_are_instances(iterable: Iterable[object], Class: type[object]) -> bool: """Returns ``True`` if all elements of iterable are instances of Class. False otherwise. """ @@ -69,8 +88,8 @@ def all_elements_are_instances(iterable: Iterable, Class) -> bool: def batch_by_property( - items: Sequence, property_func: Callable -) -> list[tuple[list, Any]]: + items: Iterable[T], property_func: Callable[[T], U] +) -> list[tuple[list[T], U | None]]: """Takes in a Sequence, and returns a list of tuples, (batch, prop) such that all items in a batch have the same output when put into the Callable property_func, and such that chaining all these @@ -79,13 +98,13 @@ def batch_by_property( Examples -------- - Normal usage:: + .. code-block:: pycon - batch_by_property([(1, 2), (3, 4), (5, 6, 7), (8, 9)], len) - # returns [([(1, 2), (3, 4)], 2), ([(5, 6, 7)], 3), ([(8, 9)], 2)] + >>> batch_by_property([(1, 2), (3, 4), (5, 6, 7), (8, 9)], len) + [([(1, 2), (3, 4)], 2), ([(5, 6, 7)], 3), ([(8, 9)], 2)] """ - batch_prop_pairs = [] - curr_batch = [] + batch_prop_pairs: list[tuple[list[T], U | None]] = [] + curr_batch: list[T] = [] curr_prop = None for item in items: prop = property_func(item) @@ -103,67 +122,84 @@ def batch_by_property( return batch_prop_pairs -def concatenate_lists(*list_of_lists: Iterable) -> list: +def concatenate_lists(*list_of_lists: Iterable[T]) -> list[T]: """Combines the Iterables provided as arguments into one list. Examples -------- - Normal usage:: + .. code-block:: pycon - concatenate_lists([1, 2], [3, 4], [5]) - # returns [1, 2, 3, 4, 5] + >>> concatenate_lists([1, 2], [3, 4], [5]) + [1, 2, 3, 4, 5] """ return [item for lst in list_of_lists for item in lst] -def list_difference_update(l1: Iterable, l2: Iterable) -> list: +def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: """Returns a list containing all the elements of l1 not in l2. Examples -------- - Normal usage:: + .. code-block:: pycon - list_difference_update([1, 2, 3, 4], [2, 4]) - # returns [1, 3] + >>> list_difference_update([1, 2, 3, 4], [2, 4]) + [1, 3] """ return [e for e in l1 if e not in l2] -def list_update(l1: Iterable, l2: Iterable) -> list: +def list_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: """Used instead of ``set.update()`` to maintain order, making sure duplicates are removed from l1, not l2. Removes overlap of l1 and l2 and then concatenates l2 unchanged. Examples -------- - Normal usage:: + .. code-block:: pycon - list_update([1, 2, 3], [2, 4, 4]) - # returns [1, 3, 2, 4, 4] + >>> list_update([1, 2, 3], [2, 4, 4]) + [1, 3, 2, 4, 4] """ return [e for e in l1 if e not in l2] + list(l2) -def listify(obj) -> list: +@overload +def listify(obj: str) -> list[str]: ... + + +@overload +def listify(obj: Iterable[T]) -> list[T]: ... + + +@overload +def listify(obj: T) -> list[T]: ... + + +def listify(obj: str | Iterable[T] | T) -> list[str] | list[T]: """Converts obj to a list intelligently. Examples -------- - Normal usage:: - - listify('str') # ['str'] - listify((1, 2)) # [1, 2] - listify(len) # [] + .. code-block:: pycon + + >>> listify("str") + ['str'] + >>> listify((1, 2)) + [1, 2] + >>> listify(len) + [] """ if isinstance(obj, str): return [obj] - try: + if isinstance(obj, Iterable): return list(obj) - except TypeError: + else: return [obj] -def make_even(iterable_1: Iterable, iterable_2: Iterable) -> tuple[list, list]: +def make_even( + iterable_1: Iterable[T], iterable_2: Iterable[U] +) -> tuple[list[T], list[U]]: """Extends the shorter of the two iterables with duplicate values until its length is equal to the longer iterable (favours earlier elements). @@ -173,13 +209,13 @@ def make_even(iterable_1: Iterable, iterable_2: Iterable) -> tuple[list, list]: Examples -------- - Normal usage:: + .. code-block:: pycon - make_even([1, 2], [3, 4, 5, 6]) + >>> make_even([1, 2], [3, 4, 5, 6]) ([1, 1, 2, 2], [3, 4, 5, 6]) - make_even([1, 2], [3, 4, 5, 6, 7]) - # ([1, 1, 1, 2, 2], [3, 4, 5, 6, 7]) + >>> make_even([1, 2], [3, 4, 5, 6, 7]) + ([1, 1, 1, 2, 2], [3, 4, 5, 6, 7]) """ list_1, list_2 = list(iterable_1), list(iterable_2) len_list_1 = len(list_1) @@ -192,8 +228,8 @@ def make_even(iterable_1: Iterable, iterable_2: Iterable) -> tuple[list, list]: def make_even_by_cycling( - iterable_1: Collection, iterable_2: Collection -) -> tuple[list, list]: + iterable_1: Collection[T], iterable_2: Collection[U] +) -> tuple[list[T], list[U]]: """Extends the shorter of the two iterables with duplicate values until its length is equal to the longer iterable (cycles over shorter iterable). @@ -203,13 +239,13 @@ def make_even_by_cycling( Examples -------- - Normal usage:: + .. code-block:: pycon - make_even_by_cycling([1, 2], [3, 4, 5, 6]) + >>> make_even_by_cycling([1, 2], [3, 4, 5, 6]) ([1, 2, 1, 2], [3, 4, 5, 6]) - make_even_by_cycling([1, 2], [3, 4, 5, 6, 7]) - # ([1, 2, 1, 2, 1], [3, 4, 5, 6, 7]) + >>> make_even_by_cycling([1, 2], [3, 4, 5, 6, 7]) + ([1, 2, 1, 2, 1], [3, 4, 5, 6, 7]) """ length = max(len(iterable_1), len(iterable_2)) cycle1 = it.cycle(iterable_1) @@ -220,7 +256,7 @@ def make_even_by_cycling( ) -def remove_list_redundancies(lst: Reversible) -> list: +def remove_list_redundancies(lst: Reversible[H]) -> list[H]: """Used instead of ``list(set(l))`` to maintain order. Keeps the last occurrence of each element. """ @@ -234,21 +270,21 @@ def remove_list_redundancies(lst: Reversible) -> list: return reversed_result -def remove_nones(sequence: Iterable) -> list: +def remove_nones(sequence: Iterable[T | None]) -> list[T]: """Removes elements where bool(x) evaluates to False. Examples -------- - Normal usage:: + .. code-block:: pycon - remove_nones(['m', '', 'l', 0, 42, False, True]) - # ['m', 'l', 42, True] + >>> remove_nones(["m", "", "l", 0, 42, False, True]) + ['m', 'l', 42, True] """ # Note this is redundant with it.chain return [x for x in sequence if x] -def resize_array(nparray: np.ndarray, length: int) -> np.ndarray: +def resize_array(nparray: npt.NDArray[F], length: int) -> npt.NDArray[F]: """Extends/truncates nparray so that ``len(result) == length``. The elements of nparray are cycled to achieve the desired length. @@ -259,7 +295,7 @@ def resize_array(nparray: np.ndarray, length: int) -> np.ndarray: Examples -------- - Normal usage:: + .. code-block:: pycon >>> points = np.array([[1, 2], [3, 4]]) >>> resize_array(points, 1) @@ -277,7 +313,9 @@ def resize_array(nparray: np.ndarray, length: int) -> np.ndarray: return np.resize(nparray, (length, *nparray.shape[1:])) -def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray: +def resize_preserving_order( + nparray: npt.NDArray[np.float_], length: int +) -> npt.NDArray[np.float_]: """Extends/truncates nparray so that ``len(result) == length``. The elements of nparray are duplicated to achieve the desired length (favours earlier elements). @@ -291,21 +329,19 @@ def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray: Examples -------- - Normal usage:: - - resize_preserving_order(np.array([]), 5) - # np.array([0., 0., 0., 0., 0.]) + .. code-block:: pycon - nparray = np.array([[1, 2], - [3, 4]]) + >>> resize_preserving_order(np.array([]), 5) + array([0., 0., 0., 0., 0.]) - resize_preserving_order(nparray, 1) - # np.array([[1, 2]]) + >>> nparray = np.array([[1, 2], [3, 4]]) + >>> resize_preserving_order(nparray, 1) + array([[1, 2]]) - resize_preserving_order(nparray, 3) - # np.array([[1, 2], - # [1, 2], - # [3, 4]]) + >>> resize_preserving_order(nparray, 3) + array([[1, 2], + [1, 2], + [3, 4]]) """ if len(nparray) == 0: return np.zeros((length, *nparray.shape[1:])) @@ -315,7 +351,7 @@ def resize_preserving_order(nparray: np.ndarray, length: int) -> np.ndarray: return nparray[indices] -def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray: +def resize_with_interpolation(nparray: npt.NDArray[F], length: int) -> npt.NDArray[F]: """Extends/truncates nparray so that ``len(result) == length``. New elements are interpolated to achieve the desired length. @@ -329,39 +365,29 @@ def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray: Examples -------- - Normal usage:: - - nparray = np.array([[1, 2], - [3, 4]]) - - resize_with_interpolation(nparray, 1) - # np.array([[1., 2.]]) - - resize_with_interpolation(nparray, 4) - # np.array([[1. , 2. ], - # [1.66666667, 2.66666667], - # [2.33333333, 3.33333333], - # [3. , 4. ]]) - - nparray = np.array([[[1, 2],[3, 4]]]) - resize_with_interpolation(nparray, 3) - # np.array([[[1., 2.], [3., 4.]], - # [[1., 2.], [3., 4.]], - # [[1., 2.], [3., 4.]]]) - - nparray = np.array([[1, 2], [3, 4], [5, 6]]) - resize_with_interpolation(nparray, 4) - # np.array([[1. , 2. ], - # [2.33333333, 3.33333333], - # [3.66666667, 4.66666667], - # [5. , 6. ]]) - - nparray = np.array([[1, 2], [3, 4], [1, 2]]) - resize_with_interpolation(nparray, 4) - # np.array([[1. , 2. ], - # [2.33333333, 3.33333333], - # [2.33333333, 3.33333333], - # [1. , 2. ]]) + .. code-block:: pycon + + >>> nparray = np.array([[1, 2], [3, 4]]) + >>> resize_with_interpolation(nparray, 1) + array([[1., 2.]]) + >>> resize_with_interpolation(nparray, 4) + array([[1. , 2. ], + [1.66666667, 2.66666667], + [2.33333333, 3.33333333], + [3. , 4. ]]) + >>> nparray = np.array([[[1, 2], [3, 4]]]) + >>> nparray = np.array([[1, 2], [3, 4], [5, 6]]) + >>> resize_with_interpolation(nparray, 4) + array([[1. , 2. ], + [2.33333333, 3.33333333], + [3.66666667, 4.66666667], + [5. , 6. ]]) + >>> nparray = np.array([[1, 2], [3, 4], [1, 2]]) + >>> resize_with_interpolation(nparray, 4) + array([[1. , 2. ], + [2.33333333, 3.33333333], + [2.33333333, 3.33333333], + [1. , 2. ]]) """ if len(nparray) == length: return nparray @@ -375,7 +401,7 @@ def resize_with_interpolation(nparray: np.ndarray, length: int) -> np.ndarray: ) -def stretch_array_to_length(nparray: np.ndarray, length: int) -> np.ndarray: +def stretch_array_to_length(nparray: npt.NDArray[F], length: int) -> npt.NDArray[F]: # todo: is this the same as resize_preserving_order()? curr_len = len(nparray) if curr_len > length: @@ -385,35 +411,54 @@ def stretch_array_to_length(nparray: np.ndarray, length: int) -> np.ndarray: return nparray[indices.astype(int)] -def tuplify(obj) -> tuple: +@overload +def tuplify(obj: str) -> tuple[str]: ... + + +@overload +def tuplify(obj: Iterable[T]) -> tuple[T]: ... + + +@overload +def tuplify(obj: T) -> tuple[T]: ... + + +def tuplify(obj: str | Iterable[T] | T) -> tuple[str] | tuple[T]: """Converts obj to a tuple intelligently. Examples -------- - Normal usage:: - - tuplify('str') # ('str',) - tuplify([1, 2]) # (1, 2) - tuplify(len) # (,) + .. code-block:: pycon + + >>> tuplify("str") + ('str',) + >>> tuplify([1, 2]) + (1, 2) + >>> tuplify(len) + (,) """ if isinstance(obj, str): return (obj,) - try: + if isinstance(obj, Iterable): return tuple(obj) - except TypeError: + else: return (obj,) -def uniq_chain(*args: Iterable) -> Generator: +def uniq_chain(*args: Iterable[T]) -> Generator[T, None, None]: """Returns a generator that yields all unique elements of the Iterables provided via args in the order provided. Examples -------- - Normal usage:: - - uniq_chain([1, 2], [2, 3], [1, 4, 4]) - # yields 1, 2, 3, 4 + .. code-block:: pycon + + >>> gen = uniq_chain([1, 2], [2, 3], [1, 4, 4]) + >>> from collections.abc import Generator + >>> isinstance(gen, Generator) + True + >>> tuple(gen) + (1, 2, 3, 4) """ unique_items = set() for x in it.chain(*args): diff --git a/mypy.ini b/mypy.ini index 2bf9c170b9..956b44ae21 100644 --- a/mypy.ini +++ b/mypy.ini @@ -82,6 +82,10 @@ ignore_errors = True [mypy-manim.utils.*] ignore_errors = True +[mypy-manim.utils.iterables] +ignore_errors = False +warn_return_any = False + [mypy-manim.__main__] ignore_errors = True