diff --git a/.github/workflows/ci-tests.yaml b/.github/workflows/ci-tests.yaml index 84315b506..a18f44b97 100644 --- a/.github/workflows/ci-tests.yaml +++ b/.github/workflows/ci-tests.yaml @@ -157,7 +157,7 @@ jobs: python -m pip install -r docs/requirements.txt - name: "Build documentation and check for consistency" env: - CHECKSUM: "b59239241d3529a179df6158271dd00ba7a86e807a37a11ac8e078ad9c377f94" + CHECKSUM: "198f61804843130d3cae0675c67cad121d980c4648cf27c7541d87219afa3d6e" run: | cd docs HASH="$(make checksum | tail -n1)" diff --git a/streamflow/deployment/connector/schemas/ssh.json b/streamflow/deployment/connector/schemas/ssh.json index 57d818594..0cbb3a48d 100644 --- a/streamflow/deployment/connector/schemas/ssh.json +++ b/streamflow/deployment/connector/schemas/ssh.json @@ -70,6 +70,11 @@ "description": "Perform a strict validation of the host SSH keys (and return exception if key is not recognized as valid)", "default": true }, + "connectTimeout": { + "type": "integer", + "description": "Max time (in seconds) to wait for establishing an SSH connection.", + "default": 30 + }, "dataTransferConnection": { "oneOf": [ { diff --git a/streamflow/deployment/connector/ssh.py b/streamflow/deployment/connector/ssh.py index d97fc72fc..ae8e423f3 100644 --- a/streamflow/deployment/connector/ssh.py +++ b/streamflow/deployment/connector/ssh.py @@ -83,6 +83,7 @@ async def _get_connection( port=port, tunnel=await self._get_connection(config.tunnel), username=config.username, + connect_timeout=config.connect_timeout, ) def _get_param_from_file(self, file_path: str): @@ -114,13 +115,14 @@ async def get_connection(self) -> asyncssh.SSHClientConnection: self._connecting = True try: self._ssh_connection = await self._get_connection(self._config) - except (ConnectionError, asyncssh.Error) as e: - if logger.isEnabledFor(logging.WARNING): - logger.warning( - f"Connection to {self._config.hostname} failed: {e}." - ) + except (ConnectionError, asyncssh.Error, asyncio.TimeoutError) as err: await self.close() - raise + if isinstance(err, asyncio.TimeoutError): + raise asyncio.TimeoutError( + f"SSH connection to {self.get_hostname()} failed: connection timed out" + ) + else: + raise finally: self._connect_event.set() else: @@ -213,6 +215,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess: ConnectionError, ConnectionLost, DisconnectError, + asyncio.TimeoutError, ) as exc: if logger.isEnabledFor(logging.WARNING): logger.warning( @@ -336,6 +339,7 @@ def __init__( self, check_host_key: bool, client_keys: MutableSequence[str], + connect_timeout: int, hostname: str, password_file: str | None, ssh_key_passphrase_file: str | None, @@ -344,6 +348,7 @@ def __init__( ): self.check_host_key: bool = check_host_key self.client_keys: MutableSequence[str] = client_keys + self.connect_timeout: int = connect_timeout self.hostname: str = hostname self.password_file: str | None = password_file self.ssh_key_passphrase_file: str | None = ssh_key_passphrase_file @@ -359,6 +364,7 @@ def __init__( nodes: MutableSequence[Any], username: str | None = None, checkHostKey: bool = True, + connectTimeout: int = 30, dataTransferConnection: str | MutableMapping[str, Any] | None = None, file: str | None = None, maxConcurrentSessions: int = 10, @@ -399,6 +405,7 @@ def __init__( template_map=services_map, ) self.checkHostKey: bool = checkHostKey + self.connectTimeout: int = connectTimeout self.passwordFile: str | None = passwordFile self.maxConcurrentSessions: int = maxConcurrentSessions self.maxConnections: int = maxConnections @@ -522,21 +529,16 @@ def _get_config(self, node: str | MutableMapping[str, Any]): return None elif isinstance(node, str): node = {"hostname": node} - ssh_key = node["sshKey"] if "sshKey" in node else self.sshKey + ssh_key = node.get("sshKey", self.sshKey) return SSHConfig( hostname=node["hostname"], - username=node["username"] if "username" in node else self.username, - check_host_key=( - node["checkHostKey"] if "checkHostKey" in node else self.checkHostKey - ), + username=node.get("username", self.username), + check_host_key=node.get("checkHostKey", self.checkHostKey), client_keys=[ssh_key] if ssh_key is not None else [], - password_file=( - node["passwordFile"] if "passwordFile" in node else self.passwordFile - ), - ssh_key_passphrase_file=( - node["sshKeyPassphraseFile"] - if "sshKeyPassphraseFile" in node - else self.sshKeyPassphraseFile + connect_timeout=node.get("connectTimeout", self.connectTimeout), + password_file=node.get("passwordFile", self.passwordFile), + ssh_key_passphrase_file=node.get( + "sshKeyPassphraseFile", self.sshKeyPassphraseFile ), tunnel=( self._get_config(node["tunnel"]) diff --git a/tests/test_schema.py b/tests/test_schema.py index 20467e276..b66b175c3 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -181,11 +181,11 @@ def test_schema_generation(): """Check that the `streamflow schema` command generates a correct JSON Schema.""" assert ( hashlib.sha256(SfSchema().dump("v1.0", False).encode()).hexdigest() - == "f8e3f739678510fc34afe419b215b54d1467d84ee6433fbb0c107bc30eb1f062" + == "c60eabe4335124cb1a241496ac370667a1525b8ab5847584f8dbf3877419282e" ) assert ( hashlib.sha256(SfSchema().dump("v1.0", True).encode()).hexdigest() - == "b91f949c055e3f5de305751540725eeba7e1a6deb1082c11bca3c6e7cfa09929" + == "d62e2dc7b71778c6aa278ee28562bbc1b0a534e286296825162ce56ecd4aeb3c" )