Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replay buffer not working for contexual bandit with scalar action spec #801

Open
sylviawhx opened this issue Dec 2, 2022 · 2 comments
Open

Comments

@sylviawhx
Copy link

I'm playing with a contextual bandit which has only one step for each episode. My goal is to save multiple Trajectory into replay buffer and use it later. However I'm not able to create a replay buffer that can be used to train the agent.

The agent spec is as below (running agent.collect_data_spec):

_TupleWrapper(Trajectory(
{'action': BoundedTensorSpec(shape=(), dtype=tf.int32, name=None, minimum=array(0, dtype=int32), maximum=array(8, dtype=int32)),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': TensorSpec(shape=(5,), dtype=tf.int32, name=None),
 'policy_info': PolicyInfo(log_probability=(), predicted_rewards_mean=(), multiobjective_scalarized_predicted_rewards_mean=(), predicted_rewards_optimistic=(), predicted_rewards_sampled=(), bandit_policy_type=()),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')}))

My trajectory looks like below (it has been passed to agent.train() and can run successfully):

Trajectory(
{'action': <tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[1]], dtype=int32)>,
 'discount': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.]], dtype=float32)>,
 'next_step_type': <tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>,
 'observation': <tf.Tensor: shape=(1, 1, 5), dtype=int32, numpy=array([[[1, 0, 1, 0, 0]]], dtype=int32)>,
 'policy_info': PolicyInfo(log_probability=(), predicted_rewards_mean=(), multiobjective_scalarized_predicted_rewards_mean=(), predicted_rewards_optimistic=(), predicted_rewards_sampled=(), bandit_policy_type=()),
 'reward': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.]], dtype=float32)>,
 'step_type': <tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[2]], dtype=int32)>})

However, after creating a replay buffer....

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=1,
max_length=1, 
)

I'm not able to feed my Trajectory into the replay buffer when running replay_buffer.add_batch(my_traj). It returns the following error:

function_node __wrapped__ResourceScatterUpdate_device_/job:localhost/replica:0/task:0/device:CPU:0}} Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [1,1], indices.shape [1], params.shape [1] [Op:ResourceScatterUpdate]

My action is a scalar (action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0,maximum=narms-1)). The action tensor in Trajectory has shape of (1,1), which is (batch_size, time). Since each action is a scalar, it does not have additional dimensions. I suspect that's why my params.shape is [1] instead of a 2D array and the replay buffer throws an error. Has anyone successfully creates a replay buffer for contextual bandit before? Please advise. Thank you!

Another confusion: When a replay buffer is created with batch_size = 1, max_length = 1 (in the case of bandit), its capacity is also 1. Then how can I store multiple Trajectory to the replay buffer?

Thank you!

@sguada
Copy link
Member

sguada commented Dec 2, 2022

I will recommend taking a look at the Bandit Colab

Double check the action_spec and the action

To add multiple trajectories to the replay buffer increase its size max_length=N

@sylviawhx
Copy link
Author

Thank you sguada! :) I'll double check it later today. Will report back if I can fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants