Skip to content

Commit

Permalink
Fixed start value of Discrete Action space
Browse files Browse the repository at this point in the history
The start value of Discrete Action-Space had no effect previously and it is now supported
  • Loading branch information
JoshuaBluem committed Dec 6, 2024
1 parent 9caa168 commit c0aa574
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
14 changes: 12 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,9 @@ def _sample_action(
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

# Rescale the action from [low, high] to [-1, 1]
# Transform action from its action_space bounds
if isinstance(self.action_space, spaces.Box):
# Rescale the action from [low, high] to [-1, 1]
scaled_action = self.policy.scale_action(unscaled_action)

# Add noise to the action (improve exploration)
Expand All @@ -400,10 +401,19 @@ def _sample_action(
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
elif isinstance(self.action_space, spaces.Discrete):
# Discrete case: Shift action values so every action starts from zero
scaled_action = self.policy.scale_action(unscaled_action)

# Use buffer action to store in buffer
buffer_action = scaled_action

# Still use the original action for env, that is scaled to its action-space bounds
action = unscaled_action
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
action = buffer_action

return action, buffer_action

def _dump_logs(self) -> None:
Expand Down
17 changes: 10 additions & 7 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,27 @@ def collect_rollouts(

with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
obs_tensor = obs_as_tensor(self._last_obs, self.device) # type: ignore[assignment, arg-type]
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()

# Rescale and perform action
clipped_actions = actions

# Rescale to action bounds
unscaled_actions = actions
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
unscaled_actions = self.policy.unscale_action(unscaled_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
unscaled_actions = np.clip(actions, self.action_space.low, self.action_space.high)
elif isinstance(self.action_space, spaces.Discrete):
# Scale actions to match action-space bounds
unscaled_actions = self.policy.unscale_action(unscaled_actions) # type: ignore[assignment, arg-type]

new_obs, rewards, dones, infos = env.step(clipped_actions)
# Perform step
new_obs, rewards, dones, infos = env.step(unscaled_actions)

self.num_timesteps += env.num_envs

Expand Down
53 changes: 39 additions & 14 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def predict(
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
elif isinstance(self.action_space, spaces.Discrete):
# transform action to its action-space bounds starting from its defined start value
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]

# Remove batch dimension if needed
if not vectorized_env:
Expand All @@ -387,30 +390,52 @@ def predict(

def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
Rescale the action from action-space bounds
Box-Case:
Scale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
Discrete-Case:
Shift start value from action_space.start to zero
:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
scaled_action: np.ndarray

if isinstance(self.action_space, spaces.Box):
# Box case
low, high = self.action_space.low, self.action_space.high
scaled_action = 2.0 * ((action - low) / (high - low)) - 1.0
elif isinstance(self.action_space, spaces.Discrete):
# discrete actions case
scaled_action = np.subtract(action, self.action_space.start)
else:
raise AssertionError(f"Trying to scale an action using an action space that is not a Box() or Discrete(): {self.action_space}")

return scaled_action

def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
Box-Case:
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
Discrete-Case:
Reverse shift start value from zero back to action_space.start
:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
unscaled_action: np.ndarray

if isinstance(self.action_space, spaces.Box):
low, high = self.action_space.low, self.action_space.high
unscaled_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
elif isinstance(self.action_space, spaces.Discrete):
# match discrete actions bounds
unscaled_action = np.add(scaled_action, self.action_space.start)
else:
raise AssertionError(f"Trying to unscale an action using an action space that is not a Box() or Discrete(): {self.action_space}")

return unscaled_action


class ActorCriticPolicy(BasePolicy):
Expand Down

0 comments on commit c0aa574

Please sign in to comment.