Skip to content

Commit

Permalink
polar coordinates ftw!
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar committed Oct 25, 2023
1 parent bf0dc54 commit 8e0f317
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 37 deletions.
4 changes: 2 additions & 2 deletions predicators/approaches/active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,10 @@ def _sample(state: State, goal: Set[GroundAtom], rng: np.random.Generator,
else:
raise NotImplementedError('Exploration strategy ' +
f'{strategy} ' + 'is not implemented.')

logging.info(f"State: {state}")
logging.info(f"Best Sample: {samples[idx]}")

return samples[idx]

return _sample
Expand Down
32 changes: 25 additions & 7 deletions predicators/envs/ball_and_cup_sticky_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@ class BallAndCupStickyTableEnv(BaseEnv):
reachable_thresh: ClassVar[float] = 0.1
objs_scale: ClassVar[float] = 0.25 # as a function of table radius
sticky_surface_mode: ClassVar[str] = "half" # half or whole
num_possible_sectors: ClassVar[
int] = 4 # number of sectors we're going to cut up the circular table into
# Types
_table_type: ClassVar[Type] = Type("table", ["x", "y", "radius", "sticky", "sticky_radius"])
_table_type: ClassVar[Type] = Type("table", [
"x", "y", "radius", "sticky", "sticky_region_start_angle",
"sticky_region_end_angle"
])
_robot_type: ClassVar[Type] = Type("robot", ["x", "y"])
_ball_type: ClassVar[Type] = Type("ball", ["x", "y", "radius", "held"])
_cup_type: ClassVar[Type] = Type("cup", ["x", "y", "radius", "held"])
Expand Down Expand Up @@ -179,9 +184,10 @@ def _get_tasks(self, num: int,
# the tables are in different positions along the circle every
# time.
theta_offset = 0.0 #rng.uniform(0, 2 * np.pi)
start_angle_choices = np.linspace(0.0, 2 * np.pi,
self.num_possible_sectors + 1)
# Now, actually instantiate the tables.
for i, theta in enumerate(thetas):
sticky_radius_factor = rng.uniform(0.05, 0.25)
x = d * np.cos(theta + theta_offset) + origin_x
y = d * np.sin(theta + theta_offset) + origin_y
if i >= CFG.sticky_table_num_sticky_tables:
Expand All @@ -191,12 +197,16 @@ def _get_tasks(self, num: int,
prefix = "sticky"
sticky = 1.0
obj = Object(f"{prefix}-table-{i}", self._table_type)
start_angle = rng.choice(start_angle_choices)
end_angle = start_angle + (2 * np.pi /
self.num_possible_sectors)
state_dict[obj] = {
"x": x,
"y": y,
"radius": radius,
"sticky": sticky,
"sticky_radius": radius * sticky_radius_factor
"sticky_region_start_angle": start_angle,
"sticky_region_end_angle": end_angle,
}
tables = sorted(state_dict)
target_table = tables[-1]
Expand Down Expand Up @@ -442,10 +452,18 @@ def simulate(self, state: State, action: Action) -> State:
# and set fall prob accordingly.
table_x = state.get(table, "x")
table_y = state.get(table, "y")
sticky_radius = state.get(table, "sticky_radius")
sticky_obj_geom = utils.Circle(table_x, table_y, sticky_radius)
if not sticky_obj_geom.contains_point(act_x, act_y):
# if self.sticky_surface_mode == "half" and act_y < table_y + 0.25 * (state.get(table, "radius") - (state.get(ball, "radius"))):
table_geom = self._object_to_geom(
table, state)
assert isinstance(table_geom, utils.Circle)
if not table_geom.sector_contains_point(
act_x, act_y,
state.get(
table,
"sticky_region_start_angle"),
state.get(
table,
"sticky_region_end_angle")):
# if self.sticky_surface_mode == "half" and act_y < table_y + 0.25 * (state.get(table, "radius") - (state.get(ball, "radius"))):
if obj_being_held == cup:
fall_prob = self._place_smooth_fall_prob
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,22 @@ def get_options(cls, env_name: str, types: Dict[str, Type],
types=[robot_type, cup_type])

return {
NavigateToTable, PickBallFromTable, PickBallFromFloor,
PlaceBallOnTable, PlaceBallOnFloor, PickCupWithoutBallFromTable,
PickCupWithBallFromTable, PickCupWithoutBallFromFloor,
PickCupWithBallFromFloor, #PlaceCupWithBallOnTable,
PlaceCupWithoutBallOnTable, PlaceCupWithBallOnFloor,
PlaceCupWithoutBallOnFloor, PlaceBallInCupOnFloor,
PlaceBallInCupOnTable, NavigateToBall, NavigateToCup
NavigateToTable,
PickBallFromTable,
PickBallFromFloor,
PlaceBallOnTable,
PlaceBallOnFloor,
PickCupWithoutBallFromTable,
PickCupWithBallFromTable,
PickCupWithoutBallFromFloor,
PickCupWithBallFromFloor, #PlaceCupWithBallOnTable,
PlaceCupWithoutBallOnTable,
PlaceCupWithBallOnFloor,
PlaceCupWithoutBallOnFloor,
PlaceBallInCupOnFloor,
PlaceBallInCupOnTable,
NavigateToBall,
NavigateToCup
}

@classmethod
Expand Down
26 changes: 24 additions & 2 deletions predicators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,16 @@ def construct_active_sampler_input(state: State, objects: Sequence[Object],
cup_x = state.get(cup, "x")
cup_y = state.get(cup, "y")
sticky = state.get(table, "sticky")
sticky_radius = state.get(table, "sticky_radius")
# sticky_radius = state.get(table, "sticky_radius")
sticky_start_angle = state.get(table,
"sticky_region_start_angle")
sticky_end_angle = state.get(table, "sticky_region_end_angle")
table_radius = state.get(table, "radius")
a, b, c, param_x, param_y = params
sampler_input_lst.append(table_radius)
sampler_input_lst.append(sticky)
sampler_input_lst.append(sticky_radius)
sampler_input_lst.append(sticky_start_angle)
sampler_input_lst.append(sticky_end_angle)
# sampler_input_lst.append(ball_x)
# sampler_input_lst.append(ball_y)
# sampler_input_lst.append(cup_x)
Expand Down Expand Up @@ -398,6 +402,24 @@ def plot(self, ax: plt.Axes, **kwargs: Any) -> None:
def contains_point(self, x: float, y: float) -> bool:
return (x - self.x)**2 + (y - self.y)**2 <= self.radius**2

def sector_contains_point(self, x: float, y: float,
sector_start_angle: float,
sector_end_angle: float) -> bool:
"""Returns true if the point x, y is contained within the sector
starting at sector_start_angle radians and ending at sector_end_angle
radians."""
# First, check that the point is even on the circle.
if not self.contains_point(x, y):
return False
# Next, convert (x, y) relative to the table's center
# to polar coordinates.
relative_x = x - self.x
relative_y = y - self.y
theta = np.arctan2(relative_y, relative_x)
if theta < 0:
theta = np.pi - theta
return sector_start_angle <= theta <= sector_end_angle

def contains_circle(self, other_circle: Circle) -> bool:
dist_between_centers = np.sqrt((other_circle.x - self.x)**2 +
(other_circle.y - self.y)**2)
Expand Down
38 changes: 19 additions & 19 deletions scripts/configs/active_sampler_learning.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
# Final active sampler learning experiments.
---
APPROACHES:
# task_repeat_explore:
# NAME: "active_sampler_learning"
# FLAGS:
# explorer: "active_sampler"
# active_sampler_explore_task_strategy: "task_repeat"
# planning_progress_explore:
# NAME: "active_sampler_learning"
# FLAGS:
# explorer: "active_sampler"
# active_sampler_explore_task_strategy: "planning_progress"
success_rate_explore:
task_repeat_explore:
NAME: "active_sampler_learning"
FLAGS:
explorer: "active_sampler"
active_sampler_explore_task_strategy: "success_rate"
active_sampler_explore_bonus: 0.01
random_score_explore:
active_sampler_explore_task_strategy: "task_repeat"
planning_progress_explore:
NAME: "active_sampler_learning"
FLAGS:
explorer: "active_sampler"
active_sampler_explore_task_strategy: "random"
random_nsrts_explore:
NAME: "active_sampler_learning"
FLAGS:
explorer: "random_nsrts"
active_sampler_explore_task_strategy: "planning_progress"
# success_rate_explore:
# NAME: "active_sampler_learning"
# FLAGS:
# explorer: "active_sampler"
# active_sampler_explore_task_strategy: "success_rate"
# active_sampler_explore_bonus: 0.01
# random_score_explore:
# NAME: "active_sampler_learning"
# FLAGS:
# explorer: "active_sampler"
# active_sampler_explore_task_strategy: "random"
# random_nsrts_explore:
# NAME: "active_sampler_learning"
# FLAGS:
# explorer: "random_nsrts"
# maple_q:
# NAME: "maple_q"
# FLAGS:
Expand Down

0 comments on commit 8e0f317

Please sign in to comment.