Skip to content

Commit

Permalink
Merge pull request #56 from cwstryker/cwstryker/generic_types
Browse files Browse the repository at this point in the history
One option for adding generic types
  • Loading branch information
virtuald authored Apr 2, 2024
2 parents 710f4b7 + 604af69 commit 9ec9148
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 26 deletions.
1 change: 1 addition & 0 deletions commands2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import button
from . import cmd
from . import typing

from .commandscheduler import CommandScheduler
from .conditionalcommand import ConditionalCommand
Expand Down
27 changes: 17 additions & 10 deletions commands2/profiledpidcommand.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@
# the WPILib BSD license file in the root directory of this project.
#

from typing import Any, Callable, Union

from .command import Command
from .subsystem import Subsystem
from typing import Any, Generic

from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians
from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians

from .command import Command
from .subsystem import Subsystem
from .typing import (
FloatOrFloatSupplier,
FloatSupplier,
TProfiledPIDController,
UseOutputFunction,
)


class ProfiledPIDCommand(Command):
class ProfiledPIDCommand(Command, Generic[TProfiledPIDController]):
"""A command that controls an output with a :class:`.ProfiledPIDController`. Runs forever by default -
to add exit conditions and/or other behavior, subclass this class. The controller calculation and
output are performed synchronously in the command's execute() method.
Expand All @@ -24,10 +30,10 @@ class ProfiledPIDCommand(Command):

def __init__(
self,
controller,
measurementSource: Callable[[], float],
goalSource: Union[float, Callable[[], float]],
useOutput: Callable[[float, Any], Any],
controller: TProfiledPIDController,
measurementSource: FloatSupplier,
goalSource: FloatOrFloatSupplier,
useOutput: UseOutputFunction,
*requirements: Subsystem,
):
"""Creates a new ProfiledPIDCommand, which controls the given output with a ProfiledPIDController. Goal
Expand All @@ -40,14 +46,15 @@ def __init__(
:param requirements: the subsystems required by this command
"""

super().__init__()
if isinstance(controller, ProfiledPIDController):
self._stateCls = TrapezoidProfile.State
elif isinstance(controller, ProfiledPIDControllerRadians):
self._stateCls = TrapezoidProfileRadians.State
else:
raise ValueError(f"unknown controller type {controller!r}")

self._controller = controller
self._controller: TProfiledPIDController = controller
self._useOutput = useOutput
self._measurement = measurementSource
if isinstance(goalSource, (float, int)):
Expand Down
35 changes: 19 additions & 16 deletions commands2/profiledpidsubsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Open Source Software; you can modify and/or share it under the terms of
# the WPILib BSD license file in the root directory of this project.

from typing import Union, cast
from typing import Generic

from wpimath.trajectory import TrapezoidProfile

from .subsystem import Subsystem
from .typing import TProfiledPIDController, TTrapezoidProfileState


class ProfiledPIDSubsystem(Subsystem):
class ProfiledPIDSubsystem(
Subsystem, Generic[TProfiledPIDController, TTrapezoidProfileState]
):
"""
A subsystem that uses a :class:`wpimath.controller.ProfiledPIDController`
or :class:`wpimath.controller.ProfiledPIDControllerRadians` to
Expand All @@ -19,12 +22,18 @@ class ProfiledPIDSubsystem(Subsystem):

def __init__(
self,
controller,
controller: TProfiledPIDController,
initial_position: float = 0,
):
"""Creates a new PIDSubsystem."""
"""
Creates a new Profiled PID Subsystem using the provided PID Controller
:param controller: the controller that controls the output
:param initial_position: the initial value of the process variable
"""
super().__init__()
self._controller = controller
self._controller: TProfiledPIDController = controller
self._enabled = False
self.setGoal(initial_position)

Expand All @@ -38,20 +47,16 @@ def periodic(self):

def getController(
self,
):
) -> TProfiledPIDController:
"""Returns the controller"""
return self._controller

def setGoal(self, goal):
"""
Sets the goal state for the subsystem.
"""
"""Sets the goal state for the subsystem."""
self._controller.setGoal(goal)

def useOutput(self, output: float, setpoint: TrapezoidProfile.State):
"""
Uses the output from the controller object.
"""
def useOutput(self, output: float, setpoint: TTrapezoidProfileState):
"""Uses the output from the controller object."""
raise NotImplementedError(f"{self.__class__} must implement useOutput")

def getMeasurement(self) -> float:
Expand All @@ -72,7 +77,5 @@ def disable(self):
self.useOutput(0, TrapezoidProfile.State())

def isEnabled(self) -> bool:
"""
Returns whether the controller is enabled.
"""
"""Returns whether the controller is enabled."""
return self._enabled
30 changes: 30 additions & 0 deletions commands2/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable, Protocol, TypeVar, Union

from typing_extensions import TypeAlias
from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians
from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians

# Generic Types
TProfiledPIDController = TypeVar(
"TProfiledPIDController", ProfiledPIDControllerRadians, ProfiledPIDController
)
TTrapezoidProfileState = TypeVar(
"TTrapezoidProfileState",
TrapezoidProfileRadians.State,
TrapezoidProfile.State,
)


# Protocols - Structural Typing
class UseOutputFunction(Protocol):

def __init__(self, *args, **kwargs) -> None: ...

def __call__(self, t: float, u: TTrapezoidProfileState) -> None: ...

def accept(self, t: float, u: TTrapezoidProfileState) -> None: ...


# Type Aliases
FloatSupplier: TypeAlias = Callable[[], float]
FloatOrFloatSupplier: TypeAlias = Union[float, Callable[[], float]]

0 comments on commit 9ec9148

Please sign in to comment.