diff --git a/discovery/seata-discovery-core/src/main/java/org/apache/seata/discovery/loadbalance/ConsistentHashLoadBalance.java b/discovery/seata-discovery-core/src/main/java/org/apache/seata/discovery/loadbalance/ConsistentHashLoadBalance.java index da0ff018593..1d4e9fb1669 100644 --- a/discovery/seata-discovery-core/src/main/java/org/apache/seata/discovery/loadbalance/ConsistentHashLoadBalance.java +++ b/discovery/seata-discovery-core/src/main/java/org/apache/seata/discovery/loadbalance/ConsistentHashLoadBalance.java @@ -19,9 +19,7 @@ import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import java.util.List; -import java.util.SortedMap; -import java.util.TreeMap; +import java.util.*; import org.apache.seata.common.loader.LoadLevel; import org.apache.seata.config.ConfigurationFactory; @@ -38,26 +36,26 @@ public class ConsistentHashLoadBalance implements LoadBalance { * The constant LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES. */ public static final String LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES = LoadBalanceFactory.LOAD_BALANCE_PREFIX - + "virtualNodes"; + + "virtualNodes"; /** * The constant VIRTUAL_NODES_NUM. */ private static final int VIRTUAL_NODES_NUM = ConfigurationFactory.getInstance().getInt( - LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES, VIRTUAL_NODES_DEFAULT); + LOAD_BALANCE_CONSISTENT_HASH_VIRTUAL_NODES, VIRTUAL_NODES_DEFAULT); /** * The ConsistentHashSelectorWrapper that caches a {@link ConsistentHashSelector}. */ private volatile ConsistentHashSelectorWrapper selectorWrapper; - @Override @SuppressWarnings("unchecked") + @Override public T select(List invokers, String xid) { if (selectorWrapper == null) { synchronized (this) { if (selectorWrapper == null) { selectorWrapper = new ConsistentHashSelectorWrapper( - new ConsistentHashSelector<>(invokers, VIRTUAL_NODES_NUM), invokers.hashCode()); + new ConsistentHashSelector<>(invokers, VIRTUAL_NODES_NUM), invokers); } } } @@ -68,25 +66,37 @@ public T select(List invokers, String xid) { private static final class ConsistentHashSelectorWrapper { private volatile ConsistentHashSelector selector; - private volatile int invokersHashcode; + // only shared with read + private volatile Set invokers; - public ConsistentHashSelectorWrapper(ConsistentHashSelector selector, int invokersHashcode) { + public ConsistentHashSelectorWrapper(ConsistentHashSelector selector, List invokers) { this.selector = selector; - this.invokersHashcode = invokersHashcode; + this.invokers = new HashSet<>(invokers); } public ConsistentHashSelector getSelector(List invokers) { - int hashCode; - if ((hashCode = invokers.hashCode()) != invokersHashcode) { + if (!equals(invokers)) { synchronized (this) { - if (hashCode != invokersHashcode) { + if (!equals(invokers)) { selector = new ConsistentHashSelector(invokers, VIRTUAL_NODES_NUM); - invokersHashcode = hashCode; + this.invokers = new HashSet<>(invokers); } } } return selector; } + + private boolean equals(List invokers) { + if (invokers.size() != this.invokers.size()) { + return false; + } + for (Object invoker : invokers) { + if (!this.invokers.contains(invoker)) { + return false; + } + } + return true; + } } private static final class ConsistentHashSelector { diff --git a/discovery/seata-discovery-core/src/test/java/org/apache/seata/discovery/loadbalance/LoadBalanceTest.java b/discovery/seata-discovery-core/src/test/java/org/apache/seata/discovery/loadbalance/LoadBalanceTest.java index d424a56b36a..f26aa4f26c0 100644 --- a/discovery/seata-discovery-core/src/test/java/org/apache/seata/discovery/loadbalance/LoadBalanceTest.java +++ b/discovery/seata-discovery-core/src/test/java/org/apache/seata/discovery/loadbalance/LoadBalanceTest.java @@ -120,14 +120,18 @@ public void testConsistentHashLoadBalance_select(List address @MethodSource("addressProvider") public void testCachedConsistentHashLoadBalance_select(List addresses) throws Exception { ConsistentHashLoadBalance loadBalance = new ConsistentHashLoadBalance(); - loadBalance.select(addresses, XID); + + List addresses1 = new ArrayList<>(addresses); + loadBalance.select(addresses1, XID); Object o1 = getConsistentHashSelectorByReflect(loadBalance); - loadBalance.select(addresses, XID); + List addresses2 = new ArrayList<>(addresses); + loadBalance.select(addresses2, XID); Object o2 = getConsistentHashSelectorByReflect(loadBalance); Assertions.assertEquals(o1, o2); - addresses = new ArrayList<>(addresses); - addresses.remove(ThreadLocalRandom.current().nextInt(addresses.size())); - loadBalance.select(addresses, XID); + + List addresses3 = new ArrayList<>(addresses); + addresses3.remove(ThreadLocalRandom.current().nextInt(addresses.size())); + loadBalance.select(addresses3, XID); Object o3 = getConsistentHashSelectorByReflect(loadBalance); Assertions.assertNotEquals(o1, o3); }