From 4c69f37b8e9c424da52d5b3d16f2d62735ad2309 Mon Sep 17 00:00:00 2001 From: Joseph Cosentino Date: Fri, 29 Sep 2023 08:55:45 -0700 Subject: [PATCH] feat: reset mqtt3 client on config change --- .../integrationtests/BridgeTest.java | 4 +- .../integrationtests/ConfigTest.java | 49 ------ .../integrationtests/KeystoreTest.java | 27 +-- .../greengrass/mqtt/bridge/MQTTBridge.java | 14 +- .../mqtt/bridge/clients/Configurable.java | 13 ++ .../mqtt/bridge/clients/LocalMqtt5Client.java | 13 +- .../clients/LocalMqttClientFactory.java | 3 +- .../mqtt/bridge/clients/MQTTClient.java | 154 ++++++++++++------ .../bridge/clients/FakePahoMqtt3Client.java | 56 +++---- .../clients/FakePahoMqtt3ClientTest.java | 2 +- .../bridge/clients/LocalMqtt5ClientTest.java | 2 +- .../mqtt/bridge/clients/MQTTClientTest.java | 95 ++++++++--- 12 files changed, 241 insertions(+), 191 deletions(-) create mode 100644 src/main/java/com/aws/greengrass/mqtt/bridge/clients/Configurable.java diff --git a/src/integrationtests/java/com/aws/greengrass/integrationtests/BridgeTest.java b/src/integrationtests/java/com/aws/greengrass/integrationtests/BridgeTest.java index 9c95b0d5..8bcc498a 100644 --- a/src/integrationtests/java/com/aws/greengrass/integrationtests/BridgeTest.java +++ b/src/integrationtests/java/com/aws/greengrass/integrationtests/BridgeTest.java @@ -101,7 +101,7 @@ void GIVEN_mqtt3_and_mapping_between_local_and_iotcore_WHEN_iotcore_message_rece @BridgeIntegrationTest( withConfig = "mqtt3_local_and_iotcore.yaml", - withBrokers = {Broker.MQTT5, Broker.MQTT3}) + withBrokers = {Broker.MQTT5}) void GIVEN_mqtt3_and_mapping_between_local_and_iotcore_WHEN_local_message_received_THEN_message_bridged_to_iotcore(BridgeIntegrationTestContext context) throws Exception { MqttMessage expectedMessage = MqttMessage.builder() .topic("topic/toIotCore") @@ -129,7 +129,7 @@ void GIVEN_mqtt3_and_mapping_between_local_and_iotcore_WHEN_local_message_receiv .contentType("contentType") .build()); - subscribeCallback.getLeft().get(AWAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + subscribeCallback.getLeft().get(5L, TimeUnit.SECONDS); } @BridgeIntegrationTest( diff --git a/src/integrationtests/java/com/aws/greengrass/integrationtests/ConfigTest.java b/src/integrationtests/java/com/aws/greengrass/integrationtests/ConfigTest.java index e21b7bde..50f8ced2 100644 --- a/src/integrationtests/java/com/aws/greengrass/integrationtests/ConfigTest.java +++ b/src/integrationtests/java/com/aws/greengrass/integrationtests/ConfigTest.java @@ -6,7 +6,6 @@ package com.aws.greengrass.integrationtests; import com.aws.greengrass.config.Topics; -import com.aws.greengrass.config.UpdateBehaviorTree; import com.aws.greengrass.dependency.State; import com.aws.greengrass.integrationtests.extensions.BridgeIntegrationTest; import com.aws.greengrass.integrationtests.extensions.BridgeIntegrationTestContext; @@ -38,12 +37,10 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; -import java.util.function.Supplier; import static com.github.grantwest.eventually.EventuallyLambdaMatcher.eventuallyEval; import static com.aws.greengrass.componentmanager.KernelConfigResolver.CONFIGURATION_CONFIG_KEY; @@ -61,8 +58,6 @@ public class ConfigTest { private static final long AWAIT_TIMEOUT_SECONDS = 30L; private static final long RECEIVE_PUBLISH_SECONDS = 2L; - private static final Supplier MERGE_UPDATE_BEHAVIOR = - () -> new UpdateBehaviorTree(UpdateBehaviorTree.UpdateBehavior.MERGE, System.currentTimeMillis()); BridgeIntegrationTestContext testContext; @@ -180,50 +175,6 @@ void GIVEN_Greengrass_with_mqtt_bridge_WHEN_connect_options_set_in_config_THEN_l () -> largeMessageHandler.getLeft().get(RECEIVE_PUBLISH_SECONDS, TimeUnit.SECONDS)); } - @BridgeIntegrationTest( - withConfig = "config.yaml", - withBrokers = Broker.MQTT3) - void GIVEN_Greengrass_with_mqtt_bridge_WHEN_multiple_serialized_config_changes_occur_THEN_bridge_reinstalls_multiple_times(ExtensionContext context) throws Exception { - ignoreExceptionOfType(context, InterruptedException.class); - - Semaphore bridgeRestarted = new Semaphore(1); - bridgeRestarted.acquire(); - - testContext.getKernel().getContext().addGlobalStateChangeListener((GreengrassService service, State was, State newState) -> { - if (service.getName().equals(MQTTBridge.SERVICE_NAME) && newState.equals(State.RUNNING)) { - bridgeRestarted.release(); - } - }); - - Topics config = testContext.getKernel().locate(MQTTBridge.SERVICE_NAME).getConfig() - .lookupTopics(CONFIGURATION_CONFIG_KEY); - - int numRestarts = 5; - for (int i = 0; i < numRestarts; i++) { - // change the configuration and wait for bridge to restart - config.updateFromMap(Utils.immutableMap(BridgeConfig.KEY_CLIENT_ID, String.format("clientId%d", i)), MERGE_UPDATE_BEHAVIOR.get()); - assertTrue(bridgeRestarted.tryAcquire(AWAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS)); - } - } - - @BridgeIntegrationTest( - withConfig = "config.yaml", - withBrokers = Broker.MQTT3) - void GIVEN_Greengrass_with_mqtt_bridge_WHEN_clientId_config_changes_THEN_bridge_reinstalls() throws Exception { - CountDownLatch bridgeRestarted = new CountDownLatch(1); - testContext.getKernel().getContext().addGlobalStateChangeListener((GreengrassService service, State was, State newState) -> { - if (service.getName().equals(MQTTBridge.SERVICE_NAME) && newState.equals(State.NEW)) { - bridgeRestarted.countDown(); - } - }); - - Topics config = testContext.getKernel().locate(MQTTBridge.SERVICE_NAME).getConfig() - .lookupTopics(CONFIGURATION_CONFIG_KEY); - config.updateFromMap(Utils.immutableMap(BridgeConfig.KEY_CLIENT_ID, "new_client_id"), MERGE_UPDATE_BEHAVIOR.get()); - - assertTrue(bridgeRestarted.await(AWAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS)); - } - @BridgeIntegrationTest( withConfig = "config.yaml", withBrokers = Broker.MQTT3) diff --git a/src/integrationtests/java/com/aws/greengrass/integrationtests/KeystoreTest.java b/src/integrationtests/java/com/aws/greengrass/integrationtests/KeystoreTest.java index fcb27d4e..4a25869a 100644 --- a/src/integrationtests/java/com/aws/greengrass/integrationtests/KeystoreTest.java +++ b/src/integrationtests/java/com/aws/greengrass/integrationtests/KeystoreTest.java @@ -9,13 +9,9 @@ import com.aws.greengrass.clientdevices.auth.certificate.CertificateHelper; import com.aws.greengrass.clientdevices.auth.certificate.CertificateStore; import com.aws.greengrass.config.Topic; -import com.aws.greengrass.dependency.State; import com.aws.greengrass.integrationtests.extensions.BridgeIntegrationTest; import com.aws.greengrass.integrationtests.extensions.BridgeIntegrationTestContext; import com.aws.greengrass.integrationtests.extensions.Broker; -import com.aws.greengrass.lifecyclemanager.GlobalStateChangeListener; -import com.aws.greengrass.lifecyclemanager.GreengrassService; -import com.aws.greengrass.mqtt.bridge.BridgeConfig; import com.aws.greengrass.mqtt.bridge.MQTTBridge; import com.aws.greengrass.mqtt.bridge.auth.MQTTClientKeyStore; import com.aws.greengrass.mqtt.bridge.model.MqttVersion; @@ -41,7 +37,6 @@ import java.util.function.Consumer; import java.util.stream.IntStream; -import static com.aws.greengrass.componentmanager.KernelConfigResolver.CONFIGURATION_CONFIG_KEY; import static com.aws.greengrass.lifecyclemanager.GreengrassService.RUNTIME_STORE_NAMESPACE_TOPIC; import static com.aws.greengrass.lifecyclemanager.GreengrassService.SERVICES_NAMESPACE_TOPIC; import static com.aws.greengrass.testcommons.testutilities.ExceptionLogProtector.ignoreExceptionOfType; @@ -154,22 +149,8 @@ void GIVEN_mqtt_bridge_WHEN_cda_ca_conf_changed_after_shutdown_THEN_bridge_keyst ignoreExceptionOfType(context, IllegalArgumentException.class); ignoreExceptionOfType(context, NullPointerException.class); - // break bridge - CountDownLatch bridgeIsBroken = new CountDownLatch(1); - GlobalStateChangeListener listener = (GreengrassService service, State was, State newState) -> { - if (service.getName().equals(MQTTBridge.SERVICE_NAME) && service.getState().equals(State.BROKEN)) { - bridgeIsBroken.countDown(); - } - }; - Topic brokerUriTopic = testContext.getKernel().getConfig().lookup( - SERVICES_NAMESPACE_TOPIC, - MQTTBridge.SERVICE_NAME, - CONFIGURATION_CONFIG_KEY, - BridgeConfig.KEY_BROKER_URI - ); - brokerUriTopic.withValue("garbage"); - testContext.getKernel().getContext().addGlobalStateChangeListener(listener); - assertTrue(bridgeIsBroken.await(AWAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS)); + // shutdown the bridge + testContext.getFromContext(MQTTBridge.class).shutdown(); CountDownLatch keyStoreUpdated = new CountDownLatch(1); MQTTClientKeyStore keyStore = testContext.getKernel().getContext().get(MQTTClientKeyStore.class); @@ -191,8 +172,8 @@ void GIVEN_mqtt_bridge_WHEN_cda_ca_conf_changed_after_shutdown_THEN_bridge_keyst Date.from(Instant.now().plusSeconds(100)), "CA")))); - // shouldn't update - assertFalse(keyStoreUpdated.await(AWAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS)); + testContext.getKernel().getContext().waitForPublishQueueToClear(); + assertFalse(keyStoreUpdated.await(5L, TimeUnit.SECONDS)); } private CompletableFuture asyncAssertNumConnects(Integer numConnects) throws InterruptedException { diff --git a/src/main/java/com/aws/greengrass/mqtt/bridge/MQTTBridge.java b/src/main/java/com/aws/greengrass/mqtt/bridge/MQTTBridge.java index da0cb839..7150c6ed 100644 --- a/src/main/java/com/aws/greengrass/mqtt/bridge/MQTTBridge.java +++ b/src/main/java/com/aws/greengrass/mqtt/bridge/MQTTBridge.java @@ -17,10 +17,9 @@ import com.aws.greengrass.lifecyclemanager.PluginService; import com.aws.greengrass.lifecyclemanager.exceptions.ServiceLoadException; import com.aws.greengrass.mqtt.bridge.auth.MQTTClientKeyStore; +import com.aws.greengrass.mqtt.bridge.clients.Configurable; import com.aws.greengrass.mqtt.bridge.clients.IoTCoreClient; -import com.aws.greengrass.mqtt.bridge.clients.LocalMqtt5Client; import com.aws.greengrass.mqtt.bridge.clients.LocalMqttClientFactory; -import com.aws.greengrass.mqtt.bridge.clients.MQTTClient; import com.aws.greengrass.mqtt.bridge.clients.MessageClient; import com.aws.greengrass.mqtt.bridge.clients.MessageClientException; import com.aws.greengrass.mqtt.bridge.clients.PubSubClient; @@ -295,17 +294,14 @@ public class ConfigurationChangeHandler { return; } - // TODO same for MQTT3 client - if (localMqttClient instanceof LocalMqtt5Client) { - ((LocalMqtt5Client) localMqttClient).applyConfig(LocalMqtt5Client.Config.fromBridgeConfig(newConfig)); + if (localMqttClient instanceof Configurable) { + ((Configurable) localMqttClient).applyConfig(newConfig); } }); private boolean reinstallRequired(BridgeConfig prevConfig, BridgeConfig newConfig) { - return !Objects.equals(prevConfig.getMqttVersion(), newConfig.getMqttVersion()) // to switch between clients - || localMqttClient instanceof MQTTClient - && (!Objects.equals(prevConfig.getBrokerUri(), newConfig.getBrokerUri()) - || !Objects.equals(prevConfig.getClientId(), newConfig.getClientId())); + // to switch between clients + return !Objects.equals(prevConfig.getMqttVersion(), newConfig.getMqttVersion()); } /** diff --git a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/Configurable.java b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/Configurable.java new file mode 100644 index 00000000..0e0ab7e9 --- /dev/null +++ b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/Configurable.java @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.aws.greengrass.mqtt.bridge.clients; + +import com.aws.greengrass.mqtt.bridge.BridgeConfig; + +@FunctionalInterface +public interface Configurable { + void applyConfig(BridgeConfig config); +} diff --git a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5Client.java b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5Client.java index 33effb3c..794c4787 100644 --- a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5Client.java +++ b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5Client.java @@ -68,7 +68,7 @@ import static com.aws.greengrass.mqtt.bridge.model.Mqtt5RouteOptions.DEFAULT_NO_LOCAL; @SuppressWarnings("PMD.CloseResource") -public class LocalMqtt5Client implements MessageClient { +public class LocalMqtt5Client implements MessageClient, Configurable { private static final Logger LOGGER = LogManager.getLogger(LocalMqtt5Client.class); @@ -334,11 +334,16 @@ public static boolean resetRequired(Config prevConfig, Config newConfig) { * Apply new configuration to this client. Depending on what configurations changed, a * {@link LocalMqtt5Client#reset()} may occur to apply them. * - * @param config new configuration + * @param bridgeConfig new bridge configuration */ - public void applyConfig(@NonNull Config config) { + @Override + public void applyConfig(@NonNull BridgeConfig bridgeConfig) { + applyConfig(Config.fromBridgeConfig(bridgeConfig)); + } + + void applyConfig(Config newConfig) { Config previousConfig = this.config; - this.config = config; + this.config = newConfig; if (Config.resetRequired(previousConfig, config)) { scheduleResetTask(); } diff --git a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqttClientFactory.java b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqttClientFactory.java index b721befa..47af7c24 100644 --- a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqttClientFactory.java +++ b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqttClientFactory.java @@ -63,8 +63,7 @@ public MessageClient createLocalMqttClient() throws MessageClientEx case MQTT3: // fall-through default: return new MQTTClient( - config.getBrokerUri(), - config.getClientId(), + config, mqttClientKeyStore, executorService ); diff --git a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClient.java b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClient.java index 2755238d..a27358b8 100644 --- a/src/main/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClient.java +++ b/src/main/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClient.java @@ -15,6 +15,8 @@ import com.aws.greengrass.util.RetryUtils; import com.aws.greengrass.util.Utils; import lombok.AccessLevel; +import lombok.Builder; +import lombok.Data; import lombok.Getter; import lombok.NonNull; import lombok.Setter; @@ -33,6 +35,7 @@ import java.time.Duration; import java.util.Collections; import java.util.HashSet; +import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; @@ -40,7 +43,8 @@ import java.util.function.Consumer; import javax.net.ssl.SSLSocketFactory; -public class MQTTClient implements MessageClient { +@SuppressWarnings("PMD.CloseResource") +public class MQTTClient implements MessageClient, Configurable { private static final Logger LOGGER = LogManager.getLogger(MQTTClient.class); public static final String TOPIC = "topic"; @@ -48,9 +52,6 @@ public class MQTTClient implements MessageClient { private static final int MAX_WAIT_RETRY_IN_SECONDS = 120; private Consumer messageHandler; - private final URI brokerUri; - private final String clientId; - private final MqttClientPersistence dataStore; private final ExecutorService executorService; private final Object subscribeLock = new Object(); @@ -67,12 +68,12 @@ public class MQTTClient implements MessageClient { private final MQTTClientKeyStore.UpdateListener onKeyStoreUpdate = new MQTTClientKeyStore.UpdateListener() { @Override public void onCAUpdate() { - if (mqttClientInternal == null) { + if (!started()) { LOGGER.atDebug().log("Client not yet initialized, skipping reset"); return; } LOGGER.atInfo().log("New CA certificate available, reconnecting client"); - reset(); + reset(false); } @Override @@ -114,40 +115,110 @@ public void deliveryComplete(IMqttDeliveryToken token) { } }; + @Getter // for testing + volatile Config config; + volatile Config pendingConfig; + + @Data + @Builder(toBuilder = true) + public static class Config { + URI brokerUri; + String clientId; + + /** + * Map from bridge configuration to client configuration. + * + * @param bridgeConfig component configuration + * @return client configuration + */ + public static Config fromBridgeConfig(BridgeConfig bridgeConfig) { + return Config.builder() + .brokerUri(bridgeConfig.getBrokerUri()) + .clientId(bridgeConfig.getClientId()) + .build(); + } + } + + /** + * Apply new configuration to this client. Client will reset. + * + * @param bridgeConfig new bridge configuration + */ + @Override + public void applyConfig(@NonNull BridgeConfig bridgeConfig) { + applyConfig(Config.fromBridgeConfig(bridgeConfig)); + } + + void applyConfig(Config newConfig) { + if (Objects.equals(this.config, newConfig)) { + return; + } + // config will be picked up the next time an mqtt client is created + this.pendingConfig = newConfig; + if (!started()) { + return; + } + reset(true); + } + + private boolean started() { + return mqttClientInternal != null; + } + /** * Construct an MQTTClient. * - * @param brokerUri broker uri - * @param clientId client id + * @param bridgeConfig bridge config * @param mqttClientKeyStore KeyStore for MQTT Client * @param executorService Executor service */ - public MQTTClient(@NonNull URI brokerUri, - @NonNull String clientId, + public MQTTClient(@NonNull BridgeConfig bridgeConfig, MQTTClientKeyStore mqttClientKeyStore, ExecutorService executorService) { - this.brokerUri = brokerUri; - this.clientId = clientId; + this.config = Config.fromBridgeConfig(bridgeConfig); this.dataStore = new MemoryPersistence(); - this.clientFactory = () -> new MqttClient(brokerUri.toString(), clientId, dataStore); + this.clientFactory = () -> { + if (this.pendingConfig != null) { + this.config = this.pendingConfig; + } + MqttClient client = new MqttClient(config.getBrokerUri().toString(), config.getClientId(), dataStore); + client.setCallback(mqttCallback); + return client; + }; this.mqttClientKeyStore = mqttClientKeyStore; this.mqttClientKeyStore.listenToUpdates(onKeyStoreUpdate); this.executorService = executorService; } - protected MQTTClient(@NonNull URI brokerUri, @NonNull String clientId, MQTTClientKeyStore mqttClientKeyStore, - ExecutorService executorService, IMqttClient mqttClient) { - this.brokerUri = brokerUri; - this.clientId = clientId; - this.clientFactory = () -> mqttClient; + protected MQTTClient(Config config, + MQTTClientKeyStore mqttClientKeyStore, + ExecutorService executorService, + CrashableSupplier clientFactory) { + this.config = config; + this.clientFactory = () -> { + if (this.pendingConfig != null) { + this.config = this.pendingConfig; + } + IMqttClient client = clientFactory.apply(); + client.setCallback(mqttCallback); + return client; + }; this.dataStore = new MemoryPersistence(); this.mqttClientKeyStore = mqttClientKeyStore; this.mqttClientKeyStore.listenToUpdates(onKeyStoreUpdate); this.executorService = executorService; } - void reset() { - disconnect(30_000L); // paho default + synchronized void reset(boolean recreateClient) { + disconnectForcibly(); + if (recreateClient) { + try { + this.mqttClientInternal = clientFactory.apply(); + } catch (MqttException e) { + LOGGER.atError().cause(e).log("unable to recreate MQTT client using new configuration, " + + "falling back to old configuration"); + } + } connectAndSubscribe(); } @@ -161,11 +232,10 @@ public void start() throws MessageClientException { } catch (MqttException e) { throw new MessageClientException("Unable to create MQTTClient", e); } - mqttClientInternal.setCallback(mqttCallback); connectAndSubscribe(); } - @SuppressWarnings({"PMD.AvoidCatchingGenericException", "PMD.AvoidCatchingNPE", "PMD.CloseResource"}) + @SuppressWarnings({"PMD.AvoidCatchingGenericException", "PMD.AvoidCatchingNPE"}) private void disconnectForcibly() { IMqttClient client = mqttClientInternal; if (client == null) { @@ -183,32 +253,11 @@ private void disconnectForcibly() { } } - @SuppressWarnings("PMD.CloseResource") - private void disconnect(long quiesceTimeout) { - IMqttClient client = mqttClientInternal; - if (client == null) { - return; - } - try { - LOGGER.debug("Disconnecting MQTT client"); - client.disconnect(quiesceTimeout); - } catch (MqttException e) { - if (MqttException.REASON_CODE_CLIENT_ALREADY_DISCONNECTED != e.getReasonCode() - && MqttException.REASON_CODE_CLIENT_CLOSED != e.getReasonCode()) { - LOGGER.atError().setCause(e).log("Failed to disconnect MQTT client"); - return; - } - } - // no need to unsubscribe because we connect with cleanSession=true - subscribedLocalMqttTopics.clear(); - LOGGER.debug("MQTT client disconnected"); - } - /** * Stop the {@link MQTTClient}. */ @Override - @SuppressWarnings({"PMD.CloseResource", "PMD.AvoidCatchingGenericException", "PMD.AvoidCatchingNPE"}) + @SuppressWarnings({"PMD.AvoidCatchingGenericException", "PMD.AvoidCatchingNPE"}) public void stop() { mqttClientKeyStore.unsubscribeFromUpdates(onKeyStoreUpdate); @@ -299,19 +348,17 @@ private MqttConnectOptions getConnectionOptions() throws KeyStoreException { MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setCleanSession(true); connOpts.setMaxInflight(1000); - - if ("ssl".equalsIgnoreCase(brokerUri.getScheme())) { + if ("ssl".equalsIgnoreCase(config.getBrokerUri().getScheme())) { SSLSocketFactory ssf = mqttClientKeyStore.getSSLSocketFactory(); connOpts.setSocketFactory(ssf); } - return connOpts; } private void connectAndSubscribe() { LOGGER.atInfo() - .kv(BridgeConfig.KEY_BROKER_URI, brokerUri) - .kv(BridgeConfig.KEY_CLIENT_ID, clientId) + .kv(BridgeConfig.KEY_BROKER_URI, config.getBrokerUri()) + .kv(BridgeConfig.KEY_CLIENT_ID, config.getClientId()) .log("Connecting to broker"); reconnectAndResubscribeAsync(); } @@ -334,11 +381,12 @@ private void cancelConnectTask() { private void reconnectAndResubscribe() { int waitBeforeRetry = MIN_WAIT_RETRY_IN_SECONDS; - while (!mqttClientInternal.isConnected() && !Thread.currentThread().isInterrupted()) { + IMqttClient client = mqttClientInternal; + while (!client.isConnected() && !Thread.currentThread().isInterrupted()) { Exception error; try { // TODO: Clean up this loop - mqttClientInternal.connect(getConnectionOptions()); + client.connect(getConnectionOptions()); break; } catch (MqttException e) { if (Utils.getUltimateCause(e) instanceof InterruptedException) { @@ -368,8 +416,8 @@ private void reconnectAndResubscribe() { } LOGGER.atInfo() - .kv(BridgeConfig.KEY_BROKER_URI, brokerUri) - .kv(BridgeConfig.KEY_CLIENT_ID, clientId) + .kv(BridgeConfig.KEY_BROKER_URI, config.getBrokerUri()) + .kv(BridgeConfig.KEY_CLIENT_ID, config.getClientId()) .log("Connected to broker"); resubscribe(); diff --git a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3Client.java b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3Client.java index 25f9926d..e3da9c18 100644 --- a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3Client.java +++ b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3Client.java @@ -27,6 +27,7 @@ public class FakePahoMqtt3Client implements IMqttClient { MqttCallback mqttCallback; String clientId; + String serverURI; @Getter List subscriptionTopics; // TODO: Support QoS @@ -38,6 +39,8 @@ public class FakePahoMqtt3Client implements IMqttClient { MqttConnectOptions connectOptions; @Getter int connectCount = 0; + @Getter + int disconnectCount = 0; final Object connectMonitor; boolean isConnected; @@ -53,8 +56,9 @@ public class TopicMessagePair { } } - FakePahoMqtt3Client(String clientId) { + FakePahoMqtt3Client(String clientId, String serverURI) { this.clientId = clientId; + this.serverURI = serverURI; this.subscriptionTopics = new ArrayList<>(); this.publishedMessages = new ArrayList<>(); this.connectMonitor = new Object(); @@ -105,7 +109,7 @@ boolean waitForConnect(int timeout) { } @Override - public void connect() throws MqttSecurityException, MqttException { + public void connect() { isConnected = true; connectCount++; synchronized (connectMonitor) { @@ -114,7 +118,7 @@ public void connect() throws MqttSecurityException, MqttException { } @Override - public void connect(MqttConnectOptions mqttConnectOptions) throws MqttSecurityException, MqttException { + public void connect(MqttConnectOptions mqttConnectOptions) { this.connectOptions = mqttConnectOptions; connect(); } @@ -127,6 +131,7 @@ public IMqttToken connectWithResult(MqttConnectOptions mqttConnectOptions) @Override public void disconnect() throws MqttException { + disconnectCount++; isConnected = false; // Reset subscriptions subscriptionTopics.clear(); @@ -153,7 +158,7 @@ public void disconnectForcibly(long quiesceTimeout, long disconnectTimeout) thro } @Override - public void subscribe(String topicFilter) throws MqttException, MqttSecurityException { + public void subscribe(String topicFilter) throws MqttException { subscribe(topicFilter, 1); } @@ -165,14 +170,14 @@ public void subscribe(String[] topicFilters) throws MqttException { } @Override - public void subscribe(String topicFilter, int qos) throws MqttException { + public void subscribe(String topicFilter, int qos) { if (!subscriptionTopics.contains(topicFilter)) { subscriptionTopics.add(topicFilter); } } @Override - public void subscribe(String[] topicFilters, int[] qos) throws MqttException { + public void subscribe(String[] topicFilters, int[] qos) { if (topicFilters.length != qos.length) { throw new IllegalArgumentException("Topic filter and qos array lengths must match"); } @@ -182,77 +187,72 @@ public void subscribe(String[] topicFilters, int[] qos) throws MqttException { } @Override - public void subscribe(String topicFilter, IMqttMessageListener iMqttMessageListener) - throws MqttException, MqttSecurityException { + public void subscribe(String topicFilter, IMqttMessageListener iMqttMessageListener) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public void subscribe(String[] topicFilters, IMqttMessageListener[] iMqttMessageListeners) throws MqttException { + public void subscribe(String[] topicFilters, IMqttMessageListener[] iMqttMessageListeners) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public void subscribe(String topicFilter, int qos, IMqttMessageListener iMqttMessageListener) throws MqttException { + public void subscribe(String topicFilter, int qos, IMqttMessageListener iMqttMessageListener) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public void subscribe(String[] topicFilters, int[] qos, IMqttMessageListener[] iMqttMessageListeners) - throws MqttException { + public void subscribe(String[] topicFilters, int[] qos, IMqttMessageListener[] iMqttMessageListeners) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String topicFilter) throws MqttException { + public IMqttToken subscribeWithResponse(String topicFilter) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String topicFilter, IMqttMessageListener iMqttMessageListener) throws MqttException { + public IMqttToken subscribeWithResponse(String topicFilter, IMqttMessageListener iMqttMessageListener) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String topicFilter, int qos) throws MqttException { + public IMqttToken subscribeWithResponse(String topicFilter, int qos) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String topicFilter, int qos, IMqttMessageListener iMqttMessageListener) - throws MqttException { + public IMqttToken subscribeWithResponse(String topicFilter, int qos, IMqttMessageListener iMqttMessageListener) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String[] topicFilters) throws MqttException { + public IMqttToken subscribeWithResponse(String[] topicFilters) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String[] topicFilters, IMqttMessageListener[] iMqttMessageListeners) - throws MqttException { + public IMqttToken subscribeWithResponse(String[] topicFilters, IMqttMessageListener[] iMqttMessageListeners) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String[] topicFilters, int[] qos) throws MqttException { + public IMqttToken subscribeWithResponse(String[] topicFilters, int[] qos) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public IMqttToken subscribeWithResponse(String[] topicFilters, int[] qos, IMqttMessageListener[] iMqttMessageListeners) - throws MqttException { + public IMqttToken subscribeWithResponse(String[] topicFilters, int[] qos, IMqttMessageListener[] iMqttMessageListeners) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } @Override - public void unsubscribe(String topicFilter) throws MqttException { + public void unsubscribe(String topicFilter) { subscriptionTopics.remove(topicFilter); } @Override - public void unsubscribe(String[] topicFilters) throws MqttException { + public void unsubscribe(String[] topicFilters) { for (String topicFilter : topicFilters) { unsubscribe(topicFilter); } @@ -265,7 +265,7 @@ public void publish(String topicFilter, byte[] bytes, int qos, boolean retained) } @Override - public void publish(String topic, MqttMessage mqttMessage) throws MqttException, MqttPersistenceException { + public void publish(String topic, MqttMessage mqttMessage) throws MqttException { publishedMessages.add(new TopicMessagePair(topic, mqttMessage)); } @@ -291,7 +291,7 @@ public String getClientId() { @Override public String getServerURI() { - throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); + return serverURI; } @Override @@ -311,7 +311,7 @@ public void reconnect() throws MqttException { } @Override - public void messageArrivedComplete(int messageId, int qos) throws MqttException { + public void messageArrivedComplete(int messageId, int qos) { throw new UnsupportedOperationException(UNSUPPORTED_OPERATION); } diff --git a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3ClientTest.java b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3ClientTest.java index 33d0b6f1..61351f47 100644 --- a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3ClientTest.java +++ b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/FakePahoMqtt3ClientTest.java @@ -30,7 +30,7 @@ public class FakePahoMqtt3ClientTest { @BeforeEach void setup() { - fakeMQTTClient = new FakePahoMqtt3Client("clientId"); + fakeMQTTClient = new FakePahoMqtt3Client("clientId", "uri"); } @Test diff --git a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5ClientTest.java b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5ClientTest.java index 7105b25a..a38e045c 100644 --- a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5ClientTest.java +++ b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/LocalMqtt5ClientTest.java @@ -503,7 +503,7 @@ void GIVEN_client_with_subscriptions_WHEN_connection_lost_and_resumed_THEN_subsc @ParameterizedTest @MethodSource("configChanges") - void GIVEN_client_WHEN_config_changes_THEN_client_is_reset(Function changeConfig, boolean resetExpected) throws InterruptedException { + void GIVEN_client_WHEN_config_changes_THEN_client_is_reset(Function changeConfig, boolean resetExpected) { client.applyConfig(changeConfig.apply(client.getConfig())); if (resetExpected) { assertThat("client resets", () -> mockMqtt5Client.getNumDisconnects().get(), eventuallyEval(is(1))); diff --git a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClientTest.java b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClientTest.java index 322cf465..376eb8e7 100644 --- a/src/test/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClientTest.java +++ b/src/test/java/com/aws/greengrass/mqtt/bridge/clients/MQTTClientTest.java @@ -8,29 +8,40 @@ import com.aws.greengrass.mqtt.bridge.auth.MQTTClientKeyStore; import com.aws.greengrass.testcommons.testutilities.GGExtension; import com.aws.greengrass.testcommons.testutilities.TestUtils; +import com.aws.greengrass.util.CrashableSupplier; +import org.eclipse.paho.client.mqttv3.IMqttClient; import org.eclipse.paho.client.mqttv3.MqttConnectOptions; +import org.eclipse.paho.client.mqttv3.MqttException; import org.eclipse.paho.client.mqttv3.MqttMessage; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import javax.net.ssl.SSLSocketFactory; import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.ExecutorService; +import java.util.function.Function; +import java.util.stream.Stream; +import static com.github.grantwest.eventually.EventuallyLambdaMatcher.eventuallyEval; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -38,23 +49,24 @@ @ExtendWith({MockitoExtension.class, GGExtension.class}) +@SuppressWarnings("PMD.CloseResource") public class MQTTClientTest { - private static final URI ENCRYPTED_URI = URI.create("ssl://localhost:8883"); - private static final String CLIENT_ID = "mqtt-bridge-1234"; - - private FakePahoMqtt3Client fakeMqttClient; + private static final MQTTClient.Config CONFIG = MQTTClient.Config.builder() + .clientId("mqtt-bridge-1234") + .brokerUri(URI.create("ssl://localhost:8883")) + .build(); + FakePahoMqtt3Client fakeMqttClient; + private final CrashableSupplier clientFactory = () -> { + fakeMqttClient = new FakePahoMqtt3Client(CONFIG.getClientId(), CONFIG.getBrokerUri().toString()); + return fakeMqttClient; + }; @Mock private MQTTClientKeyStore mockMqttClientKeyStore; private final ExecutorService ses = TestUtils.synchronousExecutorService(); - @BeforeEach - void setup() { - fakeMqttClient = new FakePahoMqtt3Client(CLIENT_ID); - } - @AfterEach void tearDown() { ses.shutdownNow(); @@ -62,7 +74,7 @@ void tearDown() { @Test void GIVEN_mqttClient_WHEN_start_THEN_clientConnects() throws MessageClientException { - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -71,7 +83,7 @@ void GIVEN_mqttClient_WHEN_start_THEN_clientConnects() throws MessageClientExcep @Test void GIVEN_subscribedMqttClient_WHEN_stop_THEN_clientUnsubscribes() throws MessageClientException { - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -94,7 +106,7 @@ void GIVEN_subscribedMqttClient_WHEN_stop_THEN_clientUnsubscribes() throws Messa @Test void GIVEN_subscribedMqttClient_WHEN_updateSubscriptions_THEN_subscriptionsUpdated() throws MessageClientException { - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -143,7 +155,7 @@ void GIVEN_subscribedMqttClient_WHEN_updateSubscriptions_THEN_subscriptionsUpdat @Test void GIVEN_subscribedMqttClient_WHEN_mqttMessageReceived_THEN_messageRoutedToHandler() throws Exception { - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -172,7 +184,7 @@ void GIVEN_subscribedMqttClient_WHEN_mqttMessageReceived_THEN_messageRoutedToHan @Test void GIVEN_mqttClient_WHEN_publish_THEN_routedToBroker() throws Exception { - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -192,7 +204,7 @@ void GIVEN_mqttClient_WHEN_publish_THEN_routedToBroker() throws Exception { @Test void GIVEN_mqttClient_WHEN_connectionLost_THEN_clientReconnectsAndResubscribes() throws Exception { - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -212,7 +224,7 @@ void GIVEN_mqttClient_WHEN_connectionLost_THEN_clientReconnectsAndResubscribes() @Test void GIVEN_mqttClient_WHEN_caRotates_THEN_connectsWithUpdatedSslContext() throws Exception { MQTTClientKeyStore mockKeyStore = mock(MQTTClientKeyStore.class); - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -221,7 +233,7 @@ void GIVEN_mqttClient_WHEN_caRotates_THEN_connectsWithUpdatedSslContext() throws // This code assumes reset synchronously disconnects. This will need to be revisited if // this assumption changes and this test starts failing - mqttClient.reset(); + mqttClient.reset(false); fakeMqttClient.waitForConnect(1000); assertThat(fakeMqttClient.getConnectOptions().getSocketFactory(), is(mockSocketFactory)); @@ -234,7 +246,7 @@ void GIVEN_mqttClient_WHEN_clientCertRotates_THEN_newCertIsUsedUponSubsequentRec SSLSocketFactory mockSocketFactory2 = mock(SSLSocketFactory.class); when(mockMqttClientKeyStore.getSSLSocketFactory()).thenReturn(mockSocketFactory1); - MQTTClient mqttClient = new MQTTClient(ENCRYPTED_URI, CLIENT_ID, mockMqttClientKeyStore, ses, fakeMqttClient); + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); mqttClient.start(); fakeMqttClient.waitForConnect(1000); @@ -250,4 +262,49 @@ void GIVEN_mqttClient_WHEN_clientCertRotates_THEN_newCertIsUsedUponSubsequentRec connectOptions = fakeMqttClient.getConnectOptions(); assertThat(connectOptions.getSocketFactory(), is(mockSocketFactory2)); } + + @ParameterizedTest + @MethodSource("configChanges") + void GIVEN_client_WHEN_config_changes_THEN_client_is_reset(Function changeConfig, boolean resetExpected) throws MessageClientException, InterruptedException { + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); + mqttClient.start(); + mqttClient.applyConfig(changeConfig.apply(MQTTClient.Config.builder().build())); + FakePahoMqtt3Client fakeMqttClient = this.fakeMqttClient; + mqttClient.reset(true); + if (resetExpected) { + assertThat("client resets", () -> fakeMqttClient.disconnectCount, eventuallyEval(is(1))); + } else { + Thread.sleep(1000L); + assertEquals(1, fakeMqttClient.connectCount); + assertEquals(0, fakeMqttClient.disconnectCount); + } + } + + @Test + void GIVEN_client_WHEN_config_does_not_change_THEN_client_is_not_reset() throws MessageClientException, InterruptedException { + MQTTClient mqttClient = new MQTTClient(CONFIG, mockMqttClientKeyStore, ses, clientFactory); + mqttClient.start(); + mqttClient.applyConfig(mqttClient.getConfig()); + Thread.sleep(1000L); + assertEquals(1, fakeMqttClient.connectCount); + assertEquals(0, fakeMqttClient.disconnectCount); + } + + @SuppressWarnings("PMD.UnusedPrivateMethod") + private static Stream configChanges() { + Function brokerUriChanges = config -> { + try { + return config.toBuilder().brokerUri(new URI("tcp://0.0.0.0:1883")).build(); + } catch (URISyntaxException e) { + fail(e); + return null; + } + }; + Function clientIdChanges = config -> config.toBuilder().clientId("newClientId").build(); + + return Stream.of( + Arguments.of(brokerUriChanges, true), + Arguments.of(clientIdChanges, true) + ); + } }