Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second Order gradient is not working #220

Open
zzong2006 opened this issue Oct 10, 2019 · 1 comment
Open

Second Order gradient is not working #220

zzong2006 opened this issue Oct 10, 2019 · 1 comment

Comments

@zzong2006
Copy link

zzong2006 commented Oct 10, 2019

I try to implement getting the second order gradient by tf_agent.
The reason why I do second order gradient is came from meta-learning algorithm MAML.

First I create tf_agents.networks.network named as actor_net which have single hidden layer with 100 nodes and give 2 action distributions for 'CartPole-v0' problem (left, right).

env_name = "CartPole-v0"  # @param {type:"string"}
fc_layer_params = (100,)
train_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
actor_net = actor_distribution_network.ActorDistributionNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

And with the two tf.GradientTape() and cloned actor_net network, I found it is possible to get the second order gradient with respect to different(cloned) networks.

with tf.GradientTape() as test_tape:
    with tf.GradientTape() as train_tape:
        curr_loss = loss(actor_net.submodules[4](np.random.random([100, 100])), tf.random.uniform([100, 2]))
    dW, db = train_tape.gradient(curr_loss, actor_net.submodules[4].weights)

    z = actor_net.copy()
    z.create_variables()
    z.set_weights(actor_net.get_weights())

    z.submodules[4].kernel = tf.subtract(actor_net.submodules[4].kernel, tf.multiply(0.003, dW))
    z.submodules[4].bias = tf.subtract(actor_net.submodules[4].bias, tf.multiply(0.003, db))

    # print(dW, db)
    curr_loss = loss(z.submodules[4](np.random.random([100, 100])), tf.random.uniform([100, 2]))
dW, db = test_tape.gradient(curr_loss, actor_net.submodules[4].trainable_variables)

# print(dW, db)

When it goes wrong, last gradient value(dW, db) from test_tape produces the array filled with None values. But I checked it was not.
I used arbitrary(but it's last) layer (actor_net.submodules[4]) not entire network. And loss is the Mean Squared Loss function.

# actor_net.submodules (same as z)
(<tf_agents.networks.encoding_network.EncodingNetwork at 0x7f0210c32d50>,
 <tf_agents.networks.categorical_projection_network.CategoricalProjectionNetwork at 0x7f016c57a950>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f0180044650>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f0180044e90>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f016c57a350>)

#####
def loss(predicted_y, desired_y):
    return tf.reduce_mean(tf.square(predicted_y - desired_y))

Up to this, I just checked the second order gradient of 'single layer' but I tried to clone the agent for calculating the gradient.

Create the agent using REINFORCE algorithm and replay_buffer,

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.compat.v2.Variable(0)
tf_agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=actor_net,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter)
tf_agent.initialize()

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=tf_agent.collect_data_spec,
    batch_size=eval_env.batch_size,
    max_length=replay_buffer_capacity)

And compute second order grads through the same structure as the codes on the single layer of network.

for _ in range(num_iterations):
    with tf.GradientTape() as test_tape:
        with tf.GradientTape() as train_tape:
            # Collect a few episodes using collect_policy and save to the replay buffer.
            collect_episode(eval_env, tf_agent.collect_policy, collect_episodes_per_iteration)

            # Use data from the buffer and update the agent's network.
            experience = replay_buffer.gather_all()

            # Compute loss and grad for the original Network
            # Referred the _train function of the reinforce algorithm
            # https://github.com/tensorflow/agents/blob/master/tf_agents/agents/reinforce/reinforce_agent.py
            non_last_mask = tf.cast(tf.math.not_equal(experience.next_step_type, ts.StepType.LAST), tf.float32)
            discounts = non_last_mask * experience.discount * tf_agent._gamma
            returns = value_ops.discounted_return(experience.reward, discounts, time_major=False)
            time_step = ts.TimeStep(experience.step_type, tf.zeros_like(experience.reward),
                                    tf.zeros_like(experience.discount),
                                    experience.observation)

            loss_info = tf_agent.total_loss(time_step, experience.action, tf.stop_gradient(returns), weights=None)
            tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')
            print('loss : ', loss_info)
        variables_to_compute_grad = tf_agent._actor_network.trainable_weights
        if tf_agent._baseline:
            variables_to_compute_grad += tf_agent._value_network.trainable_weights
        grads = train_tape.gradient(loss_info.loss, variables_to_compute_grad)

        # Now make new agent but uses copied actor_net network
        copied_tf_agent = reinforce_agent.ReinforceAgent(
            train_env.time_step_spec(),
            train_env.action_spec(),
            actor_network=tf_agent._actor_network.copy(),
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
            normalize_returns=True,
            train_step_counter=tf.compat.v2.Variable(0))
        copied_tf_agent.initialize()
        copied_tf_agent._actor_network.set_weights(tf_agent._actor_network.get_weights())

        variables_to_train = copied_tf_agent._actor_network.trainable_weights
        if copied_tf_agent._baseline:
            variables_to_train += copied_tf_agent._value_network.trainable_weights

        # update cloned network's parameters by using grads of original network
        for i in range(len(variables_to_train)):
            weight = copied_tf_agent._actor_network.trainable_weights[i]
            weight.assign_sub(0.3 * grads[i])

        replay_buffer.clear()

        # Compute loss and grad for the cloned Network with respect to original Network
        # Collect a few episodes using collect_policy and save to the replay buffer.
        collect_episode(train_env, copied_tf_agent.collect_policy, collect_episodes_per_iteration)

        # Use data from the buffer and update the agent's network.
        experience = replay_buffer.gather_all()

        non_last_mask = tf.cast(tf.math.not_equal(experience.next_step_type, ts.StepType.LAST), tf.float32)
        discounts = non_last_mask * experience.discount * copied_tf_agent._gamma
        returns = value_ops.discounted_return(experience.reward, discounts, time_major=False)
        time_step = ts.TimeStep(experience.step_type, tf.zeros_like(experience.reward),
                                tf.zeros_like(experience.discount),
                                experience.observation)
        meta_loss_info = copied_tf_agent.total_loss(time_step, experience.action, tf.stop_gradient(returns), weights=None)
        tf.debugging.check_numerics(meta_loss_info.loss, 'Loss is inf or nan')
        print('meta loss : ', meta_loss_info)
    variables_to_train = tf_agent._actor_network.trainable_weights
    if tf_agent._baseline:
        variables_to_train += tf_agent._value_network.trainable_weights
    grads = test_tape.gradient(meta_loss_info.loss, variables_to_train)
    print('meta grad ! :', tf.reduce_mean(grads[0], axis=1))

Since the values of grads from test_tape were all None, it produced error like below.

ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

But loss itself seemed normal to me.

meta loss :  LossInfo(loss=<tf.Tensor: shape=(), dtype=float32, numpy=2.3132732>, extra=())

I guess or.. am sure that the loss(total_loss) function is way more complicated than just using single layer but still think the procedure is same.
Maybe creating new agent -but cloned network- is not good idea. Currently I cannot found any good idea or reference.

@sguada
Copy link
Member

sguada commented Oct 23, 2019

Once you create a copy of variables the gradients don't follow through.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants