diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py
index d93e0564e..0fd0107c8 100644
--- a/predicators/envs/pybullet_balance.py
+++ b/predicators/envs/pybullet_balance.py
@@ -40,8 +40,8 @@ class PyBulletBalanceEnv(PyBulletEnv):
     _table_mid_half_extents = [0.1, _table_mid_w / 2, _table_height / 2]
 
     # Plate
-    _plate_height: ClassVar[float] = 0.01
-    _plate_z = _table_height - _plate_height
+    _plate_height: ClassVar[float] = 0.02
+    _plate_z = _table_height - _plate_height * 3
     _plate1_pose: ClassVar[Pose3D] = (_table_x, _table2_y - _table_mid_w / 2 -
                                       _table_side_w / 2 - _table_gap, _plate_z)
     _plate3_pose: ClassVar[Pose3D] = (_table_x, _table2_y + _table_mid_w / 2 +
@@ -83,7 +83,7 @@ class PyBulletBalanceEnv(PyBulletEnv):
 
     _camera_target: ClassVar[Pose3D] = (1.65, 0.75, 0.52)
 
-    _block_mass: ClassVar[float] = 0.5
+    _block_mass: ClassVar[float] = 1
     _block_size = CFG.balance_block_size
     _num_blocks_train = CFG.balance_num_blocks_train
     _num_blocks_test = CFG.balance_num_blocks_test
@@ -222,7 +222,7 @@ def initialize_pybullet(
         plate3_id = create_pybullet_block(
             (.9, .9, .9, 1),
             cls._plate_half_extents,
-            0.0,
+            1.0,
             1.0,
             cls._plate3_pose,
             cls._table_orientation,
@@ -232,7 +232,7 @@ def initialize_pybullet(
         plate1_id = create_pybullet_block(
             (.9, .9, .9, 1),
             cls._plate_half_extents,
-            0.0,
+            1.0,
             1.0,
             cls._plate1_pose,
             cls._table_orientation,
@@ -243,7 +243,7 @@ def initialize_pybullet(
         beam1_id = create_pybullet_block(
             (0.9, 0.9, 0.9, 1),
             cls._beam_half_extents,
-            0.0,
+            1.0,
             1.0,
             cls._beam1_pose,
             cls._table_orientation,
@@ -252,19 +252,21 @@ def initialize_pybullet(
         beam2_id = create_pybullet_block(
             (0.9, 0.9, 0.9, 1),
             cls._beam_half_extents,
-            0.0,
+            1.0,
             1.0,
             cls._beam2_pose,
             cls._table_orientation,
             physics_client_id,
         )
         bodies["beam_ids"] = [beam1_id, beam2_id]
+        cls.fix_plates_and_beams_in_place(physics_client_id, table2_id, plate1_id, 
+                                          plate3_id, beam1_id, beam2_id)
 
 
         button_id = create_pybullet_block(
             cls._button_color_off,
             [cls._button_radius] * 3,
-            0.0,
+            1.0,
             1.0,
             (cls.button_x, cls.button_y, cls.button_z),
             cls._table_orientation,
@@ -293,34 +295,34 @@ def initialize_pybullet(
 
         return physics_client_id, pybullet_robot, bodies
 
-    # @staticmethod
-    # def fix_plates_and_beams_in_place(physics_client_id, table_id, plate1_id, 
-    #                                   plate3_id, beam1_id, beam2_id):
-    # # Doesn't work for some reason
-    #     for child_id in [plate1_id, plate3_id, beam1_id, beam2_id]:
-    #         parent_pos, parent_orn = p.getBasePositionAndOrientation(table_id, 
-    #                                         physicsClientId=physics_client_id)
-    #         child_pos, child_orn = p.getBasePositionAndOrientation(child_id, 
-    #                                         physicsClientId=physics_client_id)
-    #         rel_pos, rel_orn = p.multiplyTransforms(
-    #             p.invertTransform(parent_pos, parent_orn)[0],
-    #             p.invertTransform(parent_pos, parent_orn)[1],
-    #             child_pos,
-    #             child_orn
-    #         )
-    #         p.createConstraint(
-    #             parentBodyUniqueId=table_id,
-    #             parentLinkIndex=-1,
-    #             childBodyUniqueId=child_id,
-    #             childLinkIndex=-1,
-    #             jointType=p.JOINT_FIXED,
-    #             jointAxis=(0, 0, 0),
-    #             parentFramePosition=rel_pos,
-    #             parentFrameOrientation=rel_orn,
-    #             childFramePosition=(0, 0, 0),
-    #             childFrameOrientation=(0, 0, 0),
-    #             physicsClientId=physics_client_id
-    #         )
+    @staticmethod
+    def fix_plates_and_beams_in_place(physics_client_id, table_id, plate1_id, 
+                                      plate3_id, beam1_id, beam2_id):
+    # Doesn't work for some reason
+        for child_id in [plate1_id, plate3_id, beam1_id, beam2_id]:
+            parent_pos, parent_orn = p.getBasePositionAndOrientation(table_id, 
+                                            physicsClientId=physics_client_id)
+            child_pos, child_orn = p.getBasePositionAndOrientation(child_id, 
+                                            physicsClientId=physics_client_id)
+            rel_pos, rel_orn = p.multiplyTransforms(
+                p.invertTransform(parent_pos, parent_orn)[0],
+                p.invertTransform(parent_pos, parent_orn)[1],
+                child_pos,
+                child_orn
+            )
+            p.createConstraint(
+                parentBodyUniqueId=table_id,
+                parentLinkIndex=-1,
+                childBodyUniqueId=child_id,
+                childLinkIndex=-1,
+                jointType=p.JOINT_FIXED,
+                jointAxis=(0, 0, 0),
+                parentFramePosition=rel_pos,
+                parentFrameOrientation=rel_orn,
+                childFramePosition=(0, 0, 0),
+                childFrameOrientation=(0, 0, 0),
+                physicsClientId=physics_client_id
+            )
 
     def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None:
         self._plate1.id = pybullet_bodies["table_ids"][0]
@@ -400,6 +402,9 @@ def step(self, action: Action, render_obs: bool = False) -> State:
         state = super().step(action, render_obs=render_obs)
 
         self._update_balance_beam(state)
+        self.fix_plates_and_beams_in_place(self._physics_client_id, self._table_id, 
+                                          self._plate1.id, self._plate3.id, 
+                                          self._beam_ids[0], self._beam_ids[1])
 
         # Turn machine on
         if self._PressingButton_holds(state, [self._robot, self._machine]):
@@ -481,6 +486,9 @@ def _reset_state(self, state: State) -> None:
         self._prev_diff = 0
         # Also do one beam update to make sure the initial positions match
         self._update_balance_beam(state)
+        self.fix_plates_and_beams_in_place(self._physics_client_id, self._table_id, 
+                                          self._plate1.id, self._plate3.id, 
+                                          self._beam_ids[0], self._beam_ids[1])
 
         # Update the button color
         if self._MachineOn_holds(state, [self._machine, self._robot]):
@@ -503,95 +511,83 @@ def _reset_state(self, state: State) -> None:
 
     def _update_balance_beam(self, state: State) -> None:
         """Shift the plates, beams, *and blocks on them* to simulate a
-        balance."""
-
-        # Count how many blocks are on each plate by comparing x to midpoint_x.
-        left_count = 0
-        right_count = 0
-        midpoint_y = self._table2_y
-
-        block_objs = state.get_objects(self._block_type)
+        balance, ensuring rising sides move blocks first then plate,
+        and dropping sides move plate first then blocks."""
         left_count = self.count_num_blocks(state, self._plate1)
         right_count = self.count_num_blocks(state, self._plate3)
-
         diff = left_count - right_count
         if diff == self._prev_diff:
-            return  # No change in distribution, no need to reset positions
+            return
 
-        # Compute plate/beam shifts
         shift_per_block = 0.01
-        shift_amount = diff * shift_per_block
-
-        # Update blocks
-        block_id_map = {obj: obj.id for obj in block_objs}  # Object -> int
-
-        for block_obj in block_objs:
-            # Skip out-of-view blocks
-            old_bz = state.get(block_obj, "pose_z")
-            if old_bz < 0 or self._held_obj_id == block_obj.id:
-                continue
-            by = state.get(block_obj, "pose_y")
-            block_id = block_id_map[block_obj]
-
-            # If block is left of midpoint => shift with plate1
-            # If block is right of midpoint => shift with plate3
-            if by < midpoint_y:
-                new_bz = old_bz - shift_amount
+        shift_amount = abs(diff) * shift_per_block
+        block_objs = state.get_objects(self._block_type)
+        left_dropping = diff > 0
+        right_dropping = diff < 0
+
+        def shift_blocks(is_left: bool, dropping: bool):
+            """Shift blocks for one side, dropping or rising."""
+            sign = -1 if dropping else 1
+            midpoint_y = self._table2_y
+            for block_obj in block_objs:
+                # Skip out-of-view or held
+                if state.get(block_obj, "pose_z") < 0 or \
+                   self._held_obj_id == block_obj.id:
+                    continue
+                by = state.get(block_obj, "pose_y")
+                belongs_to_side = (by < midpoint_y) if is_left else (by > midpoint_y)
+                if belongs_to_side:
+                    old_z = state.get(block_obj, "pose_z")
+                    padding = 0
+                    new_z = old_z + (sign * shift_amount) + (sign * padding)
+                    block_pos, block_orn = p.getBasePositionAndOrientation(
+                        block_obj.id, physicsClientId=self._physics_client_id)
+                    p.resetBasePositionAndOrientation(
+                        block_obj.id, [block_pos[0], block_pos[1], new_z],
+                        block_orn, physicsClientId=self._physics_client_id)
+
+        def shift_plate(is_left: bool, dropping: bool):
+            """Shift plate & beam, dropping or rising."""
+            sign = -1 if dropping else 1
+            if is_left:
+                plate_id, beam_id = self._plate1.id, self._beam_ids[0]
+                base_plate_z, base_beam_z = self._plate1_pose[2], self._beam1_pose[2]
             else:
-                new_bz = old_bz + shift_amount
-            # logging.debug(f"Current holding block: {self._held_obj_id}, shifting block {block_obj.id}")
+                plate_id, beam_id = self._plate3.id, self._beam_ids[1]
+                base_plate_z, base_beam_z = self._plate3_pose[2], self._beam2_pose[2]
 
-            # Update in PyBullet
-            block_pos, block_orn = p.getBasePositionAndOrientation(
-                block_id, physicsClientId=self._physics_client_id)
-            new_block_pos = [block_pos[0], block_pos[1], new_bz]
-            p.resetBasePositionAndOrientation(
-                block_id,
-                new_block_pos,
-                block_orn,
-                physicsClientId=self._physics_client_id)
-
-        # Update plates
-        new_plate1_z = self._plate1_pose[2] - shift_amount
-        new_beam1_z = self._beam1_pose[2] - shift_amount
-        new_plate3_z = self._plate3_pose[2] + shift_amount
-        new_beam2_z = self._beam2_pose[2] + shift_amount
-
-        # Reset the base positions of the plates and beams
-        plate1_id = self._plate1.id
-        beam1_id = self._beam_ids[0]
-        plate3_id = self._plate3.id
-        beam2_id = self._beam_ids[1]
+            new_plate_z = base_plate_z + (sign * shift_amount)
+            new_beam_z = base_beam_z + (sign * shift_amount)
 
-        plate1_pos, plate1_orn = p.getBasePositionAndOrientation(
-            plate1_id, physicsClientId=self._physics_client_id)
-        p.resetBasePositionAndOrientation(
-            plate1_id, [plate1_pos[0], plate1_pos[1], new_plate1_z],
-            plate1_orn,
-            physicsClientId=self._physics_client_id)
-
-        beam1_pos, beam1_orn = p.getBasePositionAndOrientation(
-            beam1_id, physicsClientId=self._physics_client_id)
-        p.resetBasePositionAndOrientation(
-            beam1_id, [beam1_pos[0], beam1_pos[1], new_beam1_z],
-            beam1_orn,
-            physicsClientId=self._physics_client_id)
+            plate_pos, plate_orn = p.getBasePositionAndOrientation(
+                plate_id, physicsClientId=self._physics_client_id)
+            p.resetBasePositionAndOrientation(
+                plate_id, [plate_pos[0], plate_pos[1], new_plate_z],
+                plate_orn, physicsClientId=self._physics_client_id)
 
-        plate3_pos, plate3_orn = p.getBasePositionAndOrientation(
-            plate3_id, physicsClientId=self._physics_client_id)
-        p.resetBasePositionAndOrientation(
-            plate3_id, [plate3_pos[0], plate3_pos[1], new_plate3_z],
-            plate3_orn,
-            physicsClientId=self._physics_client_id)
+            beam_pos, beam_orn = p.getBasePositionAndOrientation(
+                beam_id, physicsClientId=self._physics_client_id)
+            p.resetBasePositionAndOrientation(
+                beam_id, [beam_pos[0], beam_pos[1], new_beam_z],
+                beam_orn, physicsClientId=self._physics_client_id)
+
+        # Left side update
+        if left_dropping:
+            # Drop left plate
+            shift_plate(True, True)
+            # Drop left blocks
+            # Rise right blocks
+            shift_blocks(False, False)
+            # Rise right plate
+            shift_plate(False, False)
+        else:
+            shift_blocks(True, False)
+            shift_plate(True, False)
+            shift_plate(False, True)
+            shift_blocks(False, True)
 
-        beam2_pos, beam2_orn = p.getBasePositionAndOrientation(
-            beam2_id, physicsClientId=self._physics_client_id)
-        p.resetBasePositionAndOrientation(
-            beam2_id, [beam2_pos[0], beam2_pos[1], new_beam2_z],
-            beam2_orn,
-            physicsClientId=self._physics_client_id)
+        # Right side update
 
-        # Record the new difference
         self._prev_diff = diff
 
 
diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py
index d65e60eb4..863201198 100644
--- a/predicators/envs/pybullet_env.py
+++ b/predicators/envs/pybullet_env.py
@@ -443,7 +443,7 @@ def step(self, action: Action, render_obs: bool = False) -> Observation:
         if self._held_constraint_id is None and self._fingers_closing(action):
             # Detect if an object is held. If so, create a grasp constraint.
             self._held_obj_id = self._detect_held_object()
-            logging.debug(f"Detected held object: {self._held_obj_id}")
+            # logging.debug(f"Detected held object: {self._held_obj_id}")
             if self._held_obj_id is not None:
                 self._create_grasp_constraint()
 
@@ -496,7 +496,7 @@ def _detect_held_object(self) -> Optional[int]:
                     # A perfect score here is 1.0 (normals are unit vectors).
                     contact_normal = point[7]
                     score = expected_normal.dot(contact_normal)
-                    logging.debug(f"With obj {obj_id}, score: {score}")
+                    # logging.debug(f"With obj {obj_id}, score: {score}")
                     assert -1.0 <= score <= 1.0
 
                     # Take absolute as object/gripper could be rotated 180
diff --git a/predicators/ground_truth_models/balance/__init__.py b/predicators/ground_truth_models/balance/__init__.py
index 70853b006..d884981b5 100644
--- a/predicators/ground_truth_models/balance/__init__.py
+++ b/predicators/ground_truth_models/balance/__init__.py
@@ -1,10 +1,9 @@
 """Ground-truth models for Balance environment and variants."""
 
 from .nsrts import BalanceGroundTruthNSRTFactory
-from .options import BalanceGroundTruthOptionFactory, \
-    PyBulletBalanceGroundTruthOptionFactory
+from .options import PyBulletBalanceGroundTruthOptionFactory
 
 __all__ = [
-    "BalanceGroundTruthNSRTFactory", "BalanceGroundTruthOptionFactory",
+    "BalanceGroundTruthNSRTFactory",
     "PyBulletBalanceGroundTruthOptionFactory"
 ]
diff --git a/predicators/ground_truth_models/balance/options.py b/predicators/ground_truth_models/balance/options.py
index d3e7e47b9..baf63a043 100644
--- a/predicators/ground_truth_models/balance/options.py
+++ b/predicators/ground_truth_models/balance/options.py
@@ -18,239 +18,13 @@
 from predicators.structs import Action, Array, Object, ParameterizedOption, \
     ParameterizedPolicy, Predicate, State, Type
 
-
-class BalanceGroundTruthOptionFactory(GroundTruthOptionFactory):
-    """Ground-truth options for the (non-pybullet) blocks environment."""
-
-    env_cls = PyBulletBalanceEnv
-
-    @classmethod
-    def get_env_names(cls) -> Set[str]:
-        return {"balance"}
-
-    @classmethod
-    def get_options(cls, env_name: str, types: Dict[str, Type],
-                    predicates: Dict[str, Predicate],
-                    action_space: Box) -> Set[ParameterizedOption]:
-
-        robot_type = types["robot"]
-        block_type = types["block"]
-        machine_type = types["machine"]
-        plate_type = types["plate"]
-        block_size = CFG.blocks_block_size
-
-        Holding = predicates['Holding']
-        On = predicates['DirectlyOn']
-        OnPlate = predicates['DirectlyOnPlate']
-        GripperOpen = predicates['GripperOpen']
-        MachineOn = predicates['MachineOn']
-        # Balanced = predicates['Balanced']
-        Balanced = predicates['Balanced'].untransformed_predicate
-
-        # def _Pick_terminal(s: State, m: Dict, o: Sequence[Object],
-        #                    p: Array) -> bool:
-        #     robot, block = o
-        #     return Holding.holds(s, [block])
-
-        Pick = utils.SingletonParameterizedOption(
-            # variables: [robot, object to pick]
-            # params: []
-            "Pick",
-            cls._create_pick_policy(action_space),
-            types=[robot_type, block_type],
-            # terminal=_Pick_terminal,
-        )
-
-        # Probably will need to change the option parameters for this to work
-        # def _Stack_terminal(s: State, m: Dict, o: Sequence[Object],
-        #                     p: Array) -> bool:
-        #     block, otherblock, _ = o
-        #     return On.holds(s, [block, otherblock])
-
-        Stack = utils.SingletonParameterizedOption(
-            # variables: [robot, object on which to stack currently-held-object]
-            # params: []
-            "Stack",
-            cls._create_stack_policy(action_space, block_size),
-            types=[robot_type, block_type],
-            # types = [block_type, block_type, robot_type],
-            # terminal=_Stack_terminal,
-        )
-
-        # def _PutOnPlate_terminal(s: State, m: Dict, o: Sequence[Object],
-        #                          p: Array) -> bool:
-        #     block, _ = o
-        #     return OnPlate.holds(s, [block])
-
-        PutOnPlate = utils.SingletonParameterizedOption(
-            # variables: [robot]
-            # params: [x, y] (normalized coordinates on the table surface)
-            "PutOnPlate",
-            cls._create_putonplate_policy(action_space, block_size),
-            types=[robot_type, plate_type],
-            # types=[block_type, robot_type],
-            params_space=Box(0, 1, (2, )),
-            # terminal=_PutOnPlate_terminal,
-        )
-
-        # TurnMachineOn
-        def _TurnMachineOn_initiable(state: State, memory: Dict,
-                                     objects: Sequence[Object],
-                                     params: Array) -> bool:
-            del memory, params  # unused
-            plate1, plate2 = objects
-            robot = state.get_objects(robot_type)[0]
-            return GripperOpen.holds(state, [robot]) and\
-                    Balanced.holds(state, [plate1, plate2])
-
-        def _TurnMachineOn_terminal(state: State, memory: Dict,
-                                    objects: Sequence[Object],
-                                    params: Array) -> bool:
-            del memory, params  # unused
-            machine = state.get_objects(machine_type)[0]
-            robot = state.get_objects(robot_type)[0]
-            # _, machine, _, _ = objects
-            return MachineOn.holds(state, [machine, robot])
-
-        TurnMachineOn = ParameterizedOption(
-            "TurnMachineOn",
-            types=[plate_type, plate_type],
-            params_space=Box(0, 1, (0, )),
-            policy=cls._create_turn_machine_on_policy(),
-            initiable=_TurnMachineOn_initiable,
-            terminal=_TurnMachineOn_terminal,
-            annotation="Turn the machine on.")
-
-        return {Pick, Stack, PutOnPlate, TurnMachineOn}
-
-    @classmethod
-    def _create_turn_machine_on_policy(cls) -> ParameterizedPolicy:
-
-        def policy(state: State, memory: Dict, objects: Sequence[Object],
-                   params: Array) -> Action:
-            # This policy moves the robot up to be level with the button in the
-            # z direction and then moves forward in the y direction to press it.
-            del memory, params  # unused
-            robot = [r for r in state if r.type.name == "robot"][0]
-            # robot = objects[0]
-            x = state.get(robot, "pose_x")
-            y = state.get(robot, "pose_y")
-            z = state.get(robot, "pose_z")
-            # robot_pos = (x, y, z)
-            button_pos = (cls.env_cls.button_x, cls.env_cls.button_y,
-                          cls.env_cls.button_z)
-            arr = np.r_[button_pos, 1.0].astype(np.float32)
-            # arr = np.clip(arr, cls.env_cls.action_space.low,
-            #               cls.env_cls.action_space.high)
-            return Action(arr)
-            # if (cls.env_cls.button_z - z)**2 < cls.env_cls._button_radius**2:
-            #     # Move directly toward the button.
-            #     return cls._get_move_action(state, button_pos, robot_pos)
-            # # Move only in the z direction.
-            # return cls._get_move_action(state, (x, y, cls.env_cls.button_z),
-            #                             robot_pos)
-
-        return policy
-
-    @classmethod
-    def _create_pick_policy(cls, action_space: Box) -> ParameterizedPolicy:
-
-        def policy(state: State, memory: Dict, objects: Sequence[Object],
-                   params: Array) -> Action:
-            del memory, params  # unused
-            _, block = objects
-            block_pose = np.array([
-                state.get(block, "pose_x"),
-                state.get(block, "pose_y"),
-                state.get(block, "pose_z")
-            ])
-            arr = np.r_[block_pose, 0.0].astype(np.float32)
-            arr = np.clip(arr, action_space.low, action_space.high)
-            return Action(arr)
-
-        return policy
-
-    @classmethod
-    def _create_stack_policy(cls, action_space: Box,
-                             block_size: float) -> ParameterizedPolicy:
-
-        def policy(state: State, memory: Dict, objects: Sequence[Object],
-                   params: Array) -> Action:
-            del memory, params  # unused
-            _, block = objects
-            # _, block, _ = objects
-
-            block_pose = np.array([
-                state.get(block, "pose_x"),
-                state.get(block, "pose_y"),
-                state.get(block, "pose_z")
-            ])
-            relative_grasp = np.array([
-                0.,
-                0.,
-                block_size,
-            ])
-            arr = np.r_[block_pose + relative_grasp, 1.0].astype(np.float32)
-            arr = np.clip(arr, action_space.low, action_space.high)
-            return Action(arr)
-
-        return policy
-
-    @classmethod
-    def _create_putonplate_policy(cls, action_space: Box,
-                                  block_size: float) -> ParameterizedPolicy:
-
-        def policy(state: State, memory: Dict, objects: Sequence[Object],
-                   params: Array) -> Action:
-            del state, memory, objects  # unused
-            # De-normalize parameters to actual table coordinates.
-            x_norm, y_norm = params
-            x = BalanceEnv.x_lb + (BalanceEnv.x_ub - BalanceEnv.x_lb) * x_norm
-            y = BalanceEnv.y_lb + (BalanceEnv.y_ub - BalanceEnv.y_lb) * y_norm
-            z = BalanceEnv._table_height + 0.5 * block_size
-            arr = np.array([x, y, z, 1.0], dtype=np.float32)
-            arr = np.clip(arr, action_space.low, action_space.high)
-            return Action(arr)
-
-        return policy
-
-    ############################ Utility functions ############################
-
-    @classmethod
-    def _get_move_action(cls,
-                         state: State,
-                         target_pos: Tuple[float, float, float],
-                         robot_pos: Tuple[float, float, float],
-                         dtilt: float = 0.0,
-                         dwrist: float = 0.0,
-                         finger_status: str = "open") -> Action:
-        del state, finger_status  # used in PyBullet subclass
-        # We want to move in this direction.
-        delta = np.subtract(target_pos, robot_pos)
-        # But we can only move at most max_position_vel in one step.
-        # Get the norm full move delta.
-        pos_norm = float(np.linalg.norm(delta))
-        # If the norm is more than max_position_vel, rescale the delta so
-        # that its norm is max_position_vel.
-        if pos_norm > cls.env_cls.max_position_vel:
-            delta = cls.env_cls.max_position_vel * (delta / pos_norm)
-            pos_norm = cls.env_cls.max_position_vel
-        # Now normalize so that the action values are between -1 and 1, as
-        # expected by simulate and the action space.
-        if pos_norm > 0:
-            delta = delta / cls.env_cls.max_position_vel
-        dx, dy, dz = delta
-        return Action(np.array([dx, dy, dz, 0.0], dtype=np.float32))
-
-
 @lru_cache
 def _get_pybullet_robot() -> SingleArmPyBulletRobot:
     _, pybullet_robot, _ = \
             PyBulletBalanceEnv.initialize_pybullet(using_gui=False)
     return pybullet_robot
 
-
-class PyBulletBalanceGroundTruthOptionFactory(BalanceGroundTruthOptionFactory):
+class PyBulletBalanceGroundTruthOptionFactory(GroundTruthOptionFactory):
     """Ground-truth options for the pybullet_balance environment."""
 
     env_cls = PyBulletBalanceEnv
@@ -355,7 +129,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
                 # Move down to place.
                 cls._create_blocks_move_to_above_block_option(
                     name="MoveEndEffectorToStack",
-                    z_func=lambda block_z: (block_z + block_size * 2),
+                    z_func=lambda block_z: (block_z + block_size * 1.3),
                     finger_status="closed",
                     pybullet_robot=pybullet_robot,
                     option_types=option_types,
@@ -482,7 +256,8 @@ def _get_current_and_target_pose_and_finger_status(
             pybullet_robot, name, option_types, params_space,
             _get_current_and_target_pose_and_finger_status,
             cls._move_to_pose_tol, CFG.pybullet_max_vel_norm,
-            cls._finger_action_nudge_magnitude)
+            cls._finger_action_nudge_magnitude,
+            validate=CFG.pybullet_ik_validate)
 
     @classmethod
     def _create_blocks_move_to_above_table_option(
@@ -596,4 +371,5 @@ def _get_move_action(cls,
         return get_move_end_effector_to_pose_action(
             pybullet_robot, current_joint_positions, current_pose, target_pose,
             finger_status, CFG.pybullet_max_vel_norm,
-            cls._finger_action_nudge_magnitude)
+            cls._finger_action_nudge_magnitude,
+            validate=CFG.pybullet_ik_validate)