From 87732527d398d138dc2a0b234d8f3776004a8c92 Mon Sep 17 00:00:00 2001 From: Henrik Skov Midtiby Date: Sun, 27 Oct 2024 22:13:58 +0100 Subject: [PATCH] Type annotation work in manim/mobject/geometry/ (#3961) Fixes typehints in manim.mobject.geometry, and enables type checking of those modules. Part of #3375 --- manim/mobject/geometry/arc.py | 139 ++++++++++-------- manim/mobject/geometry/boolean_ops.py | 27 ++-- manim/mobject/geometry/labeled.py | 19 ++- manim/mobject/geometry/line.py | 96 ++++++++----- manim/mobject/geometry/polygram.py | 58 +++++--- manim/mobject/geometry/shape_matchers.py | 19 ++- manim/mobject/geometry/tips.py | 57 ++++---- manim/mobject/mobject.py | 7 +- manim/mobject/types/image_mobject.py | 60 +++++--- manim/mobject/types/point_cloud_mobject.py | 158 +++++++++++++-------- manim/mobject/types/vectorized_mobject.py | 15 +- manim/utils/file_ops.py | 4 +- manim/utils/images.py | 12 +- mypy.ini | 3 + 14 files changed, 414 insertions(+), 260 deletions(-) diff --git a/manim/mobject/geometry/arc.py b/manim/mobject/geometry/arc.py index 996e8279f6..c211deae01 100644 --- a/manim/mobject/geometry/arc.py +++ b/manim/mobject/geometry/arc.py @@ -44,7 +44,7 @@ def construct(self): import itertools import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np from typing_extensions import Self @@ -63,11 +63,20 @@ def construct(self): ) if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Any + import manim.mobject.geometry.tips as tips from manim.mobject.mobject import Mobject from manim.mobject.text.tex_mobject import SingleStringMathTex, Tex from manim.mobject.text.text_mobject import Text - from manim.typing import CubicBezierPoints, Point3D, QuadraticBezierPoints, Vector3D + from manim.typing import ( + CubicBezierPoints, + InternalPoint3D, + Point3D, + QuadraticBezierPoints, + Vector3D, + ) class TipableVMobject(VMobject, metaclass=ConvertToOpenGL): @@ -93,7 +102,7 @@ def __init__( tip_length: float = DEFAULT_ARROW_TIP_LENGTH, normal_vector: Vector3D = OUT, tip_style: dict = {}, - **kwargs, + **kwargs: Any, ) -> None: self.tip_length: float = tip_length self.normal_vector: Vector3D = normal_vector @@ -126,10 +135,10 @@ def add_tip( def create_tip( self, tip_shape: type[tips.ArrowTip] | None = None, - tip_length: float = None, - tip_width: float = None, + tip_length: float | None = None, + tip_width: float | None = None, at_start: bool = False, - ): + ) -> tips.ArrowTip: """Stylises the tip, positions it spatially, and returns the newly instantiated tip to the caller. """ @@ -142,13 +151,13 @@ def get_unpositioned_tip( tip_shape: type[tips.ArrowTip] | None = None, tip_length: float | None = None, tip_width: float | None = None, - ): + ) -> tips.ArrowTip | tips.ArrowTriangleFilledTip: """Returns a tip that has been stylistically configured, but has not yet been given a position in space. """ from manim.mobject.geometry.tips import ArrowTriangleFilledTip - style = {} + style: dict[str, Any] = {} if tip_shape is None: tip_shape = ArrowTriangleFilledTip @@ -166,7 +175,7 @@ def get_unpositioned_tip( tip = tip_shape(length=tip_length, **style) return tip - def position_tip(self, tip: tips.ArrowTip, at_start: bool = False): + def position_tip(self, tip: tips.ArrowTip, at_start: bool = False) -> tips.ArrowTip: # Last two control points, defining both # the end, and the tangency direction if at_start: @@ -175,16 +184,18 @@ def position_tip(self, tip: tips.ArrowTip, at_start: bool = False): else: handle = self.get_last_handle() anchor = self.get_end() - angles = cartesian_to_spherical(handle - anchor) + angles = cartesian_to_spherical((handle - anchor).tolist()) tip.rotate( angles[1] - PI - tip.tip_angle, ) # Rotates the tip along the azimuthal if not hasattr(self, "_init_positioning_axis"): - axis = [ - np.sin(angles[1]), - -np.cos(angles[1]), - 0, - ] # Obtains the perpendicular of the tip + axis = np.array( + [ + np.sin(angles[1]), + -np.cos(angles[1]), + 0, + ] + ) # Obtains the perpendicular of the tip tip.rotate( -angles[2] + PI / 2, axis=axis, @@ -244,7 +255,7 @@ def get_tips(self) -> VGroup: result.add(self.start_tip) return result - def get_tip(self): + def get_tip(self) -> VMobject: """Returns the TipableVMobject instance's (first) tip, otherwise throws an exception. """ @@ -252,32 +263,36 @@ def get_tip(self): if len(tips) == 0: raise Exception("tip not found") else: - return tips[0] + tip: VMobject = tips[0] + return tip def get_default_tip_length(self) -> float: return self.tip_length - def get_first_handle(self) -> Point3D: + def get_first_handle(self) -> InternalPoint3D: + # Type inference of extracting an element from a list, is not + # supported by numpy, see this numpy issue + # https://github.com/numpy/numpy/issues/16544 return self.points[1] - def get_last_handle(self) -> Point3D: + def get_last_handle(self) -> InternalPoint3D: return self.points[-2] - def get_end(self) -> Point3D: + def get_end(self) -> InternalPoint3D: if self.has_tip(): return self.tip.get_start() else: return super().get_end() - def get_start(self) -> Point3D: + def get_start(self) -> InternalPoint3D: if self.has_start_tip(): return self.start_tip.get_start() else: return super().get_start() - def get_length(self) -> np.floating: + def get_length(self) -> float: start, end = self.get_start_and_end() - return np.linalg.norm(start - end) + return float(np.linalg.norm(start - end)) class Arc(TipableVMobject): @@ -297,20 +312,20 @@ def construct(self): def __init__( self, - radius: float = 1.0, + radius: float | None = 1.0, start_angle: float = 0, angle: float = TAU / 4, num_components: int = 9, - arc_center: Point3D = ORIGIN, - **kwargs, + arc_center: InternalPoint3D = ORIGIN, + **kwargs: Any, ): if radius is None: # apparently None is passed by ArcBetweenPoints radius = 1.0 self.radius = radius - self.num_components: int = num_components - self.arc_center: Point3D = arc_center - self.start_angle: float = start_angle - self.angle: float = angle + self.num_components = num_components + self.arc_center = arc_center + self.start_angle = start_angle + self.angle = angle self._failed_to_get_center: bool = False super().__init__(**kwargs) @@ -379,7 +394,7 @@ def _set_pre_positioned_points(self) -> None: handles2 = anchors[1:] - (d_theta / 3) * tangent_vectors[1:] self.set_anchors_and_handles(anchors[:-1], handles1, handles2, anchors[1:]) - def get_arc_center(self, warning: bool = True) -> Point3D: + def get_arc_center(self, warning: bool = True) -> InternalPoint3D: """Looks at the normals to the first two anchors, and finds their intersection points """ @@ -407,12 +422,15 @@ def get_arc_center(self, warning: bool = True) -> Point3D: self._failed_to_get_center = True return np.array(ORIGIN) - def move_arc_center_to(self, point: Point3D) -> Self: + def move_arc_center_to(self, point: InternalPoint3D) -> Self: self.shift(point - self.get_arc_center()) return self def stop_angle(self) -> float: - return angle_of_vector(self.points[-1] - self.get_arc_center()) % TAU + return cast( + float, + angle_of_vector(self.points[-1] - self.get_arc_center()) % TAU, + ) class ArcBetweenPoints(Arc): @@ -439,8 +457,8 @@ def __init__( start: Point3D, end: Point3D, angle: float = TAU / 4, - radius: float = None, - **kwargs, + radius: float | None = None, + **kwargs: Any, ) -> None: if radius is not None: self.radius = radius @@ -460,19 +478,20 @@ def __init__( super().__init__(radius=radius, angle=angle, **kwargs) if angle == 0: - self.set_points_as_corners([LEFT, RIGHT]) + self.set_points_as_corners(np.array([LEFT, RIGHT])) self.put_start_and_end_on(start, end) if radius is None: center = self.get_arc_center(warning=False) if not self._failed_to_get_center: - self.radius = np.linalg.norm(np.array(start) - np.array(center)) + temp_radius: float = np.linalg.norm(np.array(start) - np.array(center)) + self.radius = temp_radius else: self.radius = np.inf class CurvedArrow(ArcBetweenPoints): - def __init__(self, start_point: Point3D, end_point: Point3D, **kwargs) -> None: + def __init__(self, start_point: Point3D, end_point: Point3D, **kwargs: Any) -> None: from manim.mobject.geometry.tips import ArrowTriangleFilledTip tip_shape = kwargs.pop("tip_shape", ArrowTriangleFilledTip) @@ -481,7 +500,7 @@ def __init__(self, start_point: Point3D, end_point: Point3D, **kwargs) -> None: class CurvedDoubleArrow(CurvedArrow): - def __init__(self, start_point: Point3D, end_point: Point3D, **kwargs) -> None: + def __init__(self, start_point: Point3D, end_point: Point3D, **kwargs: Any) -> None: if "tip_shape_end" in kwargs: kwargs["tip_shape"] = kwargs.pop("tip_shape_end") from manim.mobject.geometry.tips import ArrowTriangleFilledTip @@ -520,7 +539,7 @@ def __init__( self, radius: float | None = None, color: ParsableManimColor = RED, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( radius=radius, @@ -617,7 +636,9 @@ def construct(self): return self.point_from_proportion(proportion) @staticmethod - def from_three_points(p1: Point3D, p2: Point3D, p3: Point3D, **kwargs) -> Self: + def from_three_points( + p1: Point3D, p2: Point3D, p3: Point3D, **kwargs: Any + ) -> Circle: """Returns a circle passing through the specified three points. @@ -637,10 +658,10 @@ def construct(self): self.add(NumberPlane(), circle, dots) """ center = line_intersection( - perpendicular_bisector([p1, p2]), - perpendicular_bisector([p2, p3]), + perpendicular_bisector([np.asarray(p1), np.asarray(p2)]), + perpendicular_bisector([np.asarray(p2), np.asarray(p3)]), ) - radius = np.linalg.norm(p1 - center) + radius: float = np.linalg.norm(p1 - center) return Circle(radius=radius, **kwargs).shift(center) @@ -682,7 +703,7 @@ def __init__( stroke_width: float = 0, fill_opacity: float = 1.0, color: ParsableManimColor = WHITE, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( arc_center=point, @@ -703,7 +724,7 @@ def __init__( stroke_width: float = 5, stroke_color: ParsableManimColor = WHITE, fill_color: ParsableManimColor = BLUE, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( radius=radius, @@ -752,12 +773,12 @@ def __init__( self, label: str | SingleStringMathTex | Text | Tex, radius: float | None = None, - **kwargs, + **kwargs: Any, ) -> None: if isinstance(label, str): from manim import MathTex - rendered_label = MathTex(label, color=BLACK) + rendered_label: VMobject = MathTex(label, color=BLACK) else: rendered_label = label @@ -793,7 +814,7 @@ def construct(self): self.add(ellipse_group) """ - def __init__(self, width: float = 2, height: float = 1, **kwargs) -> None: + def __init__(self, width: float = 2, height: float = 1, **kwargs: Any) -> None: super().__init__(**kwargs) self.stretch_to_fit_width(width) self.stretch_to_fit_height(height) @@ -854,7 +875,7 @@ def __init__( fill_opacity: float = 1, stroke_width: float = 0, color: ParsableManimColor = WHITE, - **kwargs, + **kwargs: Any, ) -> None: self.inner_radius = inner_radius self.outer_radius = outer_radius @@ -903,7 +924,7 @@ def construct(self): self.add(sector, sector2) """ - def __init__(self, radius: float = 1, **kwargs) -> None: + def __init__(self, radius: float = 1, **kwargs: Any) -> None: super().__init__(inner_radius=0, outer_radius=radius, **kwargs) @@ -933,13 +954,13 @@ def construct(self): def __init__( self, - inner_radius: float | None = 1, - outer_radius: float | None = 2, + inner_radius: float = 1, + outer_radius: float = 2, fill_opacity: float = 1, stroke_width: float = 0, color: ParsableManimColor = WHITE, mark_paths_closed: bool = False, - **kwargs, + **kwargs: Any, ) -> None: self.mark_paths_closed = mark_paths_closed # is this even used? self.inner_radius = inner_radius @@ -989,7 +1010,7 @@ def __init__( start_handle: CubicBezierPoints, end_handle: CubicBezierPoints, end_anchor: CubicBezierPoints, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.add_cubic_bezier_curve(start_anchor, start_handle, end_handle, end_anchor) @@ -1080,14 +1101,16 @@ def __init__( angle: float = PI / 4, radius: float | None = None, arc_config: list[dict] | None = None, - **kwargs, + **kwargs: Any, ) -> None: n = len(vertices) point_pairs = [(vertices[k], vertices[(k + 1) % n]) for k in range(n)] if not arc_config: if radius: - all_arc_configs = itertools.repeat({"radius": radius}, len(point_pairs)) + all_arc_configs: Iterable[dict] = itertools.repeat( + {"radius": radius}, len(point_pairs) + ) else: all_arc_configs = itertools.repeat({"angle": angle}, len(point_pairs)) elif isinstance(arc_config, dict): @@ -1219,7 +1242,7 @@ def construct(self): self.wait(2) """ - def __init__(self, *arcs: Arc | ArcBetweenPoints, **kwargs) -> None: + def __init__(self, *arcs: Arc | ArcBetweenPoints, **kwargs: Any) -> None: if not all(isinstance(m, (Arc, ArcBetweenPoints)) for m in arcs): raise ValueError( "All ArcPolygon submobjects must be of type Arc/ArcBetweenPoints", diff --git a/manim/mobject/geometry/boolean_ops.py b/manim/mobject/geometry/boolean_ops.py index baef2c3fa1..a34d6fc7c4 100644 --- a/manim/mobject/geometry/boolean_ops.py +++ b/manim/mobject/geometry/boolean_ops.py @@ -13,7 +13,9 @@ from manim.mobject.types.vectorized_mobject import VMobject if TYPE_CHECKING: - from manim.typing import Point2D_Array, Point3D_Array + from typing import Any + + from manim.typing import InternalPoint3D_Array, Point2D_Array from ...constants import RendererType @@ -30,7 +32,7 @@ def _convert_2d_to_3d_array( self, points: Point2D_Array, z_dim: float = 0.0, - ) -> Point3D_Array: + ) -> InternalPoint3D_Array: """Converts an iterable with coordinates in 2D to 3D by adding :attr:`z_dim` as the Z coordinate. @@ -51,13 +53,14 @@ def _convert_2d_to_3d_array( >>> a = _BooleanOps() >>> p = [(1, 2), (3, 4)] >>> a._convert_2d_to_3d_array(p) - [array([1., 2., 0.]), array([3., 4., 0.])] + array([[1., 2., 0.], + [3., 4., 0.]]) """ - points = list(points) - for i, point in enumerate(points): + list_of_points = list(points) + for i, point in enumerate(list_of_points): if len(point) == 2: - points[i] = np.array(list(point) + [z_dim]) - return points + list_of_points[i] = np.array(list(point) + [z_dim]) + return np.asarray(list_of_points) def _convert_vmobject_to_skia_path(self, vmobject: VMobject) -> SkiaPath: """Converts a :class:`~.VMobject` to SkiaPath. This method only works for @@ -95,7 +98,7 @@ def _convert_vmobject_to_skia_path(self, vmobject: VMobject) -> SkiaPath: if vmobject.consider_points_equals(subpath[0], subpath[-1]): path.close() elif config.renderer == RendererType.CAIRO: - subpaths = vmobject.gen_subpaths_from_points_2d(points) + subpaths = vmobject.gen_subpaths_from_points_2d(points) # type: ignore[assignment] for subpath in subpaths: quads = vmobject.gen_cubic_bezier_tuples_from_points(subpath) start = subpath[0] @@ -177,7 +180,7 @@ def construct(self): """ - def __init__(self, *vmobjects: VMobject, **kwargs) -> None: + def __init__(self, *vmobjects: VMobject, **kwargs: Any) -> None: if len(vmobjects) < 2: raise ValueError("At least 2 mobjects needed for Union.") super().__init__(**kwargs) @@ -216,7 +219,7 @@ def construct(self): """ - def __init__(self, subject: VMobject, clip: VMobject, **kwargs) -> None: + def __init__(self, subject: VMobject, clip: VMobject, **kwargs: Any) -> None: super().__init__(**kwargs) outpen = SkiaPath() difference( @@ -258,7 +261,7 @@ def construct(self): """ - def __init__(self, *vmobjects: VMobject, **kwargs) -> None: + def __init__(self, *vmobjects: VMobject, **kwargs: Any) -> None: if len(vmobjects) < 2: raise ValueError("At least 2 mobjects needed for Intersection.") @@ -311,7 +314,7 @@ def construct(self): """ - def __init__(self, subject: VMobject, clip: VMobject, **kwargs) -> None: + def __init__(self, subject: VMobject, clip: VMobject, **kwargs: Any) -> None: super().__init__(**kwargs) outpen = SkiaPath() xor( diff --git a/manim/mobject/geometry/labeled.py b/manim/mobject/geometry/labeled.py index 1a39ee2771..371870b8a7 100644 --- a/manim/mobject/geometry/labeled.py +++ b/manim/mobject/geometry/labeled.py @@ -4,6 +4,8 @@ __all__ = ["LabeledLine", "LabeledArrow"] +from typing import TYPE_CHECKING + from manim.constants import * from manim.mobject.geometry.line import Arrow, Line from manim.mobject.geometry.shape_matchers import ( @@ -14,6 +16,9 @@ from manim.mobject.text.text_mobject import Text from manim.utils.color import WHITE, ManimColor, ParsableManimColor +if TYPE_CHECKING: + from typing import Any + class LabeledLine(Line): """Constructs a line containing a label box somewhere along its length. @@ -67,17 +72,19 @@ def __init__( font_size: float = DEFAULT_FONT_SIZE, label_color: ParsableManimColor = WHITE, label_frame: bool = True, - frame_fill_color: ParsableManimColor = None, + frame_fill_color: ParsableManimColor | None = None, frame_fill_opacity: float = 1, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: label_color = ManimColor(label_color) frame_fill_color = ManimColor(frame_fill_color) if isinstance(label, str): from manim import MathTex - rendered_label = MathTex(label, color=label_color, font_size=font_size) + rendered_label: Tex | MathTex | Text = MathTex( + label, color=label_color, font_size=font_size + ) else: rendered_label = label @@ -149,7 +156,7 @@ def construct(self): def __init__( self, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) diff --git a/manim/mobject/geometry/line.py b/manim/mobject/geometry/line.py index ab3a338517..ef0e4aa1ff 100644 --- a/manim/mobject/geometry/line.py +++ b/manim/mobject/geometry/line.py @@ -30,9 +30,11 @@ from manim.utils.space_ops import angle_of_vector, line_intersection, normalize if TYPE_CHECKING: + from typing import Any + from typing_extensions import Self - from manim.typing import Point2D, Point3D, Vector3D + from manim.typing import InternalPoint3D, Point2D, Point3D, Vector3D from manim.utils.color import ParsableManimColor from ..matrix import Matrix # Avoid circular import @@ -45,20 +47,21 @@ def __init__( end: Point3D | Mobject = RIGHT, buff: float = 0, path_arc: float | None = None, - **kwargs, + **kwargs: Any, ) -> None: self.dim = 3 self.buff = buff self.path_arc = path_arc self._set_start_and_end_attrs(start, end) super().__init__(**kwargs) + # TODO: Deal with the situation where path_arc is None def generate_points(self) -> None: self.set_points_by_ends( start=self.start, end=self.end, buff=self.buff, - path_arc=self.path_arc, + path_arc=self.path_arc, # type: ignore[arg-type] ) def set_points_by_ends( @@ -85,16 +88,19 @@ def set_points_by_ends( """ self._set_start_and_end_attrs(start, end) if path_arc: + # self.path_arc could potentially be None, which is not accepted + # as parameter. + assert self.path_arc is not None arc = ArcBetweenPoints(self.start, self.end, angle=self.path_arc) self.set_points(arc.points) else: - self.set_points_as_corners([self.start, self.end]) + self.set_points_as_corners(np.asarray([self.start, self.end])) self._account_for_buff(buff) init_points = generate_points - def _account_for_buff(self, buff: float) -> Self: + def _account_for_buff(self, buff: float) -> None: if buff == 0: return # @@ -104,7 +110,7 @@ def _account_for_buff(self, buff: float) -> Self: return buff_proportion = buff / length self.pointwise_become_partial(self, buff_proportion, 1 - buff_proportion) - return self + return def _set_start_and_end_attrs( self, start: Point3D | Mobject, end: Point3D | Mobject @@ -124,7 +130,7 @@ def _pointify( self, mob_or_point: Mobject | Point3D, direction: Vector3D | None = None, - ) -> Point3D: + ) -> InternalPoint3D: """Transforms a mobject into its corresponding point. Does nothing if a point is passed. ``direction`` determines the location of the point along its bounding box in that direction. @@ -148,7 +154,11 @@ def set_path_arc(self, new_value: float) -> None: self.path_arc = new_value self.init_points() - def put_start_and_end_on(self, start: Point3D, end: Point3D) -> Self: + def put_start_and_end_on( + self, + start: InternalPoint3D, + end: InternalPoint3D, + ) -> Self: """Sets starts and end coordinates of a line. Examples @@ -188,7 +198,7 @@ def get_unit_vector(self) -> Vector3D: def get_angle(self) -> float: return angle_of_vector(self.get_vector()) - def get_projection(self, point: Point3D) -> Vector3D: + def get_projection(self, point: InternalPoint3D) -> Vector3D: """Returns the projection of a point onto a line. Parameters @@ -199,10 +209,10 @@ def get_projection(self, point: Point3D) -> Vector3D: start = self.get_start() end = self.get_end() unit_vect = normalize(end - start) - return start + np.dot(point - start, unit_vect) * unit_vect + return start + float(np.dot(point - start, unit_vect)) * unit_vect def get_slope(self) -> float: - return np.tan(self.get_angle()) + return float(np.tan(self.get_angle())) def set_angle(self, angle: float, about_point: Point3D | None = None) -> Self: if about_point is None: @@ -216,7 +226,8 @@ def set_angle(self, angle: float, about_point: Point3D | None = None) -> Self: return self def set_length(self, length: float) -> Self: - return self.scale(length / self.get_length()) + scale_factor: float = length / self.get_length() + return self.scale(scale_factor) class DashedLine(Line): @@ -255,10 +266,10 @@ def construct(self): def __init__( self, - *args, + *args: Any, dash_length: float = DEFAULT_DASH_LENGTH, dashed_ratio: float = 0.5, - **kwargs, + **kwargs: Any, ) -> None: self.dash_length = dash_length self.dashed_ratio = dashed_ratio @@ -287,7 +298,7 @@ def _calculate_num_dashes(self) -> int: int(np.ceil((self.get_length() / self.dash_length) * self.dashed_ratio)), ) - def get_start(self) -> Point3D: + def get_start(self) -> InternalPoint3D: """Returns the start point of the line. Examples @@ -302,7 +313,7 @@ def get_start(self) -> Point3D: else: return super().get_start() - def get_end(self) -> Point3D: + def get_end(self) -> InternalPoint3D: """Returns the end point of the line. Examples @@ -317,7 +328,7 @@ def get_end(self) -> Point3D: else: return super().get_end() - def get_first_handle(self) -> Point3D: + def get_first_handle(self) -> InternalPoint3D: """Returns the point of the first handle. Examples @@ -327,9 +338,12 @@ def get_first_handle(self) -> Point3D: >>> DashedLine().get_first_handle() array([-0.98333333, 0. , 0. ]) """ + # Type inference of extracting an element from a list, is not + # supported by numpy, see this numpy issue + # https://github.com/numpy/numpy/issues/16544 return self.submobjects[0].points[1] - def get_last_handle(self) -> Point3D: + def get_last_handle(self) -> InternalPoint3D: """Returns the point of the last handle. Examples @@ -339,6 +353,9 @@ def get_last_handle(self) -> Point3D: >>> DashedLine().get_last_handle() array([0.98333333, 0. , 0. ]) """ + # Type inference of extracting an element from a list, is not + # supported by numpy, see this numpy issue + # https://github.com/numpy/numpy/issues/16544 return self.submobjects[-1].points[-2] @@ -381,7 +398,7 @@ def __init__( alpha: float, length: float = 1, d_alpha: float = 1e-6, - **kwargs, + **kwargs: Any, ) -> None: self.length = length self.d_alpha = d_alpha @@ -424,10 +441,10 @@ def construct(self): self.add(elbow_group) """ - def __init__(self, width: float = 0.2, angle: float = 0, **kwargs) -> None: + def __init__(self, width: float = 0.2, angle: float = 0, **kwargs: Any) -> None: self.angle = angle super().__init__(**kwargs) - self.set_points_as_corners([UP, UP + RIGHT, RIGHT]) + self.set_points_as_corners(np.array([UP, UP + RIGHT, RIGHT])) self.scale_to_fit_width(width, about_point=ORIGIN) self.rotate(self.angle, about_point=ORIGIN) @@ -522,24 +539,24 @@ def construct(self): def __init__( self, - *args, + *args: Any, stroke_width: float = 6, buff: float = MED_SMALL_BUFF, max_tip_length_to_length_ratio: float = 0.25, max_stroke_width_to_length_ratio: float = 5, - **kwargs, + **kwargs: Any, ) -> None: self.max_tip_length_to_length_ratio = max_tip_length_to_length_ratio self.max_stroke_width_to_length_ratio = max_stroke_width_to_length_ratio tip_shape = kwargs.pop("tip_shape", ArrowTriangleFilledTip) - super().__init__(*args, buff=buff, stroke_width=stroke_width, **kwargs) + super().__init__(*args, buff=buff, stroke_width=stroke_width, **kwargs) # type: ignore[misc] # TODO, should this be affected when # Arrow.set_stroke is called? self.initial_stroke_width = self.stroke_width self.add_tip(tip_shape=tip_shape) self._set_stroke_width_from_length() - def scale(self, factor: float, scale_tips: bool = False, **kwargs) -> Self: + def scale(self, factor: float, scale_tips: bool = False, **kwargs: Any) -> Self: # type: ignore[override] r"""Scale an arrow, but keep stroke width and arrow tip size fixed. @@ -625,7 +642,11 @@ def _set_stroke_width_from_length(self) -> Self: """Sets stroke width based on length.""" max_ratio = self.max_stroke_width_to_length_ratio if config.renderer == RendererType.OPENGL: - self.set_stroke( + # Mypy does not recognize that the self object in this case + # is a OpenGLVMobject and that the set_stroke method is + # defined here: + # mobject/opengl/opengl_vectorized_mobject.py#L248 + self.set_stroke( # type: ignore[call-arg] width=min(self.initial_stroke_width, max_ratio * self.get_length()), recurse=False, ) @@ -668,7 +689,10 @@ def construct(self): """ def __init__( - self, direction: Point2D | Point3D = RIGHT, buff: float = 0, **kwargs + self, + direction: Point2D | Point3D = RIGHT, + buff: float = 0, + **kwargs: Any, ) -> None: self.buff = buff if len(direction) == 2: @@ -681,7 +705,7 @@ def coordinate_label( integer_labels: bool = True, n_dim: int = 2, color: ParsableManimColor | None = None, - **kwargs, + **kwargs: Any, ) -> Matrix: """Creates a label based on the coordinates of the vector. @@ -784,7 +808,7 @@ def construct(self): self.add(box, d1, d2, d3) """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: if "tip_shape_end" in kwargs: kwargs["tip_shape"] = kwargs.pop("tip_shape_end") tip_shape_start = kwargs.pop("tip_shape_start", ArrowTriangleFilledTip) @@ -913,7 +937,7 @@ def __init__( dot_distance: float = 0.55, dot_color: ParsableManimColor = WHITE, elbow: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.lines = (line1, line2) @@ -950,9 +974,9 @@ def __init__( + quadrant[0] * radius * line1.get_unit_vector() + quadrant[1] * radius * line2.get_unit_vector() ) - angle_mobject = Elbow(**kwargs) + angle_mobject: VMobject = Elbow(**kwargs) angle_mobject.set_points_as_corners( - [anchor_angle_1, anchor_middle, anchor_angle_2], + np.array([anchor_angle_1, anchor_middle, anchor_angle_2]), ) else: angle_1 = angle_of_vector(anchor_angle_1 - inter) @@ -1052,7 +1076,7 @@ def construct(self): return self.angle_value / DEGREES if degrees else self.angle_value @staticmethod - def from_three_points(A: Point3D, B: Point3D, C: Point3D, **kwargs) -> Angle: + def from_three_points(A: Point3D, B: Point3D, C: Point3D, **kwargs: Any) -> Angle: r"""The angle between the lines AB and BC. This constructs the angle :math:`\\angle ABC`. @@ -1128,6 +1152,10 @@ def construct(self): """ def __init__( - self, line1: Line, line2: Line, length: float | None = None, **kwargs + self, + line1: Line, + line2: Line, + length: float | None = None, + **kwargs: Any, ) -> None: super().__init__(line1, line2, radius=length, elbow=True, **kwargs) diff --git a/manim/mobject/geometry/polygram.py b/manim/mobject/geometry/polygram.py index 78f54cf87a..bfb9f00ab7 100644 --- a/manim/mobject/geometry/polygram.py +++ b/manim/mobject/geometry/polygram.py @@ -30,9 +30,16 @@ from manim.utils.space_ops import angle_between_vectors, normalize, regular_vertices if TYPE_CHECKING: + from typing import Any, Literal + + import numpy.typing as npt from typing_extensions import Self - from manim.typing import Point3D, Point3D_Array + from manim.typing import ( + InternalPoint3D, + InternalPoint3D_Array, + Point3D, + ) from manim.utils.color import ParsableManimColor @@ -72,11 +79,16 @@ def construct(self): """ def __init__( - self, *vertex_groups: Point3D, color: ParsableManimColor = BLUE, **kwargs + self, + *vertex_groups: Point3D, + color: ParsableManimColor = BLUE, + **kwargs: Any, ): super().__init__(color=color, **kwargs) for vertices in vertex_groups: + # The inferred type for *vertices is Any, but it should be + # InternalPoint3D_Array first_vertex, *vertices = vertices first_vertex = np.array(first_vertex) @@ -85,7 +97,7 @@ def __init__( [*(np.array(vertex) for vertex in vertices), first_vertex], ) - def get_vertices(self) -> Point3D_Array: + def get_vertices(self) -> InternalPoint3D_Array: """Gets the vertices of the :class:`Polygram`. Returns @@ -106,7 +118,7 @@ def get_vertices(self) -> Point3D_Array: """ return self.get_start_anchors() - def get_vertex_groups(self) -> np.ndarray[Point3D_Array]: + def get_vertex_groups(self) -> InternalPoint3D_Array: """Gets the vertex groups of the :class:`Polygram`. Returns @@ -205,7 +217,7 @@ def construct(self): if radius == 0: return self - new_points = [] + new_points: list[InternalPoint3D] = [] for vertices in self.get_vertex_groups(): arcs = [] @@ -274,7 +286,7 @@ def construct(self): new_points.extend(line.points) - self.set_points(new_points) + self.set_points(np.array(new_points)) return self @@ -309,7 +321,7 @@ def construct(self): self.add(isosceles, square_and_triangles) """ - def __init__(self, *vertices: Point3D, **kwargs) -> None: + def __init__(self, *vertices: InternalPoint3D, **kwargs: Any) -> None: super().__init__(vertices, **kwargs) @@ -352,7 +364,7 @@ def __init__( density: int = 2, radius: float = 1, start_angle: float | None = None, - **kwargs, + **kwargs: Any, ) -> None: # Regular polygrams can be expressed by the number of their vertices # and their density. This relation can be expressed as its Schläfli @@ -373,7 +385,7 @@ def __init__( # Utility function for generating the individual # polygon vertices. - def gen_polygon_vertices(start_angle): + def gen_polygon_vertices(start_angle: float | None) -> tuple[list[Any], float]: reg_vertices, start_angle = regular_vertices( num_vertices, radius=radius, @@ -429,7 +441,7 @@ def construct(self): self.add(poly_group) """ - def __init__(self, n: int = 6, **kwargs) -> None: + def __init__(self, n: int = 6, **kwargs: Any) -> None: super().__init__(n, density=1, **kwargs) @@ -499,7 +511,7 @@ def __init__( inner_radius: float | None = None, density: int = 2, start_angle: float | None = TAU / 4, - **kwargs, + **kwargs: Any, ) -> None: inner_angle = TAU / (2 * n) @@ -531,7 +543,7 @@ def __init__( start_angle=self.start_angle + inner_angle, ) - vertices = [] + vertices: list[npt.NDArray] = [] for pair in zip(outer_vertices, inner_vertices): vertices.extend(pair) @@ -559,7 +571,7 @@ def construct(self): self.add(tri_group) """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(n=3, **kwargs) @@ -610,7 +622,7 @@ def __init__( grid_ystep: float | None = None, mark_paths_closed: bool = True, close_new_points: bool = True, - **kwargs, + **kwargs: Any, ): super().__init__(UR, UL, DL, DR, color=color, **kwargs) self.stretch_to_fit_width(width) @@ -681,15 +693,15 @@ def construct(self): self.add(square_1, square_2, square_3) """ - def __init__(self, side_length: float = 2.0, **kwargs) -> None: + def __init__(self, side_length: float = 2.0, **kwargs: Any) -> None: super().__init__(height=side_length, width=side_length, **kwargs) @property - def side_length(self): - return np.linalg.norm(self.get_vertices()[0] - self.get_vertices()[1]) + def side_length(self) -> float: + return float(np.linalg.norm(self.get_vertices()[0] - self.get_vertices()[1])) @side_length.setter - def side_length(self, value): + def side_length(self, value: float) -> None: self.scale(value / self.side_length) @@ -717,7 +729,7 @@ def construct(self): self.add(rect_group) """ - def __init__(self, corner_radius: float | list[float] = 0.5, **kwargs): + def __init__(self, corner_radius: float | list[float] = 0.5, **kwargs: Any): super().__init__(**kwargs) self.corner_radius = corner_radius self.round_corners(self.corner_radius) @@ -758,9 +770,13 @@ def construct(self): self.wait() """ - def __init__(self, main_shape: VMobject, *mobjects: VMobject, **kwargs) -> None: + def __init__( + self, main_shape: VMobject, *mobjects: VMobject, **kwargs: Any + ) -> None: super().__init__(**kwargs) self.append_points(main_shape.points) - sub_direction = "CCW" if main_shape.get_direction() == "CW" else "CW" + sub_direction: Literal["CCW", "CW"] = ( + "CCW" if main_shape.get_direction() == "CW" else "CW" + ) for mobject in mobjects: self.append_points(mobject.force_direction(sub_direction).points) diff --git a/manim/mobject/geometry/shape_matchers.py b/manim/mobject/geometry/shape_matchers.py index 296d9b9d9f..b546dfb4f3 100644 --- a/manim/mobject/geometry/shape_matchers.py +++ b/manim/mobject/geometry/shape_matchers.py @@ -47,7 +47,7 @@ def __init__( color: ParsableManimColor = YELLOW, buff: float = SMALL_BUFF, corner_radius: float = 0.0, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( color=color, @@ -93,8 +93,8 @@ def __init__( stroke_opacity: float = 0, fill_opacity: float = 0.75, buff: float = 0, - **kwargs, - ): + **kwargs: Any, + ) -> None: if color is None: color = config.background_color @@ -113,7 +113,7 @@ def pointwise_become_partial(self, mobject: Mobject, a: Any, b: float) -> Self: self.set_fill(opacity=b * self.original_fill_opacity) return self - def set_style(self, fill_opacity: float, **kwargs) -> Self: + def set_style(self, fill_opacity: float, **kwargs: Any) -> Self: # type: ignore[override] # Unchangeable style, except for fill_opacity # All other style arguments are ignored super().set_style( @@ -130,7 +130,10 @@ def set_style(self, fill_opacity: float, **kwargs) -> Self: return self def get_fill_color(self) -> ManimColor: - return self.color + # The type of the color property is set to Any using the property decorator + # vectorized_mobject.py#L571 + temp_color: ManimColor = self.color + return temp_color class Cross(VGroup): @@ -164,7 +167,7 @@ def __init__( stroke_color: ParsableManimColor = RED, stroke_width: float = 6.0, scale_factor: float = 1.0, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( Line(UP + LEFT, DOWN + RIGHT), Line(UP + RIGHT, DOWN + LEFT), **kwargs @@ -190,7 +193,9 @@ def construct(self): self.add(man, ul) """ - def __init__(self, mobject: Mobject, buff: float = SMALL_BUFF, **kwargs) -> None: + def __init__( + self, mobject: Mobject, buff: float = SMALL_BUFF, **kwargs: Any + ) -> None: super().__init__(LEFT, RIGHT, buff=buff, **kwargs) self.match_width(mobject) self.next_to(mobject, DOWN, buff=self.buff) diff --git a/manim/mobject/geometry/tips.py b/manim/mobject/geometry/tips.py index a7f116d3bf..e137016a88 100644 --- a/manim/mobject/geometry/tips.py +++ b/manim/mobject/geometry/tips.py @@ -25,7 +25,9 @@ from manim.utils.space_ops import angle_of_vector if TYPE_CHECKING: - from manim.typing import Point3D, Vector3D + from typing import Any + + from manim.typing import InternalPoint3D, Point3D, Vector3D class ArrowTip(VMobject, metaclass=ConvertToOpenGL): @@ -112,7 +114,7 @@ def construct(self): self.add(*big_arrows, *small_arrows, *labels) """ - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("Has to be implemented in inheriting subclasses.") @property @@ -134,7 +136,7 @@ def base(self) -> Point3D: return self.point_from_proportion(0.5) @property - def tip_point(self) -> Point3D: + def tip_point(self) -> InternalPoint3D: r"""The tip point of the arrow tip. Examples @@ -147,6 +149,9 @@ def tip_point(self) -> Point3D: array([2., 0., 0.]) """ + # Type inference of extracting an element from a list, is not + # supported by numpy, see this numpy issue + # https://github.com/numpy/numpy/issues/16544 return self.points[0] @property @@ -182,7 +187,7 @@ def tip_angle(self) -> float: return angle_of_vector(self.vector) @property - def length(self) -> np.floating: + def length(self) -> float: r"""The length of the arrow tip. Examples @@ -195,7 +200,7 @@ def length(self) -> np.floating: 0.35 """ - return np.linalg.norm(self.vector) + return float(np.linalg.norm(self.vector)) class StealthTip(ArrowTip): @@ -207,36 +212,38 @@ class StealthTip(ArrowTip): def __init__( self, - fill_opacity=1, - stroke_width=3, - length=DEFAULT_ARROW_TIP_LENGTH / 2, - start_angle=PI, - **kwargs, + fill_opacity: float = 1, + stroke_width: float = 3, + length: float = DEFAULT_ARROW_TIP_LENGTH / 2, + start_angle: float = PI, + **kwargs: Any, ): self.start_angle = start_angle VMobject.__init__( self, fill_opacity=fill_opacity, stroke_width=stroke_width, **kwargs ) self.set_points_as_corners( - [ - [2, 0, 0], # tip - [-1.2, 1.6, 0], - [0, 0, 0], # base - [-1.2, -1.6, 0], - [2, 0, 0], # close path, back to tip - ] + np.array( + [ + [2, 0, 0], # tip + [-1.2, 1.6, 0], + [0, 0, 0], # base + [-1.2, -1.6, 0], + [2, 0, 0], # close path, back to tip + ] + ) ) self.scale(length / self.length) @property - def length(self): + def length(self) -> float: """The length of the arrow tip. In this case, the length is computed as the height of the triangle encompassing the stealth tip (otherwise, the tip is scaled too large). """ - return np.linalg.norm(self.vector) * 1.6 + return float(np.linalg.norm(self.vector) * 1.6) class ArrowTriangleTip(ArrowTip, Triangle): @@ -249,7 +256,7 @@ def __init__( length: float = DEFAULT_ARROW_TIP_LENGTH, width: float = DEFAULT_ARROW_TIP_LENGTH, start_angle: float = PI, - **kwargs, + **kwargs: Any, ) -> None: Triangle.__init__( self, @@ -271,7 +278,7 @@ class ArrowTriangleFilledTip(ArrowTriangleTip): """ def __init__( - self, fill_opacity: float = 1, stroke_width: float = 0, **kwargs + self, fill_opacity: float = 1, stroke_width: float = 0, **kwargs: Any ) -> None: super().__init__(fill_opacity=fill_opacity, stroke_width=stroke_width, **kwargs) @@ -285,7 +292,7 @@ def __init__( stroke_width: float = 3, length: float = DEFAULT_ARROW_TIP_LENGTH, start_angle: float = PI, - **kwargs, + **kwargs: Any, ) -> None: self.start_angle = start_angle Circle.__init__( @@ -299,7 +306,7 @@ class ArrowCircleFilledTip(ArrowCircleTip): r"""Circular arrow tip with filled tip.""" def __init__( - self, fill_opacity: float = 1, stroke_width: float = 0, **kwargs + self, fill_opacity: float = 1, stroke_width: float = 0, **kwargs: Any ) -> None: super().__init__(fill_opacity=fill_opacity, stroke_width=stroke_width, **kwargs) @@ -313,7 +320,7 @@ def __init__( stroke_width: float = 3, length: float = DEFAULT_ARROW_TIP_LENGTH, start_angle: float = PI, - **kwargs, + **kwargs: Any, ) -> None: self.start_angle = start_angle Square.__init__( @@ -331,6 +338,6 @@ class ArrowSquareFilledTip(ArrowSquareTip): r"""Square arrow tip with filled tip.""" def __init__( - self, fill_opacity: float = 1, stroke_width: float = 0, **kwargs + self, fill_opacity: float = 1, stroke_width: float = 0, **kwargs: Any ) -> None: super().__init__(fill_opacity=fill_opacity, stroke_width=stroke_width, **kwargs) diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 2079de7923..0359f66045 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -44,6 +44,7 @@ from manim.typing import ( FunctionOverride, + InternalPoint3D, ManimFloat, ManimInt, MappingFunction, @@ -2158,17 +2159,17 @@ def get_z(self, direction: Vector3D = ORIGIN) -> ManimFloat: """Returns z Point3D of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(2, direction) - def get_start(self) -> Point3D: + def get_start(self) -> InternalPoint3D: """Returns the point, where the stroke that surrounds the :class:`~.Mobject` starts.""" self.throw_error_if_no_points() return np.array(self.points[0]) - def get_end(self) -> Point3D: + def get_end(self) -> InternalPoint3D: """Returns the point, where the stroke that surrounds the :class:`~.Mobject` ends.""" self.throw_error_if_no_points() return np.array(self.points[-1]) - def get_start_and_end(self) -> tuple[Point3D, Point3D]: + def get_start_and_end(self) -> tuple[InternalPoint3D, InternalPoint3D]: """Returns starting and ending point of a stroke as a ``tuple``.""" return self.get_start(), self.get_end() diff --git a/manim/mobject/types/image_mobject.py b/manim/mobject/types/image_mobject.py index db0304b502..56029f941e 100644 --- a/manim/mobject/types/image_mobject.py +++ b/manim/mobject/types/image_mobject.py @@ -5,6 +5,7 @@ __all__ = ["AbstractImageMobject", "ImageMobject", "ImageMobjectFromCamera"] import pathlib +from typing import TYPE_CHECKING import numpy as np from PIL import Image @@ -21,6 +22,14 @@ __all__ = ["ImageMobject", "ImageMobjectFromCamera"] +if TYPE_CHECKING: + from typing import Any + + import numpy.typing as npt + from typing_extensions import Self + + from manim.typing import StrPath + class AbstractImageMobject(Mobject): """ @@ -39,23 +48,23 @@ class AbstractImageMobject(Mobject): def __init__( self, scale_to_resolution: int, - pixel_array_dtype="uint8", - resampling_algorithm=Resampling.BICUBIC, - **kwargs, - ): + pixel_array_dtype: str = "uint8", + resampling_algorithm: Resampling = Resampling.BICUBIC, + **kwargs: Any, + ) -> None: self.pixel_array_dtype = pixel_array_dtype self.scale_to_resolution = scale_to_resolution self.set_resampling_algorithm(resampling_algorithm) super().__init__(**kwargs) - def get_pixel_array(self): + def get_pixel_array(self) -> None: raise NotImplementedError() def set_color(self, color, alpha=None, family=True): # Likely to be implemented in subclasses, but no obligation pass - def set_resampling_algorithm(self, resampling_algorithm: int): + def set_resampling_algorithm(self, resampling_algorithm: int) -> Self: """ Sets the interpolation method for upscaling the image. By default the image is interpolated using bicubic algorithm. This method lets you change it. @@ -87,7 +96,7 @@ def set_resampling_algorithm(self, resampling_algorithm: int): ) return self - def reset_points(self): + def reset_points(self) -> None: """Sets :attr:`points` to be the four image corners.""" self.points = np.array( [ @@ -171,15 +180,15 @@ def construct(self): def __init__( self, - filename_or_array, + filename_or_array: StrPath | npt.NDArray, scale_to_resolution: int = QUALITIES[DEFAULT_QUALITY]["pixel_height"], - invert=False, - image_mode="RGBA", - **kwargs, - ): - self.fill_opacity = 1 - self.stroke_opacity = 1 - self.invert = invert + invert: bool = False, + image_mode: str = "RGBA", + **kwargs: Any, + ) -> None: + self.fill_opacity: float = 1 + self.stroke_opacity: float = 1 + self.invert_image = invert self.image_mode = image_mode if isinstance(filename_or_array, (str, pathlib.PurePath)): path = get_full_raster_image_path(filename_or_array) @@ -192,7 +201,7 @@ def __init__( self.pixel_array = change_to_rgba_array( self.pixel_array, self.pixel_array_dtype ) - if self.invert: + if self.invert_image: self.pixel_array[:, :, :3] = ( np.iinfo(self.pixel_array_dtype).max - self.pixel_array[:, :, :3] ) @@ -212,7 +221,7 @@ def set_color(self, color, alpha=None, family=True): self.color = color return self - def set_opacity(self, alpha: float): + def set_opacity(self, alpha: float) -> Self: """Sets the image's opacity. Parameters @@ -226,7 +235,7 @@ def set_opacity(self, alpha: float): self.stroke_opacity = alpha return self - def fade(self, darkness: float = 0.5, family: bool = True): + def fade(self, darkness: float = 0.5, family: bool = True) -> Self: """Sets the image's opacity using a 1 - alpha relationship. Parameters @@ -243,7 +252,7 @@ def fade(self, darkness: float = 0.5, family: bool = True): def interpolate_color( self, mobject1: ImageMobject, mobject2: ImageMobject, alpha: float - ): + ) -> None: """Interpolates the array of pixel color values from one ImageMobject into an array of equal size in the target ImageMobject. @@ -279,7 +288,7 @@ def interpolate_color( alpha, ).astype(self.pixel_array_dtype) - def get_style(self): + def get_style(self) -> dict[str, Any]: return { "fill_color": ManimColor(self.color.get_rgb()).to_hex(), "fill_opacity": self.fill_opacity, @@ -292,7 +301,12 @@ def get_style(self): class ImageMobjectFromCamera(AbstractImageMobject): - def __init__(self, camera, default_display_frame_config=None, **kwargs): + def __init__( + self, + camera, + default_display_frame_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: self.camera = camera if default_display_frame_config is None: default_display_frame_config = { @@ -309,14 +323,14 @@ def get_pixel_array(self): self.pixel_array = self.camera.pixel_array return self.pixel_array - def add_display_frame(self, **kwargs): + def add_display_frame(self, **kwargs: Any) -> Self: config = dict(self.default_display_frame_config) config.update(kwargs) self.display_frame = SurroundingRectangle(self, **config) self.add(self.display_frame) return self - def interpolate_color(self, mobject1, mobject2, alpha): + def interpolate_color(self, mobject1, mobject2, alpha) -> None: assert mobject1.pixel_array.shape == mobject2.pixel_array.shape, ( f"Mobject pixel array shapes incompatible for interpolation.\n" f"Mobject 1 ({mobject1}) : {mobject1.pixel_array.shape}\n" diff --git a/manim/mobject/types/point_cloud_mobject.py b/manim/mobject/types/point_cloud_mobject.py index 289df3f6d0..c9f54e6ed2 100644 --- a/manim/mobject/types/point_cloud_mobject.py +++ b/manim/mobject/types/point_cloud_mobject.py @@ -4,6 +4,8 @@ __all__ = ["PMobject", "Mobject1D", "Mobject2D", "PGroup", "PointCloudDot", "Point"] +from typing import TYPE_CHECKING + import numpy as np from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL @@ -17,6 +19,7 @@ WHITE, YELLOW, ManimColor, + ParsableManimColor, color_gradient, color_to_rgba, rgba_to_color, @@ -25,6 +28,15 @@ __all__ = ["PMobject", "Mobject1D", "Mobject2D", "PGroup", "PointCloudDot", "Point"] +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any + + import numpy.typing as npt + from typing_extensions import Self + + from manim.typing import ManimFloat, Point3D, Vector3D + class PMobject(Mobject, metaclass=ConvertToOpenGL): """A disc made of a cloud of Dots @@ -55,19 +67,25 @@ def construct(self): """ - def __init__(self, stroke_width=DEFAULT_STROKE_WIDTH, **kwargs): + def __init__(self, stroke_width: int = DEFAULT_STROKE_WIDTH, **kwargs: Any) -> None: self.stroke_width = stroke_width super().__init__(**kwargs) - def reset_points(self): + def reset_points(self) -> Self: self.rgbas = np.zeros((0, 4)) self.points = np.zeros((0, 3)) return self - def get_array_attrs(self): + def get_array_attrs(self) -> list[str]: return super().get_array_attrs() + ["rgbas"] - def add_points(self, points, rgbas=None, color=None, alpha=1): + def add_points( + self, + points: npt.NDArray, + rgbas: npt.NDArray | None = None, + color: ParsableManimColor | None = None, + alpha: float = 1, + ) -> Self: """Add points. Points must be a Nx3 numpy array. @@ -85,24 +103,26 @@ def add_points(self, points, rgbas=None, color=None, alpha=1): self.rgbas = np.append(self.rgbas, rgbas, axis=0) return self - def set_color(self, color=YELLOW, family=True): + def set_color( + self, color: ParsableManimColor = YELLOW, family: bool = True + ) -> Self: rgba = color_to_rgba(color) mobs = self.family_members_with_points() if family else [self] for mob in mobs: mob.rgbas[:, :] = rgba - self.color = color + self.color = ManimColor.parse(color) return self - def get_stroke_width(self): + def get_stroke_width(self) -> int: return self.stroke_width - def set_stroke_width(self, width, family=True): + def set_stroke_width(self, width: int, family: bool = True) -> Self: mobs = self.family_members_with_points() if family else [self] for mob in mobs: mob.stroke_width = width return self - def set_color_by_gradient(self, *colors): + def set_color_by_gradient(self, *colors: ParsableManimColor) -> Self: self.rgbas = np.array( list(map(color_to_rgba, color_gradient(*colors, len(self.points)))), ) @@ -110,11 +130,11 @@ def set_color_by_gradient(self, *colors): def set_colors_by_radial_gradient( self, - center=None, - radius=1, - inner_color=WHITE, - outer_color=BLACK, - ): + center: Point3D | None = None, + radius: float = 1, + inner_color: ParsableManimColor = WHITE, + outer_color: ParsableManimColor = BLACK, + ) -> Self: start_rgba, end_rgba = list(map(color_to_rgba, [inner_color, outer_color])) if center is None: center = self.get_center() @@ -129,19 +149,19 @@ def set_colors_by_radial_gradient( ) return self - def match_colors(self, mobject): + def match_colors(self, mobject: Mobject) -> Self: Mobject.align_data(self, mobject) self.rgbas = np.array(mobject.rgbas) return self - def filter_out(self, condition): + def filter_out(self, condition: npt.NDArray) -> Self: for mob in self.family_members_with_points(): to_eliminate = ~np.apply_along_axis(condition, 1, mob.points) mob.points = mob.points[to_eliminate] mob.rgbas = mob.rgbas[to_eliminate] return self - def thin_out(self, factor=5): + def thin_out(self, factor: int = 5) -> Self: """Removes all but every nth point for n = factor""" for mob in self.family_members_with_points(): num_points = self.get_num_points() @@ -150,23 +170,27 @@ def thin_out(self, factor=5): ) return self - def sort_points(self, function=lambda p: p[0]): + def sort_points( + self, function: Callable[[npt.NDArray[ManimFloat]], float] = lambda p: p[0] + ) -> Self: """Function is any map from R^3 to R""" for mob in self.family_members_with_points(): indices = np.argsort(np.apply_along_axis(function, 1, mob.points)) mob.apply_over_attr_arrays(lambda arr, idx=indices: arr[idx]) return self - def fade_to(self, color, alpha, family=True): + def fade_to( + self, color: ParsableManimColor, alpha: float, family: bool = True + ) -> Self: self.rgbas = interpolate(self.rgbas, color_to_rgba(color), alpha) for mob in self.submobjects: mob.fade_to(color, alpha, family) return self - def get_all_rgbas(self): + def get_all_rgbas(self) -> npt.NDArray: return self.get_merged_array("rgbas") - def ingest_submobjects(self): + def ingest_submobjects(self) -> Self: attrs = self.get_array_attrs() arrays = list(map(self.get_merged_array, attrs)) for attr, array in zip(attrs, arrays): @@ -174,30 +198,32 @@ def ingest_submobjects(self): self.submobjects = [] return self - def get_color(self): + def get_color(self) -> ManimColor: return rgba_to_color(self.rgbas[0, :]) - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> Any: index = alpha * (self.get_num_points() - 1) - return self.points[index] + return self.points[np.floor(index)] @staticmethod - def get_mobject_type_class(): + def get_mobject_type_class() -> type[PMobject]: return PMobject # Alignment - def align_points_with_larger(self, larger_mobject): + def align_points_with_larger(self, larger_mobject: Mobject) -> None: assert isinstance(larger_mobject, PMobject) self.apply_over_attr_arrays( lambda a: stretch_array_to_length(a, larger_mobject.get_num_points()), ) - def get_point_mobject(self, center=None): + def get_point_mobject(self, center: Point3D | None = None) -> Point: if center is None: center = self.get_center() return Point(center) - def interpolate_color(self, mobject1, mobject2, alpha): + def interpolate_color( + self, mobject1: Mobject, mobject2: Mobject, alpha: float + ) -> Self: self.rgbas = interpolate(mobject1.rgbas, mobject2.rgbas, alpha) self.set_stroke_width( interpolate( @@ -208,7 +234,7 @@ def interpolate_color(self, mobject1, mobject2, alpha): ) return self - def pointwise_become_partial(self, mobject, a, b): + def pointwise_become_partial(self, mobject: Mobject, a: float, b: float) -> None: lower_index, upper_index = (int(x * mobject.get_num_points()) for x in (a, b)) for attr in self.get_array_attrs(): full_array = getattr(mobject, attr) @@ -218,24 +244,31 @@ def pointwise_become_partial(self, mobject, a, b): # TODO, Make the two implementations below non-redundant class Mobject1D(PMobject, metaclass=ConvertToOpenGL): - def __init__(self, density=DEFAULT_POINT_DENSITY_1D, **kwargs): + def __init__(self, density: int = DEFAULT_POINT_DENSITY_1D, **kwargs: Any) -> None: self.density = density self.epsilon = 1.0 / self.density super().__init__(**kwargs) - def add_line(self, start, end, color=None): + def add_line( + self, + start: npt.NDArray, + end: npt.NDArray, + color: ParsableManimColor | None = None, + ) -> None: start, end = list(map(np.array, [start, end])) length = np.linalg.norm(end - start) if length == 0: - points = [start] + points = np.array([start]) else: epsilon = self.epsilon / length - points = [interpolate(start, end, t) for t in np.arange(0, 1, epsilon)] + points = np.array( + [interpolate(start, end, t) for t in np.arange(0, 1, epsilon)] + ) self.add_points(points, color=color) class Mobject2D(PMobject, metaclass=ConvertToOpenGL): - def __init__(self, density=DEFAULT_POINT_DENSITY_2D, **kwargs): + def __init__(self, density: int = DEFAULT_POINT_DENSITY_2D, **kwargs: Any) -> None: self.density = density self.epsilon = 1.0 / self.density super().__init__(**kwargs) @@ -264,7 +297,7 @@ def construct(self): """ - def __init__(self, *pmobs, **kwargs): + def __init__(self, *pmobs: Any, **kwargs: Any) -> None: if not all(isinstance(m, (PMobject, OpenGLPMobject)) for m in pmobs): raise ValueError( "All submobjects must be of type PMobject or OpenGLPMObject" @@ -273,10 +306,13 @@ def __init__(self, *pmobs, **kwargs): super().__init__(**kwargs) self.add(*pmobs) - def fade_to(self, color, alpha, family=True): + def fade_to( + self, color: ParsableManimColor, alpha: float, family: bool = True + ) -> Self: if family: for mob in self.submobjects: mob.fade_to(color, alpha, family) + return self class PointCloudDot(Mobject1D): @@ -313,13 +349,13 @@ def construct(self): def __init__( self, - center=ORIGIN, - radius=2.0, - stroke_width=2, - density=DEFAULT_POINT_DENSITY_1D, - color=YELLOW, - **kwargs, - ): + center: Vector3D = ORIGIN, + radius: float = 2.0, + stroke_width: int = 2, + density: int = DEFAULT_POINT_DENSITY_1D, + color: ManimColor = YELLOW, + **kwargs: Any, + ) -> None: self.radius = radius self.epsilon = 1.0 / density super().__init__( @@ -327,22 +363,24 @@ def __init__( ) self.shift(center) - def init_points(self): + def init_points(self) -> None: self.reset_points() self.generate_points() - def generate_points(self): + def generate_points(self) -> None: self.add_points( - [ - r * (np.cos(theta) * RIGHT + np.sin(theta) * UP) - for r in np.arange(self.epsilon, self.radius, self.epsilon) - # Num is equal to int(stop - start)/ (step + 1) reformulated. - for theta in np.linspace( - 0, - 2 * np.pi, - num=int(2 * np.pi * (r + self.epsilon) / self.epsilon), - ) - ], + np.array( + [ + r * (np.cos(theta) * RIGHT + np.sin(theta) * UP) + for r in np.arange(self.epsilon, self.radius, self.epsilon) + # Num is equal to int(stop - start)/ (step + 1) reformulated. + for theta in np.linspace( + 0, + 2 * np.pi, + num=int(2 * np.pi * (r + self.epsilon) / self.epsilon), + ) + ] + ), ) @@ -367,14 +405,16 @@ def construct(self): self.add(point) """ - def __init__(self, location=ORIGIN, color=BLACK, **kwargs): + def __init__( + self, location: Vector3D = ORIGIN, color: ManimColor = BLACK, **kwargs: Any + ) -> None: self.location = location super().__init__(color=color, **kwargs) - def init_points(self): + def init_points(self) -> None: self.reset_points() self.generate_points() self.set_points([self.location]) - def generate_points(self): - self.add_points([self.location]) + def generate_points(self) -> None: + self.add_points(np.array([self.location])) diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index 748d5966a5..4e8c268871 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -48,12 +48,15 @@ from manim.utils.space_ops import rotate_vector, shoelace_direction if TYPE_CHECKING: + from typing import Any + import numpy.typing as npt from typing_extensions import Self from manim.typing import ( BezierPoints, CubicBezierPoints, + InternalPoint3D_Array, ManimFloat, MappingFunction, Point2D, @@ -133,7 +136,7 @@ def __init__( tolerance_for_point_equality: float = 1e-6, n_points_per_cubic_curve: int = 4, cap_style: CapStyleType = CapStyleType.AUTO, - **kwargs, + **kwargs: Any, ): self.fill_opacity = fill_opacity self.stroke_opacity = stroke_opacity @@ -708,7 +711,7 @@ def set_shade_in_3d( return self def set_points(self, points: Point3D_Array) -> Self: - self.points: Point3D_Array = np.array(points) + self.points: InternalPoint3D_Array = np.array(points) return self def resize_points( @@ -1586,7 +1589,7 @@ def get_anchors_and_handles(self) -> list[Point3D_Array]: nppcc = self.n_points_per_cubic_curve return [self.points[i::nppcc] for i in range(nppcc)] - def get_start_anchors(self) -> Point3D_Array: + def get_start_anchors(self) -> InternalPoint3D_Array: """Returns the start anchors of the bezier curves. Returns @@ -2058,10 +2061,8 @@ def construct(self): """ def __init__( - self, - *vmobjects: VMobject | Iterable[VMobject], - **kwargs, - ): + self, *vmobjects: VMobject | Iterable[VMobject], **kwargs: Any + ) -> None: super().__init__(**kwargs) self.add(*vmobjects) diff --git a/manim/utils/file_ops.py b/manim/utils/file_ops.py index 7efcee02c5..05a31e6e46 100644 --- a/manim/utils/file_ops.py +++ b/manim/utils/file_ops.py @@ -28,6 +28,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from manim.typing import StrPath + from ..scene.scene_file_writer import SceneFileWriter from manim import __version__, config, logger @@ -159,7 +161,7 @@ def guarantee_empty_existence(path: Path) -> Path: def seek_full_path_from_defaults( - file_name: str, default_dir: Path, extensions: list[str] + file_name: StrPath, default_dir: Path, extensions: list[str] ) -> Path: possible_paths = [Path(file_name).expanduser()] possible_paths += [ diff --git a/manim/utils/images.py b/manim/utils/images.py index 1ce3dd8c28..7e0accf4ca 100644 --- a/manim/utils/images.py +++ b/manim/utils/images.py @@ -9,7 +9,8 @@ "change_to_rgba_array", ] -from pathlib import Path +from pathlib import Path, PurePath +from typing import TYPE_CHECKING import numpy as np from PIL import Image @@ -17,8 +18,11 @@ from .. import config from ..utils.file_ops import seek_full_path_from_defaults +if TYPE_CHECKING: + import numpy.typing as npt -def get_full_raster_image_path(image_file_name: str) -> Path: + +def get_full_raster_image_path(image_file_name: str | PurePath) -> Path: return seek_full_path_from_defaults( image_file_name, default_dir=config.get_dir("assets_dir"), @@ -26,7 +30,7 @@ def get_full_raster_image_path(image_file_name: str) -> Path: ) -def get_full_vector_image_path(image_file_name: str) -> Path: +def get_full_vector_image_path(image_file_name: str | PurePath) -> Path: return seek_full_path_from_defaults( image_file_name, default_dir=config.get_dir("assets_dir"), @@ -49,7 +53,7 @@ def invert_image(image: np.array) -> Image: return Image.fromarray(arr) -def change_to_rgba_array(image, dtype="uint8"): +def change_to_rgba_array(image: npt.NDArray, dtype="uint8") -> npt.NDArray: """Converts an RGB array into RGBA with the alpha value opacity maxed.""" pa = image if len(pa.shape) == 2: diff --git a/mypy.ini b/mypy.ini index 956b44ae21..12c0a1dd69 100644 --- a/mypy.ini +++ b/mypy.ini @@ -70,6 +70,9 @@ ignore_errors = True [mypy-manim.mobject.*] ignore_errors = True +[mypy-manim.mobject.geometry.*] +ignore_errors = False + [mypy-manim.plugins.*] ignore_errors = True