From a809aa8a9f08003ecd462dd9d4410d9f92b60336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Bart=C3=B3k?= Date: Wed, 9 Oct 2024 00:56:16 -0700 Subject: [PATCH] Fix numpy test issue. PiperOrigin-RevId: 683926353 Change-Id: I88a15bf0cb64ee60a95bd6567128e61272e4e39c --- tf_agents/bandits/agents/exp3_mixture_agent_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tf_agents/bandits/agents/exp3_mixture_agent_test.py b/tf_agents/bandits/agents/exp3_mixture_agent_test.py index ca1766378..6feb4941c 100644 --- a/tf_agents/bandits/agents/exp3_mixture_agent_test.py +++ b/tf_agents/bandits/agents/exp3_mixture_agent_test.py @@ -195,7 +195,9 @@ def testMixtureUpdate( reward_aggregates = self.evaluate( mixed_agent._variable_collection.reward_aggregates ) - self.assertAllInSet(reward_aggregates[: num_agents - 1], [0.999]) + self.assertAllClose( + reward_aggregates[: num_agents - 1], [0.999] * (num_agents - 1) + ) agent_prob = 1 / num_agents est_rewards = 0.5 / agent_prob per_step_update = est_rewards