From 8e9d8a13f1a868106cd8615cfc2bc1a3de619291 Mon Sep 17 00:00:00 2001 From: Joe Block Date: Mon, 7 Nov 2022 21:31:42 -0700 Subject: [PATCH] Add SSL support Signed-off-by: Joe Block --- ha_mqtt_discoverable/__init__.py | 40 +++++++++++++++++++++++----- ha_mqtt_discoverable/cli/__init__.py | 8 +++++- ha_mqtt_discoverable/sensors.py | 1 + ha_mqtt_discoverable/settings.py | 7 +++++ pyproject.toml | 2 +- 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/ha_mqtt_discoverable/__init__.py b/ha_mqtt_discoverable/__init__.py index d10ad13..87f280c 100644 --- a/ha_mqtt_discoverable/__init__.py +++ b/ha_mqtt_discoverable/__init__.py @@ -4,10 +4,11 @@ import json import logging +import ssl import paho.mqtt.client as mqtt -__version__ = "0.3.1" +__version__ = "0.4.0" CONFIGURATION_KEY_NAMES = { "act_t": "action_topic", @@ -480,6 +481,10 @@ def __init__(self, settings: dict = {}) -> None: raise RuntimeError(f"client_name is unset. {settings_error_base}") self.client_name = settings["client_name"] + if "debug" not in settings: + settings["debug"] = False + self.debug = settings["debug"] + if "mqtt_server" not in settings: raise RuntimeError(f"mqtt_server is unset. {settings_error_base}") self.mqtt_server = settings["mqtt_server"] @@ -492,10 +497,6 @@ def __init__(self, settings: dict = {}) -> None: raise RuntimeError(f"mqtt_password is unset. {settings_error_base}") self.mqtt_password = settings["mqtt_password"] - if "debug" not in settings: - settings["debug"] = False - self.debug = settings["debug"] - if "mqtt_user" not in settings: raise RuntimeError(f"mqtt_user is unset. {settings_error_base}") self.mqtt_user = settings["mqtt_user"] @@ -524,6 +525,16 @@ def __init__(self, settings: dict = {}) -> None: if "unique_id" in settings: self.unique_id = settings["unique_id"] + # SSL setup + + self.use_tls = False + if "use_tls" in settings: + if settings["use_tls"]: + self.tls_ca_cert = settings["tls_ca_cert"] + self.tls_certfile = settings["tls_certfile"] + self.tls_key = settings["tls_key"] + self.use_tls = settings["use_tls"] + self.topic_prefix = f"{self.mqtt_prefix}/{self.device_class}/{self.device_name}" self.config_topic = f"{self.topic_prefix}/config" self.state_topic = f"{self.topic_prefix}/state" @@ -559,7 +570,24 @@ def _connect(self) -> None: f"Creating mqtt client({self.client_name}) for {self.mqtt_server}" ) self.mqtt_client = mqtt.Client(self.client_name) - logging.info(f"Connecting to {self.mqtt_server}...") + if self.use_tls: + logging.info(f"Connecting to {self.mqtt_server}...") + logging.info("Configuring SSL") + logging.debug(f"ca_certs=s{elf.tls_ca_cert}") + logging.debug(f"certfile={self.tls_certfile}") + logging.debug(f"keyfile={self.tls_key}") + self.mqtt_client.tls_set( + ca_certs=self.tls_ca_cert, + certfile=self.tls_certfile, + keyfile=self.tls_key, + cert_reqs=ssl.CERT_REQUIRED, + tls_version=ssl.PROTOCOL_TLS, + ) + else: + logging.warning(f"Connecting to {self.mqtt_server} without SSL") + self.mqtt_client.username_pw_set( + self.mqtt_user, password=self.mqtt_password + ) self.mqtt_client.connect(self.mqtt_server) else: logging.debug("Reusing existing mqtt_client...") diff --git a/ha_mqtt_discoverable/cli/__init__.py b/ha_mqtt_discoverable/cli/__init__.py index 1a185d9..f6e0b38 100644 --- a/ha_mqtt_discoverable/cli/__init__.py +++ b/ha_mqtt_discoverable/cli/__init__.py @@ -9,7 +9,7 @@ from ha_mqtt_discoverable import __version__ as MODULE_VERSION -def create_base_parser(description: str = None): +def create_base_parser(description: str = "Base parser"): """ Parse the command line options """ @@ -42,6 +42,12 @@ def create_base_parser(description: str = None): parser.add_argument("--mqtt-password", type=str, help="MQTT password.") parser.add_argument("--mqtt-server", type=str, help="MQTT server.") parser.add_argument("--settings-file", type=str, help="Settings file.") + + parser.add_argument("--use-tls", "--use-ssl", action="store_true", help="Use TLS.") + parser.add_argument("--tls-ca-cert", type=str, help="Path to CA cert.") + parser.add_argument("--tls-certfile", type=str, help="Path to certfile.") + parser.add_argument("--tls-key", type=str, help="Path to tls key.") + parser.add_argument( "--version", "-v", help="Show version and exit", action="store_true" ) diff --git a/ha_mqtt_discoverable/sensors.py b/ha_mqtt_discoverable/sensors.py index 42511b9..5ebf080 100644 --- a/ha_mqtt_discoverable/sensors.py +++ b/ha_mqtt_discoverable/sensors.py @@ -29,6 +29,7 @@ def __init__(self, settings: dict = {}) -> None: logging.debug(f"metric_name: {self.metric_name}") logging.debug(f"topic_prefix: {self.topic_prefix}") logging.debug(f"self.state_topic: {self.state_topic}") + logging.debug(f"settings: {settings}") def generate_config(self) -> dict: """ diff --git a/ha_mqtt_discoverable/settings.py b/ha_mqtt_discoverable/settings.py index b951fdd..69e1c03 100644 --- a/ha_mqtt_discoverable/settings.py +++ b/ha_mqtt_discoverable/settings.py @@ -37,6 +37,13 @@ def load_mqtt_settings(path: str = None, cli=None) -> dict: if hasattr(cli, "unique_id"): settings["unique_id"] = cli.unique_id + # ssl + settings["use_tls"] = cli.use_tls + if cli.use_tls: + settings["certfile"] = cli.tls_certfile + settings["keyfile"] = cli.tls_key + settings["ca_certs"] = cli.tls_ca_cert + # Validate that we have all the settings data we need if "client_name" not in settings: diff --git a/pyproject.toml b/pyproject.toml index 6c525d3..f85aa2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ha-mqtt-discoverable" -version = "0.3.1" +version = "0.4.0" description = "" authors = ["Joe Block "] readme = "README.md"