diff --git a/.github/workflows/ci-tests.yaml b/.github/workflows/ci-tests.yaml index 500f4d83b..10326a94b 100644 --- a/.github/workflows/ci-tests.yaml +++ b/.github/workflows/ci-tests.yaml @@ -157,7 +157,7 @@ jobs: python -m pip install -r docs/requirements.txt - name: "Build documentation and check for consistency" env: - CHECKSUM: "b59239241d3529a179df6158271dd00ba7a86e807a37a11ac8e078ad9c377f94" + CHECKSUM: "32fa2a0dd0bbb96a69946d22eebf3bed279697f7a1cac093e7cbad2e7e0edfec" run: | cd docs HASH="$(make checksum | tail -n1)" diff --git a/streamflow/deployment/connector/schemas/ssh.json b/streamflow/deployment/connector/schemas/ssh.json index 57d818594..0c0a7d3c7 100644 --- a/streamflow/deployment/connector/schemas/ssh.json +++ b/streamflow/deployment/connector/schemas/ssh.json @@ -70,6 +70,11 @@ "description": "Perform a strict validation of the host SSH keys (and return exception if key is not recognized as valid)", "default": true }, + "connectTimeout": { + "type": "integer", + "description": "Time (in seconds) to wait for establish the connection. When an attempt fails, the time to wait is increased.", + "default": 30 + }, "dataTransferConnection": { "oneOf": [ { diff --git a/streamflow/deployment/connector/ssh.py b/streamflow/deployment/connector/ssh.py index d97fc72fc..26138d8b0 100644 --- a/streamflow/deployment/connector/ssh.py +++ b/streamflow/deployment/connector/ssh.py @@ -83,6 +83,7 @@ async def _get_connection( port=port, tunnel=await self._get_connection(config.tunnel), username=config.username, + connect_timeout=config.connect_timeout * (self.connection_attempts + 1), ) def _get_param_from_file(self, file_path: str): @@ -114,13 +115,14 @@ async def get_connection(self) -> asyncssh.SSHClientConnection: self._connecting = True try: self._ssh_connection = await self._get_connection(self._config) - except (ConnectionError, asyncssh.Error) as e: - if logger.isEnabledFor(logging.WARNING): - logger.warning( - f"Connection to {self._config.hostname} failed: {e}." - ) + except (ConnectionError, asyncssh.Error, asyncio.TimeoutError) as err: await self.close() - raise + if isinstance(err, asyncio.TimeoutError): + raise asyncio.TimeoutError( + f"The SSH connection attempt to {self.get_hostname()} took too long." + ) + else: + raise finally: self._connect_event.set() else: @@ -213,6 +215,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess: ConnectionError, ConnectionLost, DisconnectError, + asyncio.TimeoutError, ) as exc: if logger.isEnabledFor(logging.WARNING): logger.warning( @@ -336,6 +339,7 @@ def __init__( self, check_host_key: bool, client_keys: MutableSequence[str], + connect_timeout: int, hostname: str, password_file: str | None, ssh_key_passphrase_file: str | None, @@ -344,6 +348,7 @@ def __init__( ): self.check_host_key: bool = check_host_key self.client_keys: MutableSequence[str] = client_keys + self.connect_timeout: int = connect_timeout self.hostname: str = hostname self.password_file: str | None = password_file self.ssh_key_passphrase_file: str | None = ssh_key_passphrase_file @@ -359,6 +364,7 @@ def __init__( nodes: MutableSequence[Any], username: str | None = None, checkHostKey: bool = True, + connectTimeout: int = 30, dataTransferConnection: str | MutableMapping[str, Any] | None = None, file: str | None = None, maxConcurrentSessions: int = 10, @@ -399,6 +405,7 @@ def __init__( template_map=services_map, ) self.checkHostKey: bool = checkHostKey + self.connect_timeout: int = connectTimeout self.passwordFile: str | None = passwordFile self.maxConcurrentSessions: int = maxConcurrentSessions self.maxConnections: int = maxConnections @@ -543,6 +550,11 @@ def _get_config(self, node: str | MutableMapping[str, Any]): if "tunnel" in node else self.tunnel if hasattr(self, "tunnel") else None ), + connect_timeout=( + node["connect_timeout"] + if "connect_timeout" in node + else self.connect_timeout + ), ) def _get_ssh_client_process( diff --git a/tests/test_schema.py b/tests/test_schema.py index 20467e276..789883d21 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -181,11 +181,11 @@ def test_schema_generation(): """Check that the `streamflow schema` command generates a correct JSON Schema.""" assert ( hashlib.sha256(SfSchema().dump("v1.0", False).encode()).hexdigest() - == "f8e3f739678510fc34afe419b215b54d1467d84ee6433fbb0c107bc30eb1f062" + == "bed6608171b77a8d7665532a6ea2405f53e9bab45c6d7719e052856eeff0f6fb" ) assert ( hashlib.sha256(SfSchema().dump("v1.0", True).encode()).hexdigest() - == "b91f949c055e3f5de305751540725eeba7e1a6deb1082c11bca3c6e7cfa09929" + == "7ccfaf9c38100ed943ebc3b57dbb3edfe7e2512e4784d87858c4dd470970768b" )