Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BernoulliThompsonSamplingAgent error: 'tensorflow.python.eager.function' has no attribute 'register'` #802

Open
sylviawhx opened this issue Dec 3, 2022 · 0 comments

Comments

@sylviawhx
Copy link

sylviawhx commented Dec 3, 2022

Hi Friends,

BernoulliThompsonSamplingAgent seems not working when calling agent.collect_policy.action(step). It returns the following error module 'tensorflow.python.eager.function' has no attribute 'register'. All other agents works fine running the same command.

I also tried to run this example and it returns the same error.
https://github.com/tensorflow/agents/blob/11d3b98f2d9edbb7e33484657ce61e941a73312d/tf_agents/bandits/agents/examples/v2/train_eval_bernoulli.py

I'm using tf_agents 0.15.0
Full traceback here:

AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_30661/3076113303.py in <module>
----> 1 action_step = agent.collect_policy.action(step) # agent gets action given the state

/opt/conda/lib/python3.7/site-packages/tf_agents/policies/tf_policy.py in action(self, time_step, policy_state, seed)
    322     if self._automatic_state_reset:
    323       policy_state = self._maybe_reset_state(time_step, policy_state)
--> 324     step = action_fn(time_step=time_step, policy_state=policy_state, seed=seed)
    325 
    326     def clip_action(action, action_spec):

/opt/conda/lib/python3.7/site-packages/tf_agents/utils/common.py in with_check_resource_vars(*fn_args, **fn_kwargs)
    186         # We're either in eager mode or in tf.function mode (no in-between); so
    187         # autodep-like behavior is already expected of fn.
--> 188         return fn(*fn_args, **fn_kwargs)
    189       if not resource_variables_enabled():
    190         raise RuntimeError(MISSING_RESOURCE_VARIABLES_ERROR)

/opt/conda/lib/python3.7/site-packages/tf_agents/policies/tf_policy.py in _action(self, time_step, policy_state, seed)
    558     """
    559     seed_stream = tfp.util.SeedStream(seed=seed, salt='tf_agents_tf_policy')
--> 560     distribution_step = self._distribution(time_step, policy_state)  # pytype: disable=wrong-arg-types
    561     actions = tf.nest.map_structure(
    562         lambda d: reparameterized_sampling.sample(d, seed=seed_stream()),

/opt/conda/lib/python3.7/site-packages/tf_agents/bandits/policies/bernoulli_thompson_sampling_policy.py in _distribution(self, time_step, policy_state)
    137     # Sample from the posterior distribution.
    138     posterior_dist = tfd.Beta(self._alpha, self._beta)
--> 139     predicted_reward_sampled = posterior_dist.sample([batch_size])
    140     predicted_reward_means_1d = tf.stack([
    141         self._alpha[k] / (self._alpha[k] + self._beta[k]) for k in range(

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py in sample(self, sample_shape, seed, name, **kwargs)
   1199     """
   1200     with self._name_and_control_scope(name):
-> 1201       return self._call_sample_n(sample_shape, seed, **kwargs)
   1202 
   1203   def _call_sample_and_log_prob(self, sample_shape, seed, **kwargs):

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_sample_n(self, sample_shape, seed, **kwargs)
   1177         sample_shape, 'sample_shape')
   1178     samples = self._sample_n(
-> 1179         n, seed=seed() if callable(seed) else seed, **kwargs)
   1180     samples = tf.nest.map_structure(
   1181         lambda x: tf.reshape(x, ps.concat([sample_shape, ps.shape(x)[1:]], 0)),

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/beta.py in _sample_n(self, n, seed)
    300     log_gamma1 = gamma_lib.random_gamma(
    301         shape=[n], concentration=expanded_concentration1, seed=seed1,
--> 302         log_space=True)
    303     log_gamma2 = gamma_lib.random_gamma(
    304         shape=[n], concentration=expanded_concentration0, seed=seed2,

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/gamma.py in random_gamma(shape, concentration, rate, log_rate, seed, log_space)
    716     shape, concentration, rate=None, log_rate=None, seed=None, log_space=False):
    717   return random_gamma_with_runtime(
--> 718       shape, concentration, rate, log_rate, seed, log_space)[0]
    719 
    720 

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/gamma.py in random_gamma_with_runtime(shape, concentration, rate, log_rate, seed, log_space)
    710   seed = samplers.sanitize_seed(seed, salt='random_gamma')
    711   return _random_gamma_gradient(
--> 712       total_shape, concentration, rate, log_rate, seed, log_space)
    713 
    714 

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/internal/custom_gradient.py in none_wrapper(*args, **kwargs)
    113           return val, vjp_bwd_wrapped
    114 
--> 115         return f_wrapped(*trimmed_args, **kwargs)
    116 
    117       return none_wrapper

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/custom_gradient.py in __call__(self, *a, **k)
    340 
    341   def __call__(self, *a, **k):
--> 342     return self._d(self._f, a, k)
    343 
    344 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/custom_gradient.py in decorated(wrapped, args, kwargs)
    294     """Decorated function with custom gradient."""
    295     if context.executing_eagerly():
--> 296       return _eager_mode_decorator(wrapped, args, kwargs)
    297     else:
    298       return _graph_mode_decorator(wrapped, args, kwargs)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/custom_gradient.py in _eager_mode_decorator(f, args, kwargs)
    540   """Implement custom gradient decorator for eager mode."""
    541   with tape_lib.VariableWatcher() as variable_watcher:
--> 542     result, grad_fn = f(*args, **kwargs)
    543   flat_args = composite_tensor_gradient.get_flat_tensors_for_gradients(
    544       nest.flatten(args))

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/internal/custom_gradient.py in f_wrapped(*args, **kwargs)
     89               reconstruct_args.append(args[0])
     90               args = args[1:]
---> 91           val, aux = vjp_fwd(*reconstruct_args, **kwargs)
     92 
     93           def vjp_bwd_wrapped(*g, **kwargs):

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/gamma.py in _random_gamma_fwd(shape, concentration, rate, log_rate, seed, log_space)
    597   """Compute output, aux (collaborates with _random_gamma_bwd)."""
    598   samples, impl = _random_gamma_no_gradient(
--> 599       shape, concentration, rate, log_rate, seed, log_space)
    600   return ((samples, impl),
    601           (samples, concentration, rate, log_rate))

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/internal/implementation_selection.py in f_wrapped(*args, **kwargs)
     78     try:
     79       tf.config.run_functions_eagerly(False)
---> 80       return f(*args, **kwargs)
     81     finally:
     82       tf.config.run_functions_eagerly(orig)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/distributions/gamma.py in _random_gamma_no_gradient(shape, concentration, rate, log_rate, seed, log_space)
    540   return sampler_impl(
    541       shape=shape, concentration=concentration, rate=rate, log_rate=log_rate,
--> 542       seed=seed, log_space=log_space)
    543 
    544 

/opt/conda/lib/python3.7/site-packages/tensorflow_probability/python/internal/implementation_selection.py in impl_selecting_fn(**kwargs)
    157     # Grappler will kick in during session execution to optimize the graph.
    158     samples, runtime = defun_default_fn(**kwargs)
--> 159     function.register(defun_cpu_fn, **kwargs)
    160     return samples, runtime
    161   return impl_selecting_fn

AttributeError: module 'tensorflow.python.eager.function' has no attribute 'register'



Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant