You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I try to implement getting the second order gradient by tf_agent.
The reason why I do second order gradient is came from meta-learning algorithm MAML.
First I create tf_agents.networks.network named as actor_net which have single hidden layer with 100 nodes and give 2 action distributions for 'CartPole-v0' problem (left, right).
And with the two tf.GradientTape() and cloned actor_net network, I found it is possible to get the second order gradient with respect to different(cloned) networks.
When it goes wrong, last gradient value(dW, db) from test_tape produces the array filled with None values. But I checked it was not.
I used arbitrary(but it's last) layer (actor_net.submodules[4]) not entire network. And loss is the Mean Squared Loss function.
And compute second order grads through the same structure as the codes on the single layer of network.
for_inrange(num_iterations):
withtf.GradientTape() astest_tape:
withtf.GradientTape() astrain_tape:
# Collect a few episodes using collect_policy and save to the replay buffer.collect_episode(eval_env, tf_agent.collect_policy, collect_episodes_per_iteration)
# Use data from the buffer and update the agent's network.experience=replay_buffer.gather_all()
# Compute loss and grad for the original Network# Referred the _train function of the reinforce algorithm# https://github.com/tensorflow/agents/blob/master/tf_agents/agents/reinforce/reinforce_agent.pynon_last_mask=tf.cast(tf.math.not_equal(experience.next_step_type, ts.StepType.LAST), tf.float32)
discounts=non_last_mask*experience.discount*tf_agent._gammareturns=value_ops.discounted_return(experience.reward, discounts, time_major=False)
time_step=ts.TimeStep(experience.step_type, tf.zeros_like(experience.reward),
tf.zeros_like(experience.discount),
experience.observation)
loss_info=tf_agent.total_loss(time_step, experience.action, tf.stop_gradient(returns), weights=None)
tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')
print('loss : ', loss_info)
variables_to_compute_grad=tf_agent._actor_network.trainable_weightsiftf_agent._baseline:
variables_to_compute_grad+=tf_agent._value_network.trainable_weightsgrads=train_tape.gradient(loss_info.loss, variables_to_compute_grad)
# Now make new agent but uses copied actor_net networkcopied_tf_agent=reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=tf_agent._actor_network.copy(),
optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
normalize_returns=True,
train_step_counter=tf.compat.v2.Variable(0))
copied_tf_agent.initialize()
copied_tf_agent._actor_network.set_weights(tf_agent._actor_network.get_weights())
variables_to_train=copied_tf_agent._actor_network.trainable_weightsifcopied_tf_agent._baseline:
variables_to_train+=copied_tf_agent._value_network.trainable_weights# update cloned network's parameters by using grads of original networkforiinrange(len(variables_to_train)):
weight=copied_tf_agent._actor_network.trainable_weights[i]
weight.assign_sub(0.3*grads[i])
replay_buffer.clear()
# Compute loss and grad for the cloned Network with respect to original Network# Collect a few episodes using collect_policy and save to the replay buffer.collect_episode(train_env, copied_tf_agent.collect_policy, collect_episodes_per_iteration)
# Use data from the buffer and update the agent's network.experience=replay_buffer.gather_all()
non_last_mask=tf.cast(tf.math.not_equal(experience.next_step_type, ts.StepType.LAST), tf.float32)
discounts=non_last_mask*experience.discount*copied_tf_agent._gammareturns=value_ops.discounted_return(experience.reward, discounts, time_major=False)
time_step=ts.TimeStep(experience.step_type, tf.zeros_like(experience.reward),
tf.zeros_like(experience.discount),
experience.observation)
meta_loss_info=copied_tf_agent.total_loss(time_step, experience.action, tf.stop_gradient(returns), weights=None)
tf.debugging.check_numerics(meta_loss_info.loss, 'Loss is inf or nan')
print('meta loss : ', meta_loss_info)
variables_to_train=tf_agent._actor_network.trainable_weightsiftf_agent._baseline:
variables_to_train+=tf_agent._value_network.trainable_weightsgrads=test_tape.gradient(meta_loss_info.loss, variables_to_train)
print('meta grad ! :', tf.reduce_mean(grads[0], axis=1))
Since the values of grads from test_tape were all None, it produced error like below.
ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.
But loss itself seemed normal to me.
meta loss : LossInfo(loss=<tf.Tensor: shape=(), dtype=float32, numpy=2.3132732>, extra=())
I guess or.. am sure that the loss(total_loss) function is way more complicated than just using single layer but still think the procedure is same.
Maybe creating new agent -but cloned network- is not good idea. Currently I cannot found any good idea or reference.
The text was updated successfully, but these errors were encountered:
I try to implement getting the second order gradient by tf_agent.
The reason why I do second order gradient is came from meta-learning algorithm MAML.
First I create
tf_agents.networks.network
named asactor_net
which have single hidden layer with 100 nodes and give 2 action distributions for 'CartPole-v0' problem (left, right).And with the two
tf.GradientTape()
and clonedactor_net
network, I found it is possible to get the second order gradient with respect to different(cloned) networks.When it goes wrong, last gradient value(
dW, db
) fromtest_tape
produces the array filled withNone
values. But I checked it was not.I used arbitrary(but it's last) layer (
actor_net.submodules[4]
) not entire network. And loss is the Mean Squared Loss function.Up to this, I just checked the second order gradient of 'single layer' but I tried to clone the agent for calculating the gradient.
Create the agent using REINFORCE algorithm and replay_buffer,
And compute second order grads through the same structure as the codes on the single layer of network.
Since the values of
grads
fromtest_tape
were allNone
, it produced error like below.But loss itself seemed normal to me.
I guess or.. am sure that the loss(
total_loss
) function is way more complicated than just using single layer but still think the procedure is same.Maybe creating new agent -but cloned network- is not good idea. Currently I cannot found any good idea or reference.
The text was updated successfully, but these errors were encountered: