Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to mobject/valuetracker.py #4129

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 66 additions & 38 deletions manim/mobject/value_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@

__all__ = ["ValueTracker", "ComplexValueTracker"]

from typing import TYPE_CHECKING, Any

import numpy as np

from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.utils.paths import straight_path

if TYPE_CHECKING:
from typing_extensions import Self

from manim.typing import PathFuncType


class ValueTracker(Mobject, metaclass=ConvertToOpenGL):
"""A mobject that can be used for tracking (real-valued) parameters.
Expand Down Expand Up @@ -69,76 +75,102 @@ def construct(self):

"""

def __init__(self, value=0, **kwargs):
def __init__(self, value: float = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.set(points=np.zeros((1, 3)))
self.set_value(value)

def get_value(self) -> float:
"""Get the current value of this ValueTracker."""
return self.points[0, 0]
value: float = self.points[0, 0]
return value

def set_value(self, value: float):
"""Sets a new scalar value to the ValueTracker"""
def set_value(self, value: float) -> Self:
"""Sets a new scalar value to the ValueTracker."""
self.points[0, 0] = value
return self

def increment_value(self, d_value: float):
"""Increments (adds) a scalar value to the ValueTracker"""
def increment_value(self, d_value: float) -> Self:
"""Increments (adds) a scalar value to the ValueTracker."""
self.set_value(self.get_value() + d_value)
return self

def __bool__(self):
"""Return whether the value of this value tracker evaluates as true."""
def __bool__(self) -> bool:
"""Return whether the value of this ValueTracker evaluates as true."""
return bool(self.get_value())

def __iadd__(self, d_value: float):
"""adds ``+=`` syntax to increment the value of the ValueTracker"""
self.increment_value(d_value)
def __iadd__(self, d_value: float | Mobject) -> Self:
"""adds ``+=`` syntax to increment the value of the ValueTracker."""
if isinstance(d_value, Mobject):
raise ValueError(
"Cannot increment ValueTracker by a Mobject. Please provide a scalar value."
)
else:
self.increment_value(d_value)
return self

def __ifloordiv__(self, d_value: float):
"""Set the value of this value tracker to the floor division of the current value by ``d_value``."""
def __ifloordiv__(self, d_value: float) -> Self:
"""Set the value of this ValueTracker to the floor division of the current value by ``d_value``."""
self.set_value(self.get_value() // d_value)
return self

def __imod__(self, d_value: float):
"""Set the value of this value tracker to the current value modulo ``d_value``."""
def __imod__(self, d_value: float) -> Self:
"""Set the value of this ValueTracker to the current value modulo ``d_value``."""
self.set_value(self.get_value() % d_value)
return self

def __imul__(self, d_value: float):
"""Set the value of this value tracker to the product of the current value and ``d_value``."""
def __imul__(self, d_value: float) -> Self:
"""Set the value of this ValueTracker to the product of the current value and ``d_value``."""
self.set_value(self.get_value() * d_value)
return self

def __ipow__(self, d_value: float):
"""Set the value of this value tracker to the current value raised to the power of ``d_value``."""
def __ipow__(self, d_value: float) -> Self:
"""Set the value of this ValueTracker to the current value raised to the power of ``d_value``."""
self.set_value(self.get_value() ** d_value)
return self

def __isub__(self, d_value: float):
"""adds ``-=`` syntax to decrement the value of the ValueTracker"""
self.increment_value(-d_value)
def __sub__(self, d_value: float | Mobject) -> Self:
"""Decrements the ValueTracker by ``d_value``."""
if isinstance(d_value, Mobject):
raise ValueError(
"Cannot decrement ValueTracker by a Mobject. Please provide a scalar value."
)
else:
self.increment_value(-d_value)
return self

def __itruediv__(self, d_value: float):
"""Sets the value of this value tracker to the current value divided by ``d_value``."""
def __isub__(self, d_value: float | Mobject) -> Self:
"""Adds ``-=`` syntax to decrement the value of the ValueTracker."""
if isinstance(d_value, Mobject):
raise ValueError(
"Cannot decrement ValueTracker by a Mobject. Please provide a scalar value."
)
else:
self.increment_value(-d_value)
return self

def __itruediv__(self, d_value: float) -> Self:
"""Sets the value of this ValueTracker to the current value divided by ``d_value``."""
self.set_value(self.get_value() / d_value)
return self

def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path()):
"""
Turns self into an interpolation between mobject1
and mobject2.
"""
def interpolate(
self,
mobject1: Mobject,
mobject2: Mobject,
alpha: float,
path_func: PathFuncType = straight_path(),
) -> Self:
"""Turns ``self`` into an interpolation between ``mobject1`` and ``mobject2``."""
self.set(points=path_func(mobject1.points, mobject2.points, alpha))
return self


class ComplexValueTracker(ValueTracker):
"""Tracks a complex-valued parameter.

The value is internally stored as a points array [a, b, 0]. This can be accessed directly
to represent the value geometrically, see the usage example.
When the value is set through :attr:`animate`, the value will take a straight path from the
source point to the destination point.

Expand All @@ -161,16 +193,12 @@ def construct(self):
self.play(tracker.animate.set_value(tracker.get_value() / (-2 + 3j)))
"""

def get_value(self):
"""Get the current value of this value tracker as a complex number.

The value is internally stored as a points array [a, b, 0]. This can be accessed directly
to represent the value geometrically, see the usage example.
"""
def get_value(self) -> complex: # type: ignore [override]
Copy link
Contributor Author

@fmuenkel fmuenkel Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not know how to resolve this.
Between ValueTracker and ComplexValueTracker, the latter makes more sense as the base class not the other way around, but that would be a large re-write.

"""Get the current value of this ComplexValueTracker as a complex number."""
return complex(*self.points[0, :2])

def set_value(self, z):
"""Sets a new complex value to the ComplexValueTracker"""
z = complex(z)
def set_value(self, value: complex | float) -> Self:
"""Sets a new complex value to the ComplexValueTracker."""
z = complex(value)
self.points[0, :2] = (z.real, z.imag)
return self
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ ignore_errors = False
[mypy-manim.mobject.geometry.*]
ignore_errors = True

[mypy-manim.mobject.value_tracker]
ignore_errors = False

[mypy-manim.renderer.*]
ignore_errors = True

Expand Down
Loading