Skip to content

Commit

Permalink
More mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
khughes-bdai committed Jan 26, 2024
1 parent 5969668 commit 6a14702
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 41 deletions.
6 changes: 3 additions & 3 deletions spot_wrapper/cam_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from bosdyn.api.data_chunk_pb2 import DataChunk
from bosdyn.api.spot_cam import audio_pb2
from bosdyn.api.spot_cam.camera_pb2 import Camera
from bosdyn.api.spot_cam.compositor_pb2 import IrColorMap
from bosdyn.api.spot_cam.compositor_pb2 import GetVisibleCamerasResponse, IrColorMap
from bosdyn.api.spot_cam.logging_pb2 import Logpoint
from bosdyn.api.spot_cam.power_pb2 import PowerStatus
from bosdyn.api.spot_cam.ptz_pb2 import PtzDescription, PtzPosition, PtzVelocity
Expand Down Expand Up @@ -151,7 +151,7 @@ def list_screens(self) -> typing.List[str]:
"""
return [screen.name for screen in self.client.list_screens()]

def get_visible_cameras(self):
def get_visible_cameras(self) -> GetVisibleCamerasResponse:
"""
Get the camera data for the camera currently visible on the stream
Expand Down Expand Up @@ -232,7 +232,7 @@ def get_bit_status(
degradations.append((degradation.type, degradation.description))
return events, degradations

def get_temperature(self) -> typing.Tuple[str, float]:
def get_temperature(self) -> typing.List[typing.Tuple[str, float]]:
"""
Get temperatures of various components of the camera
Expand Down
1 change: 1 addition & 0 deletions spot_wrapper/spot_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def gripper_angle_open(self, gripper_ang: float) -> typing.Tuple[bool, str]:
return True, "Opened gripper successfully"

def hand_pose(self, data) -> typing.Tuple[bool, str]:
# TODO what is the type of data? Is it a ROS message type?
"""
Set the pose of the hand
Expand Down
2 changes: 1 addition & 1 deletion spot_wrapper/spot_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _req_feedback(self) -> spot_check_pb2.SpotCheckFeedbackResponse:

return resp

def _spot_check_cmd(self, command: spot_check_pb2.SpotCheckCommandRequest):
def _spot_check_cmd(self, command: spot_check_pb2.SpotCheckCommandRequest) -> None:
"""Send a Spot Check command"""
start_time_seconds, start_time_ns = int(time.time()), int(time.time_ns() % 1e9)
req = spot_check_pb2.SpotCheckCommandRequest(
Expand Down
5 changes: 3 additions & 2 deletions spot_wrapper/spot_dance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
import tempfile
import time
from typing import List, Tuple, Union
from typing import Any, List, Tuple, Union

# TODO [mypy] imports are not getting resolved by linter causing style errors
from bosdyn.api.spot.choreography_sequence_pb2 import (
Animation,
ChoreographySequence,
Expand Down Expand Up @@ -149,7 +150,7 @@ def stop_recording_state(self) -> Tuple[bool, str, StopRecordingStateResponse]:
)

def choreography_log_to_animation_file(
self, name: str, fpath: str, has_arm: bool, **kwargs
self, name: str, fpath: str, has_arm: bool, **kwargs: Any
) -> Tuple[bool, str, str]:
"""save a choreography log to a file as an animation"""
try:
Expand Down
2 changes: 1 addition & 1 deletion spot_wrapper/spot_docking.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def undock(self, timeout: int = 20) -> typing.Tuple[bool, str]:
except Exception as e:
return False, f"Exception while trying to undock: {e}"

def get_docking_state(self, **kwargs) -> docking_pb2.DockState:
def get_docking_state(self, **kwargs: typing.Any) -> docking_pb2.DockState:
"""Get docking state of robot."""
state = self._docking_client.get_docking_state(**kwargs)
return state
43 changes: 25 additions & 18 deletions spot_wrapper/spot_graph_nav.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def _init_current_graph_nav_state(self) -> None:
# Store the most recent knowledge of the state of the robot based on rpc calls.
self._current_graph: typing.Optional[map_pb2.Graph] = None
self._current_edges: typing.Dict[str, typing.List[str]] = {} # maps to_waypoint to list(from_waypoint)
self._current_waypoint_snapshots = {} # maps id to waypoint snapshot
self._current_edge_snapshots = {} # maps id to edge snapshot
self._current_annotation_name_to_wp_id = {}
self._current_anchored_world_objects = {} # maps object id to a (wo, waypoint, fiducial)
self._current_anchors = {} # maps anchor id to anchor
self._current_waypoint_snapshots: typing.Dict[str, map_pb2.WaypointSnapshot] = {} # map id to waypoint snapshot
self._current_edge_snapshots: typing.Dict[str, map_pb2.EdgeSnapshot] = {} # maps id to edge snapshot
self._current_annotation_name_to_wp_id: typing.Dict[str, typing.Optional[str]] = {} # TODO what does this map?
self._current_anchored_world_objects: typing.Dict[str, typing.Tuple] = (
{}
) # maps object id to a (wo, waypoint, fiducial) TODO what exactly are these types
self._current_anchors: typing.Dict[str, map_pb2.Anchor] = {} # maps anchor id to anchor

def list_graph(self) -> typing.List[str]:
"""List waypoint ids of graph_nav
Expand Down Expand Up @@ -390,7 +392,9 @@ def _write_bytes(self, filepath: str, filename: str, data: bytes) -> None:
f.write(data)
f.close()

def _list_graph_waypoint_and_edge_ids(self, *args: typing.Any):
def _list_graph_waypoint_and_edge_ids(
self, *args: typing.Any
) -> typing.Tuple[typing.Dict[str, typing.Optional[str]], typing.Dict[str, typing.List[str]]]:
"""List the waypoint ids and edge ids of the graph currently on the robot."""

# Download current graph
Expand Down Expand Up @@ -559,15 +563,17 @@ def _navigate_route(self, waypoint_ids: typing.List[str]) -> typing.Tuple[bool,
Note that each waypoint must have an edge between them, aka be adjacent.
"""
for i in range(len(waypoint_ids)):
waypoint_ids[i] = self._find_unique_waypoint_id(
unique_id = self._find_unique_waypoint_id(
waypoint_ids[i],
self._current_graph,
self._current_annotation_name_to_wp_id,
self._logger,
)
if not waypoint_ids[i]:
if not unique_id:
self._logger.error("navigate_route: Failed to find the unique waypoint id.")
return False, "Failed to find the unique waypoint id."
else:
waypoint_ids[i] = unique_id

edge_ids_list = []
# Attempt to find edges in the current graph that match the ordered waypoint pairs.
Expand Down Expand Up @@ -748,7 +754,7 @@ def _find_unique_waypoint_id(
self,
short_code: str,
graph: map_pb2.Graph,
name_to_id: typing.Dict[str, str],
name_to_id: typing.Dict[str, typing.Optional[str]],
logger: logging.Logger,
) -> typing.Optional[str]:
"""Convert either a 2 letter short code or an annotation name into the associated unique id."""
Expand All @@ -761,8 +767,8 @@ def _find_unique_waypoint_id(
return name_to_id[short_code]
else:
logger.error(
"The waypoint name %s is used for multiple different unique waypoints. Please use"
+ "the waypoint id." % (short_code)
"The waypoint name {0} is used for multiple different unique waypoints. Please use the"
" waypoint id.".format(short_code)
)
return None
# Also not an waypoint annotation name, so we will operate under the assumption that it is a
Expand All @@ -779,12 +785,12 @@ def _find_unique_waypoint_id(

def _update_waypoints_and_edges(
self, graph: map_pb2.Graph, localization_id: str, logger: logging.Logger
) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, str]]:
) -> typing.Tuple[typing.Dict[str, typing.Optional[str]], typing.Dict[str, typing.List[str]]]:
"""Update and print waypoint ids and edge ids."""
name_to_id: typing.Dict[str, str] = dict()
edges: typing.Dict[str, str] = dict()
name_to_id: typing.Dict[str, typing.Optional[str]] = dict()
edges: typing.Dict[str, typing.List[str]] = dict()

short_code_to_count = {}
short_code_to_count: typing.Dict[str, int] = {}
waypoint_to_timestamp = []
for waypoint in graph.waypoints:
# Determine the timestamp that this waypoint was created at.
Expand All @@ -799,9 +805,10 @@ def _update_waypoints_and_edges(

# Determine how many waypoints have the same short code.
short_code = self._id_to_short_code(waypoint.id)
if short_code not in short_code_to_count:
short_code_to_count[short_code] = 0
short_code_to_count[short_code] += 1
if short_code:
if short_code not in short_code_to_count:
short_code_to_count[short_code] = 0
short_code_to_count[short_code] += 1

# Add the annotation name/id into the current dictionary.
waypoint_name = waypoint.annotations.name
Expand Down
6 changes: 4 additions & 2 deletions spot_wrapper/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def fixture(
"""

def decorator(cls: typing.Type[BaseMockSpot]) -> typing.Callable:
def fixturefunc(monkeypatch, **kwargs) -> typing.Iterator[SpotFixture]:
def fixturefunc(monkeypatch: typing.Any, **kwargs: typing.Any) -> typing.Iterator[SpotFixture]:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as thread_pool:
server = grpc.server(thread_pool)
port = server.add_insecure_port(f"{address}:0")
Expand All @@ -64,7 +64,9 @@ def fixturefunc(monkeypatch, **kwargs) -> typing.Iterator[SpotFixture]:
try:
with monkeypatch.context() as m:

def mock_secure_channel(target, _, *args, **kwargs):
def mock_secure_channel(
target: typing.Any, _: typing.Any, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return grpc.insecure_channel(target, *args, **kwargs)

m.setattr(grpc, "secure_channel", mock_secure_channel)
Expand Down
6 changes: 3 additions & 3 deletions spot_wrapper/testing/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ def add_to(self, server: grpc.Server) -> None:
for add in collect_servicer_add_functions(self.__class__):
add(self, server)

def __enter__(self):
def __enter__(self) -> typing.Any:
return self

def __exit__(self, *exc) -> None:
def __exit__(self, *exc: typing.Any) -> None:
self.shutdown()

def shutdown(self):
def shutdown(self) -> None:
"""
Shutdown what needs to be shutdown.
Expand Down
3 changes: 2 additions & 1 deletion spot_wrapper/tests/test_graph_nav_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import logging
import typing

from bosdyn.api.graph_nav import map_pb2

Expand All @@ -26,7 +27,7 @@ def test_short_code(self) -> None:
# Set up
self.logger = logging.Logger("test_graph_nav_util", level=logging.INFO)
self.graph = map_pb2.Graph()
self.name_to_id = {"ABCDE": "Node1"}
self.name_to_id: typing.Dict[str, typing.Optional[str]] = {"ABCDE": "Node1"}
# Test normal short code
assert graph_nav_util._find_unique_waypoint_id("AC", self.graph, self.name_to_id, self.logger) == "AC"
# Test annotation name that is known
Expand Down
18 changes: 8 additions & 10 deletions spot_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
robot_state_pb2,
world_object_pb2,
)
from bosdyn.api.docking import docking_pb2
from bosdyn.api.spot import robot_command_pb2 as spot_command_pb2
from bosdyn.api.spot.choreography_sequence_pb2 import (
Animation,
ChoreographySequence,
ChoreographyStatusResponse,
StartRecordingStateResponse,
StopRecordingStateResponse,
UploadChoreographyResponse,
)
from bosdyn.choreography.client.choreography import (
ChoreographyClient,
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def _robot_command(
command_proto: robot_command_pb2.RobotCommand,
end_time_secs: typing.Optional[float] = None,
timesync_endpoint: typing.Optional[TimeSyncEndpoint] = None,
) -> typing.Tuple[bool, str, typing.Optional[str]]:
) -> typing.Tuple[bool, str, typing.Optional[int]]:
"""Generic blocking function for sending commands to robots.
Args:
Expand All @@ -1066,7 +1066,7 @@ def _manipulation_request(
self,
request_proto: manipulation_api_pb2.ManipulationApiRequest,
end_time_secs: typing.Optional[float] = None,
timesync_endpoint: TimeSyncEndpoint = None,
timesync_endpoint: typing.Optional[TimeSyncEndpoint] = None,
) -> typing.Tuple[bool, str, typing.Optional[str]]:
"""Generic function for sending requests to the manipulation api of a robot.
Expand Down Expand Up @@ -1344,7 +1344,7 @@ def trajectory_cmd(

def robot_command(
self, robot_command: robot_command_pb2.RobotCommand
) -> typing.Tuple[bool, str, typing.Optional[str]]:
) -> typing.Tuple[bool, str, typing.Optional[int]]:
end_time = time.time() + MAX_COMMAND_DURATION
return self._robot_command(
robot_command,
Expand All @@ -1362,10 +1362,10 @@ def manipulation_command(
timesync_endpoint=self._robot.time_sync.endpoint,
)

def get_robot_command_feedback(self, cmd_id: int) -> robot_command_pb2.RobotCommandFeedbackResponse:
def get_robot_command_feedback(self, cmd_id: int) -> robot_command_pb2.RobotCommandFeedback:
return self._robot_command_client.robot_command_feedback(cmd_id)

def get_manipulation_command_feedback(self, cmd_id):
def get_manipulation_command_feedback(self, cmd_id: int) -> manipulation_api_pb2.ManipulationApiFeedbackResponse:
feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest(manipulation_cmd_id=cmd_id)

return self._manipulation_api_client.manipulation_api_feedback_command(
Expand Down Expand Up @@ -1432,9 +1432,7 @@ def execute_choreography_by_name(
else:
return False, "Spot is not licensed for choreography"

def upload_choreography(
self, choreography_sequence: ChoreographySequence
) -> typing.Tuple[bool, str, UploadChoreographyResponse]:
def upload_choreography(self, choreography_sequence: ChoreographySequence) -> typing.Tuple[bool, str]:
"""Upload choreography sequence for later playback"""
if self._is_licensed_for_choreography:
return self._spot_dance.upload_choreography(choreography_sequence)
Expand Down Expand Up @@ -1480,7 +1478,7 @@ def get_choreography_status(
response,
)

def get_docking_state(self, **kwargs: typing.Any):
def get_docking_state(self, **kwargs: typing.Any) -> docking_pb2.DockState:
"""Get docking state of robot."""
state = self._docking_client.get_docking_state(**kwargs)
return state
Expand Down

0 comments on commit 6a14702

Please sign in to comment.